diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index a5c29e7abdc..81b70940174 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -14,7 +14,6 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" - "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/wgproxy" "github.com/netbirdio/netbird/client/internal/peer/guard" @@ -36,10 +35,17 @@ const ( connPriorityICEP2P ConnPriority = 2 ) +type WgInterface interface { + UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error + RemovePeer(publicKey string) error + GetProxy() wgproxy.Proxy + GetStats(peerKey string) (configurer.WGStats, error) +} + type WgConfig struct { WgListenPort int RemoteKey string - WgInterface iface.IWGIface + WgInterface WgInterface AllowedIps string PreSharedKey *wgtypes.Key } @@ -107,6 +113,8 @@ type Conn struct { guard *guard.Guard semaphore *semaphoregroup.SemaphoreGroup + + endpointUpdater *endpointUpdater } // NewConn creates a new not opened Conn to the remote peer. @@ -133,6 +141,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu statusRelay: NewAtomicConnStatus(), statusICE: NewAtomicConnStatus(), semaphore: semaphore, + endpointUpdater: &endpointUpdater{ + log: connLog, + wgConfig: config.WgConfig, + initiator: isWireGuardInitiator(config), + }, } rFns := WorkerRelayCallbacks{ @@ -240,7 +253,7 @@ func (conn *Conn) Close() { conn.wgProxyICE = nil } - if err := conn.removeWgPeer(); err != nil { + if err := conn.endpointUpdater.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } @@ -364,7 +377,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon wgProxy.Work() } - if err = conn.configureWGEndpoint(ep); err != nil { + if err = conn.endpointUpdater.configureWGEndpoint(ep); err != nil { conn.handleConfigurationFailure(err, wgProxy) return } @@ -397,7 +410,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) { conn.log.Debugf("ICE disconnected, set Relay to active connection") conn.wgProxyRelay.Work() - if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { + if err := conn.endpointUpdater.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil { conn.log.Errorf("failed to switch to relay conn: %v", err) } conn.workerRelay.EnableWgWatcher(conn.ctx) @@ -456,7 +469,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) { } wgProxy.Work() - if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil { + if err := conn.endpointUpdater.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil { if err := wgProxy.CloseConn(); err != nil { conn.log.Warnf("Failed to close relay connection: %v", err) } @@ -486,7 +499,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() { if conn.currentConnPriority == connPriorityRelay { conn.log.Debugf("clean up WireGuard config") - if err := conn.removeWgPeer(); err != nil { + if err := conn.endpointUpdater.removeWgPeer(); err != nil { conn.log.Errorf("failed to remove wg endpoint: %v", err) } } @@ -525,23 +538,6 @@ func (conn *Conn) listenGuardEvent(ctx context.Context) { } } -func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error { - var endpoint *net.UDPAddr - - // Force to only one side send handshake request to avoid the handshake congestion in WireGuard connection. - // Configure up the WireGuard endpoint only on the initiator side. - if isWireGuardInitiator(conn.config) { - endpoint = addr - } - return conn.config.WgConfig.WgInterface.UpdatePeer( - conn.config.WgConfig.RemoteKey, - conn.config.WgConfig.AllowedIps, - defaultWgKeepAlive, - endpoint, - conn.config.WgConfig.PreSharedKey, - ) -} - func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) { peerState := State{ PubKey: conn.config.Key, @@ -721,10 +717,6 @@ func (conn *Conn) iceP2PIsActive() bool { return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected } -func (conn *Conn) removeWgPeer() error { - return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey) -} - func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) { conn.log.Warnf("Failed to update wg peer configuration: %v", err) if wgProxy != nil { diff --git a/client/internal/peer/endpoint.go b/client/internal/peer/endpoint.go new file mode 100644 index 00000000000..c2392bf59c9 --- /dev/null +++ b/client/internal/peer/endpoint.go @@ -0,0 +1,86 @@ +package peer + +import ( + "context" + "net" + "sync" + "time" + + "github.com/sirupsen/logrus" +) + +// fallbackDelay could be const but because of testing it is a var +var fallbackDelay = 5 * time.Second + +type endpointUpdater struct { + log *logrus.Entry + wgConfig WgConfig + initiator bool + + cancelFunc func() + configUpdateMutex sync.Mutex +} + +// configureWGEndpoint sets up the WireGuard endpoint configuration. +// The initiator immediately configures the endpoint, while the non-initiator +// waits for a fallback period before configuring to avoid handshake congestion. +func (e *endpointUpdater) configureWGEndpoint(addr *net.UDPAddr) error { + if e.initiator { + return e.updateWireGuardPeer(addr) + } + + // prevent to run new update while cancel the previous update + e.configUpdateMutex.Lock() + if e.cancelFunc != nil { + e.cancelFunc() + } + e.configUpdateMutex.Unlock() + + var ctx context.Context + ctx, e.cancelFunc = context.WithCancel(context.Background()) + go e.scheduleDelayedUpdate(ctx, addr) + + return e.updateWireGuardPeer(nil) +} + +func (e *endpointUpdater) removeWgPeer() error { + e.configUpdateMutex.Lock() + defer e.configUpdateMutex.Unlock() + + if e.cancelFunc != nil { + e.cancelFunc() + } + + return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey) +} + +// scheduleDelayedUpdate waits for the fallback period before updating the endpoint +func (e *endpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr) { + t := time.NewTimer(fallbackDelay) + select { + case <-ctx.Done(): + t.Stop() + return + case <-t.C: + e.configUpdateMutex.Lock() + defer e.configUpdateMutex.Unlock() + + if ctx.Err() != nil { + return + } + + if err := e.updateWireGuardPeer(addr); err != nil { + e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err) + } + } +} + +func (e *endpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr) error { + return e.wgConfig.WgInterface.UpdatePeer( + e.wgConfig.RemoteKey, + e.wgConfig.AllowedIps, + defaultWgKeepAlive, + endpoint, + e.wgConfig.PreSharedKey, + ) +} diff --git a/client/internal/peer/endpoint_test.go b/client/internal/peer/endpoint_test.go new file mode 100644 index 00000000000..ec980b7d7c8 --- /dev/null +++ b/client/internal/peer/endpoint_test.go @@ -0,0 +1,178 @@ +package peer + +import ( + "net" + "testing" + "time" + + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/mock" + "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + + "github.com/netbirdio/netbird/client/iface/configurer" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +type MockWgInterface struct { + mock.Mock + + lastSetAddr *net.UDPAddr +} + +func (m *MockWgInterface) GetStats(peerKey string) (configurer.WGStats, error) { + panic("implement me") +} + +func (m *MockWgInterface) GetProxy() wgproxy.Proxy { + panic("implement me") +} + +func (m *MockWgInterface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { + args := m.Called(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) + m.lastSetAddr = endpoint + return args.Error(0) +} + +func (m *MockWgInterface) RemovePeer(publicKey string) error { + args := m.Called(publicKey) + return args.Error(0) +} + +func Test_endpointUpdater_initiator(t *testing.T) { + mockWgInterface := &MockWgInterface{} + e := &endpointUpdater{ + log: log.WithField("peer", "my-peer-key"), + wgConfig: WgConfig{ + WgListenPort: 51820, + RemoteKey: "secret-remote-key", + WgInterface: mockWgInterface, + AllowedIps: "172.16.254.1", + }, + initiator: true, + } + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 1234, + } + + mockWgInterface.On( + "UpdatePeer", + e.wgConfig.RemoteKey, + e.wgConfig.AllowedIps, + defaultWgKeepAlive, + addr, + (*wgtypes.Key)(nil), + ).Return(nil) + + if err := e.configureWGEndpoint(addr); err != nil { + t.Fatalf("updateWireGuardPeer() failed: %v", err) + } + + mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr, (*wgtypes.Key)(nil)) +} + +func Test_endpointUpdater_nonInitiator(t *testing.T) { + fallbackDelay = 1 * time.Second + mockWgInterface := &MockWgInterface{} + e := &endpointUpdater{ + log: log.WithField("peer", "my-peer-key"), + wgConfig: WgConfig{ + WgListenPort: 51820, + RemoteKey: "secret-remote-key", + WgInterface: mockWgInterface, + AllowedIps: "172.16.254.1", + }, + initiator: false, + } + addr := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 1234, + } + + mockWgInterface.On( + "UpdatePeer", + e.wgConfig.RemoteKey, + e.wgConfig.AllowedIps, + defaultWgKeepAlive, + (*net.UDPAddr)(nil), + (*wgtypes.Key)(nil), + ).Return(nil) + + mockWgInterface.On( + "UpdatePeer", + e.wgConfig.RemoteKey, + e.wgConfig.AllowedIps, + defaultWgKeepAlive, + addr, + (*wgtypes.Key)(nil), + ).Return(nil) + + err := e.configureWGEndpoint(addr) + if err != nil { + t.Fatalf("updateWireGuardPeer() failed: %v", err) + } + mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, (*net.UDPAddr)(nil), (*wgtypes.Key)(nil)) + + time.Sleep(fallbackDelay + time.Second) + + mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr, (*wgtypes.Key)(nil)) +} + +func Test_endpointUpdater_overRule(t *testing.T) { + fallbackDelay = 1 * time.Second + mockWgInterface := &MockWgInterface{} + e := &endpointUpdater{ + log: log.WithField("peer", "my-peer-key"), + wgConfig: WgConfig{ + WgListenPort: 51820, + RemoteKey: "secret-remote-key", + WgInterface: mockWgInterface, + AllowedIps: "172.16.254.1", + }, + initiator: false, + } + addr1 := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 1000, + } + + addr2 := &net.UDPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: 1001, + } + + mockWgInterface.On( + "UpdatePeer", + e.wgConfig.RemoteKey, + e.wgConfig.AllowedIps, + defaultWgKeepAlive, + (*net.UDPAddr)(nil), + (*wgtypes.Key)(nil), + ).Return(nil) + + mockWgInterface.On( + "UpdatePeer", + e.wgConfig.RemoteKey, + e.wgConfig.AllowedIps, + defaultWgKeepAlive, + addr2, + (*wgtypes.Key)(nil), + ).Return(nil) + + if err := e.configureWGEndpoint(addr1); err != nil { + t.Fatalf("updateWireGuardPeer() failed: %v", err) + } + mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, (*net.UDPAddr)(nil), (*wgtypes.Key)(nil)) + + if err := e.configureWGEndpoint(addr2); err != nil { + t.Fatalf("updateWireGuardPeer() failed: %v", err) + } + + time.Sleep(fallbackDelay + time.Second) + + mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr2, (*wgtypes.Key)(nil)) + + if mockWgInterface.lastSetAddr != addr2 { + t.Fatalf("lastSetAddr is not equal to addr2") + } +}