diff --git a/client/Dockerfile-rootless b/client/Dockerfile-rootless index 62bcaf964bd..78314ba121c 100644 --- a/client/Dockerfile-rootless +++ b/client/Dockerfile-rootless @@ -9,6 +9,7 @@ USER netbird:netbird ENV NB_FOREGROUND_MODE=true ENV NB_USE_NETSTACK_MODE=true +ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true ENV NB_CONFIG=config.json ENV NB_DAEMON_ADDR=unix://netbird.sock ENV NB_DISABLE_DNS=true diff --git a/client/cmd/trace.go b/client/cmd/trace.go new file mode 100644 index 00000000000..b2ff1f1b54e --- /dev/null +++ b/client/cmd/trace.go @@ -0,0 +1,137 @@ +package cmd + +import ( + "fmt" + "math/rand" + "strings" + + "github.com/spf13/cobra" + "google.golang.org/grpc/status" + + "github.com/netbirdio/netbird/client/proto" +) + +var traceCmd = &cobra.Command{ + Use: "trace ", + Short: "Trace a packet through the firewall", + Example: ` + netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack + netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53 + netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0 + netbird debug trace in 100.64.1.1 self -p tcp --dport 80`, + Args: cobra.ExactArgs(3), + RunE: tracePacket, +} + +func init() { + debugCmd.AddCommand(traceCmd) + + traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)") + traceCmd.Flags().Uint16("sport", 0, "Source port") + traceCmd.Flags().Uint16("dport", 0, "Destination port") + traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type") + traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code") + traceCmd.Flags().Bool("syn", false, "TCP SYN flag") + traceCmd.Flags().Bool("ack", false, "TCP ACK flag") + traceCmd.Flags().Bool("fin", false, "TCP FIN flag") + traceCmd.Flags().Bool("rst", false, "TCP RST flag") + traceCmd.Flags().Bool("psh", false, "TCP PSH flag") + traceCmd.Flags().Bool("urg", false, "TCP URG flag") +} + +func tracePacket(cmd *cobra.Command, args []string) error { + direction := strings.ToLower(args[0]) + if direction != "in" && direction != "out" { + return fmt.Errorf("invalid direction: use 'in' or 'out'") + } + + protocol := cmd.Flag("protocol").Value.String() + if protocol != "tcp" && protocol != "udp" && protocol != "icmp" { + return fmt.Errorf("invalid protocol: use tcp/udp/icmp") + } + + sport, err := cmd.Flags().GetUint16("sport") + if err != nil { + return fmt.Errorf("invalid source port: %v", err) + } + dport, err := cmd.Flags().GetUint16("dport") + if err != nil { + return fmt.Errorf("invalid destination port: %v", err) + } + + // For TCP/UDP, generate random ephemeral port (49152-65535) if not specified + if protocol != "icmp" { + if sport == 0 { + sport = uint16(rand.Intn(16383) + 49152) + } + if dport == 0 { + dport = uint16(rand.Intn(16383) + 49152) + } + } + + var tcpFlags *proto.TCPFlags + if protocol == "tcp" { + syn, _ := cmd.Flags().GetBool("syn") + ack, _ := cmd.Flags().GetBool("ack") + fin, _ := cmd.Flags().GetBool("fin") + rst, _ := cmd.Flags().GetBool("rst") + psh, _ := cmd.Flags().GetBool("psh") + urg, _ := cmd.Flags().GetBool("urg") + + tcpFlags = &proto.TCPFlags{ + Syn: syn, + Ack: ack, + Fin: fin, + Rst: rst, + Psh: psh, + Urg: urg, + } + } + + icmpType, _ := cmd.Flags().GetUint32("icmp-type") + icmpCode, _ := cmd.Flags().GetUint32("icmp-code") + + conn, err := getClient(cmd) + if err != nil { + return err + } + defer conn.Close() + + client := proto.NewDaemonServiceClient(conn) + resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{ + SourceIp: args[1], + DestinationIp: args[2], + Protocol: protocol, + SourcePort: uint32(sport), + DestinationPort: uint32(dport), + Direction: direction, + TcpFlags: tcpFlags, + IcmpType: &icmpType, + IcmpCode: &icmpCode, + }) + if err != nil { + return fmt.Errorf("trace failed: %v", status.Convert(err).Message()) + } + + printTrace(cmd, args[1], args[2], protocol, sport, dport, resp) + return nil +} + +func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) { + cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto)) + + for _, stage := range resp.Stages { + if stage.ForwardingDetails != nil { + cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails) + } else { + cmd.Printf("%s: %s\n", stage.Name, stage.Message) + } + } + + disposition := map[bool]string{ + true: "\033[32mALLOWED\033[0m", // Green + false: "\033[31mDENIED\033[0m", // Red + }[resp.FinalDisposition] + + cmd.Printf("\nFinal disposition: %s\n", disposition) +} diff --git a/client/firewall/create.go b/client/firewall/create.go index 9466f4b4d6b..37ea5ceb3fa 100644 --- a/client/firewall/create.go +++ b/client/firewall/create.go @@ -14,13 +14,13 @@ import ( ) // NewFirewall creates a firewall manager instance -func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { if !iface.IsUserspaceBind() { return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS) } // use userspace packet filtering firewall - fm, err := uspfilter.Create(iface) + fm, err := uspfilter.Create(iface, disableServerRoutes) if err != nil { return nil, err } diff --git a/client/firewall/create_linux.go b/client/firewall/create_linux.go index 076d08ec27b..be1b37916bb 100644 --- a/client/firewall/create_linux.go +++ b/client/firewall/create_linux.go @@ -33,12 +33,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK" // FWType is the type for the firewall type type FWType int -func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { +func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) { // on the linux system we try to user nftables or iptables // in any case, because we need to allow netbird interface traffic // so we use AllowNetbird traffic from these firewall managers // for the userspace packet filtering firewall - fm, err := createNativeFirewall(iface, stateManager) + fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes) if !iface.IsUserspaceBind() { return fm, err @@ -47,10 +47,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal if err != nil { log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err) } - return createUserspaceFirewall(iface, fm) + return createUserspaceFirewall(iface, fm, disableServerRoutes) } -func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) { +func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) { fm, err := createFW(iface) if err != nil { return nil, fmt.Errorf("create firewall: %s", err) @@ -77,12 +77,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) { } } -func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) { +func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) { var errUsp error if fm != nil { - fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm) + fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes) } else { - fm, errUsp = uspfilter.Create(iface) + fm, errUsp = uspfilter.Create(iface, disableServerRoutes) } if errUsp != nil { diff --git a/client/firewall/iface.go b/client/firewall/iface.go index f349f9210a6..d842abaa124 100644 --- a/client/firewall/iface.go +++ b/client/firewall/iface.go @@ -1,6 +1,8 @@ package firewall import ( + wgdevice "golang.zx2c4.com/wireguard/device" + "github.com/netbirdio/netbird/client/iface/device" ) @@ -10,4 +12,6 @@ type IFaceMapper interface { Address() device.WGAddress IsUserspaceBind() bool SetFilter(device.PacketFilter) error + GetDevice() *device.FilteredDevice + GetWGDevice() *wgdevice.Device } diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index 75f082fc4c0..679f288e32a 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -213,6 +213,11 @@ func (m *Manager) AllowNetbird() error { // Flush doesn't need to be implemented for this manager func (m *Manager) Flush() error { return nil } +// SetLogLevel sets the log level for the firewall manager +func (m *Manager) SetLogLevel(log.Level) { + // not supported +} + func getConntrackEstablished() []string { return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"} } diff --git a/client/firewall/iptables/router_linux.go b/client/firewall/iptables/router_linux.go index a47d3ffe698..6522daa3f41 100644 --- a/client/firewall/iptables/router_linux.go +++ b/client/firewall/iptables/router_linux.go @@ -135,7 +135,16 @@ func (r *router) AddRouteFiltering( } rule := genRouteFilteringRuleSpec(params) - if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil { + // Insert DROP rules at the beginning, append ACCEPT rules at the end + var err error + if action == firewall.ActionDrop { + // after the established rule + err = r.iptablesClient.Insert(tableFilter, chainRTFWD, 2, rule...) + } else { + err = r.iptablesClient.Append(tableFilter, chainRTFWD, rule...) + } + + if err != nil { return nil, fmt.Errorf("add route rule: %v", err) } diff --git a/client/firewall/manager/firewall.go b/client/firewall/manager/firewall.go index f46e5eb5d57..de25ff1f11c 100644 --- a/client/firewall/manager/firewall.go +++ b/client/firewall/manager/firewall.go @@ -99,6 +99,8 @@ type Manager interface { // Flush the changes to firewall controller Flush() error + + SetLogLevel(log.Level) } func GenKey(format string, pair RouterPair) string { diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index a78626dbcd5..4fe52bd5361 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -318,6 +318,11 @@ func (m *Manager) cleanupNetbirdTables() error { return nil } +// SetLogLevel sets the log level for the firewall manager +func (m *Manager) SetLogLevel(log.Level) { + // not supported +} + // Flush rule/chain/set operations from the buffer // // Method also get all rules after flush and refreshes handle values in the rulesets diff --git a/client/firewall/nftables/manager_linux_test.go b/client/firewall/nftables/manager_linux_test.go index 8d693725a6d..eaa8ef1f571 100644 --- a/client/firewall/nftables/manager_linux_test.go +++ b/client/firewall/nftables/manager_linux_test.go @@ -107,7 +107,7 @@ func TestNftablesManager(t *testing.T) { Kind: expr.VerdictAccept, }, } - require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions") + compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1) ipToAdd, _ := netip.AddrFromSlice(ip) add := ipToAdd.Unmap() @@ -307,3 +307,18 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) { stdout, stderr = runIptablesSave(t) verifyIptablesOutput(t, stdout, stderr) } + +func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) { + t.Helper() + require.Equal(t, len(got), len(want), "expression count mismatch") + + for i := range got { + if _, isCounter := got[i].(*expr.Counter); isCounter { + _, wantIsCounter := want[i].(*expr.Counter) + require.True(t, wantIsCounter, "expected Counter at index %d", i) + continue + } + + require.Equal(t, got[i], want[i], "expression mismatch at index %d", i) + } +} diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 19734673b72..92f81f39cfa 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -233,7 +233,13 @@ func (r *router) AddRouteFiltering( UserData: []byte(ruleKey), } - rule = r.conn.AddRule(rule) + // Insert DROP rules at the beginning, append ACCEPT rules at the end + if action == firewall.ActionDrop { + // TODO: Insert after the established rule + rule = r.conn.InsertRule(rule) + } else { + rule = r.conn.AddRule(rule) + } log.Tracef("Adding route rule %s", spew.Sdump(rule)) if err := r.conn.Flush(); err != nil { diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index cc07922559d..03f23f5e622 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -3,6 +3,11 @@ package uspfilter import ( + "context" + "time" + + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -17,17 +22,29 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) + } + + if m.forwarder != nil { + m.forwarder.Stop() + } + + if m.logger != nil { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := m.logger.Stop(ctx); err != nil { + log.Errorf("failed to shutdown logger: %v", err) + } } if m.nativeFirewall != nil { diff --git a/client/firewall/uspfilter/allow_netbird_windows.go b/client/firewall/uspfilter/allow_netbird_windows.go index 0d55d62689c..37958597826 100644 --- a/client/firewall/uspfilter/allow_netbird_windows.go +++ b/client/firewall/uspfilter/allow_netbird_windows.go @@ -1,9 +1,11 @@ package uspfilter import ( + "context" "fmt" "os/exec" "syscall" + "time" log "github.com/sirupsen/logrus" @@ -29,17 +31,29 @@ func (m *Manager) Reset(*statemanager.Manager) error { if m.udpTracker != nil { m.udpTracker.Close() - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) } if m.icmpTracker != nil { m.icmpTracker.Close() - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) } if m.tcpTracker != nil { m.tcpTracker.Close() - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) + } + + if m.forwarder != nil { + m.forwarder.Stop() + } + + if m.logger != nil { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := m.logger.Stop(ctx); err != nil { + log.Errorf("failed to shutdown logger: %v", err) + } } if !isWindowsFirewallReachable() { diff --git a/client/firewall/uspfilter/common/iface.go b/client/firewall/uspfilter/common/iface.go new file mode 100644 index 00000000000..d44e7950936 --- /dev/null +++ b/client/firewall/uspfilter/common/iface.go @@ -0,0 +1,16 @@ +package common + +import ( + wgdevice "golang.zx2c4.com/wireguard/device" + + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" +) + +// IFaceMapper defines subset methods of interface required for manager +type IFaceMapper interface { + SetFilter(device.PacketFilter) error + Address() iface.WGAddress + GetWGDevice() *wgdevice.Device + GetDevice() *device.FilteredDevice +} diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go index e459bc75ae1..f5f5025400f 100644 --- a/client/firewall/uspfilter/conntrack/common.go +++ b/client/firewall/uspfilter/conntrack/common.go @@ -10,12 +10,11 @@ import ( // BaseConnTrack provides common fields and locking for all connection types type BaseConnTrack struct { - SourceIP net.IP - DestIP net.IP - SourcePort uint16 - DestPort uint16 - lastSeen atomic.Int64 // Unix nano for atomic access - established atomic.Bool + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 + lastSeen atomic.Int64 // Unix nano for atomic access } // these small methods will be inlined by the compiler @@ -25,16 +24,6 @@ func (b *BaseConnTrack) UpdateLastSeen() { b.lastSeen.Store(time.Now().UnixNano()) } -// IsEstablished safely checks if connection is established -func (b *BaseConnTrack) IsEstablished() bool { - return b.established.Load() -} - -// SetEstablished safely sets the established state -func (b *BaseConnTrack) SetEstablished(state bool) { - b.established.Store(state) -} - // GetLastSeen safely gets the last seen timestamp func (b *BaseConnTrack) GetLastSeen() time.Time { return time.Unix(0, b.lastSeen.Load()) diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go index 72d006def57..81fa64b19d7 100644 --- a/client/firewall/uspfilter/conntrack/common_test.go +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -3,8 +3,14 @@ package conntrack import ( "net" "testing" + + "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) +var logger = log.NewFromLogrus(logrus.StandardLogger()) + func BenchmarkIPOperations(b *testing.B) { b.Run("MakeIPAddr", func(b *testing.B) { ip := net.ParseIP("192.168.1.1") @@ -34,37 +40,11 @@ func BenchmarkIPOperations(b *testing.B) { }) } -func BenchmarkAtomicOperations(b *testing.B) { - conn := &BaseConnTrack{} - b.Run("UpdateLastSeen", func(b *testing.B) { - for i := 0; i < b.N; i++ { - conn.UpdateLastSeen() - } - }) - - b.Run("IsEstablished", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = conn.IsEstablished() - } - }) - - b.Run("SetEstablished", func(b *testing.B) { - for i := 0; i < b.N; i++ { - conn.SetEstablished(i%2 == 0) - } - }) - - b.Run("GetLastSeen", func(b *testing.B) { - for i := 0; i < b.N; i++ { - _ = conn.GetLastSeen() - } - }) -} // Memory pressure tests func BenchmarkMemoryPressure(b *testing.B) { b.Run("TCPHighLoad", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() // Generate different IPs @@ -89,7 +69,7 @@ func BenchmarkMemoryPressure(b *testing.B) { }) b.Run("UDPHighLoad", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, logger) defer tracker.Close() // Generate different IPs diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index e0a971678f1..25cd9e87d72 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -6,6 +6,8 @@ import ( "time" "github.com/google/gopacket/layers" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) const ( @@ -33,6 +35,7 @@ type ICMPConnTrack struct { // ICMPTracker manages ICMP connection states type ICMPTracker struct { + logger *nblog.Logger connections map[ICMPConnKey]*ICMPConnTrack timeout time.Duration cleanupTicker *time.Ticker @@ -42,12 +45,13 @@ type ICMPTracker struct { } // NewICMPTracker creates a new ICMP connection tracker -func NewICMPTracker(timeout time.Duration) *ICMPTracker { +func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker { if timeout == 0 { timeout = DefaultICMPTimeout } tracker := &ICMPTracker{ + logger: logger, connections: make(map[ICMPConnKey]*ICMPConnTrack), timeout: timeout, cleanupTicker: time.NewTicker(ICMPCleanupInterval), @@ -62,7 +66,6 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker { // TrackOutbound records an outbound ICMP Echo Request func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { key := makeICMPKey(srcIP, dstIP, id, seq) - now := time.Now().UnixNano() t.mutex.Lock() conn, exists := t.connections[key] @@ -80,24 +83,19 @@ func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq u ID: id, Sequence: seq, } - conn.lastSeen.Store(now) - conn.established.Store(true) + conn.UpdateLastSeen() t.connections[key] = conn + + t.logger.Trace("New ICMP connection %v", key) } t.mutex.Unlock() - conn.lastSeen.Store(now) + conn.UpdateLastSeen() } // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { - switch icmpType { - case uint8(layers.ICMPv4TypeDestinationUnreachable), - uint8(layers.ICMPv4TypeTimeExceeded): - return true - case uint8(layers.ICMPv4TypeEchoReply): - // continue processing - default: + if icmpType != uint8(layers.ICMPv4TypeEchoReply) { return false } @@ -115,8 +113,7 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq return false } - return conn.IsEstablished() && - ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && conn.ID == id && conn.Sequence == seq @@ -141,6 +138,8 @@ func (t *ICMPTracker) cleanup() { t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.DestIP) delete(t.connections, key) + + t.logger.Debug("Removed ICMP connection %v (timeout)", key) } } } diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go index 21176e719d4..32553c8360f 100644 --- a/client/firewall/uspfilter/conntrack/icmp_test.go +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -7,7 +7,7 @@ import ( func BenchmarkICMPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewICMPTracker(DefaultICMPTimeout) + tracker := NewICMPTracker(DefaultICMPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -20,7 +20,7 @@ func BenchmarkICMPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewICMPTracker(DefaultICMPTimeout) + tracker := NewICMPTracker(DefaultICMPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index a7968dc7375..7c12e8ad01f 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -5,7 +5,10 @@ package conntrack import ( "net" "sync" + "sync/atomic" "time" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) const ( @@ -61,12 +64,24 @@ type TCPConnKey struct { // TCPConnTrack represents a TCP connection state type TCPConnTrack struct { BaseConnTrack - State TCPState + State TCPState + established atomic.Bool sync.RWMutex } +// IsEstablished safely checks if connection is established +func (t *TCPConnTrack) IsEstablished() bool { + return t.established.Load() +} + +// SetEstablished safely sets the established state +func (t *TCPConnTrack) SetEstablished(state bool) { + t.established.Store(state) +} + // TCPTracker manages TCP connection states type TCPTracker struct { + logger *nblog.Logger connections map[ConnKey]*TCPConnTrack mutex sync.RWMutex cleanupTicker *time.Ticker @@ -76,8 +91,9 @@ type TCPTracker struct { } // NewTCPTracker creates a new TCP connection tracker -func NewTCPTracker(timeout time.Duration) *TCPTracker { +func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker { tracker := &TCPTracker{ + logger: logger, connections: make(map[ConnKey]*TCPConnTrack), cleanupTicker: time.NewTicker(TCPCleanupInterval), done: make(chan struct{}), @@ -93,7 +109,6 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker { func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { // Create key before lock key := makeConnKey(srcIP, dstIP, srcPort, dstPort) - now := time.Now().UnixNano() t.mutex.Lock() conn, exists := t.connections[key] @@ -113,9 +128,11 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d }, State: TCPStateNew, } - conn.lastSeen.Store(now) + conn.UpdateLastSeen() conn.established.Store(false) t.connections[key] = conn + + t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort) } t.mutex.Unlock() @@ -123,7 +140,7 @@ func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d conn.Lock() t.updateState(conn, flags, true) conn.Unlock() - conn.lastSeen.Store(now) + conn.UpdateLastSeen() } // IsValidInbound checks if an inbound TCP packet matches a tracked connection @@ -171,6 +188,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo if flags&TCPRst != 0 { conn.State = TCPStateClosed conn.SetEstablished(false) + + t.logger.Trace("TCP connection reset: %s:%d -> %s:%d", + conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) return } @@ -227,6 +247,9 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo if flags&TCPAck != 0 { conn.State = TCPStateTimeWait // Keep established = false from previous state + + t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d", + conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) } case TCPStateCloseWait: @@ -237,11 +260,17 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo case TCPStateLastAck: if flags&TCPAck != 0 { conn.State = TCPStateClosed + + t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d", + conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) } case TCPStateTimeWait: // Stay in TIME-WAIT for 2MSL before transitioning to closed // This is handled by the cleanup routine + + t.logger.Trace("TCP connection completed - %s:%d -> %s:%d", + conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) } } @@ -318,6 +347,8 @@ func (t *TCPTracker) cleanup() { t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.DestIP) delete(t.connections, key) + + t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort) } } } diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 6c8f82423bd..5f4c43915fb 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -9,7 +9,7 @@ import ( ) func TestTCPStateMachine(t *testing.T) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("100.64.0.1") @@ -154,7 +154,7 @@ func TestTCPStateMachine(t *testing.T) { t.Run(tt.name, func(t *testing.T) { t.Helper() - tracker = NewTCPTracker(DefaultTCPTimeout) + tracker = NewTCPTracker(DefaultTCPTimeout, logger) tt.test(t) }) } @@ -162,7 +162,7 @@ func TestTCPStateMachine(t *testing.T) { } func TestRSTHandling(t *testing.T) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("100.64.0.1") @@ -233,7 +233,7 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, func BenchmarkTCPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -246,7 +246,7 @@ func BenchmarkTCPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -264,7 +264,7 @@ func BenchmarkTCPTracker(b *testing.B) { }) b.Run("ConcurrentAccess", func(b *testing.B) { - tracker := NewTCPTracker(DefaultTCPTimeout) + tracker := NewTCPTracker(DefaultTCPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -287,7 +287,7 @@ func BenchmarkTCPTracker(b *testing.B) { // Benchmark connection cleanup func BenchmarkCleanup(b *testing.B) { b.Run("TCPCleanup", func(b *testing.B) { - tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing + tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing defer tracker.Close() // Pre-populate with expired connections diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index a969a4e8425..e73465e3195 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -4,6 +4,8 @@ import ( "net" "sync" "time" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" ) const ( @@ -20,6 +22,7 @@ type UDPConnTrack struct { // UDPTracker manages UDP connection states type UDPTracker struct { + logger *nblog.Logger connections map[ConnKey]*UDPConnTrack timeout time.Duration cleanupTicker *time.Ticker @@ -29,12 +32,13 @@ type UDPTracker struct { } // NewUDPTracker creates a new UDP connection tracker -func NewUDPTracker(timeout time.Duration) *UDPTracker { +func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker { if timeout == 0 { timeout = DefaultUDPTimeout } tracker := &UDPTracker{ + logger: logger, connections: make(map[ConnKey]*UDPConnTrack), timeout: timeout, cleanupTicker: time.NewTicker(UDPCleanupInterval), @@ -49,7 +53,6 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker { // TrackOutbound records an outbound UDP connection func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { key := makeConnKey(srcIP, dstIP, srcPort, dstPort) - now := time.Now().UnixNano() t.mutex.Lock() conn, exists := t.connections[key] @@ -67,13 +70,14 @@ func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, d DestPort: dstPort, }, } - conn.lastSeen.Store(now) - conn.established.Store(true) + conn.UpdateLastSeen() t.connections[key] = conn + + t.logger.Trace("New UDP connection: %v", conn) } t.mutex.Unlock() - conn.lastSeen.Store(now) + conn.UpdateLastSeen() } // IsValidInbound checks if an inbound packet matches a tracked connection @@ -92,8 +96,7 @@ func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, return false } - return conn.IsEstablished() && - ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && + return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) && ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) && conn.DestPort == srcPort && conn.SourcePort == dstPort @@ -120,6 +123,8 @@ func (t *UDPTracker) cleanup() { t.ipPool.Put(conn.SourceIP) t.ipPool.Put(conn.DestIP) delete(t.connections, key) + + t.logger.Trace("Removed UDP connection %v (timeout)", conn) } } } diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index 67172189069..fa83ee356a3 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -29,7 +29,7 @@ func TestNewUDPTracker(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - tracker := NewUDPTracker(tt.timeout) + tracker := NewUDPTracker(tt.timeout, logger) assert.NotNil(t, tracker) assert.Equal(t, tt.wantTimeout, tracker.timeout) assert.NotNil(t, tracker.connections) @@ -40,7 +40,7 @@ func TestNewUDPTracker(t *testing.T) { } func TestUDPTracker_TrackOutbound(t *testing.T) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.2") @@ -58,12 +58,11 @@ func TestUDPTracker_TrackOutbound(t *testing.T) { assert.True(t, conn.DestIP.Equal(dstIP)) assert.Equal(t, srcPort, conn.SourcePort) assert.Equal(t, dstPort, conn.DestPort) - assert.True(t, conn.IsEstablished()) assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) } func TestUDPTracker_IsValidInbound(t *testing.T) { - tracker := NewUDPTracker(1 * time.Second) + tracker := NewUDPTracker(1*time.Second, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.2") @@ -162,6 +161,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { cleanupTicker: time.NewTicker(cleanupInterval), done: make(chan struct{}), ipPool: NewPreallocatedIPs(), + logger: logger, } // Start cleanup routine @@ -211,7 +211,7 @@ func TestUDPTracker_Cleanup(t *testing.T) { func BenchmarkUDPTracker(b *testing.B) { b.Run("TrackOutbound", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") @@ -224,7 +224,7 @@ func BenchmarkUDPTracker(b *testing.B) { }) b.Run("IsValidInbound", func(b *testing.B) { - tracker := NewUDPTracker(DefaultUDPTimeout) + tracker := NewUDPTracker(DefaultUDPTimeout, logger) defer tracker.Close() srcIP := net.ParseIP("192.168.1.1") diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go new file mode 100644 index 00000000000..e8a265c94d5 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -0,0 +1,81 @@ +package forwarder + +import ( + wgdevice "golang.zx2c4.com/wireguard/device" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" +) + +// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device +type endpoint struct { + logger *nblog.Logger + dispatcher stack.NetworkDispatcher + device *wgdevice.Device + mtu uint32 +} + +func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + e.dispatcher = dispatcher +} + +func (e *endpoint) IsAttached() bool { + return e.dispatcher != nil +} + +func (e *endpoint) MTU() uint32 { + return e.mtu +} + +func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities { + return stack.CapabilityNone +} + +func (e *endpoint) MaxHeaderLength() uint16 { + return 0 +} + +func (e *endpoint) LinkAddress() tcpip.LinkAddress { + return "" +} + +func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) { + var written int + for _, pkt := range pkts.AsSlice() { + netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice()) + + data := stack.PayloadSince(pkt.NetworkHeader()) + if data == nil { + continue + } + + // Send the packet through WireGuard + address := netHeader.DestinationAddress() + err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice()) + if err != nil { + e.logger.Error("CreateOutboundPacket: %v", err) + continue + } + written++ + } + + return written, nil +} + +func (e *endpoint) Wait() { + // not required +} + +func (e *endpoint) ARPHardwareType() header.ARPHardwareType { + return header.ARPHardwareNone +} + +func (e *endpoint) AddHeader(*stack.PacketBuffer) { + // not required +} + +func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool { + return true +} diff --git a/client/firewall/uspfilter/forwarder/forwarder.go b/client/firewall/uspfilter/forwarder/forwarder.go new file mode 100644 index 00000000000..4ed152b79c9 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/forwarder.go @@ -0,0 +1,166 @@ +package forwarder + +import ( + "context" + "fmt" + "net" + "runtime" + + log "github.com/sirupsen/logrus" + "gvisor.dev/gvisor/pkg/buffer" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + + "github.com/netbirdio/netbird/client/firewall/uspfilter/common" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" +) + +const ( + defaultReceiveWindow = 32768 + defaultMaxInFlight = 1024 + iosReceiveWindow = 16384 + iosMaxInFlight = 256 +) + +type Forwarder struct { + logger *nblog.Logger + stack *stack.Stack + endpoint *endpoint + udpForwarder *udpForwarder + ctx context.Context + cancel context.CancelFunc + ip net.IP + netstack bool +} + +func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{ + tcp.NewProtocol, + udp.NewProtocol, + icmp.NewProtocol4, + }, + HandleLocal: false, + }) + + mtu, err := iface.GetDevice().MTU() + if err != nil { + return nil, fmt.Errorf("get MTU: %w", err) + } + nicID := tcpip.NICID(1) + endpoint := &endpoint{ + logger: logger, + device: iface.GetWGDevice(), + mtu: uint32(mtu), + } + + if err := s.CreateNIC(nicID, endpoint); err != nil { + return nil, fmt.Errorf("failed to create NIC: %v", err) + } + + ones, _ := iface.Address().Network.Mask.Size() + protoAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: tcpip.AddrFromSlice(iface.Address().IP.To4()), + PrefixLen: ones, + }, + } + + if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil { + return nil, fmt.Errorf("failed to add protocol address: %s", err) + } + + defaultSubnet, err := tcpip.NewSubnet( + tcpip.AddrFrom4([4]byte{0, 0, 0, 0}), + tcpip.MaskFromBytes([]byte{0, 0, 0, 0}), + ) + if err != nil { + return nil, fmt.Errorf("creating default subnet: %w", err) + } + + if err := s.SetPromiscuousMode(nicID, true); err != nil { + return nil, fmt.Errorf("set promiscuous mode: %s", err) + } + if err := s.SetSpoofing(nicID, true); err != nil { + return nil, fmt.Errorf("set spoofing: %s", err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: defaultSubnet, + NIC: nicID, + }, + }) + + ctx, cancel := context.WithCancel(context.Background()) + f := &Forwarder{ + logger: logger, + stack: s, + endpoint: endpoint, + udpForwarder: newUDPForwarder(mtu, logger), + ctx: ctx, + cancel: cancel, + netstack: netstack, + ip: iface.Address().IP, + } + + receiveWindow := defaultReceiveWindow + maxInFlight := defaultMaxInFlight + if runtime.GOOS == "ios" { + receiveWindow = iosReceiveWindow + maxInFlight = iosMaxInFlight + } + + tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket) + + udpForwarder := udp.NewForwarder(s, f.handleUDP) + s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket) + + s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP) + + log.Debugf("forwarder: Initialization complete with NIC %d", nicID) + return f, nil +} + +func (f *Forwarder) InjectIncomingPacket(payload []byte) error { + if len(payload) < header.IPv4MinimumSize { + return fmt.Errorf("packet too small: %d bytes", len(payload)) + } + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithData(payload), + }) + defer pkt.DecRef() + + if f.endpoint.dispatcher != nil { + f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt) + } + return nil +} + +// Stop gracefully shuts down the forwarder +func (f *Forwarder) Stop() { + f.cancel() + + if f.udpForwarder != nil { + f.udpForwarder.Stop() + } + + f.stack.Close() + f.stack.Wait() +} + +func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP { + if f.netstack && f.ip.Equal(addr.AsSlice()) { + return net.IPv4(127, 0, 0, 1) + } + return addr.AsSlice() +} diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go new file mode 100644 index 00000000000..14cdc37be85 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -0,0 +1,109 @@ +package forwarder + +import ( + "context" + "net" + "time" + + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// handleICMP handles ICMP packets from the network stack +func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool { + ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second) + defer cancel() + + lc := net.ListenConfig{} + // TODO: support non-root + conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0") + if err != nil { + f.logger.Error("Failed to create ICMP socket for %v: %v", id, err) + + // This will make netstack reply on behalf of the original destination, that's ok for now + return false + } + defer func() { + if err := conn.Close(); err != nil { + f.logger.Debug("Failed to close ICMP socket: %v", err) + } + }() + + dstIP := f.determineDialAddr(id.LocalAddress) + dst := &net.IPAddr{IP: dstIP} + + // Get the complete ICMP message (header + data) + fullPacket := stack.PayloadSince(pkt.TransportHeader()) + payload := fullPacket.AsSlice() + + icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice()) + + // For Echo Requests, send and handle response + switch icmpHdr.Type() { + case header.ICMPv4Echo: + return f.handleEchoResponse(icmpHdr, payload, dst, conn, id) + case header.ICMPv4EchoReply: + // dont process our own replies + return true + default: + } + + // For other ICMP types (Time Exceeded, Destination Unreachable, etc) + _, err = conn.WriteTo(payload, dst) + if err != nil { + f.logger.Error("Failed to write ICMP packet for %v: %v", id, err) + return true + } + + f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v", + id, icmpHdr.Type(), icmpHdr.Code()) + + return true +} + +func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool { + if _, err := conn.WriteTo(payload, dst); err != nil { + f.logger.Error("Failed to write ICMP packet for %v: %v", id, err) + return true + } + + f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v", + id, icmpHdr.Type(), icmpHdr.Code()) + + if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { + f.logger.Error("Failed to set read deadline for ICMP response: %v", err) + return true + } + + response := make([]byte, f.endpoint.mtu) + n, _, err := conn.ReadFrom(response) + if err != nil { + if !isTimeout(err) { + f.logger.Error("Failed to read ICMP response: %v", err) + } + return true + } + + ipHdr := make([]byte, header.IPv4MinimumSize) + ip := header.IPv4(ipHdr) + ip.Encode(&header.IPv4Fields{ + TotalLength: uint16(header.IPv4MinimumSize + n), + TTL: 64, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: id.LocalAddress, + DstAddr: id.RemoteAddress, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + fullPacket := make([]byte, 0, len(ipHdr)+n) + fullPacket = append(fullPacket, ipHdr...) + fullPacket = append(fullPacket, response[:n]...) + + if err := f.InjectIncomingPacket(fullPacket); err != nil { + f.logger.Error("Failed to inject ICMP response: %v", err) + return true + } + + f.logger.Trace("Forwarded ICMP echo reply for %v", id) + return true +} diff --git a/client/firewall/uspfilter/forwarder/tcp.go b/client/firewall/uspfilter/forwarder/tcp.go new file mode 100644 index 00000000000..6d7cf3b6a70 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/tcp.go @@ -0,0 +1,90 @@ +package forwarder + +import ( + "context" + "fmt" + "io" + "net" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// handleTCP is called by the TCP forwarder for new connections. +func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) { + id := r.ID() + + dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) + + outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr) + if err != nil { + r.Complete(true) + f.logger.Trace("forwarder: dial error for %v: %v", id, err) + return + } + + // Create wait queue for blocking syscalls + wq := waiter.Queue{} + + ep, epErr := r.CreateEndpoint(&wq) + if epErr != nil { + f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr) + if err := outConn.Close(); err != nil { + f.logger.Debug("forwarder: outConn close error: %v", err) + } + r.Complete(true) + return + } + + // Complete the handshake + r.Complete(false) + + inConn := gonet.NewTCPConn(&wq, ep) + + f.logger.Trace("forwarder: established TCP connection %v", id) + + go f.proxyTCP(id, inConn, outConn, ep) +} + +func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) { + defer func() { + if err := inConn.Close(); err != nil { + f.logger.Debug("forwarder: inConn close error: %v", err) + } + if err := outConn.Close(); err != nil { + f.logger.Debug("forwarder: outConn close error: %v", err) + } + ep.Close() + }() + + // Create context for managing the proxy goroutines + ctx, cancel := context.WithCancel(f.ctx) + defer cancel() + + errChan := make(chan error, 2) + + go func() { + _, err := io.Copy(outConn, inConn) + errChan <- err + }() + + go func() { + _, err := io.Copy(inConn, outConn) + errChan <- err + }() + + select { + case <-ctx.Done(): + f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id) + return + case err := <-errChan: + if err != nil && !isClosedError(err) { + f.logger.Error("proxyTCP: copy error: %v", err) + } + f.logger.Trace("forwarder: tearing down TCP connection %v", id) + return + } +} diff --git a/client/firewall/uspfilter/forwarder/udp.go b/client/firewall/uspfilter/forwarder/udp.go new file mode 100644 index 00000000000..97e4662fd39 --- /dev/null +++ b/client/firewall/uspfilter/forwarder/udp.go @@ -0,0 +1,288 @@ +package forwarder + +import ( + "context" + "errors" + "fmt" + "net" + "sync" + "sync/atomic" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" + + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" +) + +const ( + udpTimeout = 30 * time.Second +) + +type udpPacketConn struct { + conn *gonet.UDPConn + outConn net.Conn + lastSeen atomic.Int64 + cancel context.CancelFunc + ep tcpip.Endpoint +} + +type udpForwarder struct { + sync.RWMutex + logger *nblog.Logger + conns map[stack.TransportEndpointID]*udpPacketConn + bufPool sync.Pool + ctx context.Context + cancel context.CancelFunc +} + +type idleConn struct { + id stack.TransportEndpointID + conn *udpPacketConn +} + +func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder { + ctx, cancel := context.WithCancel(context.Background()) + f := &udpForwarder{ + logger: logger, + conns: make(map[stack.TransportEndpointID]*udpPacketConn), + ctx: ctx, + cancel: cancel, + bufPool: sync.Pool{ + New: func() any { + b := make([]byte, mtu) + return &b + }, + }, + } + go f.cleanup() + return f +} + +// Stop stops the UDP forwarder and all active connections +func (f *udpForwarder) Stop() { + f.cancel() + + f.Lock() + defer f.Unlock() + + for id, conn := range f.conns { + conn.cancel() + if err := conn.conn.Close(); err != nil { + f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err) + } + if err := conn.outConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) + } + + conn.ep.Close() + delete(f.conns, id) + } +} + +// cleanup periodically removes idle UDP connections +func (f *udpForwarder) cleanup() { + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for { + select { + case <-f.ctx.Done(): + return + case <-ticker.C: + var idleConns []idleConn + + f.RLock() + for id, conn := range f.conns { + if conn.getIdleDuration() > udpTimeout { + idleConns = append(idleConns, idleConn{id, conn}) + } + } + f.RUnlock() + + for _, idle := range idleConns { + idle.conn.cancel() + if err := idle.conn.conn.Close(); err != nil { + f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err) + } + if err := idle.conn.outConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err) + } + + idle.conn.ep.Close() + + f.Lock() + delete(f.conns, idle.id) + f.Unlock() + + f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id) + } + } + } +} + +// handleUDP is called by the UDP forwarder for new packets +func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) { + if f.ctx.Err() != nil { + f.logger.Trace("forwarder: context done, dropping UDP packet") + return + } + + id := r.ID() + + f.udpForwarder.RLock() + _, exists := f.udpForwarder.conns[id] + f.udpForwarder.RUnlock() + if exists { + f.logger.Trace("forwarder: existing UDP connection for %v", id) + return + } + + dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort) + outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr) + if err != nil { + f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err) + // TODO: Send ICMP error message + return + } + + // Create wait queue for blocking syscalls + wq := waiter.Queue{} + ep, epErr := r.CreateEndpoint(&wq) + if epErr != nil { + f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr) + if err := outConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) + } + return + } + + inConn := gonet.NewUDPConn(f.stack, &wq, ep) + connCtx, connCancel := context.WithCancel(f.ctx) + + pConn := &udpPacketConn{ + conn: inConn, + outConn: outConn, + cancel: connCancel, + ep: ep, + } + pConn.updateLastSeen() + + f.udpForwarder.Lock() + // Double-check no connection was created while we were setting up + if _, exists := f.udpForwarder.conns[id]; exists { + f.udpForwarder.Unlock() + pConn.cancel() + if err := inConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err) + } + if err := outConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) + } + return + } + f.udpForwarder.conns[id] = pConn + f.udpForwarder.Unlock() + + f.logger.Trace("forwarder: established UDP connection to %v", id) + go f.proxyUDP(connCtx, pConn, id, ep) +} + +func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) { + defer func() { + pConn.cancel() + if err := pConn.conn.Close(); err != nil { + f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err) + } + if err := pConn.outConn.Close(); err != nil { + f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err) + } + + ep.Close() + + f.udpForwarder.Lock() + delete(f.udpForwarder.conns, id) + f.udpForwarder.Unlock() + }() + + errChan := make(chan error, 2) + + go func() { + errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound") + }() + + go func() { + errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound") + }() + + select { + case <-ctx.Done(): + f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id) + return + case err := <-errChan: + if err != nil && !isClosedError(err) { + f.logger.Error("proxyUDP: copy error: %v", err) + } + f.logger.Trace("forwarder: tearing down UDP connection %v", id) + return + } +} + +func (c *udpPacketConn) updateLastSeen() { + c.lastSeen.Store(time.Now().UnixNano()) +} + +func (c *udpPacketConn) getIdleDuration() time.Duration { + lastSeen := time.Unix(0, c.lastSeen.Load()) + return time.Since(lastSeen) +} + +func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error { + bufp := bufPool.Get().(*[]byte) + defer bufPool.Put(bufp) + buffer := *bufp + + if err := src.SetReadDeadline(time.Now().Add(udpTimeout)); err != nil { + return fmt.Errorf("set read deadline: %w", err) + } + if err := src.SetWriteDeadline(time.Now().Add(udpTimeout)); err != nil { + return fmt.Errorf("set write deadline: %w", err) + } + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + n, err := src.Read(buffer) + if err != nil { + if isTimeout(err) { + continue + } + return fmt.Errorf("read from %s: %w", direction, err) + } + + _, err = dst.Write(buffer[:n]) + if err != nil { + return fmt.Errorf("write to %s: %w", direction, err) + } + + c.updateLastSeen() + } + } +} + +func isClosedError(err error) bool { + return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) +} + +func isTimeout(err error) bool { + var netErr net.Error + if errors.As(err, &netErr) { + return netErr.Timeout() + } + return false +} diff --git a/client/firewall/uspfilter/localip.go b/client/firewall/uspfilter/localip.go new file mode 100644 index 00000000000..7664b65d554 --- /dev/null +++ b/client/firewall/uspfilter/localip.go @@ -0,0 +1,134 @@ +package uspfilter + +import ( + "fmt" + "net" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/firewall/uspfilter/common" +) + +type localIPManager struct { + mu sync.RWMutex + + // Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory) + ipv4Bitmap [1 << 16]uint32 +} + +func newLocalIPManager() *localIPManager { + return &localIPManager{} +} + +func (m *localIPManager) setBitmapBit(ip net.IP) { + ipv4 := ip.To4() + if ipv4 == nil { + return + } + high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) + low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) + m.ipv4Bitmap[high] |= 1 << (low % 32) +} + +func (m *localIPManager) checkBitmapBit(ip net.IP) bool { + ipv4 := ip.To4() + if ipv4 == nil { + return false + } + high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) + low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) + return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0 +} + +func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error { + if ipv4 := ip.To4(); ipv4 != nil { + high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1]) + low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3]) + if int(high) >= len(*newIPv4Bitmap) { + return fmt.Errorf("invalid IPv4 address: %s", ip) + } + ipStr := ip.String() + if _, exists := ipv4Set[ipStr]; !exists { + ipv4Set[ipStr] = struct{}{} + *ipv4Addresses = append(*ipv4Addresses, ipStr) + newIPv4Bitmap[high] |= 1 << (low % 32) + } + } + return nil +} + +func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) { + addrs, err := iface.Addrs() + if err != nil { + log.Debugf("get addresses for interface %s failed: %v", iface.Name, err) + return + } + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + default: + continue + } + + if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil { + log.Debugf("process IP failed: %v", err) + } + } +} + +func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) { + defer func() { + if r := recover(); r != nil { + err = fmt.Errorf("panic: %v", r) + } + }() + + var newIPv4Bitmap [1 << 16]uint32 + ipv4Set := make(map[string]struct{}) + var ipv4Addresses []string + + // 127.0.0.0/8 + high := uint16(127) << 8 + for i := uint16(0); i < 256; i++ { + newIPv4Bitmap[high|i] = 0xffffffff + } + + if iface != nil { + if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil { + return err + } + } + + interfaces, err := net.Interfaces() + if err != nil { + log.Warnf("failed to get interfaces: %v", err) + } else { + for _, intf := range interfaces { + m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses) + } + } + + m.mu.Lock() + m.ipv4Bitmap = newIPv4Bitmap + m.mu.Unlock() + + log.Debugf("Local IPv4 addresses: %v", ipv4Addresses) + return nil +} + +func (m *localIPManager) IsLocalIP(ip net.IP) bool { + m.mu.RLock() + defer m.mu.RUnlock() + + if ipv4 := ip.To4(); ipv4 != nil { + return m.checkBitmapBit(ipv4) + } + + return false +} diff --git a/client/firewall/uspfilter/localip_test.go b/client/firewall/uspfilter/localip_test.go new file mode 100644 index 00000000000..02f41bf4f61 --- /dev/null +++ b/client/firewall/uspfilter/localip_test.go @@ -0,0 +1,270 @@ +package uspfilter + +import ( + "net" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/iface" +) + +func TestLocalIPManager(t *testing.T) { + tests := []struct { + name string + setupAddr iface.WGAddress + testIP net.IP + expected bool + }{ + { + name: "Localhost range", + setupAddr: iface.WGAddress{ + IP: net.ParseIP("192.168.1.1"), + Network: &net.IPNet{ + IP: net.ParseIP("192.168.1.0"), + Mask: net.CIDRMask(24, 32), + }, + }, + testIP: net.ParseIP("127.0.0.2"), + expected: true, + }, + { + name: "Localhost standard address", + setupAddr: iface.WGAddress{ + IP: net.ParseIP("192.168.1.1"), + Network: &net.IPNet{ + IP: net.ParseIP("192.168.1.0"), + Mask: net.CIDRMask(24, 32), + }, + }, + testIP: net.ParseIP("127.0.0.1"), + expected: true, + }, + { + name: "Localhost range edge", + setupAddr: iface.WGAddress{ + IP: net.ParseIP("192.168.1.1"), + Network: &net.IPNet{ + IP: net.ParseIP("192.168.1.0"), + Mask: net.CIDRMask(24, 32), + }, + }, + testIP: net.ParseIP("127.255.255.255"), + expected: true, + }, + { + name: "Local IP matches", + setupAddr: iface.WGAddress{ + IP: net.ParseIP("192.168.1.1"), + Network: &net.IPNet{ + IP: net.ParseIP("192.168.1.0"), + Mask: net.CIDRMask(24, 32), + }, + }, + testIP: net.ParseIP("192.168.1.1"), + expected: true, + }, + { + name: "Local IP doesn't match", + setupAddr: iface.WGAddress{ + IP: net.ParseIP("192.168.1.1"), + Network: &net.IPNet{ + IP: net.ParseIP("192.168.1.0"), + Mask: net.CIDRMask(24, 32), + }, + }, + testIP: net.ParseIP("192.168.1.2"), + expected: false, + }, + { + name: "IPv6 address", + setupAddr: iface.WGAddress{ + IP: net.ParseIP("fe80::1"), + Network: &net.IPNet{ + IP: net.ParseIP("fe80::"), + Mask: net.CIDRMask(64, 128), + }, + }, + testIP: net.ParseIP("fe80::1"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager := newLocalIPManager() + + mock := &IFaceMock{ + AddressFunc: func() iface.WGAddress { + return tt.setupAddr + }, + } + + err := manager.UpdateLocalIPs(mock) + require.NoError(t, err) + + result := manager.IsLocalIP(tt.testIP) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestLocalIPManager_AllInterfaces(t *testing.T) { + manager := newLocalIPManager() + mock := &IFaceMock{} + + // Get actual local interfaces + interfaces, err := net.Interfaces() + require.NoError(t, err) + + var tests []struct { + ip string + expected bool + } + + // Add all local interface IPs to test cases + for _, iface := range interfaces { + addrs, err := iface.Addrs() + require.NoError(t, err) + + for _, addr := range addrs { + var ip net.IP + switch v := addr.(type) { + case *net.IPNet: + ip = v.IP + case *net.IPAddr: + ip = v.IP + default: + continue + } + + if ip4 := ip.To4(); ip4 != nil { + tests = append(tests, struct { + ip string + expected bool + }{ + ip: ip4.String(), + expected: true, + }) + } + } + } + + // Add some external IPs as negative test cases + externalIPs := []string{ + "8.8.8.8", + "1.1.1.1", + "208.67.222.222", + } + for _, ip := range externalIPs { + tests = append(tests, struct { + ip string + expected bool + }{ + ip: ip, + expected: false, + }) + } + + require.NotEmpty(t, tests, "No test cases generated") + + err = manager.UpdateLocalIPs(mock) + require.NoError(t, err) + + t.Logf("Testing %d IPs", len(tests)) + for _, tt := range tests { + t.Run(tt.ip, func(t *testing.T) { + result := manager.IsLocalIP(net.ParseIP(tt.ip)) + require.Equal(t, tt.expected, result, "IP: %s", tt.ip) + }) + } +} + +// MapImplementation is a version using map[string]struct{} +type MapImplementation struct { + localIPs map[string]struct{} +} + +func BenchmarkIPChecks(b *testing.B) { + interfaces := make([]net.IP, 16) + for i := range interfaces { + interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i)) + } + + // Setup bitmap version + bitmapManager := &localIPManager{ + ipv4Bitmap: [1 << 16]uint32{}, + } + for _, ip := range interfaces[:8] { // Add half of IPs + bitmapManager.setBitmapBit(ip) + } + + // Setup map version + mapManager := &MapImplementation{ + localIPs: make(map[string]struct{}), + } + for _, ip := range interfaces[:8] { + mapManager.localIPs[ip.String()] = struct{}{} + } + + b.Run("Bitmap_Hit", func(b *testing.B) { + ip := interfaces[4] + b.ResetTimer() + for i := 0; i < b.N; i++ { + bitmapManager.checkBitmapBit(ip) + } + }) + + b.Run("Bitmap_Miss", func(b *testing.B) { + ip := interfaces[12] + b.ResetTimer() + for i := 0; i < b.N; i++ { + bitmapManager.checkBitmapBit(ip) + } + }) + + b.Run("Map_Hit", func(b *testing.B) { + ip := interfaces[4] + b.ResetTimer() + for i := 0; i < b.N; i++ { + // nolint:gosimple + _, _ = mapManager.localIPs[ip.String()] + } + }) + + b.Run("Map_Miss", func(b *testing.B) { + ip := interfaces[12] + b.ResetTimer() + for i := 0; i < b.N; i++ { + // nolint:gosimple + _, _ = mapManager.localIPs[ip.String()] + } + }) +} + +func BenchmarkWGPosition(b *testing.B) { + wgIP := net.ParseIP("10.10.0.1") + + // Create two managers - one checks WG IP first, other checks it last + b.Run("WG_First", func(b *testing.B) { + bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}} + bm.setBitmapBit(wgIP) + b.ResetTimer() + for i := 0; i < b.N; i++ { + bm.checkBitmapBit(wgIP) + } + }) + + b.Run("WG_Last", func(b *testing.B) { + bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}} + // Fill with other IPs first + for i := 0; i < 15; i++ { + bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i))) + } + bm.setBitmapBit(wgIP) // Add WG IP last + b.ResetTimer() + for i := 0; i < b.N; i++ { + bm.checkBitmapBit(wgIP) + } + }) +} diff --git a/client/firewall/uspfilter/log/log.go b/client/firewall/uspfilter/log/log.go new file mode 100644 index 00000000000..984b6ad08e1 --- /dev/null +++ b/client/firewall/uspfilter/log/log.go @@ -0,0 +1,196 @@ +// Package logger provides a high-performance, non-blocking logger for userspace networking +package log + +import ( + "context" + "fmt" + "io" + "sync" + "sync/atomic" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + maxBatchSize = 1024 * 16 // 16KB max batch size + maxMessageSize = 1024 * 2 // 2KB per message + bufferSize = 1024 * 256 // 256KB ring buffer + defaultFlushInterval = 2 * time.Second +) + +// Level represents log severity +type Level uint32 + +const ( + LevelPanic Level = iota + LevelFatal + LevelError + LevelWarn + LevelInfo + LevelDebug + LevelTrace +) + +var levelStrings = map[Level]string{ + LevelPanic: "PANC", + LevelFatal: "FATL", + LevelError: "ERRO", + LevelWarn: "WARN", + LevelInfo: "INFO", + LevelDebug: "DEBG", + LevelTrace: "TRAC", +} + +// Logger is a high-performance, non-blocking logger +type Logger struct { + output io.Writer + level atomic.Uint32 + buffer *ringBuffer + shutdown chan struct{} + closeOnce sync.Once + wg sync.WaitGroup + + // Reusable buffer pool for formatting messages + bufPool sync.Pool +} + +func NewFromLogrus(logrusLogger *log.Logger) *Logger { + l := &Logger{ + output: logrusLogger.Out, + buffer: newRingBuffer(bufferSize), + shutdown: make(chan struct{}), + bufPool: sync.Pool{ + New: func() interface{} { + // Pre-allocate buffer for message formatting + b := make([]byte, 0, maxMessageSize) + return &b + }, + }, + } + logrusLevel := logrusLogger.GetLevel() + l.level.Store(uint32(logrusLevel)) + level := levelStrings[Level(logrusLevel)] + log.Debugf("New uspfilter logger created with loglevel %v", level) + + l.wg.Add(1) + go l.worker() + + return l +} + +func (l *Logger) SetLevel(level Level) { + l.level.Store(uint32(level)) + + log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level]) +} + +func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) { + *buf = (*buf)[:0] + + // Timestamp + *buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00") + *buf = append(*buf, ' ') + + // Level + *buf = append(*buf, levelStrings[level]...) + *buf = append(*buf, ' ') + + // Message + if len(args) > 0 { + *buf = append(*buf, fmt.Sprintf(format, args...)...) + } else { + *buf = append(*buf, format...) + } + + *buf = append(*buf, '\n') +} + +func (l *Logger) log(level Level, format string, args ...interface{}) { + bufp := l.bufPool.Get().(*[]byte) + l.formatMessage(bufp, level, format, args...) + + if len(*bufp) > maxMessageSize { + *bufp = (*bufp)[:maxMessageSize] + } + _, _ = l.buffer.Write(*bufp) + + l.bufPool.Put(bufp) +} + +func (l *Logger) Error(format string, args ...interface{}) { + if l.level.Load() >= uint32(LevelError) { + l.log(LevelError, format, args...) + } +} + +func (l *Logger) Warn(format string, args ...interface{}) { + if l.level.Load() >= uint32(LevelWarn) { + l.log(LevelWarn, format, args...) + } +} + +func (l *Logger) Info(format string, args ...interface{}) { + if l.level.Load() >= uint32(LevelInfo) { + l.log(LevelInfo, format, args...) + } +} + +func (l *Logger) Debug(format string, args ...interface{}) { + if l.level.Load() >= uint32(LevelDebug) { + l.log(LevelDebug, format, args...) + } +} + +func (l *Logger) Trace(format string, args ...interface{}) { + if l.level.Load() >= uint32(LevelTrace) { + l.log(LevelTrace, format, args...) + } +} + +// worker periodically flushes the buffer +func (l *Logger) worker() { + defer l.wg.Done() + + ticker := time.NewTicker(defaultFlushInterval) + defer ticker.Stop() + + buf := make([]byte, 0, maxBatchSize) + + for { + select { + case <-l.shutdown: + return + case <-ticker.C: + // Read accumulated messages + n, _ := l.buffer.Read(buf[:cap(buf)]) + if n == 0 { + continue + } + + // Write batch + _, _ = l.output.Write(buf[:n]) + } + } +} + +// Stop gracefully shuts down the logger +func (l *Logger) Stop(ctx context.Context) error { + done := make(chan struct{}) + + l.closeOnce.Do(func() { + close(l.shutdown) + }) + + go func() { + l.wg.Wait() + close(done) + }() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + return nil + } +} diff --git a/client/firewall/uspfilter/log/ringbuffer.go b/client/firewall/uspfilter/log/ringbuffer.go new file mode 100644 index 00000000000..dbc8f1289a7 --- /dev/null +++ b/client/firewall/uspfilter/log/ringbuffer.go @@ -0,0 +1,85 @@ +package log + +import "sync" + +// ringBuffer is a simple ring buffer implementation +type ringBuffer struct { + buf []byte + size int + r, w int64 // Read and write positions + mu sync.Mutex +} + +func newRingBuffer(size int) *ringBuffer { + return &ringBuffer{ + buf: make([]byte, size), + size: size, + } +} + +func (r *ringBuffer) Write(p []byte) (n int, err error) { + if len(p) == 0 { + return 0, nil + } + + r.mu.Lock() + defer r.mu.Unlock() + + if len(p) > r.size { + p = p[:r.size] + } + + n = len(p) + + // Write data, handling wrap-around + pos := int(r.w % int64(r.size)) + writeLen := min(len(p), r.size-pos) + copy(r.buf[pos:], p[:writeLen]) + + // If we have more data and need to wrap around + if writeLen < len(p) { + copy(r.buf, p[writeLen:]) + } + + // Update write position + r.w += int64(n) + + return n, nil +} + +func (r *ringBuffer) Read(p []byte) (n int, err error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.w == r.r { + return 0, nil + } + + // Calculate available data accounting for wraparound + available := int(r.w - r.r) + if available < 0 { + available += r.size + } + available = min(available, r.size) + + // Limit read to buffer size + toRead := min(available, len(p)) + if toRead == 0 { + return 0, nil + } + + // Read data, handling wrap-around + pos := int(r.r % int64(r.size)) + readLen := min(toRead, r.size-pos) + n = copy(p, r.buf[pos:pos+readLen]) + + // If we need more data and need to wrap around + if readLen < toRead { + n += copy(p[readLen:toRead], r.buf[:toRead-readLen]) + } + + // Update read position + r.r += int64(n) + + return n, nil +} diff --git a/client/firewall/uspfilter/rule.go b/client/firewall/uspfilter/rule.go index c59d4b264ce..6a4415f7315 100644 --- a/client/firewall/uspfilter/rule.go +++ b/client/firewall/uspfilter/rule.go @@ -2,14 +2,15 @@ package uspfilter import ( "net" + "net/netip" "github.com/google/gopacket" firewall "github.com/netbirdio/netbird/client/firewall/manager" ) -// Rule to handle management of rules -type Rule struct { +// PeerRule to handle management of rules +type PeerRule struct { id string ip net.IP ipLayer gopacket.LayerType @@ -24,6 +25,21 @@ type Rule struct { } // GetRuleID returns the rule id -func (r *Rule) GetRuleID() string { +func (r *PeerRule) GetRuleID() string { + return r.id +} + +type RouteRule struct { + id string + sources []netip.Prefix + destination netip.Prefix + proto firewall.Protocol + srcPort *firewall.Port + dstPort *firewall.Port + action firewall.Action +} + +// GetRuleID returns the rule id +func (r *RouteRule) GetRuleID() string { return r.id } diff --git a/client/firewall/uspfilter/tracer.go b/client/firewall/uspfilter/tracer.go new file mode 100644 index 00000000000..a4c653b3b4b --- /dev/null +++ b/client/firewall/uspfilter/tracer.go @@ -0,0 +1,390 @@ +package uspfilter + +import ( + "fmt" + "net" + "time" + + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + + fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" +) + +type PacketStage int + +const ( + StageReceived PacketStage = iota + StageConntrack + StagePeerACL + StageRouting + StageRouteACL + StageForwarding + StageCompleted +) + +const msgProcessingCompleted = "Processing completed" + +func (s PacketStage) String() string { + return map[PacketStage]string{ + StageReceived: "Received", + StageConntrack: "Connection Tracking", + StagePeerACL: "Peer ACL", + StageRouting: "Routing", + StageRouteACL: "Route ACL", + StageForwarding: "Forwarding", + StageCompleted: "Completed", + }[s] +} + +type ForwarderAction struct { + Action string + RemoteAddr string + Error error +} + +type TraceResult struct { + Timestamp time.Time + Stage PacketStage + Message string + Allowed bool + ForwarderAction *ForwarderAction +} + +type PacketTrace struct { + SourceIP net.IP + DestinationIP net.IP + Protocol string + SourcePort uint16 + DestinationPort uint16 + Direction fw.RuleDirection + Results []TraceResult +} + +type TCPState struct { + SYN bool + ACK bool + FIN bool + RST bool + PSH bool + URG bool +} + +type PacketBuilder struct { + SrcIP net.IP + DstIP net.IP + Protocol fw.Protocol + SrcPort uint16 + DstPort uint16 + ICMPType uint8 + ICMPCode uint8 + Direction fw.RuleDirection + PayloadSize int + TCPState *TCPState +} + +func (t *PacketTrace) AddResult(stage PacketStage, message string, allowed bool) { + t.Results = append(t.Results, TraceResult{ + Timestamp: time.Now(), + Stage: stage, + Message: message, + Allowed: allowed, + }) +} + +func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, allowed bool, action *ForwarderAction) { + t.Results = append(t.Results, TraceResult{ + Timestamp: time.Now(), + Stage: stage, + Message: message, + Allowed: allowed, + ForwarderAction: action, + }) +} + +func (p *PacketBuilder) Build() ([]byte, error) { + ip := p.buildIPLayer() + pktLayers := []gopacket.SerializableLayer{ip} + + transportLayer, err := p.buildTransportLayer(ip) + if err != nil { + return nil, err + } + pktLayers = append(pktLayers, transportLayer...) + + if p.PayloadSize > 0 { + payload := make([]byte, p.PayloadSize) + pktLayers = append(pktLayers, gopacket.Payload(payload)) + } + + return serializePacket(pktLayers) +} + +func (p *PacketBuilder) buildIPLayer() *layers.IPv4 { + return &layers.IPv4{ + Version: 4, + TTL: 64, + Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)), + SrcIP: p.SrcIP, + DstIP: p.DstIP, + } +} + +func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { + switch p.Protocol { + case "tcp": + return p.buildTCPLayer(ip) + case "udp": + return p.buildUDPLayer(ip) + case "icmp": + return p.buildICMPLayer() + default: + return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol) + } +} + +func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(p.SrcPort), + DstPort: layers.TCPPort(p.DstPort), + Window: 65535, + SYN: p.TCPState != nil && p.TCPState.SYN, + ACK: p.TCPState != nil && p.TCPState.ACK, + FIN: p.TCPState != nil && p.TCPState.FIN, + RST: p.TCPState != nil && p.TCPState.RST, + PSH: p.TCPState != nil && p.TCPState.PSH, + URG: p.TCPState != nil && p.TCPState.URG, + } + if err := tcp.SetNetworkLayerForChecksum(ip); err != nil { + return nil, fmt.Errorf("set network layer for TCP checksum: %w", err) + } + return []gopacket.SerializableLayer{tcp}, nil +} + +func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) { + udp := &layers.UDP{ + SrcPort: layers.UDPPort(p.SrcPort), + DstPort: layers.UDPPort(p.DstPort), + } + if err := udp.SetNetworkLayerForChecksum(ip); err != nil { + return nil, fmt.Errorf("set network layer for UDP checksum: %w", err) + } + return []gopacket.SerializableLayer{udp}, nil +} + +func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) { + icmp := &layers.ICMPv4{ + TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode), + } + if p.ICMPType == layers.ICMPv4TypeEchoRequest || p.ICMPType == layers.ICMPv4TypeEchoReply { + icmp.Id = uint16(1) + icmp.Seq = uint16(1) + } + return []gopacket.SerializableLayer{icmp}, nil +} + +func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) { + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + if err := gopacket.SerializeLayers(buf, opts, layers...); err != nil { + return nil, fmt.Errorf("serialize packet: %w", err) + } + return buf.Bytes(), nil +} + +func getIPProtocolNumber(protocol fw.Protocol) int { + switch protocol { + case fw.ProtocolTCP: + return int(layers.IPProtocolTCP) + case fw.ProtocolUDP: + return int(layers.IPProtocolUDP) + case fw.ProtocolICMP: + return int(layers.IPProtocolICMPv4) + default: + return 0 + } +} + +func (m *Manager) TracePacketFromBuilder(builder *PacketBuilder) (*PacketTrace, error) { + packetData, err := builder.Build() + if err != nil { + return nil, fmt.Errorf("build packet: %w", err) + } + + return m.TracePacket(packetData, builder.Direction), nil +} + +func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *PacketTrace { + + d := m.decoders.Get().(*decoder) + defer m.decoders.Put(d) + + trace := &PacketTrace{Direction: direction} + + // Initial packet decoding + if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil { + trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false) + return trace + } + + // Extract base packet info + srcIP, dstIP := m.extractIPs(d) + trace.SourceIP = srcIP + trace.DestinationIP = dstIP + + // Determine protocol and ports + switch d.decoded[1] { + case layers.LayerTypeTCP: + trace.Protocol = "TCP" + trace.SourcePort = uint16(d.tcp.SrcPort) + trace.DestinationPort = uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + trace.Protocol = "UDP" + trace.SourcePort = uint16(d.udp.SrcPort) + trace.DestinationPort = uint16(d.udp.DstPort) + case layers.LayerTypeICMPv4: + trace.Protocol = "ICMP" + } + + trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d", + trace.Protocol, srcIP, trace.SourcePort, dstIP, trace.DestinationPort), true) + + if direction == fw.RuleDirectionOUT { + return m.traceOutbound(packetData, trace) + } + + return m.traceInbound(packetData, trace, d, srcIP, dstIP) +} + +func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace { + if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) { + return trace + } + + if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) { + return trace + } + + if !m.handleRouting(trace) { + return trace + } + + if m.nativeRouter { + return m.handleNativeRouter(trace) + } + + return m.handleRouteACLs(trace, d, srcIP, dstIP) +} + +func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool { + allowed := m.isValidTrackedConnection(d, srcIP, dstIP) + msg := "No existing connection found" + if allowed { + msg = m.buildConntrackStateMessage(d) + trace.AddResult(StageConntrack, msg, true) + trace.AddResult(StageCompleted, "Packet allowed by connection tracking", true) + return true + } + trace.AddResult(StageConntrack, msg, false) + return false +} + +func (m *Manager) buildConntrackStateMessage(d *decoder) string { + msg := "Matched existing connection state" + switch d.decoded[1] { + case layers.LayerTypeTCP: + flags := getTCPFlags(&d.tcp) + msg += fmt.Sprintf(" (TCP Flags: SYN=%v ACK=%v RST=%v FIN=%v)", + flags&conntrack.TCPSyn != 0, + flags&conntrack.TCPAck != 0, + flags&conntrack.TCPRst != 0, + flags&conntrack.TCPFin != 0) + case layers.LayerTypeICMPv4: + msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq) + } + return msg +} + +func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool { + if !m.localForwarding { + trace.AddResult(StageRouting, "Local forwarding disabled", false) + trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false) + return true + } + + trace.AddResult(StageRouting, "Packet destined for local delivery", true) + blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) + + msg := "Allowed by peer ACL rules" + if blocked { + msg = "Blocked by peer ACL rules" + } + trace.AddResult(StagePeerACL, msg, !blocked) + + if m.netstack { + m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked) + } + + trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked) + return true +} + +func (m *Manager) handleRouting(trace *PacketTrace) bool { + if !m.routingEnabled { + trace.AddResult(StageRouting, "Routing disabled", false) + trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false) + return false + } + trace.AddResult(StageRouting, "Routing enabled, checking ACLs", true) + return true +} + +func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace { + trace.AddResult(StageRouteACL, "Using native router, skipping ACL checks", true) + trace.AddResult(StageForwarding, "Forwarding via native router", true) + trace.AddResult(StageCompleted, msgProcessingCompleted, true) + return trace +} + +func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace { + proto := getProtocolFromPacket(d) + srcPort, dstPort := getPortsFromPacket(d) + allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) + + msg := "Allowed by route ACLs" + if !allowed { + msg = "Blocked by route ACLs" + } + trace.AddResult(StageRouteACL, msg, allowed) + + if allowed && m.forwarder != nil { + m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true) + } + + trace.AddResult(StageCompleted, msgProcessingCompleted, allowed) + return trace +} + +func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr string, allowed bool) { + fwdAction := &ForwarderAction{ + Action: action, + RemoteAddr: remoteAddr, + } + trace.AddResultWithForwarder(StageForwarding, + fmt.Sprintf("Forwarding to %s", fwdAction.Action), allowed, fwdAction) +} + +func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace { + // will create or update the connection state + dropped := m.processOutgoingHooks(packetData) + if dropped { + trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false) + } else { + trace.AddResult(StageCompleted, "Packet allowed (outgoing)", true) + } + return trace +} diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 757249b2dd5..889e4cbb1a9 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -1,11 +1,14 @@ package uspfilter import ( + "errors" "fmt" "net" "net/netip" "os" + "slices" "strconv" + "strings" "sync" "github.com/google/gopacket" @@ -14,28 +17,48 @@ import ( log "github.com/sirupsen/logrus" firewall "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" - "github.com/netbirdio/netbird/client/iface" - "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" + nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/statemanager" ) const layerTypeAll = 0 -const EnvDisableConntrack = "NB_DISABLE_CONNTRACK" +const ( + // EnvDisableConntrack disables the stateful filter, replies to outbound traffic won't be allowed. + EnvDisableConntrack = "NB_DISABLE_CONNTRACK" -var ( - errRouteNotSupported = fmt.Errorf("route not supported with userspace firewall") -) + // EnvDisableUserspaceRouting disables userspace routing, to-be-routed packets will be dropped. + EnvDisableUserspaceRouting = "NB_DISABLE_USERSPACE_ROUTING" -// IFaceMapper defines subset methods of interface required for manager -type IFaceMapper interface { - SetFilter(device.PacketFilter) error - Address() iface.WGAddress -} + // EnvForceUserspaceRouter forces userspace routing even if native routing is available. + EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER" + + // EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack + // Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible + EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING" +) // RuleSet is a set of rules grouped by a string key -type RuleSet map[string]Rule +type RuleSet map[string]PeerRule + +type RouteRules []RouteRule + +func (r RouteRules) Sort() { + slices.SortStableFunc(r, func(a, b RouteRule) int { + // Deny rules come first + if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop { + return -1 + } + if a.action != firewall.ActionDrop && b.action == firewall.ActionDrop { + return 1 + } + return strings.Compare(a.id, b.id) + }) +} // Manager userspace firewall manager type Manager struct { @@ -43,17 +66,32 @@ type Manager struct { outgoingRules map[string]RuleSet // incomingRules is used for filtering and hooks incomingRules map[string]RuleSet + routeRules RouteRules wgNetwork *net.IPNet decoders sync.Pool - wgIface IFaceMapper + wgIface common.IFaceMapper nativeFirewall firewall.Manager mutex sync.RWMutex - stateful bool + // indicates whether we forward packets not destined for ourselves + routingEnabled bool + // indicates whether we leave forwarding and filtering to the native firewall + nativeRouter bool + // indicates whether we track outbound connections + stateful bool + // indicates whether wireguards runs in netstack mode + netstack bool + // indicates whether we forward local traffic to the native stack + localForwarding bool + + localipmanager *localIPManager + udpTracker *conntrack.UDPTracker icmpTracker *conntrack.ICMPTracker tcpTracker *conntrack.TCPTracker + forwarder *forwarder.Forwarder + logger *nblog.Logger } // decoder for packages @@ -70,22 +108,32 @@ type decoder struct { } // Create userspace firewall manager constructor -func Create(iface IFaceMapper) (*Manager, error) { - return create(iface) +func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) { + return create(iface, nil, disableServerRoutes) } -func CreateWithNativeFirewall(iface IFaceMapper, nativeFirewall firewall.Manager) (*Manager, error) { - mgr, err := create(iface) +func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { + if nativeFirewall == nil { + return nil, errors.New("native firewall is nil") + } + + mgr, err := create(iface, nativeFirewall, disableServerRoutes) if err != nil { return nil, err } - mgr.nativeFirewall = nativeFirewall return mgr, nil } -func create(iface IFaceMapper) (*Manager, error) { - disableConntrack, _ := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) +func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) { + disableConntrack, err := strconv.ParseBool(os.Getenv(EnvDisableConntrack)) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvDisableConntrack, err) + } + enableLocalForwarding, err := strconv.ParseBool(os.Getenv(EnvEnableNetstackLocalForwarding)) + if err != nil { + log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err) + } m := &Manager{ decoders: sync.Pool{ @@ -101,52 +149,161 @@ func create(iface IFaceMapper) (*Manager, error) { return d }, }, - outgoingRules: make(map[string]RuleSet), - incomingRules: make(map[string]RuleSet), - wgIface: iface, - stateful: !disableConntrack, + nativeFirewall: nativeFirewall, + outgoingRules: make(map[string]RuleSet), + incomingRules: make(map[string]RuleSet), + wgIface: iface, + localipmanager: newLocalIPManager(), + routingEnabled: false, + stateful: !disableConntrack, + logger: nblog.NewFromLogrus(log.StandardLogger()), + netstack: netstack.IsEnabled(), + // default true for non-netstack, for netstack only if explicitly enabled + localForwarding: !netstack.IsEnabled() || enableLocalForwarding, + } + + if err := m.localipmanager.UpdateLocalIPs(iface); err != nil { + return nil, fmt.Errorf("update local IPs: %w", err) } // Only initialize trackers if stateful mode is enabled if disableConntrack { log.Info("conntrack is disabled") } else { - m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout) - m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout) - m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout) + m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger) + m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger) + m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger) + } + + m.determineRouting(iface, disableServerRoutes) + + if err := m.blockInvalidRouted(iface); err != nil { + log.Errorf("failed to block invalid routed traffic: %v", err) } if err := iface.SetFilter(m); err != nil { - return nil, err + return nil, fmt.Errorf("set filter: %w", err) } return m, nil } +func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error { + if m.forwarder == nil { + return nil + } + wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String()) + if err != nil { + return fmt.Errorf("parse wireguard network: %w", err) + } + log.Debugf("blocking invalid routed traffic for %s", wgPrefix) + + if _, err := m.AddRouteFiltering( + []netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)}, + wgPrefix, + firewall.ProtocolALL, + nil, + nil, + firewall.ActionDrop, + ); err != nil { + return fmt.Errorf("block wg nte : %w", err) + } + + // TODO: Block networks that we're a client of + + return nil +} + +func (m *Manager) determineRouting(iface common.IFaceMapper, disableServerRoutes bool) { + disableUspRouting, _ := strconv.ParseBool(os.Getenv(EnvDisableUserspaceRouting)) + forceUserspaceRouter, _ := strconv.ParseBool(os.Getenv(EnvForceUserspaceRouter)) + + switch { + case disableUspRouting: + m.routingEnabled = false + m.nativeRouter = false + log.Info("userspace routing is disabled") + + case disableServerRoutes: + // if server routes are disabled we will let packets pass to the native stack + m.routingEnabled = true + m.nativeRouter = true + + log.Info("server routes are disabled") + + case forceUserspaceRouter: + m.routingEnabled = true + m.nativeRouter = false + + log.Info("userspace routing is forced") + + case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported(): + // if the OS supports routing natively, then we don't need to filter/route ourselves + // netstack mode won't support native routing as there is no interface + + m.routingEnabled = true + m.nativeRouter = true + + log.Info("native routing is enabled") + + default: + m.routingEnabled = true + m.nativeRouter = false + + log.Info("userspace routing enabled by default") + } + + // netstack needs the forwarder for local traffic + if m.netstack && m.localForwarding || + m.routingEnabled && !m.nativeRouter { + + m.initForwarder(iface) + } +} + +// initForwarder initializes the forwarder, it disables routing on errors +func (m *Manager) initForwarder(iface common.IFaceMapper) { + // Only supported in userspace mode as we need to inject packets back into wireguard directly + intf := iface.GetWGDevice() + if intf == nil { + log.Info("forwarding not supported") + m.routingEnabled = false + return + } + + forwarder, err := forwarder.New(iface, m.logger, m.netstack) + if err != nil { + log.Errorf("failed to create forwarder: %v", err) + m.routingEnabled = false + return + } + + m.forwarder = forwarder +} + func (m *Manager) Init(*statemanager.Manager) error { return nil } func (m *Manager) IsServerRouteSupported() bool { - if m.nativeFirewall == nil { - return false - } else { - return true - } + return m.nativeFirewall != nil || m.routingEnabled && m.forwarder != nil } func (m *Manager) AddNatRule(pair firewall.RouterPair) error { - if m.nativeFirewall == nil { - return errRouteNotSupported + if m.nativeRouter && m.nativeFirewall != nil { + return m.nativeFirewall.AddNatRule(pair) } - return m.nativeFirewall.AddNatRule(pair) + + // userspace routed packets are always SNATed to the inbound direction + // TODO: implement outbound SNAT + return nil } // RemoveNatRule removes a routing firewall rule func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error { - if m.nativeFirewall == nil { - return errRouteNotSupported + if m.nativeRouter && m.nativeFirewall != nil { + return m.nativeFirewall.RemoveNatRule(pair) } - return m.nativeFirewall.RemoveNatRule(pair) + return nil } // AddPeerFiltering rule to the firewall @@ -162,7 +319,7 @@ func (m *Manager) AddPeerFiltering( _ string, comment string, ) ([]firewall.Rule, error) { - r := Rule{ + r := PeerRule{ id: uuid.New().String(), ip: ip, ipLayer: layers.LayerTypeIPv6, @@ -205,18 +362,56 @@ func (m *Manager) AddPeerFiltering( return []firewall.Rule{&r}, nil } -func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) { - if m.nativeFirewall == nil { - return nil, errRouteNotSupported +func (m *Manager) AddRouteFiltering( + sources []netip.Prefix, + destination netip.Prefix, + proto firewall.Protocol, + sPort *firewall.Port, + dPort *firewall.Port, + action firewall.Action, +) (firewall.Rule, error) { + if m.nativeRouter && m.nativeFirewall != nil { + return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) + } + + m.mutex.Lock() + defer m.mutex.Unlock() + + ruleID := uuid.New().String() + rule := RouteRule{ + id: ruleID, + sources: sources, + destination: destination, + proto: proto, + srcPort: sPort, + dstPort: dPort, + action: action, } - return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action) + + m.routeRules = append(m.routeRules, rule) + m.routeRules.Sort() + + return &rule, nil } func (m *Manager) DeleteRouteRule(rule firewall.Rule) error { - if m.nativeFirewall == nil { - return errRouteNotSupported + if m.nativeRouter && m.nativeFirewall != nil { + return m.nativeFirewall.DeleteRouteRule(rule) } - return m.nativeFirewall.DeleteRouteRule(rule) + + m.mutex.Lock() + defer m.mutex.Unlock() + + ruleID := rule.GetRuleID() + idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool { + return r.id == ruleID + }) + if idx < 0 { + return fmt.Errorf("route rule not found: %s", ruleID) + } + + m.routeRules = slices.Delete(m.routeRules, idx, idx+1) + return nil } // DeletePeerRule from the firewall by rule definition @@ -224,7 +419,7 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error { m.mutex.Lock() defer m.mutex.Unlock() - r, ok := rule.(*Rule) + r, ok := rule.(*PeerRule) if !ok { return fmt.Errorf("delete rule: invalid rule type: %T", rule) } @@ -255,10 +450,14 @@ func (m *Manager) DropOutgoing(packetData []byte) bool { // DropIncoming filter incoming packets func (m *Manager) DropIncoming(packetData []byte) bool { - return m.dropFilter(packetData, m.incomingRules) + return m.dropFilter(packetData) +} + +// UpdateLocalIPs updates the list of local IPs +func (m *Manager) UpdateLocalIPs() error { + return m.localipmanager.UpdateLocalIPs(m.wgIface) } -// processOutgoingHooks processes UDP hooks for outgoing packets and tracks TCP/UDP/ICMP func (m *Manager) processOutgoingHooks(packetData []byte) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -279,18 +478,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { return false } - // Always process UDP hooks - if d.decoded[1] == layers.LayerTypeUDP { - // Track UDP state only if enabled - if m.stateful { - m.trackUDPOutbound(d, srcIP, dstIP) - } - return m.checkUDPHooks(d, dstIP, packetData) - } - - // Track other protocols only if stateful mode is enabled + // Track all protocols if stateful mode is enabled if m.stateful { switch d.decoded[1] { + case layers.LayerTypeUDP: + m.trackUDPOutbound(d, srcIP, dstIP) case layers.LayerTypeTCP: m.trackTCPOutbound(d, srcIP, dstIP) case layers.LayerTypeICMPv4: @@ -298,6 +490,11 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool { } } + // Process UDP hooks even if stateful mode is disabled + if d.decoded[1] == layers.LayerTypeUDP { + return m.checkUDPHooks(d, dstIP, packetData) + } + return false } @@ -379,10 +576,9 @@ func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) { } } -// dropFilter implements filtering logic for incoming packets -func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { - // TODO: Disable router if --disable-server-router is set - +// dropFilter implements filtering logic for incoming packets. +// If it returns true, the packet should be dropped. +func (m *Manager) dropFilter(packetData []byte) bool { m.mutex.RLock() defer m.mutex.RUnlock() @@ -390,25 +586,120 @@ func (m *Manager) dropFilter(packetData []byte, rules map[string]RuleSet) bool { defer m.decoders.Put(d) if !m.isValidPacket(d, packetData) { + m.logger.Trace("Invalid packet structure") return true } srcIP, dstIP := m.extractIPs(d) if srcIP == nil { - log.Errorf("unknown layer: %v", d.decoded[0]) + m.logger.Error("Unknown network layer: %v", d.decoded[0]) return true } - if !m.isWireguardTraffic(srcIP, dstIP) { + // For all inbound traffic, first check if it matches a tracked connection. + // This must happen before any other filtering because the packets are statefully tracked. + if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { return false } - // Check connection state only if enabled - if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) { + if m.localipmanager.IsLocalIP(dstIP) { + return m.handleLocalTraffic(d, srcIP, dstIP, packetData) + } + + return m.handleRoutedTraffic(d, srcIP, dstIP, packetData) +} + +// handleLocalTraffic handles local traffic. +// If it returns true, the packet should be dropped. +func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { + if !m.localForwarding { + m.logger.Trace("Dropping local packet (local forwarding disabled): src=%s dst=%s", srcIP, dstIP) + return true + } + + if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) { + m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s", + srcIP, dstIP) + return true + } + + // if running in netstack mode we need to pass this to the forwarder + if m.netstack { + m.handleNetstackLocalTraffic(packetData) + + // don't process this packet further + return true + } + + return false +} +func (m *Manager) handleNetstackLocalTraffic(packetData []byte) { + if m.forwarder == nil { + return + } + + if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { + m.logger.Error("Failed to inject local packet: %v", err) + } +} + +// handleRoutedTraffic handles routed traffic. +// If it returns true, the packet should be dropped. +func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool { + // Drop if routing is disabled + if !m.routingEnabled { + m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s", + srcIP, dstIP) + return true + } + + // Pass to native stack if native router is enabled or forced + if m.nativeRouter { return false } - return m.applyRules(srcIP, packetData, rules, d) + // Get protocol and ports for route ACL check + proto := getProtocolFromPacket(d) + srcPort, dstPort := getPortsFromPacket(d) + + // Check route ACLs + if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) { + m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v", + srcIP, srcPort, dstIP, dstPort, proto) + return true + } + + // Let forwarder handle the packet if it passed route ACLs + if err := m.forwarder.InjectIncomingPacket(packetData); err != nil { + m.logger.Error("Failed to inject incoming packet: %v", err) + } + + // Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture + return true +} + +func getProtocolFromPacket(d *decoder) firewall.Protocol { + switch d.decoded[1] { + case layers.LayerTypeTCP: + return firewall.ProtocolTCP + case layers.LayerTypeUDP: + return firewall.ProtocolUDP + case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6: + return firewall.ProtocolICMP + default: + return firewall.ProtocolALL + } +} + +func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) { + switch d.decoded[1] { + case layers.LayerTypeTCP: + return uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort) + case layers.LayerTypeUDP: + return uint16(d.udp.SrcPort), uint16(d.udp.DstPort) + default: + return 0, 0 + } } func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { @@ -424,10 +715,6 @@ func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool { return true } -func (m *Manager) isWireguardTraffic(srcIP, dstIP net.IP) bool { - return m.wgNetwork.Contains(srcIP) && m.wgNetwork.Contains(dstIP) -} - func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool { switch d.decoded[1] { case layers.LayerTypeTCP: @@ -462,7 +749,22 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool return false } -func (m *Manager) applyRules(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { +// isSpecialICMP returns true if the packet is a special ICMP packet that should be allowed +func (m *Manager) isSpecialICMP(d *decoder) bool { + if d.decoded[1] != layers.LayerTypeICMPv4 { + return false + } + + icmpType := d.icmp4.TypeCode.Type() + return icmpType == layers.ICMPv4TypeDestinationUnreachable || + icmpType == layers.ICMPv4TypeTimeExceeded +} + +func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool { + if m.isSpecialICMP(d) { + return false + } + if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok { return filter } @@ -496,7 +798,7 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool { return false } -func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decoder) (bool, bool) { +func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) { payloadLayer := d.decoded[1] for _, rule := range rules { if rule.matchByIP && !ip.Equal(rule.ip) { @@ -533,6 +835,51 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]Rule, d *decode return false, false } +// routeACLsPass returns treu if the packet is allowed by the route ACLs +func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool { + m.mutex.RLock() + defer m.mutex.RUnlock() + + srcAddr := netip.AddrFrom4([4]byte(srcIP.To4())) + dstAddr := netip.AddrFrom4([4]byte(dstIP.To4())) + + for _, rule := range m.routeRules { + if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) { + return rule.action == firewall.ActionAccept + } + } + return false +} + +func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { + if !rule.destination.Contains(dstAddr) { + return false + } + + sourceMatched := false + for _, src := range rule.sources { + if src.Contains(srcAddr) { + sourceMatched = true + break + } + } + if !sourceMatched { + return false + } + + if rule.proto != firewall.ProtocolALL && rule.proto != proto { + return false + } + + if proto == firewall.ProtocolTCP || proto == firewall.ProtocolUDP { + if !portsMatch(rule.srcPort, srcPort) || !portsMatch(rule.dstPort, dstPort) { + return false + } + } + + return true +} + // SetNetwork of the wireguard interface to which filtering applied func (m *Manager) SetNetwork(network *net.IPNet) { m.wgNetwork = network @@ -544,7 +891,7 @@ func (m *Manager) SetNetwork(network *net.IPNet) { func (m *Manager) AddUDPPacketHook( in bool, ip net.IP, dPort uint16, hook func([]byte) bool, ) string { - r := Rule{ + r := PeerRule{ id: uuid.New().String(), ip: ip, protoLayer: layers.LayerTypeUDP, @@ -561,12 +908,12 @@ func (m *Manager) AddUDPPacketHook( m.mutex.Lock() if in { if _, ok := m.incomingRules[r.ip.String()]; !ok { - m.incomingRules[r.ip.String()] = make(map[string]Rule) + m.incomingRules[r.ip.String()] = make(map[string]PeerRule) } m.incomingRules[r.ip.String()][r.id] = r } else { if _, ok := m.outgoingRules[r.ip.String()]; !ok { - m.outgoingRules[r.ip.String()] = make(map[string]Rule) + m.outgoingRules[r.ip.String()] = make(map[string]PeerRule) } m.outgoingRules[r.ip.String()][r.id] = r } @@ -599,3 +946,10 @@ func (m *Manager) RemovePacketHook(hookID string) error { } return fmt.Errorf("hook with given id not found") } + +// SetLogLevel sets the log level for the firewall manager +func (m *Manager) SetLogLevel(level log.Level) { + if m.logger != nil { + m.logger.SetLevel(nblog.Level(level)) + } +} diff --git a/client/firewall/uspfilter/uspfilter_bench_test.go b/client/firewall/uspfilter/uspfilter_bench_test.go index 46bc4439d83..875bb2425b1 100644 --- a/client/firewall/uspfilter/uspfilter_bench_test.go +++ b/client/firewall/uspfilter/uspfilter_bench_test.go @@ -1,9 +1,12 @@ +//go:build uspbench + package uspfilter import ( "fmt" "math/rand" "net" + "net/netip" "os" "strings" "testing" @@ -155,7 +158,7 @@ func BenchmarkCoreFiltering(b *testing.B) { // Create manager and basic setup manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -185,7 +188,7 @@ func BenchmarkCoreFiltering(b *testing.B) { // Measure inbound packet processing b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(inbound, manager.incomingRules) + manager.dropFilter(inbound) } }) } @@ -200,7 +203,7 @@ func BenchmarkStateScaling(b *testing.B) { b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -228,7 +231,7 @@ func BenchmarkStateScaling(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(testIn, manager.incomingRules) + manager.dropFilter(testIn) } }) } @@ -248,7 +251,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -269,7 +272,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(inbound, manager.incomingRules) + manager.dropFilter(inbound) } }) } @@ -447,7 +450,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { b.Run(sc.name, func(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -472,7 +475,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { manager.processOutgoingHooks(syn) // SYN-ACK synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck)) - manager.dropFilter(synack, manager.incomingRules) + manager.dropFilter(synack) // ACK ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck)) manager.processOutgoingHooks(ack) @@ -481,7 +484,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - manager.dropFilter(inbound, manager.incomingRules) + manager.dropFilter(inbound) } }) } @@ -574,7 +577,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -618,7 +621,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { // SYN-ACK synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) - manager.dropFilter(synack, manager.incomingRules) + manager.dropFilter(synack) // ACK ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], @@ -646,7 +649,7 @@ func BenchmarkLongLivedConnections(b *testing.B) { // First outbound data manager.processOutgoingHooks(outPackets[connIdx]) // Then inbound response - this is what we're actually measuring - manager.dropFilter(inPackets[connIdx], manager.incomingRules) + manager.dropFilter(inPackets[connIdx]) } }) } @@ -665,7 +668,7 @@ func BenchmarkShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -754,17 +757,17 @@ func BenchmarkShortLivedConnections(b *testing.B) { // Connection establishment manager.processOutgoingHooks(p.syn) - manager.dropFilter(p.synAck, manager.incomingRules) + manager.dropFilter(p.synAck) manager.processOutgoingHooks(p.ack) // Data transfer manager.processOutgoingHooks(p.request) - manager.dropFilter(p.response, manager.incomingRules) + manager.dropFilter(p.response) // Connection teardown manager.processOutgoingHooks(p.finClient) - manager.dropFilter(p.ackServer, manager.incomingRules) - manager.dropFilter(p.finServer, manager.incomingRules) + manager.dropFilter(p.ackServer) + manager.dropFilter(p.finServer) manager.processOutgoingHooks(p.ackClient) } }) @@ -784,7 +787,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -825,7 +828,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i], 80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck)) - manager.dropFilter(synack, manager.incomingRules) + manager.dropFilter(synack) ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i], uint16(1024+i), 80, uint16(conntrack.TCPAck)) @@ -852,7 +855,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) { // Simulate bidirectional traffic manager.processOutgoingHooks(outPackets[connIdx]) - manager.dropFilter(inPackets[connIdx], manager.incomingRules) + manager.dropFilter(inPackets[connIdx]) } }) }) @@ -872,7 +875,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { manager, _ := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) defer b.Cleanup(func() { require.NoError(b, manager.Reset(nil)) }) @@ -949,15 +952,15 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) { // Full connection lifecycle manager.processOutgoingHooks(p.syn) - manager.dropFilter(p.synAck, manager.incomingRules) + manager.dropFilter(p.synAck) manager.processOutgoingHooks(p.ack) manager.processOutgoingHooks(p.request) - manager.dropFilter(p.response, manager.incomingRules) + manager.dropFilter(p.response) manager.processOutgoingHooks(p.finClient) - manager.dropFilter(p.ackServer, manager.incomingRules) - manager.dropFilter(p.finServer, manager.incomingRules) + manager.dropFilter(p.ackServer) + manager.dropFilter(p.finServer) manager.processOutgoingHooks(p.ackClient) } }) @@ -996,3 +999,72 @@ func generateTCPPacketWithFlags(b *testing.B, srcIP, dstIP net.IP, srcPort, dstP require.NoError(b, gopacket.SerializeLayers(buf, opts, ipv4, tcp, gopacket.Payload("test"))) return buf.Bytes() } + +func BenchmarkRouteACLs(b *testing.B) { + manager := setupRoutedManager(b, "10.10.0.100/16") + + // Add several route rules to simulate real-world scenario + rules := []struct { + sources []netip.Prefix + dest netip.Prefix + proto fw.Protocol + port *fw.Port + }{ + { + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + port: &fw.Port{Values: []uint16{80, 443}}, + }, + { + sources: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/12"), + netip.MustParsePrefix("10.0.0.0/8"), + }, + dest: netip.MustParsePrefix("0.0.0.0/0"), + proto: fw.ProtocolICMP, + }, + { + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + dest: netip.MustParsePrefix("192.168.0.0/16"), + proto: fw.ProtocolUDP, + port: &fw.Port{Values: []uint16{53}}, + }, + } + + for _, r := range rules { + _, err := manager.AddRouteFiltering( + r.sources, + r.dest, + r.proto, + nil, + r.port, + fw.ActionAccept, + ) + if err != nil { + b.Fatal(err) + } + } + + // Test cases that exercise different matching scenarios + cases := []struct { + srcIP string + dstIP string + proto fw.Protocol + dstPort uint16 + }{ + {"100.10.0.1", "192.168.1.100", fw.ProtocolTCP, 443}, // Match first rule + {"172.16.0.1", "8.8.8.8", fw.ProtocolICMP, 0}, // Match second rule + {"1.1.1.1", "192.168.1.53", fw.ProtocolUDP, 53}, // Match third rule + {"192.168.1.1", "10.0.0.1", fw.ProtocolTCP, 8080}, // No match + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + for _, tc := range cases { + srcIP := net.ParseIP(tc.srcIP) + dstIP := net.ParseIP(tc.dstIP) + manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort) + } + } +} diff --git a/client/firewall/uspfilter/uspfilter_filter_test.go b/client/firewall/uspfilter/uspfilter_filter_test.go new file mode 100644 index 00000000000..d7aebb1aab0 --- /dev/null +++ b/client/firewall/uspfilter/uspfilter_filter_test.go @@ -0,0 +1,1014 @@ +package uspfilter + +import ( + "net" + "net/netip" + "testing" + + "github.com/golang/mock/gomock" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" + "github.com/stretchr/testify/require" + wgdevice "golang.zx2c4.com/wireguard/device" + + fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface" + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/mocks" +) + +func TestPeerACLFiltering(t *testing.T) { + localIP := net.ParseIP("100.10.0.100") + wgNet := &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + } + + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: localIP, + Network: wgNet, + } + }, + } + + manager, err := Create(ifaceMock, false) + require.NoError(t, err) + require.NotNil(t, manager) + + t.Cleanup(func() { + require.NoError(t, manager.Reset(nil)) + }) + + manager.wgNetwork = wgNet + + err = manager.UpdateLocalIPs() + require.NoError(t, err) + + testCases := []struct { + name string + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + ruleIP string + ruleProto fw.Protocol + ruleSrcPort *fw.Port + ruleDstPort *fw.Port + ruleAction fw.Action + shouldBeBlocked bool + }{ + { + name: "Allow TCP traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow UDP traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 53, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolUDP, + ruleDstPort: &fw.Port{Values: []uint16{53}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow ICMP traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolICMP, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolICMP, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow all traffic from WG peer", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolALL, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow traffic from non-WG source", + srcIP: "192.168.1.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "192.168.1.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow all traffic with 0.0.0.0 rule", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + ruleIP: "0.0.0.0", + ruleProto: fw.ProtocolALL, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Allow TCP traffic within port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Block TCP traffic outside port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 7999, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + { + name: "Allow TCP traffic with source port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 32100, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}}, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: false, + }, + { + name: "Block TCP traffic outside source port range", + srcIP: "100.10.0.1", + dstIP: "100.10.0.100", + proto: fw.ProtocolTCP, + srcPort: 31999, + dstPort: 443, + ruleIP: "100.10.0.1", + ruleProto: fw.ProtocolTCP, + ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}}, + ruleDstPort: &fw.Port{Values: []uint16{443}}, + ruleAction: fw.ActionAccept, + shouldBeBlocked: true, + }, + } + + t.Run("Implicit DROP (no rules)", func(t *testing.T) { + packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443) + isDropped := manager.DropIncoming(packet) + require.True(t, isDropped, "Packet should be dropped when no rules exist") + }) + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rules, err := manager.AddPeerFiltering( + net.ParseIP(tc.ruleIP), + tc.ruleProto, + tc.ruleSrcPort, + tc.ruleDstPort, + tc.ruleAction, + "", + tc.name, + ) + require.NoError(t, err) + require.NotEmpty(t, rules) + + t.Cleanup(func() { + for _, rule := range rules { + require.NoError(t, manager.DeletePeerRule(rule)) + } + }) + + packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort) + isDropped := manager.DropIncoming(packet) + require.Equal(t, tc.shouldBeBlocked, isDropped) + }) + } +} + +func createTestPacket(t *testing.T, srcIP, dstIP string, proto fw.Protocol, srcPort, dstPort uint16) []byte { + t.Helper() + + buf := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: true, + FixLengths: true, + } + + ipLayer := &layers.IPv4{ + Version: 4, + TTL: 64, + SrcIP: net.ParseIP(srcIP), + DstIP: net.ParseIP(dstIP), + } + + var err error + switch proto { + case fw.ProtocolTCP: + ipLayer.Protocol = layers.IPProtocolTCP + tcp := &layers.TCP{ + SrcPort: layers.TCPPort(srcPort), + DstPort: layers.TCPPort(dstPort), + } + err = tcp.SetNetworkLayerForChecksum(ipLayer) + require.NoError(t, err) + err = gopacket.SerializeLayers(buf, opts, ipLayer, tcp) + + case fw.ProtocolUDP: + ipLayer.Protocol = layers.IPProtocolUDP + udp := &layers.UDP{ + SrcPort: layers.UDPPort(srcPort), + DstPort: layers.UDPPort(dstPort), + } + err = udp.SetNetworkLayerForChecksum(ipLayer) + require.NoError(t, err) + err = gopacket.SerializeLayers(buf, opts, ipLayer, udp) + + case fw.ProtocolICMP: + ipLayer.Protocol = layers.IPProtocolICMPv4 + icmp := &layers.ICMPv4{ + TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0), + } + err = gopacket.SerializeLayers(buf, opts, ipLayer, icmp) + + default: + err = gopacket.SerializeLayers(buf, opts, ipLayer) + } + + require.NoError(t, err) + return buf.Bytes() +} + +func setupRoutedManager(tb testing.TB, network string) *Manager { + tb.Helper() + + ctrl := gomock.NewController(tb) + dev := mocks.NewMockDevice(ctrl) + dev.EXPECT().MTU().Return(1500, nil).AnyTimes() + + localIP, wgNet, err := net.ParseCIDR(network) + require.NoError(tb, err) + + ifaceMock := &IFaceMock{ + SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: localIP, + Network: wgNet, + } + }, + GetDeviceFunc: func() *device.FilteredDevice { + return &device.FilteredDevice{Device: dev} + }, + GetWGDeviceFunc: func() *wgdevice.Device { + return &wgdevice.Device{} + }, + } + + manager, err := Create(ifaceMock, false) + require.NoError(tb, err) + require.NotNil(tb, manager) + require.True(tb, manager.routingEnabled) + require.False(tb, manager.nativeRouter) + + tb.Cleanup(func() { + require.NoError(tb, manager.Reset(nil)) + }) + + return manager +} + +func TestRouteACLFiltering(t *testing.T) { + manager := setupRoutedManager(t, "10.10.0.100/16") + + type rule struct { + sources []netip.Prefix + dest netip.Prefix + proto fw.Protocol + srcPort *fw.Port + dstPort *fw.Port + action fw.Action + } + + testCases := []struct { + name string + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + rule rule + shouldPass bool + }{ + { + name: "Allow TCP with specific source and destination", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{443}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow any source to specific destination", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{443}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow any source to any destination", + srcIP: "172.16.0.1", + dstIP: "203.0.113.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + dest: netip.MustParsePrefix("0.0.0.0/0"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{443}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow UDP DNS traffic", + srcIP: "100.10.0.1", + dstIP: "192.168.1.53", + proto: fw.ProtocolUDP, + srcPort: 54321, + dstPort: 53, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolUDP, + dstPort: &fw.Port{Values: []uint16{53}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow ICMP to any destination", + srcIP: "100.10.0.1", + dstIP: "8.8.8.8", + proto: fw.ProtocolICMP, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("0.0.0.0/0"), + proto: fw.ProtocolICMP, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow all protocols but specific port", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolALL, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Implicit deny - wrong destination port", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "Implicit deny - wrong protocol", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "Implicit deny - wrong source network", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "Source port match", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + srcPort: &fw.Port{Values: []uint16{12345}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Multiple source networks", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{ + netip.MustParsePrefix("100.10.0.0/16"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow ALL protocol without ports", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolICMP, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolALL, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow ALL protocol with specific ports", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolALL, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Multiple source networks with mismatched protocol", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + // Should not match TCP rule + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{ + netip.MustParsePrefix("100.10.0.0/16"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "Allow multiple destination ports", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80, 8080, 443}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow multiple source ports", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Allow ALL protocol with both src and dst ports", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolALL, + srcPort: &fw.Port{Values: []uint16{12345}}, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Port Range - Within Range", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{ + IsRange: true, + Values: []uint16{8000, 8100}, + }, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Port Range - Outside Range", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 7999, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{ + IsRange: true, + Values: []uint16{8000, 8100}, + }, + action: fw.ActionAccept, + }, + shouldPass: false, + }, + { + name: "Source Port Range - Within Range", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 32100, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + srcPort: &fw.Port{ + IsRange: true, + Values: []uint16{32000, 33000}, + }, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Mixed Port Specification - Range and Single", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 32100, + dstPort: 443, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + srcPort: &fw.Port{ + IsRange: true, + Values: []uint16{32000, 33000}, + }, + dstPort: &fw.Port{ + Values: []uint16{443}, + }, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Edge Case - Port at Range Boundary", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8100, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{ + IsRange: true, + Values: []uint16{8000, 8100}, + }, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "UDP Port Range", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolUDP, + srcPort: 12345, + dstPort: 5060, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolUDP, + dstPort: &fw.Port{ + IsRange: true, + Values: []uint16{5060, 5070}, + }, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "ALL Protocol with Port Range", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolALL, + dstPort: &fw.Port{ + IsRange: true, + Values: []uint16{8000, 8100}, + }, + action: fw.ActionAccept, + }, + shouldPass: true, + }, + { + name: "Drop TCP traffic to specific destination", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{443}}, + action: fw.ActionDrop, + }, + shouldPass: false, + }, + { + name: "Drop all traffic to specific destination", + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolALL, + action: fw.ActionDrop, + }, + shouldPass: false, + }, + { + name: "Drop traffic from multiple source networks", + srcIP: "172.16.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + rule: rule{ + sources: []netip.Prefix{ + netip.MustParsePrefix("100.10.0.0/16"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionDrop, + }, + shouldPass: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + rule, err := manager.AddRouteFiltering( + tc.rule.sources, + tc.rule.dest, + tc.rule.proto, + tc.rule.srcPort, + tc.rule.dstPort, + tc.rule.action, + ) + require.NoError(t, err) + require.NotNil(t, rule) + + t.Cleanup(func() { + require.NoError(t, manager.DeleteRouteRule(rule)) + }) + + srcIP := net.ParseIP(tc.srcIP) + dstIP := net.ParseIP(tc.dstIP) + + // testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed + // to the forwarder + isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort) + require.Equal(t, tc.shouldPass, isAllowed) + }) + } +} + +func TestRouteACLOrder(t *testing.T) { + manager := setupRoutedManager(t, "10.10.0.100/16") + + type testCase struct { + name string + rules []struct { + sources []netip.Prefix + dest netip.Prefix + proto fw.Protocol + srcPort *fw.Port + dstPort *fw.Port + action fw.Action + } + packets []struct { + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + shouldPass bool + } + } + + testCases := []testCase{ + { + name: "Drop rules take precedence over accept", + rules: []struct { + sources []netip.Prefix + dest netip.Prefix + proto fw.Protocol + srcPort *fw.Port + dstPort *fw.Port + action fw.Action + }{ + { + // Accept rule added first + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80, 443}}, + action: fw.ActionAccept, + }, + { + // Drop rule added second but should be evaluated first + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{443}}, + action: fw.ActionDrop, + }, + }, + packets: []struct { + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + shouldPass bool + }{ + { + // Should be dropped by the drop rule + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + shouldPass: false, + }, + { + // Should be allowed by the accept rule (port 80 not in drop rule) + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + shouldPass: true, + }, + }, + }, + { + name: "Multiple drop rules take precedence", + rules: []struct { + sources []netip.Prefix + dest netip.Prefix + proto fw.Protocol + srcPort *fw.Port + dstPort *fw.Port + action fw.Action + }{ + { + // Accept all + sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")}, + dest: netip.MustParsePrefix("0.0.0.0/0"), + proto: fw.ProtocolALL, + action: fw.ActionAccept, + }, + { + // Drop specific port + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{443}}, + action: fw.ActionDrop, + }, + { + // Drop different port + sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")}, + dest: netip.MustParsePrefix("192.168.1.0/24"), + proto: fw.ProtocolTCP, + dstPort: &fw.Port{Values: []uint16{80}}, + action: fw.ActionDrop, + }, + }, + packets: []struct { + srcIP string + dstIP string + proto fw.Protocol + srcPort uint16 + dstPort uint16 + shouldPass bool + }{ + { + // Should be dropped by first drop rule + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 443, + shouldPass: false, + }, + { + // Should be dropped by second drop rule + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 80, + shouldPass: false, + }, + { + // Should be allowed by the accept rule (different port) + srcIP: "100.10.0.1", + dstIP: "192.168.1.100", + proto: fw.ProtocolTCP, + srcPort: 12345, + dstPort: 8080, + shouldPass: true, + }, + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + var rules []fw.Rule + for _, r := range tc.rules { + rule, err := manager.AddRouteFiltering( + r.sources, + r.dest, + r.proto, + r.srcPort, + r.dstPort, + r.action, + ) + require.NoError(t, err) + require.NotNil(t, rule) + rules = append(rules, rule) + } + + t.Cleanup(func() { + for _, rule := range rules { + require.NoError(t, manager.DeleteRouteRule(rule)) + } + }) + + for i, p := range tc.packets { + srcIP := net.ParseIP(p.srcIP) + dstIP := net.ParseIP(p.dstIP) + + isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort) + require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i) + } + }) + } +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 9d795de691f..089bf8f5531 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -9,17 +9,38 @@ import ( "github.com/google/gopacket" "github.com/google/gopacket/layers" + "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" + wgdevice "golang.zx2c4.com/wireguard/device" fw "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack" + "github.com/netbirdio/netbird/client/firewall/uspfilter/log" "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) +var logger = log.NewFromLogrus(logrus.StandardLogger()) + type IFaceMock struct { - SetFilterFunc func(device.PacketFilter) error - AddressFunc func() iface.WGAddress + SetFilterFunc func(device.PacketFilter) error + AddressFunc func() iface.WGAddress + GetWGDeviceFunc func() *wgdevice.Device + GetDeviceFunc func() *device.FilteredDevice +} + +func (i *IFaceMock) GetWGDevice() *wgdevice.Device { + if i.GetWGDeviceFunc == nil { + return nil + } + return i.GetWGDeviceFunc() +} + +func (i *IFaceMock) GetDevice() *device.FilteredDevice { + if i.GetDeviceFunc == nil { + return nil + } + return i.GetDeviceFunc() } func (i *IFaceMock) SetFilter(iface device.PacketFilter) error { @@ -41,7 +62,7 @@ func TestManagerCreate(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -61,7 +82,7 @@ func TestManagerAddPeerFiltering(t *testing.T) { }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -95,7 +116,7 @@ func TestManagerDeleteRule(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -166,12 +187,12 @@ func TestAddUDPPacketHook(t *testing.T) { t.Run(tt.name, func(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) require.NoError(t, err) manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook) - var addedRule Rule + var addedRule PeerRule if tt.in { if len(manager.incomingRules[tt.ip.String()]) != 1 { t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules)) @@ -215,7 +236,7 @@ func TestManagerReset(t *testing.T) { SetFilterFunc: func(device.PacketFilter) error { return nil }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -247,9 +268,18 @@ func TestManagerReset(t *testing.T) { func TestNotMatchByIP(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, + AddressFunc: func() iface.WGAddress { + return iface.WGAddress{ + IP: net.ParseIP("100.10.0.100"), + Network: &net.IPNet{ + IP: net.ParseIP("100.10.0.0"), + Mask: net.CIDRMask(16, 32), + }, + } + }, } - m, err := Create(ifaceMock) + m, err := Create(ifaceMock, false) if err != nil { t.Errorf("failed to create Manager: %v", err) return @@ -298,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) { return } - if m.dropFilter(buf.Bytes(), m.incomingRules) { + if m.dropFilter(buf.Bytes()) { t.Errorf("expected packet to be accepted") return } @@ -317,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) { } // creating manager instance - manager, err := Create(iface) + manager, err := Create(iface, false) if err != nil { t.Fatalf("Failed to create Manager: %s", err) } @@ -363,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) { func TestProcessOutgoingHooks(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) require.NoError(t, err) manager.wgNetwork = &net.IPNet{ @@ -371,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) { Mask: net.CIDRMask(16, 32), } manager.udpTracker.Close() - manager.udpTracker = conntrack.NewUDPTracker(100 * time.Millisecond) + manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger) defer func() { require.NoError(t, manager.Reset(nil)) }() @@ -449,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { ifaceMock := &IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, } - manager, err := Create(ifaceMock) + manager, err := Create(ifaceMock, false) require.NoError(t, err) time.Sleep(time.Second) @@ -476,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) { func TestStatefulFirewall_UDPTracking(t *testing.T) { manager, err := Create(&IFaceMock{ SetFilterFunc: func(device.PacketFilter) error { return nil }, - }) + }, false) require.NoError(t, err) manager.wgNetwork = &net.IPNet{ @@ -485,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { } manager.udpTracker.Close() // Close the existing tracker - manager.udpTracker = conntrack.NewUDPTracker(200 * time.Millisecond) + manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger) manager.decoders = sync.Pool{ New: func() any { d := &decoder{ @@ -606,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { for _, cp := range checkPoints { time.Sleep(cp.sleep) - drop = manager.dropFilter(inboundBuf.Bytes(), manager.incomingRules) + drop = manager.dropFilter(inboundBuf.Bytes()) require.Equal(t, cp.shouldAllow, !drop, cp.description) // If the connection should still be valid, verify it exists @@ -677,7 +707,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { require.NoError(t, err) // Verify the invalid packet is dropped - drop = manager.dropFilter(testBuf.Bytes(), manager.incomingRules) + drop = manager.dropFilter(testBuf.Bytes()) require.True(t, drop, tc.description) }) } diff --git a/client/iface/device.go b/client/iface/device.go index 0d4e6914554..2a170adfb41 100644 --- a/client/iface/device.go +++ b/client/iface/device.go @@ -3,6 +3,8 @@ package iface import ( + wgdevice "golang.zx2c4.com/wireguard/device" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" ) @@ -15,4 +17,5 @@ type WGTunDevice interface { DeviceName() string Close() error FilteredDevice() *device.FilteredDevice + Device() *wgdevice.Device } diff --git a/client/iface/device/device_darwin.go b/client/iface/device/device_darwin.go index b5a128bc1cc..fe7ed175207 100644 --- a/client/iface/device/device_darwin.go +++ b/client/iface/device/device_darwin.go @@ -117,6 +117,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } +// Device returns the wireguard device +func (t *TunDevice) Device() *device.Device { + return t.device +} + // assignAddr Adds IP address to the tunnel interface and network route based on the range provided func (t *TunDevice) assignAddr() error { cmd := exec.Command("ifconfig", t.name, "inet", t.address.IP.String(), t.address.IP.String()) diff --git a/client/iface/device/device_kernel_unix.go b/client/iface/device/device_kernel_unix.go index 0dfed4d9071..3314b576b25 100644 --- a/client/iface/device/device_kernel_unix.go +++ b/client/iface/device/device_kernel_unix.go @@ -9,6 +9,7 @@ import ( "github.com/pion/transport/v3" log "github.com/sirupsen/logrus" + "golang.zx2c4.com/wireguard/device" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" @@ -151,6 +152,11 @@ func (t *TunKernelDevice) DeviceName() string { return t.name } +// Device returns the wireguard device, not applicable for kernel devices +func (t *TunKernelDevice) Device() *device.Device { + return nil +} + func (t *TunKernelDevice) FilteredDevice() *FilteredDevice { return nil } diff --git a/client/iface/device/device_netstack.go b/client/iface/device/device_netstack.go index f5d39e9e074..c7d297187ed 100644 --- a/client/iface/device/device_netstack.go +++ b/client/iface/device/device_netstack.go @@ -117,3 +117,8 @@ func (t *TunNetstackDevice) DeviceName() string { func (t *TunNetstackDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } + +// Device returns the wireguard device +func (t *TunNetstackDevice) Device() *device.Device { + return t.device +} diff --git a/client/iface/device/device_usp_unix.go b/client/iface/device/device_usp_unix.go index 3562f312ded..4ac87aecbb8 100644 --- a/client/iface/device/device_usp_unix.go +++ b/client/iface/device/device_usp_unix.go @@ -124,6 +124,11 @@ func (t *USPDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } +// Device returns the wireguard device +func (t *USPDevice) Device() *device.Device { + return t.device +} + // assignAddr Adds IP address to the tunnel interface func (t *USPDevice) assignAddr() error { link := newWGLink(t.name) diff --git a/client/iface/device/device_windows.go b/client/iface/device/device_windows.go index 86968d06d7e..e603d7696f9 100644 --- a/client/iface/device/device_windows.go +++ b/client/iface/device/device_windows.go @@ -150,6 +150,11 @@ func (t *TunDevice) FilteredDevice() *FilteredDevice { return t.filteredDevice } +// Device returns the wireguard device +func (t *TunDevice) Device() *device.Device { + return t.device +} + func (t *TunDevice) GetInterfaceGUIDString() (string, error) { if t.nativeTunDevice == nil { return "", fmt.Errorf("interface has not been initialized yet") diff --git a/client/iface/device_android.go b/client/iface/device_android.go index 3d15080fff4..028f6fa7d78 100644 --- a/client/iface/device_android.go +++ b/client/iface/device_android.go @@ -1,6 +1,8 @@ package iface import ( + wgdevice "golang.zx2c4.com/wireguard/device" + "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/device" ) @@ -13,4 +15,5 @@ type WGTunDevice interface { DeviceName() string Close() error FilteredDevice() *device.FilteredDevice + Device() *wgdevice.Device } diff --git a/client/iface/iface.go b/client/iface/iface.go index 1fb9c269179..64219975f5d 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -11,6 +11,8 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + wgdevice "golang.zx2c4.com/wireguard/device" + "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface/bind" "github.com/netbirdio/netbird/client/iface/configurer" @@ -203,6 +205,11 @@ func (w *WGIface) GetDevice() *device.FilteredDevice { return w.tun.FilteredDevice() } +// GetWGDevice returns the WireGuard device +func (w *WGIface) GetWGDevice() *wgdevice.Device { + return w.tun.Device() +} + // GetStats returns the last handshake time, rx and tx bytes for the given peer func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) { return w.configurer.GetStats(peerKey) diff --git a/client/iface/iface_moc.go b/client/iface/iface_moc.go index d91a7224ff2..5f57bc82159 100644 --- a/client/iface/iface_moc.go +++ b/client/iface/iface_moc.go @@ -4,6 +4,7 @@ import ( "net" "time" + wgdevice "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" @@ -29,6 +30,7 @@ type MockWGIface struct { SetFilterFunc func(filter device.PacketFilter) error GetFilterFunc func() device.PacketFilter GetDeviceFunc func() *device.FilteredDevice + GetWGDeviceFunc func() *wgdevice.Device GetStatsFunc func(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDStringFunc func() (string, error) GetProxyFunc func() wgproxy.Proxy @@ -102,11 +104,14 @@ func (m *MockWGIface) GetDevice() *device.FilteredDevice { return m.GetDeviceFunc() } +func (m *MockWGIface) GetWGDevice() *wgdevice.Device { + return m.GetWGDeviceFunc() +} + func (m *MockWGIface) GetStats(peerKey string) (configurer.WGStats, error) { return m.GetStatsFunc(peerKey) } func (m *MockWGIface) GetProxy() wgproxy.Proxy { - //TODO implement me - panic("implement me") + return m.GetProxyFunc() } diff --git a/client/iface/iwginterface.go b/client/iface/iwginterface.go index f5ab2953905..472ab45f9d8 100644 --- a/client/iface/iwginterface.go +++ b/client/iface/iwginterface.go @@ -6,6 +6,7 @@ import ( "net" "time" + wgdevice "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" @@ -32,5 +33,6 @@ type IWGIface interface { SetFilter(filter device.PacketFilter) error GetFilter() device.PacketFilter GetDevice() *device.FilteredDevice + GetWGDevice() *wgdevice.Device GetStats(peerKey string) (configurer.WGStats, error) } diff --git a/client/iface/iwginterface_windows.go b/client/iface/iwginterface_windows.go index 96eec52a502..c9183cafdce 100644 --- a/client/iface/iwginterface_windows.go +++ b/client/iface/iwginterface_windows.go @@ -4,6 +4,7 @@ import ( "net" "time" + wgdevice "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/netbirdio/netbird/client/iface/bind" @@ -30,6 +31,7 @@ type IWGIface interface { SetFilter(filter device.PacketFilter) error GetFilter() device.PacketFilter GetDevice() *device.FilteredDevice + GetWGDevice() *wgdevice.Device GetStats(peerKey string) (configurer.WGStats, error) GetInterfaceGUIDString() (string, error) } diff --git a/client/internal/acl/manager_test.go b/client/internal/acl/manager_test.go index 6049b4f48e2..217dbce9f45 100644 --- a/client/internal/acl/manager_test.go +++ b/client/internal/acl/manager_test.go @@ -49,9 +49,10 @@ func TestDefaultManager(t *testing.T) { IP: ip, Network: network, }).AnyTimes() + ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(ifaceMock, nil) + fw, err := firewall.NewFirewall(ifaceMock, nil, false) if err != nil { t.Errorf("create firewall: %v", err) return @@ -342,9 +343,10 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) { IP: ip, Network: network, }).AnyTimes() + ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes() // we receive one rule from the management so for testing purposes ignore it - fw, err := firewall.NewFirewall(ifaceMock, nil) + fw, err := firewall.NewFirewall(ifaceMock, nil, false) if err != nil { t.Errorf("create firewall: %v", err) return diff --git a/client/internal/acl/mocks/iface_mapper.go b/client/internal/acl/mocks/iface_mapper.go index 3ed12b6dd76..08aa4fd5a01 100644 --- a/client/internal/acl/mocks/iface_mapper.go +++ b/client/internal/acl/mocks/iface_mapper.go @@ -8,6 +8,8 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" + wgdevice "golang.zx2c4.com/wireguard/device" + iface "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" ) @@ -90,3 +92,31 @@ func (mr *MockIFaceMapperMockRecorder) SetFilter(arg0 interface{}) *gomock.Call mr.mock.ctrl.T.Helper() return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetFilter", reflect.TypeOf((*MockIFaceMapper)(nil).SetFilter), arg0) } + +// GetDevice mocks base method. +func (m *MockIFaceMapper) GetDevice() *device.FilteredDevice { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDevice") + ret0, _ := ret[0].(*device.FilteredDevice) + return ret0 +} + +// GetDevice indicates an expected call of GetDevice. +func (mr *MockIFaceMapperMockRecorder) GetDevice() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetDevice)) +} + +// GetWGDevice mocks base method. +func (m *MockIFaceMapper) GetWGDevice() *wgdevice.Device { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetWGDevice") + ret0, _ := ret[0].(*wgdevice.Device) + return ret0 +} + +// GetWGDevice indicates an expected call of GetWGDevice. +func (mr *MockIFaceMapperMockRecorder) GetWGDevice() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetWGDevice", reflect.TypeOf((*MockIFaceMapper)(nil).GetWGDevice)) +} diff --git a/client/internal/dns/server_test.go b/client/internal/dns/server_test.go index c166820c457..14ff1bb713e 100644 --- a/client/internal/dns/server_test.go +++ b/client/internal/dns/server_test.go @@ -849,7 +849,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) { return nil, err } - pf, err := uspfilter.Create(wgIface) + pf, err := uspfilter.Create(wgIface, false) if err != nil { t.Fatalf("failed to create uspfilter: %v", err) return nil, err diff --git a/client/internal/engine.go b/client/internal/engine.go index 4f69adfa6d7..5170bcf2eb9 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -43,13 +43,13 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" + "github.com/netbirdio/netbird/management/domain" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" nbssh "github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/system" nbdns "github.com/netbirdio/netbird/dns" mgm "github.com/netbirdio/netbird/management/client" - "github.com/netbirdio/netbird/management/domain" mgmProto "github.com/netbirdio/netbird/management/proto" auth "github.com/netbirdio/netbird/relay/auth/hmac" relayClient "github.com/netbirdio/netbird/relay/client" @@ -194,6 +194,10 @@ type Peer struct { WgAllowedIps string } +type localIpUpdater interface { + UpdateLocalIPs() error +} + // NewEngine creates a new Connection Engine with probes attached func NewEngine( clientCtx context.Context, @@ -434,7 +438,7 @@ func (e *Engine) createFirewall() error { } var err error - e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager) + e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes) if err != nil || e.firewall == nil { log.Errorf("failed creating firewall manager: %s", err) return nil @@ -884,6 +888,14 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error { e.acl.ApplyFiltering(networkMap) } + if e.firewall != nil { + if localipfw, ok := e.firewall.(localIpUpdater); ok { + if err := localipfw.UpdateLocalIPs(); err != nil { + log.Errorf("failed to update local IPs: %v", err) + } + } + } + // DNS forwarder dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap) dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes()) @@ -1447,6 +1459,11 @@ func (e *Engine) GetRouteManager() routemanager.Manager { return e.routeManager } +// GetFirewallManager returns the firewall manager +func (e *Engine) GetFirewallManager() manager.Manager { + return e.firewall +} + func findIPFromInterfaceName(ifaceName string) (net.IP, error) { iface, err := net.InterfaceByName(ifaceName) if err != nil { @@ -1658,6 +1675,14 @@ func (e *Engine) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) { return nm, nil } +// GetWgAddr returns the wireguard address +func (e *Engine) GetWgAddr() net.IP { + if e.wgInterface == nil { + return nil + } + return e.wgInterface.Address().IP +} + // updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag func (e *Engine) updateDNSForwarder(enabled bool, domains []string) { if !enabled { diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 6f73fb166c6..34bd67893d3 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -422,11 +422,6 @@ func (m *DefaultManager) classifyRoutes(newRoutes []*route.Route) (map[route.ID] haID := newRoute.GetHAUniqueID() if newRoute.Peer == m.pubKey { ownNetworkIDs[haID] = true - // only linux is supported for now - if runtime.GOOS != "linux" { - log.Warnf("received a route to manage, but agent doesn't support router mode on %s OS", runtime.GOOS) - continue - } newServerRoutesMap[newRoute.ID] = newRoute } } diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 413f94a54a5..c1c5ad78def 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -2571,6 +2571,330 @@ func (*SetNetworkMapPersistenceResponse) Descriptor() ([]byte, []int) { return file_daemon_proto_rawDescGZIP(), []int{39} } +type TCPFlags struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Syn bool `protobuf:"varint,1,opt,name=syn,proto3" json:"syn,omitempty"` + Ack bool `protobuf:"varint,2,opt,name=ack,proto3" json:"ack,omitempty"` + Fin bool `protobuf:"varint,3,opt,name=fin,proto3" json:"fin,omitempty"` + Rst bool `protobuf:"varint,4,opt,name=rst,proto3" json:"rst,omitempty"` + Psh bool `protobuf:"varint,5,opt,name=psh,proto3" json:"psh,omitempty"` + Urg bool `protobuf:"varint,6,opt,name=urg,proto3" json:"urg,omitempty"` +} + +func (x *TCPFlags) Reset() { + *x = TCPFlags{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[40] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TCPFlags) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TCPFlags) ProtoMessage() {} + +func (x *TCPFlags) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[40] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TCPFlags.ProtoReflect.Descriptor instead. +func (*TCPFlags) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{40} +} + +func (x *TCPFlags) GetSyn() bool { + if x != nil { + return x.Syn + } + return false +} + +func (x *TCPFlags) GetAck() bool { + if x != nil { + return x.Ack + } + return false +} + +func (x *TCPFlags) GetFin() bool { + if x != nil { + return x.Fin + } + return false +} + +func (x *TCPFlags) GetRst() bool { + if x != nil { + return x.Rst + } + return false +} + +func (x *TCPFlags) GetPsh() bool { + if x != nil { + return x.Psh + } + return false +} + +func (x *TCPFlags) GetUrg() bool { + if x != nil { + return x.Urg + } + return false +} + +type TracePacketRequest struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + SourceIp string `protobuf:"bytes,1,opt,name=source_ip,json=sourceIp,proto3" json:"source_ip,omitempty"` + DestinationIp string `protobuf:"bytes,2,opt,name=destination_ip,json=destinationIp,proto3" json:"destination_ip,omitempty"` + Protocol string `protobuf:"bytes,3,opt,name=protocol,proto3" json:"protocol,omitempty"` + SourcePort uint32 `protobuf:"varint,4,opt,name=source_port,json=sourcePort,proto3" json:"source_port,omitempty"` + DestinationPort uint32 `protobuf:"varint,5,opt,name=destination_port,json=destinationPort,proto3" json:"destination_port,omitempty"` + Direction string `protobuf:"bytes,6,opt,name=direction,proto3" json:"direction,omitempty"` + TcpFlags *TCPFlags `protobuf:"bytes,7,opt,name=tcp_flags,json=tcpFlags,proto3,oneof" json:"tcp_flags,omitempty"` + IcmpType *uint32 `protobuf:"varint,8,opt,name=icmp_type,json=icmpType,proto3,oneof" json:"icmp_type,omitempty"` + IcmpCode *uint32 `protobuf:"varint,9,opt,name=icmp_code,json=icmpCode,proto3,oneof" json:"icmp_code,omitempty"` +} + +func (x *TracePacketRequest) Reset() { + *x = TracePacketRequest{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[41] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TracePacketRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TracePacketRequest) ProtoMessage() {} + +func (x *TracePacketRequest) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[41] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TracePacketRequest.ProtoReflect.Descriptor instead. +func (*TracePacketRequest) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{41} +} + +func (x *TracePacketRequest) GetSourceIp() string { + if x != nil { + return x.SourceIp + } + return "" +} + +func (x *TracePacketRequest) GetDestinationIp() string { + if x != nil { + return x.DestinationIp + } + return "" +} + +func (x *TracePacketRequest) GetProtocol() string { + if x != nil { + return x.Protocol + } + return "" +} + +func (x *TracePacketRequest) GetSourcePort() uint32 { + if x != nil { + return x.SourcePort + } + return 0 +} + +func (x *TracePacketRequest) GetDestinationPort() uint32 { + if x != nil { + return x.DestinationPort + } + return 0 +} + +func (x *TracePacketRequest) GetDirection() string { + if x != nil { + return x.Direction + } + return "" +} + +func (x *TracePacketRequest) GetTcpFlags() *TCPFlags { + if x != nil { + return x.TcpFlags + } + return nil +} + +func (x *TracePacketRequest) GetIcmpType() uint32 { + if x != nil && x.IcmpType != nil { + return *x.IcmpType + } + return 0 +} + +func (x *TracePacketRequest) GetIcmpCode() uint32 { + if x != nil && x.IcmpCode != nil { + return *x.IcmpCode + } + return 0 +} + +type TraceStage struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` + Allowed bool `protobuf:"varint,3,opt,name=allowed,proto3" json:"allowed,omitempty"` + ForwardingDetails *string `protobuf:"bytes,4,opt,name=forwarding_details,json=forwardingDetails,proto3,oneof" json:"forwarding_details,omitempty"` +} + +func (x *TraceStage) Reset() { + *x = TraceStage{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[42] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TraceStage) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TraceStage) ProtoMessage() {} + +func (x *TraceStage) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[42] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TraceStage.ProtoReflect.Descriptor instead. +func (*TraceStage) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{42} +} + +func (x *TraceStage) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *TraceStage) GetMessage() string { + if x != nil { + return x.Message + } + return "" +} + +func (x *TraceStage) GetAllowed() bool { + if x != nil { + return x.Allowed + } + return false +} + +func (x *TraceStage) GetForwardingDetails() string { + if x != nil && x.ForwardingDetails != nil { + return *x.ForwardingDetails + } + return "" +} + +type TracePacketResponse struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Stages []*TraceStage `protobuf:"bytes,1,rep,name=stages,proto3" json:"stages,omitempty"` + FinalDisposition bool `protobuf:"varint,2,opt,name=final_disposition,json=finalDisposition,proto3" json:"final_disposition,omitempty"` +} + +func (x *TracePacketResponse) Reset() { + *x = TracePacketResponse{} + if protoimpl.UnsafeEnabled { + mi := &file_daemon_proto_msgTypes[43] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *TracePacketResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TracePacketResponse) ProtoMessage() {} + +func (x *TracePacketResponse) ProtoReflect() protoreflect.Message { + mi := &file_daemon_proto_msgTypes[43] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TracePacketResponse.ProtoReflect.Descriptor instead. +func (*TracePacketResponse) Descriptor() ([]byte, []int) { + return file_daemon_proto_rawDescGZIP(), []int{43} +} + +func (x *TracePacketResponse) GetStages() []*TraceStage { + if x != nil { + return x.Stages + } + return nil +} + +func (x *TracePacketResponse) GetFinalDisposition() bool { + if x != nil { + return x.FinalDisposition + } + return false +} + var File_daemon_proto protoreflect.FileDescriptor var file_daemon_proto_rawDesc = []byte{ @@ -2920,87 +3244,141 @@ var file_daemon_proto_rawDesc = []byte{ 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x22, 0x22, 0x0a, 0x20, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, - 0x0b, 0x0a, 0x07, 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, - 0x50, 0x41, 0x4e, 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, - 0x10, 0x02, 0x12, 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, - 0x04, 0x57, 0x41, 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, - 0x05, 0x12, 0x09, 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, - 0x54, 0x52, 0x41, 0x43, 0x45, 0x10, 0x07, 0x32, 0x93, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, - 0x69, 0x6e, 0x12, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, - 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, - 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, + 0x6e, 0x73, 0x65, 0x22, 0x76, 0x0a, 0x08, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x12, + 0x10, 0x0a, 0x03, 0x73, 0x79, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x73, 0x79, + 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x61, 0x63, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, + 0x61, 0x63, 0x6b, 0x12, 0x10, 0x0a, 0x03, 0x66, 0x69, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, + 0x52, 0x03, 0x66, 0x69, 0x6e, 0x12, 0x10, 0x0a, 0x03, 0x72, 0x73, 0x74, 0x18, 0x04, 0x20, 0x01, + 0x28, 0x08, 0x52, 0x03, 0x72, 0x73, 0x74, 0x12, 0x10, 0x0a, 0x03, 0x70, 0x73, 0x68, 0x18, 0x05, + 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x70, 0x73, 0x68, 0x12, 0x10, 0x0a, 0x03, 0x75, 0x72, 0x67, + 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x75, 0x72, 0x67, 0x22, 0x80, 0x03, 0x0a, 0x12, + 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x12, 0x1b, 0x0a, 0x09, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x69, 0x70, 0x18, + 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x49, 0x70, 0x12, + 0x25, 0x0a, 0x0e, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x5f, 0x69, + 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0d, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x49, 0x70, 0x12, 0x1a, 0x0a, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x12, 0x1f, 0x0a, 0x0b, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x5f, 0x70, 0x6f, 0x72, + 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0a, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x50, + 0x6f, 0x72, 0x74, 0x12, 0x29, 0x0a, 0x10, 0x64, 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, + 0x6f, 0x6e, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x0f, 0x64, + 0x65, 0x73, 0x74, 0x69, 0x6e, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x1c, + 0x0a, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, + 0x09, 0x52, 0x09, 0x64, 0x69, 0x72, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x32, 0x0a, 0x09, + 0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61, 0x67, 0x73, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x10, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, 0x43, 0x50, 0x46, 0x6c, 0x61, 0x67, + 0x73, 0x48, 0x00, 0x52, 0x08, 0x74, 0x63, 0x70, 0x46, 0x6c, 0x61, 0x67, 0x73, 0x88, 0x01, 0x01, + 0x12, 0x20, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x08, 0x20, + 0x01, 0x28, 0x0d, 0x48, 0x01, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x54, 0x79, 0x70, 0x65, 0x88, + 0x01, 0x01, 0x12, 0x20, 0x0a, 0x09, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x18, + 0x09, 0x20, 0x01, 0x28, 0x0d, 0x48, 0x02, 0x52, 0x08, 0x69, 0x63, 0x6d, 0x70, 0x43, 0x6f, 0x64, + 0x65, 0x88, 0x01, 0x01, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x74, 0x63, 0x70, 0x5f, 0x66, 0x6c, 0x61, + 0x67, 0x73, 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x74, 0x79, 0x70, 0x65, + 0x42, 0x0c, 0x0a, 0x0a, 0x5f, 0x69, 0x63, 0x6d, 0x70, 0x5f, 0x63, 0x6f, 0x64, 0x65, 0x22, 0x9f, + 0x01, 0x0a, 0x0a, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x67, 0x65, 0x12, 0x12, 0x0a, + 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, + 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x61, + 0x6c, 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, 0x61, 0x6c, + 0x6c, 0x6f, 0x77, 0x65, 0x64, 0x12, 0x32, 0x0a, 0x12, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, + 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x04, 0x20, 0x01, 0x28, + 0x09, 0x48, 0x00, 0x52, 0x11, 0x66, 0x6f, 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x44, + 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x88, 0x01, 0x01, 0x42, 0x15, 0x0a, 0x13, 0x5f, 0x66, 0x6f, + 0x72, 0x77, 0x61, 0x72, 0x64, 0x69, 0x6e, 0x67, 0x5f, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, + 0x22, 0x6e, 0x0a, 0x13, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2a, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x67, 0x65, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x53, 0x74, 0x61, 0x67, 0x65, 0x52, 0x06, 0x73, 0x74, 0x61, + 0x67, 0x65, 0x73, 0x12, 0x2b, 0x0a, 0x11, 0x66, 0x69, 0x6e, 0x61, 0x6c, 0x5f, 0x64, 0x69, 0x73, + 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x10, + 0x66, 0x69, 0x6e, 0x61, 0x6c, 0x44, 0x69, 0x73, 0x70, 0x6f, 0x73, 0x69, 0x74, 0x69, 0x6f, 0x6e, + 0x2a, 0x62, 0x0a, 0x08, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x0b, 0x0a, 0x07, + 0x55, 0x4e, 0x4b, 0x4e, 0x4f, 0x57, 0x4e, 0x10, 0x00, 0x12, 0x09, 0x0a, 0x05, 0x50, 0x41, 0x4e, + 0x49, 0x43, 0x10, 0x01, 0x12, 0x09, 0x0a, 0x05, 0x46, 0x41, 0x54, 0x41, 0x4c, 0x10, 0x02, 0x12, + 0x09, 0x0a, 0x05, 0x45, 0x52, 0x52, 0x4f, 0x52, 0x10, 0x03, 0x12, 0x08, 0x0a, 0x04, 0x57, 0x41, + 0x52, 0x4e, 0x10, 0x04, 0x12, 0x08, 0x0a, 0x04, 0x49, 0x4e, 0x46, 0x4f, 0x10, 0x05, 0x12, 0x09, + 0x0a, 0x05, 0x44, 0x45, 0x42, 0x55, 0x47, 0x10, 0x06, 0x12, 0x09, 0x0a, 0x05, 0x54, 0x52, 0x41, + 0x43, 0x45, 0x10, 0x07, 0x32, 0xdd, 0x09, 0x0a, 0x0d, 0x44, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x53, + 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x36, 0x0a, 0x05, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, + 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, + 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, + 0x0a, 0x0c, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, 0x6e, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, - 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, - 0x0a, 0x02, 0x55, 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x55, 0x70, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, - 0x06, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, - 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, - 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, - 0x12, 0x13, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, - 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, - 0x09, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, - 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, - 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x4b, 0x0a, 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, - 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, + 0x6f, 0x67, 0x69, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x57, 0x61, 0x69, 0x74, 0x53, 0x53, 0x4f, 0x4c, 0x6f, 0x67, 0x69, + 0x6e, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x2d, 0x0a, 0x02, 0x55, + 0x70, 0x12, 0x11, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x12, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x55, 0x70, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x06, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x12, 0x15, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, + 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, + 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x33, 0x0a, 0x04, 0x44, 0x6f, 0x77, 0x6e, 0x12, 0x13, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x14, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x6f, 0x77, 0x6e, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x42, 0x0a, 0x09, 0x47, 0x65, + 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x18, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, + 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x43, 0x6f, + 0x6e, 0x66, 0x69, 0x67, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x4b, + 0x0a, 0x0c, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, - 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, - 0x0a, 0x0e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, - 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, - 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, - 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x12, 0x53, 0x0a, 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, - 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, - 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, - 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, - 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, - 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, - 0x12, 0x48, 0x0a, 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, - 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, - 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, - 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, - 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, - 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x73, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, - 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, - 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43, - 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, - 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, - 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, - 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, - 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, - 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, - 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18, - 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, - 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, - 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, - 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x28, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, - 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, - 0x6e, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, - 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x64, 0x61, + 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, + 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x51, 0x0a, 0x0e, 0x53, + 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x12, 0x1d, 0x2e, + 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, + 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, + 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x53, + 0x0a, 0x10, 0x44, 0x65, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, + 0x6b, 0x73, 0x12, 0x1d, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, + 0x63, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1e, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x6c, 0x65, 0x63, + 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, 0x64, + 0x6c, 0x65, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, + 0x67, 0x42, 0x75, 0x6e, 0x64, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x42, 0x75, 0x6e, + 0x64, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, + 0x0b, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, + 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, + 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, + 0x6e, 0x2e, 0x47, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, + 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x53, 0x65, 0x74, 0x4c, 0x6f, + 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x53, 0x65, 0x74, 0x4c, 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, + 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4c, + 0x6f, 0x67, 0x4c, 0x65, 0x76, 0x65, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, + 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x12, + 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, + 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x53, 0x74, 0x61, 0x74, 0x65, 0x73, 0x52, 0x65, + 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x45, 0x0a, 0x0a, 0x43, 0x6c, 0x65, 0x61, + 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x19, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, + 0x43, 0x6c, 0x65, 0x61, 0x6e, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, + 0x74, 0x1a, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x43, 0x6c, 0x65, 0x61, 0x6e, + 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, + 0x48, 0x0a, 0x0b, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x12, 0x1a, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, + 0x61, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, + 0x6d, 0x6f, 0x6e, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x53, 0x74, 0x61, 0x74, 0x65, 0x52, + 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x6f, 0x0a, 0x18, 0x53, 0x65, 0x74, + 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, + 0x74, 0x65, 0x6e, 0x63, 0x65, 0x12, 0x27, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, + 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, + 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x28, + 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x53, 0x65, 0x74, 0x4e, 0x65, 0x74, 0x77, 0x6f, + 0x72, 0x6b, 0x4d, 0x61, 0x70, 0x50, 0x65, 0x72, 0x73, 0x69, 0x73, 0x74, 0x65, 0x6e, 0x63, 0x65, + 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x48, 0x0a, 0x0b, 0x54, 0x72, + 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x12, 0x1a, 0x2e, 0x64, 0x61, 0x65, 0x6d, + 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, + 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x64, 0x61, 0x65, 0x6d, 0x6f, 0x6e, 0x2e, 0x54, + 0x72, 0x61, 0x63, 0x65, 0x50, 0x61, 0x63, 0x6b, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x22, 0x00, 0x42, 0x08, 0x5a, 0x06, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -3016,7 +3394,7 @@ func file_daemon_proto_rawDescGZIP() []byte { } var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 41) +var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 45) var file_daemon_proto_goTypes = []interface{}{ (LogLevel)(0), // 0: daemon.LogLevel (*LoginRequest)(nil), // 1: daemon.LoginRequest @@ -3059,16 +3437,20 @@ var file_daemon_proto_goTypes = []interface{}{ (*DeleteStateResponse)(nil), // 38: daemon.DeleteStateResponse (*SetNetworkMapPersistenceRequest)(nil), // 39: daemon.SetNetworkMapPersistenceRequest (*SetNetworkMapPersistenceResponse)(nil), // 40: daemon.SetNetworkMapPersistenceResponse - nil, // 41: daemon.Network.ResolvedIPsEntry - (*durationpb.Duration)(nil), // 42: google.protobuf.Duration - (*timestamppb.Timestamp)(nil), // 43: google.protobuf.Timestamp + (*TCPFlags)(nil), // 41: daemon.TCPFlags + (*TracePacketRequest)(nil), // 42: daemon.TracePacketRequest + (*TraceStage)(nil), // 43: daemon.TraceStage + (*TracePacketResponse)(nil), // 44: daemon.TracePacketResponse + nil, // 45: daemon.Network.ResolvedIPsEntry + (*durationpb.Duration)(nil), // 46: google.protobuf.Duration + (*timestamppb.Timestamp)(nil), // 47: google.protobuf.Timestamp } var file_daemon_proto_depIdxs = []int32{ - 42, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration + 46, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 19, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus - 43, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp - 43, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp - 42, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration + 47, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp + 47, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp + 46, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration 16, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 15, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState 14, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState @@ -3076,48 +3458,52 @@ var file_daemon_proto_depIdxs = []int32{ 17, // 9: daemon.FullStatus.relays:type_name -> daemon.RelayState 18, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 25, // 11: daemon.ListNetworksResponse.routes:type_name -> daemon.Network - 41, // 12: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry + 45, // 12: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry 0, // 13: daemon.GetLogLevelResponse.level:type_name -> daemon.LogLevel 0, // 14: daemon.SetLogLevelRequest.level:type_name -> daemon.LogLevel 32, // 15: daemon.ListStatesResponse.states:type_name -> daemon.State - 24, // 16: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList - 1, // 17: daemon.DaemonService.Login:input_type -> daemon.LoginRequest - 3, // 18: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest - 5, // 19: daemon.DaemonService.Up:input_type -> daemon.UpRequest - 7, // 20: daemon.DaemonService.Status:input_type -> daemon.StatusRequest - 9, // 21: daemon.DaemonService.Down:input_type -> daemon.DownRequest - 11, // 22: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest - 20, // 23: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest - 22, // 24: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest - 22, // 25: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest - 26, // 26: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest - 28, // 27: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest - 30, // 28: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest - 33, // 29: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest - 35, // 30: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest - 37, // 31: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest - 39, // 32: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest - 2, // 33: daemon.DaemonService.Login:output_type -> daemon.LoginResponse - 4, // 34: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse - 6, // 35: daemon.DaemonService.Up:output_type -> daemon.UpResponse - 8, // 36: daemon.DaemonService.Status:output_type -> daemon.StatusResponse - 10, // 37: daemon.DaemonService.Down:output_type -> daemon.DownResponse - 12, // 38: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse - 21, // 39: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse - 23, // 40: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse - 23, // 41: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse - 27, // 42: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse - 29, // 43: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse - 31, // 44: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse - 34, // 45: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse - 36, // 46: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse - 38, // 47: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse - 40, // 48: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse - 33, // [33:49] is the sub-list for method output_type - 17, // [17:33] is the sub-list for method input_type - 17, // [17:17] is the sub-list for extension type_name - 17, // [17:17] is the sub-list for extension extendee - 0, // [0:17] is the sub-list for field type_name + 41, // 16: daemon.TracePacketRequest.tcp_flags:type_name -> daemon.TCPFlags + 43, // 17: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage + 24, // 18: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList + 1, // 19: daemon.DaemonService.Login:input_type -> daemon.LoginRequest + 3, // 20: daemon.DaemonService.WaitSSOLogin:input_type -> daemon.WaitSSOLoginRequest + 5, // 21: daemon.DaemonService.Up:input_type -> daemon.UpRequest + 7, // 22: daemon.DaemonService.Status:input_type -> daemon.StatusRequest + 9, // 23: daemon.DaemonService.Down:input_type -> daemon.DownRequest + 11, // 24: daemon.DaemonService.GetConfig:input_type -> daemon.GetConfigRequest + 20, // 25: daemon.DaemonService.ListNetworks:input_type -> daemon.ListNetworksRequest + 22, // 26: daemon.DaemonService.SelectNetworks:input_type -> daemon.SelectNetworksRequest + 22, // 27: daemon.DaemonService.DeselectNetworks:input_type -> daemon.SelectNetworksRequest + 26, // 28: daemon.DaemonService.DebugBundle:input_type -> daemon.DebugBundleRequest + 28, // 29: daemon.DaemonService.GetLogLevel:input_type -> daemon.GetLogLevelRequest + 30, // 30: daemon.DaemonService.SetLogLevel:input_type -> daemon.SetLogLevelRequest + 33, // 31: daemon.DaemonService.ListStates:input_type -> daemon.ListStatesRequest + 35, // 32: daemon.DaemonService.CleanState:input_type -> daemon.CleanStateRequest + 37, // 33: daemon.DaemonService.DeleteState:input_type -> daemon.DeleteStateRequest + 39, // 34: daemon.DaemonService.SetNetworkMapPersistence:input_type -> daemon.SetNetworkMapPersistenceRequest + 42, // 35: daemon.DaemonService.TracePacket:input_type -> daemon.TracePacketRequest + 2, // 36: daemon.DaemonService.Login:output_type -> daemon.LoginResponse + 4, // 37: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse + 6, // 38: daemon.DaemonService.Up:output_type -> daemon.UpResponse + 8, // 39: daemon.DaemonService.Status:output_type -> daemon.StatusResponse + 10, // 40: daemon.DaemonService.Down:output_type -> daemon.DownResponse + 12, // 41: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse + 21, // 42: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse + 23, // 43: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse + 23, // 44: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse + 27, // 45: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse + 29, // 46: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse + 31, // 47: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse + 34, // 48: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse + 36, // 49: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse + 38, // 50: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse + 40, // 51: daemon.DaemonService.SetNetworkMapPersistence:output_type -> daemon.SetNetworkMapPersistenceResponse + 44, // 52: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse + 36, // [36:53] is the sub-list for method output_type + 19, // [19:36] is the sub-list for method input_type + 19, // [19:19] is the sub-list for extension type_name + 19, // [19:19] is the sub-list for extension extendee + 0, // [0:19] is the sub-list for field type_name } func init() { file_daemon_proto_init() } @@ -3606,15 +3992,65 @@ func file_daemon_proto_init() { return nil } } + file_daemon_proto_msgTypes[40].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TCPFlags); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[41].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TracePacketRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[42].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TraceStage); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_daemon_proto_msgTypes[43].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*TracePacketResponse); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } } file_daemon_proto_msgTypes[0].OneofWrappers = []interface{}{} + file_daemon_proto_msgTypes[41].OneofWrappers = []interface{}{} + file_daemon_proto_msgTypes[42].OneofWrappers = []interface{}{} type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: file_daemon_proto_rawDesc, NumEnums: 1, - NumMessages: 41, + NumMessages: 45, NumExtensions: 0, NumServices: 1, }, diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index b626276de60..bab0aa9e975 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -57,6 +57,8 @@ service DaemonService { // SetNetworkMapPersistence enables or disables network map persistence rpc SetNetworkMapPersistence(SetNetworkMapPersistenceRequest) returns (SetNetworkMapPersistenceResponse) {} + + rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {} } @@ -356,3 +358,36 @@ message SetNetworkMapPersistenceRequest { } message SetNetworkMapPersistenceResponse {} + +message TCPFlags { + bool syn = 1; + bool ack = 2; + bool fin = 3; + bool rst = 4; + bool psh = 5; + bool urg = 6; +} + +message TracePacketRequest { + string source_ip = 1; + string destination_ip = 2; + string protocol = 3; + uint32 source_port = 4; + uint32 destination_port = 5; + string direction = 6; + optional TCPFlags tcp_flags = 7; + optional uint32 icmp_type = 8; + optional uint32 icmp_code = 9; +} + +message TraceStage { + string name = 1; + string message = 2; + bool allowed = 3; + optional string forwarding_details = 4; +} + +message TracePacketResponse { + repeated TraceStage stages = 1; + bool final_disposition = 2; +} diff --git a/client/proto/daemon_grpc.pb.go b/client/proto/daemon_grpc.pb.go index 39424aee938..9dcb543a80c 100644 --- a/client/proto/daemon_grpc.pb.go +++ b/client/proto/daemon_grpc.pb.go @@ -51,6 +51,7 @@ type DaemonServiceClient interface { DeleteState(ctx context.Context, in *DeleteStateRequest, opts ...grpc.CallOption) (*DeleteStateResponse, error) // SetNetworkMapPersistence enables or disables network map persistence SetNetworkMapPersistence(ctx context.Context, in *SetNetworkMapPersistenceRequest, opts ...grpc.CallOption) (*SetNetworkMapPersistenceResponse, error) + TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) } type daemonServiceClient struct { @@ -205,6 +206,15 @@ func (c *daemonServiceClient) SetNetworkMapPersistence(ctx context.Context, in * return out, nil } +func (c *daemonServiceClient) TracePacket(ctx context.Context, in *TracePacketRequest, opts ...grpc.CallOption) (*TracePacketResponse, error) { + out := new(TracePacketResponse) + err := c.cc.Invoke(ctx, "/daemon.DaemonService/TracePacket", in, out, opts...) + if err != nil { + return nil, err + } + return out, nil +} + // DaemonServiceServer is the server API for DaemonService service. // All implementations must embed UnimplementedDaemonServiceServer // for forward compatibility @@ -242,6 +252,7 @@ type DaemonServiceServer interface { DeleteState(context.Context, *DeleteStateRequest) (*DeleteStateResponse, error) // SetNetworkMapPersistence enables or disables network map persistence SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) + TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) mustEmbedUnimplementedDaemonServiceServer() } @@ -297,6 +308,9 @@ func (UnimplementedDaemonServiceServer) DeleteState(context.Context, *DeleteStat func (UnimplementedDaemonServiceServer) SetNetworkMapPersistence(context.Context, *SetNetworkMapPersistenceRequest) (*SetNetworkMapPersistenceResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method SetNetworkMapPersistence not implemented") } +func (UnimplementedDaemonServiceServer) TracePacket(context.Context, *TracePacketRequest) (*TracePacketResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method TracePacket not implemented") +} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. @@ -598,6 +612,24 @@ func _DaemonService_SetNetworkMapPersistence_Handler(srv interface{}, ctx contex return interceptor(ctx, in, info, handler) } +func _DaemonService_TracePacket_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TracePacketRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(DaemonServiceServer).TracePacket(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: "/daemon.DaemonService/TracePacket", + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(DaemonServiceServer).TracePacket(ctx, req.(*TracePacketRequest)) + } + return interceptor(ctx, in, info, handler) +} + // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -669,6 +701,10 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{ MethodName: "SetNetworkMapPersistence", Handler: _DaemonService_SetNetworkMapPersistence_Handler, }, + { + MethodName: "TracePacket", + Handler: _DaemonService_TracePacket_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "daemon.proto", diff --git a/client/server/debug.go b/client/server/debug.go index a37195b290c..749220d62c8 100644 --- a/client/server/debug.go +++ b/client/server/debug.go @@ -538,7 +538,24 @@ func (s *Server) SetLogLevel(_ context.Context, req *proto.SetLogLevelRequest) ( } log.SetLevel(level) + + if s.connectClient == nil { + return nil, fmt.Errorf("connect client not initialized") + } + engine := s.connectClient.Engine() + if engine == nil { + return nil, fmt.Errorf("engine not initialized") + } + + fwManager := engine.GetFirewallManager() + if fwManager == nil { + return nil, fmt.Errorf("firewall manager not initialized") + } + + fwManager.SetLogLevel(level) + log.Infof("Log level set to %s", level.String()) + return &proto.SetLogLevelResponse{}, nil } diff --git a/client/server/trace.go b/client/server/trace.go new file mode 100644 index 00000000000..66b83d8cf86 --- /dev/null +++ b/client/server/trace.go @@ -0,0 +1,123 @@ +package server + +import ( + "context" + "fmt" + "net" + + fw "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/firewall/uspfilter" + "github.com/netbirdio/netbird/client/proto" +) + +type packetTracer interface { + TracePacketFromBuilder(builder *uspfilter.PacketBuilder) (*uspfilter.PacketTrace, error) +} + +func (s *Server) TracePacket(_ context.Context, req *proto.TracePacketRequest) (*proto.TracePacketResponse, error) { + s.mutex.Lock() + defer s.mutex.Unlock() + + if s.connectClient == nil { + return nil, fmt.Errorf("connect client not initialized") + } + engine := s.connectClient.Engine() + if engine == nil { + return nil, fmt.Errorf("engine not initialized") + } + + fwManager := engine.GetFirewallManager() + if fwManager == nil { + return nil, fmt.Errorf("firewall manager not initialized") + } + + tracer, ok := fwManager.(packetTracer) + if !ok { + return nil, fmt.Errorf("firewall manager does not support packet tracing") + } + + srcIP := net.ParseIP(req.GetSourceIp()) + if req.GetSourceIp() == "self" { + srcIP = engine.GetWgAddr() + } + + dstIP := net.ParseIP(req.GetDestinationIp()) + if req.GetDestinationIp() == "self" { + dstIP = engine.GetWgAddr() + } + + if srcIP == nil || dstIP == nil { + return nil, fmt.Errorf("invalid IP address") + } + + var tcpState *uspfilter.TCPState + if flags := req.GetTcpFlags(); flags != nil { + tcpState = &uspfilter.TCPState{ + SYN: flags.GetSyn(), + ACK: flags.GetAck(), + FIN: flags.GetFin(), + RST: flags.GetRst(), + PSH: flags.GetPsh(), + URG: flags.GetUrg(), + } + } + + var dir fw.RuleDirection + switch req.GetDirection() { + case "in": + dir = fw.RuleDirectionIN + case "out": + dir = fw.RuleDirectionOUT + default: + return nil, fmt.Errorf("invalid direction") + } + + var protocol fw.Protocol + switch req.GetProtocol() { + case "tcp": + protocol = fw.ProtocolTCP + case "udp": + protocol = fw.ProtocolUDP + case "icmp": + protocol = fw.ProtocolICMP + default: + return nil, fmt.Errorf("invalid protocolcol") + } + + builder := &uspfilter.PacketBuilder{ + SrcIP: srcIP, + DstIP: dstIP, + Protocol: protocol, + SrcPort: uint16(req.GetSourcePort()), + DstPort: uint16(req.GetDestinationPort()), + Direction: dir, + TCPState: tcpState, + ICMPType: uint8(req.GetIcmpType()), + ICMPCode: uint8(req.GetIcmpCode()), + } + trace, err := tracer.TracePacketFromBuilder(builder) + if err != nil { + return nil, fmt.Errorf("trace packet: %w", err) + } + + resp := &proto.TracePacketResponse{} + + for _, result := range trace.Results { + stage := &proto.TraceStage{ + Name: result.Stage.String(), + Message: result.Message, + Allowed: result.Allowed, + } + if result.ForwarderAction != nil { + details := fmt.Sprintf("%s to %s", result.ForwarderAction.Action, result.ForwarderAction.RemoteAddr) + stage.ForwardingDetails = &details + } + resp.Stages = append(resp.Stages, stage) + } + + if len(trace.Results) > 0 { + resp.FinalDisposition = trace.Results[len(trace.Results)-1].Allowed + } + + return resp, nil +} diff --git a/go.mod b/go.mod index e65296a5344..04aa6144c22 100644 --- a/go.mod +++ b/go.mod @@ -102,6 +102,7 @@ require ( gorm.io/driver/postgres v1.5.7 gorm.io/driver/sqlite v1.5.7 gorm.io/gorm v1.25.12 + gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 nhooyr.io/websocket v1.8.11 ) @@ -237,7 +238,6 @@ require ( gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 // indirect - gvisor.dev/gvisor v0.0.0-20231020174304-db3d49b921f9 // indirect k8s.io/apimachinery v0.26.2 // indirect ) @@ -245,7 +245,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 -replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 +replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 diff --git a/go.sum b/go.sum index e3670b99e26..5a4604ca70b 100644 --- a/go.sum +++ b/go.sum @@ -533,8 +533,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d h1:bRq5TKgC7Iq20pDiuC54yXaWnAVeS5PdGpSokFTlR28= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20241010133937-e0df50df217d/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= -github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9 h1:Pu/7EukijT09ynHUOzQYW7cC3M/BKU8O4qyN/TvTGoY= -github.com/netbirdio/wireguard-go v0.0.0-20241125150134-f9cdce5e32e9/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= +github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6 h1:X5h5QgP7uHAv78FWgHV8+WYLjHxK9v3ilkVXT1cpCrQ= +github.com/netbirdio/wireguard-go v0.0.0-20241230120307-6a676aebaaf6/go.mod h1:tkCQ4FQXmpAgYVh++1cq16/dH4QJtmvpRv19DWGAHSA= github.com/nicksnyder/go-i18n/v2 v2.4.0 h1:3IcvPOAvnCKwNm0TB0dLDTuawWEj+ax/RERNC+diLMM= github.com/nicksnyder/go-i18n/v2 v2.4.0/go.mod h1:nxYSZE9M0bf3Y70gPQjN9ha7XNHX7gMc814+6wVyEI4= github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI8A= @@ -1248,8 +1248,8 @@ gorm.io/gorm v1.25.12 h1:I0u8i2hWQItBq1WfE0o2+WuL9+8L21K9e2HHSTE/0f8= gorm.io/gorm v1.25.12/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= gotest.tools/v3 v3.5.0 h1:Ljk6PdHdOhAb5aDMWXjDLMMhph+BpztA4v1QdqEW2eY= gotest.tools/v3 v3.5.0/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= -gvisor.dev/gvisor v0.0.0-20231020174304-db3d49b921f9 h1:sCEaoA7ZmkuFwa2IR61pl4+RYZPwCJOiaSYT0k+BRf8= -gvisor.dev/gvisor v0.0.0-20231020174304-db3d49b921f9/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8= +gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1 h1:qDCwdCWECGnwQSQC01Dpnp09fRHxJs9PbktotUqG+hs= +gvisor.dev/gvisor v0.0.0-20231020174304-b8a429915ff1/go.mod h1:8hmigyCdYtw5xJGfQDJzSH5Ju8XEIDBnpyi8+O6GRt8= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=