Add wireguard peer configuration

This commit is contained in:
Daniel Lundin
2019-05-11 20:03:22 +02:00
parent 6930e9418a
commit 679a101d4f
+54 -5
View File
@@ -20,6 +20,7 @@ import (
"github.com/vishvananda/netlink" "github.com/vishvananda/netlink"
"github.com/vishvananda/netns" "github.com/vishvananda/netns"
"gopkg.in/alecthomas/kingpin.v2" "gopkg.in/alecthomas/kingpin.v2"
"net"
) )
var ( var (
@@ -32,12 +33,14 @@ var (
wgLinkAddr = kingpin.Flag("wg-link-addr", "Wireguard interface address").Default("172.72.72.1/32").String() wgLinkAddr = kingpin.Flag("wg-link-addr", "Wireguard interface address").Default("172.72.72.1/32").String()
wgListenPort = kingpin.Flag("wg-listen-port", "Wireguard UDP port to listen to").Default("51820").Int() wgListenPort = kingpin.Flag("wg-listen-port", "Wireguard UDP port to listen to").Default("51820").Int()
wgEndpoint = kingpin.Flag("wg-endpoint", "Wireguard endpoint address").Default("127.0.0.1").String() wgEndpoint = kingpin.Flag("wg-endpoint", "Wireguard endpoint address").Default("127.0.0.1").String()
wgAllowedIPs = kingpin.Flag("wg-client-allowed-ips", "Wireguard allowed ips for ").Default("0.0.0.0/0").Strings()
) )
type Server struct { type Server struct {
serverConfigPath string serverConfigPath string
mutex sync.RWMutex mutex sync.RWMutex
Config *ServerConfig Config *ServerConfig
allowedIPs []net.IPNet
} }
type WgLink struct { type WgLink struct {
@@ -62,6 +65,15 @@ func ifname(n string) []byte {
} }
func NewServer() *Server { func NewServer() *Server {
allowedIPs := make([]net.IPNet, 0)
for _, ip := range *wgAllowedIPs {
_, ipnet, err := net.ParseCIDR(ip)
if err != nil {
log.Fatal(err)
}
allowedIPs = append(allowedIPs, *ipnet)
}
err := os.MkdirAll(*dataDir, 0700) err := os.MkdirAll(*dataDir, 0700)
if err != nil { if err != nil {
log.WithError(err).Fatalf("Error initializing data directory: %s", *dataDir) log.WithError(err).Fatalf("Error initializing data directory: %s", *dataDir)
@@ -75,6 +87,7 @@ func NewServer() *Server {
s := Server{ s := Server{
serverConfigPath: cfgPath, serverConfigPath: cfgPath,
Config: config, Config: config,
allowedIPs: allowedIPs,
} }
log.Debug("Server initialized: ", *dataDir) log.Debug("Server initialized: ", *dataDir)
@@ -155,8 +168,11 @@ func (s *Server) initInterface() error {
}, },
}) })
conn.Flush() return conn.Flush()
}
func (s *Server) configureWireguard() error {
log.Debugf("Reconfiguring wireguard interface %s", *wgLinkName)
wg, err := wireguardctrl.New() wg, err := wireguardctrl.New()
if err != nil { if err != nil {
return err return err
@@ -168,9 +184,31 @@ func (s *Server) initInterface() error {
return err return err
} }
peers := make([]wgtypes.PeerConfig, 0)
for user, cfg := range s.Config.Users {
for id, dev := range cfg.Devices {
pubKey, err := wgtypes.ParseKey(dev.PublicKey)
if err != nil {
return err
}
peer := wgtypes.PeerConfig{
PublicKey: pubKey,
ReplaceAllowedIPs: true,
AllowedIPs: s.allowedIPs,
}
log.WithFields(log.Fields{"user": user, "device": id, "key": dev.PublicKey}).Debug("Adding wireguard peer")
peers = append(peers, peer)
}
}
cfg := wgtypes.Config{ cfg := wgtypes.Config{
PrivateKey: &key, PrivateKey: &key,
ListenPort: wgListenPort, ListenPort: wgListenPort,
ReplacePeers: true,
Peers: peers,
} }
wg.ConfigureDevice(*wgLinkName, cfg) wg.ConfigureDevice(*wgLinkName, cfg)
@@ -183,6 +221,11 @@ func (s *Server) Start() error {
return err return err
} }
err = s.configureWireguard()
if err != nil {
return err
}
router := httprouter.New() router := httprouter.New()
router.GET("/", s.Index) router.GET("/", s.Index)
router.GET("/api/v1/users/:user/devices", s.withAuth(s.GetDevices)) router.GET("/api/v1/users/:user/devices", s.withAuth(s.GetDevices))
@@ -286,8 +329,8 @@ func (s *Server) CreateDevice(w http.ResponseWriter, r *http.Request, ps httprou
} }
i = i + 1 i = i + 1
c.Devices[strconv.Itoa(i)] = NewDeviceConfig() device := NewDeviceConfig()
c.Devices[strconv.Itoa(i)] = device
err := s.Config.Write() err := s.Config.Write()
if err != nil { if err != nil {
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
@@ -299,4 +342,10 @@ func (s *Server) CreateDevice(w http.ResponseWriter, r *http.Request, ps httprou
log.Error(err) log.Error(err)
w.WriteHeader(http.StatusInternalServerError) w.WriteHeader(http.StatusInternalServerError)
} }
err = s.configureWireguard()
if err != nil {
log.Error(err)
w.WriteHeader(http.StatusInternalServerError)
}
} }