Files
wg-ui/server.go
T
2019-05-03 20:06:03 +02:00

246 lines
5.4 KiB
Go

package main
import (
"context"
"encoding/json"
"net/http"
"os"
"strings"
"github.com/dgrijalva/jwt-go"
"github.com/google/nftables"
"github.com/google/nftables/expr"
"github.com/julienschmidt/httprouter"
"github.com/mdlayher/wireguardctrl"
"github.com/mdlayher/wireguardctrl/wgtypes"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
"github.com/vishvananda/netns"
"gopkg.in/alecthomas/kingpin.v2"
)
var (
listenAddr = kingpin.Flag("listen-address", "Address to listen to").Default(":8080").String()
natLink = kingpin.Flag("nat-device", "Network interface to masquerade").Default("wlp2s0").String()
wgLinkName = kingpin.Flag("wg-device-name", "Wireguard network device name").Default("wg0").String()
wgLinkAddr = kingpin.Flag("wg-link-addr", "Wireguard interface address").Default("172.72.72.1/32").String()
wgListenPort = kingpin.Flag("wg-listen-port", "Wireguard UDP port to listen to").Default("51820").Int()
wgEndpoint = kingpin.Flag("wg-endpoint", "Wireguard endpoint address").Default("127.0.0.1").String()
)
type Server struct {
storage *Storage
config *ServerConfig
}
type WgLink struct {
attrs *netlink.LinkAttrs
}
type jwtClaims struct {
}
func (w *WgLink) Attrs() *netlink.LinkAttrs {
return w.attrs
}
func (w *WgLink) Type() string {
return "wireguard"
}
func ifname(n string) []byte {
b := make([]byte, 16)
copy(b, []byte(n+"\x00"))
return b
}
func NewServer() *Server {
storage := NewStorage()
server := &Server{
storage: storage,
config: storage.GetServerConfig(),
}
return server
}
func (s *Server) initInterface() error {
attrs := netlink.NewLinkAttrs()
attrs.Name = *wgLinkName
link := WgLink{
attrs: &attrs,
}
log.Debug("Adding wireguard device: ", *wgLinkName)
err := netlink.LinkAdd(&link)
if os.IsExist(err) {
log.Infof("Wireguard interface %s already exists. Reusing.", *wgLinkName)
} else if err != nil {
return err
}
log.Debug("Adding ip address to wireguard device: ", *wgLinkAddr)
addr, _ := netlink.ParseAddr(*wgLinkAddr)
err = netlink.AddrAdd(&link, addr)
if os.IsExist(err) {
log.Infof("Wireguard interface %s already has the requested address: ", *wgLinkAddr)
} else if err != nil {
return err
}
log.Debug("Adding NAT / IP masquerading using nftables")
ns, err := netns.Get()
if err != nil {
return err
}
conn := nftables.Conn{NetNS: int(ns)}
log.Debug("Flushing nftable rulesets")
// conn.FlushRuleset()
log.Debug("Setting up nftable rules for ip masquerading")
nat := conn.AddTable(&nftables.Table{
Family: nftables.TableFamilyIPv4,
Name: "nat",
})
conn.AddChain(&nftables.Chain{
Name: "prerouting",
Table: nat,
Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPrerouting,
Priority: nftables.ChainPriorityFilter,
})
post := conn.AddChain(&nftables.Chain{
Name: "postrouting",
Table: nat,
Type: nftables.ChainTypeNAT,
Hooknum: nftables.ChainHookPostrouting,
Priority: nftables.ChainPriorityNATSource,
})
conn.AddRule(&nftables.Rule{
Table: nat,
Chain: post,
Exprs: []expr.Any{
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: ifname(*natLink),
},
&expr.Masq{},
},
})
conn.Flush()
wg, err := wireguardctrl.New()
if err != nil {
return err
}
log.Debug("Adding wireguard private key")
key, err := wgtypes.ParseKey(s.config.PrivateKey)
if err != nil {
return err
}
cfg := wgtypes.Config{
PrivateKey: &key,
ListenPort: wgListenPort,
}
wg.ConfigureDevice(*wgLinkName, cfg)
return nil
}
func (s *Server) Start() error {
err := s.initInterface()
if err != nil {
return err
}
router := httprouter.New()
router.GET("/", s.Index)
router.GET("/api/v1/users/:user/devices", s.withAuth(s.GetDevices))
log.WithField("listenAddr", *listenAddr).Info("Starting server")
return http.ListenAndServe(*listenAddr, router)
}
func userFromJwtToken(r *http.Request) string {
authHeader := r.Header.Get("authorization")
if authHeader == "" {
log.Debug("No Authorization header")
return ""
}
if !strings.HasPrefix(authHeader, "Bearer ") {
log.Debug("Incorrect Authorization header: ", authHeader)
return ""
}
claims := jwt.MapClaims{}
token, err := jwt.ParseWithClaims(authHeader[7:], &claims, func(token *jwt.Token) (interface{}, error) {
return []byte(""), nil
})
if token == nil {
log.Debug("Error parsing JWT token: ", err)
return ""
}
user, ok := claims["email"]
if ok {
return user.(string)
}
user, ok = claims["sub"]
if ok {
return user.(string)
}
return ""
}
func (s *Server) withAuth(handler httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
log.Debug("Auth required")
user := userFromJwtToken(r)
if user == "" {
user = "anonymous"
}
ctx := context.WithValue(r.Context(), "user", user)
handler(w, r.WithContext(ctx), ps)
}
}
func (s *Server) Index(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
log.Debug("Index")
w.Write([]byte("Hello World"))
}
func (s *Server) GetDevices(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
user := r.Context().Value("user")
if user != ps.ByName("user") {
log.WithField("user", user).WithField("path", r.URL.Path).Warn("Unauthorized access")
w.WriteHeader(http.StatusUnauthorized)
return
}
err := json.NewEncoder(w).Encode(s.config)
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
}
}