mirror of
https://github.com/Threnklyn/wg-ui.git
synced 2026-05-21 22:33:29 +02:00
246 lines
5.4 KiB
Go
246 lines
5.4 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/dgrijalva/jwt-go"
|
|
"github.com/google/nftables"
|
|
"github.com/google/nftables/expr"
|
|
"github.com/julienschmidt/httprouter"
|
|
"github.com/mdlayher/wireguardctrl"
|
|
"github.com/mdlayher/wireguardctrl/wgtypes"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/vishvananda/netlink"
|
|
"github.com/vishvananda/netns"
|
|
"gopkg.in/alecthomas/kingpin.v2"
|
|
)
|
|
|
|
var (
|
|
listenAddr = kingpin.Flag("listen-address", "Address to listen to").Default(":8080").String()
|
|
natLink = kingpin.Flag("nat-device", "Network interface to masquerade").Default("wlp2s0").String()
|
|
|
|
wgLinkName = kingpin.Flag("wg-device-name", "Wireguard network device name").Default("wg0").String()
|
|
wgLinkAddr = kingpin.Flag("wg-link-addr", "Wireguard interface address").Default("172.72.72.1/32").String()
|
|
wgListenPort = kingpin.Flag("wg-listen-port", "Wireguard UDP port to listen to").Default("51820").Int()
|
|
wgEndpoint = kingpin.Flag("wg-endpoint", "Wireguard endpoint address").Default("127.0.0.1").String()
|
|
)
|
|
|
|
type Server struct {
|
|
storage *Storage
|
|
config *ServerConfig
|
|
}
|
|
|
|
type WgLink struct {
|
|
attrs *netlink.LinkAttrs
|
|
}
|
|
|
|
type jwtClaims struct {
|
|
}
|
|
|
|
func (w *WgLink) Attrs() *netlink.LinkAttrs {
|
|
return w.attrs
|
|
}
|
|
|
|
func (w *WgLink) Type() string {
|
|
return "wireguard"
|
|
}
|
|
|
|
func ifname(n string) []byte {
|
|
b := make([]byte, 16)
|
|
copy(b, []byte(n+"\x00"))
|
|
return b
|
|
}
|
|
|
|
func NewServer() *Server {
|
|
storage := NewStorage()
|
|
|
|
server := &Server{
|
|
storage: storage,
|
|
config: storage.GetServerConfig(),
|
|
}
|
|
|
|
return server
|
|
}
|
|
|
|
func (s *Server) initInterface() error {
|
|
attrs := netlink.NewLinkAttrs()
|
|
attrs.Name = *wgLinkName
|
|
|
|
link := WgLink{
|
|
attrs: &attrs,
|
|
}
|
|
|
|
log.Debug("Adding wireguard device: ", *wgLinkName)
|
|
err := netlink.LinkAdd(&link)
|
|
if os.IsExist(err) {
|
|
log.Infof("Wireguard interface %s already exists. Reusing.", *wgLinkName)
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Debug("Adding ip address to wireguard device: ", *wgLinkAddr)
|
|
addr, _ := netlink.ParseAddr(*wgLinkAddr)
|
|
err = netlink.AddrAdd(&link, addr)
|
|
if os.IsExist(err) {
|
|
log.Infof("Wireguard interface %s already has the requested address: ", *wgLinkAddr)
|
|
} else if err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Debug("Adding NAT / IP masquerading using nftables")
|
|
|
|
ns, err := netns.Get()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
conn := nftables.Conn{NetNS: int(ns)}
|
|
|
|
log.Debug("Flushing nftable rulesets")
|
|
// conn.FlushRuleset()
|
|
|
|
log.Debug("Setting up nftable rules for ip masquerading")
|
|
|
|
nat := conn.AddTable(&nftables.Table{
|
|
Family: nftables.TableFamilyIPv4,
|
|
Name: "nat",
|
|
})
|
|
|
|
conn.AddChain(&nftables.Chain{
|
|
Name: "prerouting",
|
|
Table: nat,
|
|
Type: nftables.ChainTypeNAT,
|
|
Hooknum: nftables.ChainHookPrerouting,
|
|
Priority: nftables.ChainPriorityFilter,
|
|
})
|
|
|
|
post := conn.AddChain(&nftables.Chain{
|
|
Name: "postrouting",
|
|
Table: nat,
|
|
Type: nftables.ChainTypeNAT,
|
|
Hooknum: nftables.ChainHookPostrouting,
|
|
Priority: nftables.ChainPriorityNATSource,
|
|
})
|
|
|
|
conn.AddRule(&nftables.Rule{
|
|
Table: nat,
|
|
Chain: post,
|
|
Exprs: []expr.Any{
|
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
|
&expr.Cmp{
|
|
Op: expr.CmpOpEq,
|
|
Register: 1,
|
|
Data: ifname(*natLink),
|
|
},
|
|
&expr.Masq{},
|
|
},
|
|
})
|
|
|
|
conn.Flush()
|
|
|
|
wg, err := wireguardctrl.New()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
log.Debug("Adding wireguard private key")
|
|
key, err := wgtypes.ParseKey(s.config.PrivateKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
cfg := wgtypes.Config{
|
|
PrivateKey: &key,
|
|
ListenPort: wgListenPort,
|
|
}
|
|
wg.ConfigureDevice(*wgLinkName, cfg)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) Start() error {
|
|
err := s.initInterface()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
router := httprouter.New()
|
|
router.GET("/", s.Index)
|
|
router.GET("/api/v1/users/:user/devices", s.withAuth(s.GetDevices))
|
|
|
|
log.WithField("listenAddr", *listenAddr).Info("Starting server")
|
|
return http.ListenAndServe(*listenAddr, router)
|
|
}
|
|
|
|
func userFromJwtToken(r *http.Request) string {
|
|
authHeader := r.Header.Get("authorization")
|
|
if authHeader == "" {
|
|
log.Debug("No Authorization header")
|
|
return ""
|
|
}
|
|
|
|
if !strings.HasPrefix(authHeader, "Bearer ") {
|
|
log.Debug("Incorrect Authorization header: ", authHeader)
|
|
return ""
|
|
}
|
|
|
|
claims := jwt.MapClaims{}
|
|
token, err := jwt.ParseWithClaims(authHeader[7:], &claims, func(token *jwt.Token) (interface{}, error) {
|
|
return []byte(""), nil
|
|
})
|
|
|
|
if token == nil {
|
|
log.Debug("Error parsing JWT token: ", err)
|
|
return ""
|
|
}
|
|
|
|
user, ok := claims["email"]
|
|
if ok {
|
|
return user.(string)
|
|
}
|
|
|
|
user, ok = claims["sub"]
|
|
if ok {
|
|
return user.(string)
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
func (s *Server) withAuth(handler httprouter.Handle) httprouter.Handle {
|
|
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
|
log.Debug("Auth required")
|
|
|
|
user := userFromJwtToken(r)
|
|
if user == "" {
|
|
user = "anonymous"
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), "user", user)
|
|
handler(w, r.WithContext(ctx), ps)
|
|
}
|
|
}
|
|
|
|
func (s *Server) Index(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
|
|
log.Debug("Index")
|
|
w.Write([]byte("Hello World"))
|
|
}
|
|
|
|
func (s *Server) GetDevices(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
|
|
user := r.Context().Value("user")
|
|
if user != ps.ByName("user") {
|
|
log.WithField("user", user).WithField("path", r.URL.Path).Warn("Unauthorized access")
|
|
w.WriteHeader(http.StatusUnauthorized)
|
|
return
|
|
}
|
|
err := json.NewEncoder(w).Encode(s.config)
|
|
if err != nil {
|
|
log.Error(err)
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
}
|
|
}
|