-
-
Notifications
You must be signed in to change notification settings - Fork 561
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
The non handshake initiator peer start the handshake after timeout
- Loading branch information
Showing
3 changed files
with
284 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
package peer | ||
|
||
import ( | ||
"context" | ||
"net" | ||
"sync" | ||
"time" | ||
|
||
"github.com/sirupsen/logrus" | ||
) | ||
|
||
// fallbackDelay could be const but because of testing it is a var | ||
var fallbackDelay = 5 * time.Second | ||
|
||
type endpointUpdater struct { | ||
log *logrus.Entry | ||
wgConfig WgConfig | ||
initiator bool | ||
|
||
cancelFunc func() | ||
configUpdateMutex sync.Mutex | ||
} | ||
|
||
// configureWGEndpoint sets up the WireGuard endpoint configuration. | ||
// The initiator immediately configures the endpoint, while the non-initiator | ||
// waits for a fallback period before configuring to avoid handshake congestion. | ||
func (e *endpointUpdater) configureWGEndpoint(addr *net.UDPAddr) error { | ||
if e.initiator { | ||
return e.updateWireGuardPeer(addr) | ||
} | ||
|
||
// prevent to run new update while cancel the previous update | ||
e.configUpdateMutex.Lock() | ||
if e.cancelFunc != nil { | ||
e.cancelFunc() | ||
} | ||
e.configUpdateMutex.Unlock() | ||
|
||
var ctx context.Context | ||
ctx, e.cancelFunc = context.WithCancel(context.Background()) | ||
go e.scheduleDelayedUpdate(ctx, addr) | ||
|
||
return e.updateWireGuardPeer(nil) | ||
} | ||
|
||
func (e *endpointUpdater) removeWgPeer() error { | ||
e.configUpdateMutex.Lock() | ||
defer e.configUpdateMutex.Unlock() | ||
|
||
if e.cancelFunc != nil { | ||
e.cancelFunc() | ||
} | ||
|
||
return e.wgConfig.WgInterface.RemovePeer(e.wgConfig.RemoteKey) | ||
} | ||
|
||
// scheduleDelayedUpdate waits for the fallback period before updating the endpoint | ||
func (e *endpointUpdater) scheduleDelayedUpdate(ctx context.Context, addr *net.UDPAddr) { | ||
t := time.NewTimer(fallbackDelay) | ||
select { | ||
case <-ctx.Done(): | ||
t.Stop() | ||
return | ||
case <-t.C: | ||
e.configUpdateMutex.Lock() | ||
defer e.configUpdateMutex.Unlock() | ||
|
||
if ctx.Err() != nil { | ||
return | ||
} | ||
|
||
if err := e.updateWireGuardPeer(addr); err != nil { | ||
e.log.Errorf("failed to update WireGuard peer, address: %s, error: %v", addr, err) | ||
} | ||
} | ||
} | ||
|
||
func (e *endpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr) error { | ||
return e.wgConfig.WgInterface.UpdatePeer( | ||
e.wgConfig.RemoteKey, | ||
e.wgConfig.AllowedIps, | ||
defaultWgKeepAlive, | ||
endpoint, | ||
e.wgConfig.PreSharedKey, | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
package peer | ||
|
||
import ( | ||
"net" | ||
"testing" | ||
"time" | ||
|
||
log "github.com/sirupsen/logrus" | ||
"github.com/stretchr/testify/mock" | ||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" | ||
|
||
"github.com/netbirdio/netbird/client/iface/configurer" | ||
"github.com/netbirdio/netbird/client/iface/wgproxy" | ||
) | ||
|
||
type MockWgInterface struct { | ||
mock.Mock | ||
|
||
lastSetAddr *net.UDPAddr | ||
} | ||
|
||
func (m *MockWgInterface) GetStats(peerKey string) (configurer.WGStats, error) { | ||
panic("implement me") | ||
} | ||
|
||
func (m *MockWgInterface) GetProxy() wgproxy.Proxy { | ||
panic("implement me") | ||
} | ||
|
||
func (m *MockWgInterface) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error { | ||
args := m.Called(peerKey, allowedIps, keepAlive, endpoint, preSharedKey) | ||
m.lastSetAddr = endpoint | ||
return args.Error(0) | ||
} | ||
|
||
func (m *MockWgInterface) RemovePeer(publicKey string) error { | ||
args := m.Called(publicKey) | ||
return args.Error(0) | ||
} | ||
|
||
func Test_endpointUpdater_initiator(t *testing.T) { | ||
mockWgInterface := &MockWgInterface{} | ||
e := &endpointUpdater{ | ||
log: log.WithField("peer", "my-peer-key"), | ||
wgConfig: WgConfig{ | ||
WgListenPort: 51820, | ||
RemoteKey: "secret-remote-key", | ||
WgInterface: mockWgInterface, | ||
AllowedIps: "172.16.254.1", | ||
}, | ||
initiator: true, | ||
} | ||
addr := &net.UDPAddr{ | ||
IP: net.ParseIP("127.0.0.1"), | ||
Port: 1234, | ||
} | ||
|
||
mockWgInterface.On( | ||
"UpdatePeer", | ||
e.wgConfig.RemoteKey, | ||
e.wgConfig.AllowedIps, | ||
defaultWgKeepAlive, | ||
addr, | ||
(*wgtypes.Key)(nil), | ||
).Return(nil) | ||
|
||
if err := e.configureWGEndpoint(addr); err != nil { | ||
t.Fatalf("updateWireGuardPeer() failed: %v", err) | ||
} | ||
|
||
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr, (*wgtypes.Key)(nil)) | ||
} | ||
|
||
func Test_endpointUpdater_nonInitiator(t *testing.T) { | ||
fallbackDelay = 1 * time.Second | ||
mockWgInterface := &MockWgInterface{} | ||
e := &endpointUpdater{ | ||
log: log.WithField("peer", "my-peer-key"), | ||
wgConfig: WgConfig{ | ||
WgListenPort: 51820, | ||
RemoteKey: "secret-remote-key", | ||
WgInterface: mockWgInterface, | ||
AllowedIps: "172.16.254.1", | ||
}, | ||
initiator: false, | ||
} | ||
addr := &net.UDPAddr{ | ||
IP: net.ParseIP("127.0.0.1"), | ||
Port: 1234, | ||
} | ||
|
||
mockWgInterface.On( | ||
"UpdatePeer", | ||
e.wgConfig.RemoteKey, | ||
e.wgConfig.AllowedIps, | ||
defaultWgKeepAlive, | ||
(*net.UDPAddr)(nil), | ||
(*wgtypes.Key)(nil), | ||
).Return(nil) | ||
|
||
mockWgInterface.On( | ||
"UpdatePeer", | ||
e.wgConfig.RemoteKey, | ||
e.wgConfig.AllowedIps, | ||
defaultWgKeepAlive, | ||
addr, | ||
(*wgtypes.Key)(nil), | ||
).Return(nil) | ||
|
||
err := e.configureWGEndpoint(addr) | ||
if err != nil { | ||
t.Fatalf("updateWireGuardPeer() failed: %v", err) | ||
} | ||
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, (*net.UDPAddr)(nil), (*wgtypes.Key)(nil)) | ||
|
||
time.Sleep(fallbackDelay + time.Second) | ||
|
||
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr, (*wgtypes.Key)(nil)) | ||
} | ||
|
||
func Test_endpointUpdater_overRule(t *testing.T) { | ||
fallbackDelay = 1 * time.Second | ||
mockWgInterface := &MockWgInterface{} | ||
e := &endpointUpdater{ | ||
log: log.WithField("peer", "my-peer-key"), | ||
wgConfig: WgConfig{ | ||
WgListenPort: 51820, | ||
RemoteKey: "secret-remote-key", | ||
WgInterface: mockWgInterface, | ||
AllowedIps: "172.16.254.1", | ||
}, | ||
initiator: false, | ||
} | ||
addr1 := &net.UDPAddr{ | ||
IP: net.ParseIP("127.0.0.1"), | ||
Port: 1000, | ||
} | ||
|
||
addr2 := &net.UDPAddr{ | ||
IP: net.ParseIP("127.0.0.1"), | ||
Port: 1001, | ||
} | ||
|
||
mockWgInterface.On( | ||
"UpdatePeer", | ||
e.wgConfig.RemoteKey, | ||
e.wgConfig.AllowedIps, | ||
defaultWgKeepAlive, | ||
(*net.UDPAddr)(nil), | ||
(*wgtypes.Key)(nil), | ||
).Return(nil) | ||
|
||
mockWgInterface.On( | ||
"UpdatePeer", | ||
e.wgConfig.RemoteKey, | ||
e.wgConfig.AllowedIps, | ||
defaultWgKeepAlive, | ||
addr2, | ||
(*wgtypes.Key)(nil), | ||
).Return(nil) | ||
|
||
if err := e.configureWGEndpoint(addr1); err != nil { | ||
t.Fatalf("updateWireGuardPeer() failed: %v", err) | ||
} | ||
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, (*net.UDPAddr)(nil), (*wgtypes.Key)(nil)) | ||
|
||
if err := e.configureWGEndpoint(addr2); err != nil { | ||
t.Fatalf("updateWireGuardPeer() failed: %v", err) | ||
} | ||
|
||
time.Sleep(fallbackDelay + time.Second) | ||
|
||
mockWgInterface.AssertCalled(t, "UpdatePeer", e.wgConfig.RemoteKey, e.wgConfig.AllowedIps, defaultWgKeepAlive, addr2, (*wgtypes.Key)(nil)) | ||
|
||
if mockWgInterface.lastSetAddr != addr2 { | ||
t.Fatalf("lastSetAddr is not equal to addr2") | ||
} | ||
} |