Skip to content

Commit

Permalink
Improve routing decision logic
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Jan 9, 2025
1 parent 6335ef8 commit 706f98c
Showing 1 changed file with 61 additions and 53 deletions.
114 changes: 61 additions & 53 deletions client/firewall/uspfilter/uspfilter.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package uspfilter

import (
"errors"
"fmt"
"net"
"net/netip"
Expand Down Expand Up @@ -31,11 +32,6 @@ const (
// EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped.
EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING"

// EnvForceNativeRouter forces forwarding to the native stack (even if doesn't support routing).
// This is useful when routing/firewall setup is done manually instead of by netbird.
// This setting always disables userspace routing and filtering of routed traffic.
EnvForceNativeRouter = "NB_FORCE_NATIVE_ROUTER"

// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
)
Expand Down Expand Up @@ -88,49 +84,23 @@ type decoder struct {

// Create userspace firewall manager constructor
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
return create(iface, disableServerRoutes)
return create(iface, nil, disableServerRoutes)
}

func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
mgr, err := create(iface, disableServerRoutes)
if err != nil {
return nil, err
if nativeFirewall == nil {
return nil, errors.New("native firewall is nil")
}

mgr.nativeFirewall = nativeFirewall

if disableServerRoutes {
// skip native vs userspace router logic altogether
return mgr, nil
}

if forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter)); forceUserspaceRouter {
log.Info("userspace routing is forced")
return mgr, nil
}

forceNativeRouter, _ := strconv.ParseBool(EnvForceNativeRouter)

// if the OS supports routing natively, or it is explicitly requested, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface
if forceNativeRouter ||
!netstack.IsEnabled() && mgr.nativeFirewall != nil && mgr.nativeFirewall.IsServerRouteSupported() {

mgr.nativeRouter = true
mgr.routingEnabled = true
if mgr.forwarder != nil {
mgr.forwarder.Stop()
}

log.Info("native routing is enabled")
return mgr, nil
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
if err != nil {
return nil, err
}

log.Info("userspace routing is enabled")
return mgr, nil
}

func create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack))

m := &Manager{
Expand All @@ -147,16 +117,16 @@ func create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error
return d
},
},
nativeFirewall: nativeFirewall,
outgoingRules: make(map[string]RuleSet),
incomingRules: make(map[string]RuleSet),
routeRules: make(map[string]RouteRule),
wgIface: iface,
localipmanager: newLocalIPManager(),
routingEnabled: false,
stateful: !disableConntrack,
// TODO: support changing log level from logrus
logger: nblog.NewFromLogrus(log.StandardLogger()),
netstack: netstack.IsEnabled(),
logger: nblog.NewFromLogrus(log.StandardLogger()),
netstack: netstack.IsEnabled(),
}

if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
Expand All @@ -172,24 +142,62 @@ func create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
}

m.determineRouting(iface, disableServerRoutes)

if err := iface.SetFilter(m); err != nil {
return nil, fmt.Errorf("set filter: %w", err)
}
return m, nil
}

func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes bool) {
disableUspRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting))
if disableUspRouting || disableServerRoutes {
forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter))

switch {
case disableUspRouting:
m.routingEnabled = false
m.nativeRouter = false
log.Info("userspace routing is disabled")
} else {

case disableServerRoutes:
// if server routes are disabled we will let packets pass to the native stack
m.routingEnabled = true
m.nativeRouter = true

log.Info("server routes are disabled")

case forceUserspaceRouter:
m.routingEnabled = true
m.nativeRouter = false

log.Info("userspace routing is forced")

case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
// if the OS supports routing natively, then we don't need to filter/route ourselves
// netstack mode won't support native routing as there is no interface

m.routingEnabled = true
m.nativeRouter = true

log.Info("native routing is enabled")

default:
m.routingEnabled = true
m.nativeRouter = false

log.Info("userspace routing enabled by default")
}

// netstack needs the forwarder for local traffic
if m.netstack || m.routingEnabled {
m.initForwarder(iface)
}
if m.netstack ||
m.routingEnabled && !m.nativeRouter {

if err := iface.SetFilter(m); err != nil {
return nil, fmt.Errorf("set filter: %w", err)
m.initForwarder(iface)
}
return m, nil
}

// initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder(iface common.IFaceMapper) {
// Only supported in userspace mode as we need to inject packets back into wireguard directly
intf := iface.GetWGDevice()
Expand Down Expand Up @@ -218,7 +226,7 @@ func (m *Manager) IsServerRouteSupported() bool {
}

func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
if m.nativeRouter {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.AddNatRule(pair)
}

Expand All @@ -229,7 +237,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {

// RemoveNatRule removes a routing firewall rule
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
if m.nativeRouter {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.RemoveNatRule(pair)
}
return nil
Expand Down Expand Up @@ -313,7 +321,7 @@ func (m *Manager) AddRouteFiltering(
dPort *firewall.Port,
action firewall.Action,
) (firewall.Rule, error) {
if m.nativeRouter {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
}

Expand All @@ -337,7 +345,7 @@ func (m *Manager) AddRouteFiltering(
}

func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
if m.nativeRouter {
if m.nativeRouter && m.nativeFirewall != nil {
return m.nativeFirewall.DeleteRouteRule(rule)
}

Expand Down

0 comments on commit 706f98c

Please sign in to comment.