Skip to content

Commit

Permalink
The non handshake initiator peer start the handshake after timeout
Browse files Browse the repository at this point in the history
  • Loading branch information
pappz committed Jan 20, 2025
1 parent 795dca5 commit a8c4556
Show file tree
Hide file tree
Showing 3 changed files with 284 additions and 28 deletions.
48 changes: 20 additions & 28 deletions client/internal/peer/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy"
"github.com/netbirdio/netbird/client/internal/peer/guard"
Expand All @@ -36,10 +35,17 @@ const (
connPriorityICEP2P ConnPriority = 2
)

type WgInterface interface {
UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
RemovePeer(publicKey string) error
GetProxy() wgproxy.Proxy
GetStats(peerKey string) (configurer.WGStats, error)
}

type WgConfig struct {
WgListenPort int
RemoteKey string
WgInterface iface.IWGIface
WgInterface WgInterface
AllowedIps string
PreSharedKey *wgtypes.Key
}
Expand Down Expand Up @@ -107,6 +113,8 @@ type Conn struct {

guard *guard.Guard
semaphore *semaphoregroup.SemaphoreGroup

endpointUpdater *endpointUpdater
}

// NewConn creates a new not opened Conn to the remote peer.
Expand All @@ -133,6 +141,11 @@ func NewConn(engineCtx context.Context, config ConnConfig, statusRecorder *Statu
statusRelay: NewAtomicConnStatus(),
statusICE: NewAtomicConnStatus(),
semaphore: semaphore,
endpointUpdater: &endpointUpdater{
log: connLog,
wgConfig: config.WgConfig,
initiator: isWireGuardInitiator(config),
},
}

rFns := WorkerRelayCallbacks{
Expand Down Expand Up @@ -240,7 +253,7 @@ func (conn *Conn) Close() {
conn.wgProxyICE = nil
}

if err := conn.removeWgPeer(); err != nil {
if err := conn.endpointUpdater.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}

Expand Down Expand Up @@ -364,7 +377,7 @@ func (conn *Conn) iCEConnectionIsReady(priority ConnPriority, iceConnInfo ICECon
wgProxy.Work()
}

if err = conn.configureWGEndpoint(ep); err != nil {
if err = conn.endpointUpdater.configureWGEndpoint(ep); err != nil {
conn.handleConfigurationFailure(err, wgProxy)
return
}
Expand Down Expand Up @@ -397,7 +410,7 @@ func (conn *Conn) onWorkerICEStateDisconnected(newState ConnStatus) {
conn.log.Debugf("ICE disconnected, set Relay to active connection")
conn.wgProxyRelay.Work()

if err := conn.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
if err := conn.endpointUpdater.configureWGEndpoint(conn.wgProxyRelay.EndpointAddr()); err != nil {
conn.log.Errorf("failed to switch to relay conn: %v", err)
}
conn.workerRelay.EnableWgWatcher(conn.ctx)
Expand Down Expand Up @@ -456,7 +469,7 @@ func (conn *Conn) relayConnectionIsReady(rci RelayConnInfo) {
}

wgProxy.Work()
if err := conn.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
if err := conn.endpointUpdater.configureWGEndpoint(wgProxy.EndpointAddr()); err != nil {
if err := wgProxy.CloseConn(); err != nil {
conn.log.Warnf("Failed to close relay connection: %v", err)
}
Expand Down Expand Up @@ -486,7 +499,7 @@ func (conn *Conn) onWorkerRelayStateDisconnected() {

if conn.currentConnPriority == connPriorityRelay {
conn.log.Debugf("clean up WireGuard config")
if err := conn.removeWgPeer(); err != nil {
if err := conn.endpointUpdater.removeWgPeer(); err != nil {
conn.log.Errorf("failed to remove wg endpoint: %v", err)
}
}
Expand Down Expand Up @@ -525,23 +538,6 @@ func (conn *Conn) listenGuardEvent(ctx context.Context) {
}
}

func (conn *Conn) configureWGEndpoint(addr *net.UDPAddr) error {
var endpoint *net.UDPAddr

// Force to only one side send handshake request to avoid the handshake congestion in WireGuard connection.
// Configure up the WireGuard endpoint only on the initiator side.
if isWireGuardInitiator(conn.config) {
endpoint = addr
}
return conn.config.WgConfig.WgInterface.UpdatePeer(
conn.config.WgConfig.RemoteKey,
conn.config.WgConfig.AllowedIps,
defaultWgKeepAlive,
endpoint,
conn.config.WgConfig.PreSharedKey,
)
}

func (conn *Conn) updateRelayStatus(relayServerAddr string, rosenpassPubKey []byte) {
peerState := State{
PubKey: conn.config.Key,
Expand Down Expand Up @@ -721,10 +717,6 @@ func (conn *Conn) iceP2PIsActive() bool {
return conn.currentConnPriority == connPriorityICEP2P && conn.statusICE.Get() == StatusConnected
}

func (conn *Conn) removeWgPeer() error {
return conn.config.WgConfig.WgInterface.RemovePeer(conn.config.WgConfig.RemoteKey)
}

func (conn *Conn) handleConfigurationFailure(err error, wgProxy wgproxy.Proxy) {
conn.log.Warnf("Failed to update wg peer configuration: %v", err)
if wgProxy != nil {
Expand Down
86 changes: 86 additions & 0 deletions client/internal/peer/endpoint.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package peer

import (
"context"
"net"
"sync"
"time"

"github.com/sirupsen/logrus"
)

// fallbackDelay could be const but because of testing it is a var
var fallbackDelay = 5 * time.Second

type endpointUpdater struct {
log *logrus.Entry
wgConfig WgConfig
initiator bool

cancelFunc func()
configUpdateMutex sync.Mutex
}

// configureWGEndpoint sets up the WireGuard endpoint configuration.
// The initiator immediately configures the endpoint, while the non-initiator
// waits for a fallback period before configuring to avoid handshake congestion.
func (e *endpointUpdater) configureWGEndpoint(addr *net.UDPAddr) error {
if e.initiator {
return e.updateWireGuardPeer(addr)
}

// prevent to run new update while cancel the previous update
e.configUpdateMutex.Lock()
if e.cancelFunc != nil {
e.cancelFunc()
}
e.configUpdateMutex.Unlock()

var ctx context.Context
ctx, e.cancelFunc = context.WithCancel(context.Background())
go e.scheduleDelayedUpdate(ctx, addr)

return e.updateWireGuardPeer(nil)
}

func (e *endpointUpdater) removeWgPeer() error {
e.configUpdateMutex.Lock()
defer e.configUpdateMutex.Unlock()

if e.cancelFunc != nil {
e.cancelFunc()
}

return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey)
}

// scheduleDelayedUpdate waits for the fallback period before updating the endpoint
func (e *endpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr) {
t := time.NewTimer(fallbackDelay)
select {
case <-ctx.Done():
t.Stop()
return
case <-t.C:
e.configUpdateMutex.Lock()
defer e.configUpdateMutex.Unlock()

if ctx.Err() != nil {
return
}

if err := e.updateWireGuardPeer(addr); err != nil {
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err)
}
}
}

func (e *endpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr) error {
return e.wgConfig.WgInterface.UpdatePeer(
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
endpoint,
e.wgConfig.PreSharedKey,
)
}
178 changes: 178 additions & 0 deletions client/internal/peer/endpoint_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
package peer

import (
"net"
"testing"
"time"

log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/mock"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"

"github.com/netbirdio/netbird/client/iface/configurer"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)

type MockWgInterface struct {
mock.Mock

lastSetAddr *net.UDPAddr
}

func (m *MockWgInterface) GetStats(peerKey string) (configurer.WGStats, error) {
panic("implement me")
}

func (m *MockWgInterface) GetProxy() wgproxy.Proxy {
panic("implement me")
}

func (m *MockWgInterface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
args := m.Called(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
m.lastSetAddr = endpoint
return args.Error(0)
}

func (m *MockWgInterface) RemovePeer(publicKey string) error {
args := m.Called(publicKey)
return args.Error(0)
}

func Test_endpointUpdater_initiator(t *testing.T) {
mockWgInterface := &MockWgInterface{}
e := &endpointUpdater{
log: log.WithField("peer", "my-peer-key"),
wgConfig: WgConfig{
WgListenPort: 51820,
RemoteKey: "secret-remote-key",
WgInterface: mockWgInterface,
AllowedIps: "172.16.254.1",
},
initiator: true,
}
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1234,
}

mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
addr,
(*wgtypes.Key)(nil),
).Return(nil)

if err := e.configureWGEndpoint(addr); err != nil {
t.Fatalf("updateWireGuardPeer() failed: %v", err)
}

mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr, (*wgtypes.Key)(nil))
}

func Test_endpointUpdater_nonInitiator(t *testing.T) {
fallbackDelay = 1 * time.Second
mockWgInterface := &MockWgInterface{}
e := &endpointUpdater{
log: log.WithField("peer", "my-peer-key"),
wgConfig: WgConfig{
WgListenPort: 51820,
RemoteKey: "secret-remote-key",
WgInterface: mockWgInterface,
AllowedIps: "172.16.254.1",
},
initiator: false,
}
addr := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1234,
}

mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
(*net.UDPAddr)(nil),
(*wgtypes.Key)(nil),
).Return(nil)

mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
addr,
(*wgtypes.Key)(nil),
).Return(nil)

err := e.configureWGEndpoint(addr)
if err != nil {
t.Fatalf("updateWireGuardPeer() failed: %v", err)
}
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, (*net.UDPAddr)(nil), (*wgtypes.Key)(nil))

time.Sleep(fallbackDelay + time.Second)

mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr, (*wgtypes.Key)(nil))
}

func Test_endpointUpdater_overRule(t *testing.T) {
fallbackDelay = 1 * time.Second
mockWgInterface := &MockWgInterface{}
e := &endpointUpdater{
log: log.WithField("peer", "my-peer-key"),
wgConfig: WgConfig{
WgListenPort: 51820,
RemoteKey: "secret-remote-key",
WgInterface: mockWgInterface,
AllowedIps: "172.16.254.1",
},
initiator: false,
}
addr1 := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1000,
}

addr2 := &net.UDPAddr{
IP: net.ParseIP("127.0.0.1"),
Port: 1001,
}

mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
(*net.UDPAddr)(nil),
(*wgtypes.Key)(nil),
).Return(nil)

mockWgInterface.On(
"UpdatePeer",
e.wgConfig.RemoteKey,
e.wgConfig.AllowedIps,
defaultWgKeepAlive,
addr2,
(*wgtypes.Key)(nil),
).Return(nil)

if err := e.configureWGEndpoint(addr1); err != nil {
t.Fatalf("updateWireGuardPeer() failed: %v", err)
}
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, (*net.UDPAddr)(nil), (*wgtypes.Key)(nil))

if err := e.configureWGEndpoint(addr2); err != nil {
t.Fatalf("updateWireGuardPeer() failed: %v", err)
}

time.Sleep(fallbackDelay + time.Second)

mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr2, (*wgtypes.Key)(nil))

if mockWgInterface.lastSetAddr != addr2 {
t.Fatalf("lastSetAddr is not equal to addr2")
}
}

0 comments on commit a8c4556

Please sign in to comment.