From 3c158d4f29c4d51c4c11b636baa5ee4e53ae8601 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Sun, 22 Dec 2024 13:06:19 +0100 Subject: [PATCH] Fix races, improve performance and add benchmarks --- client/firewall/uspfilter/conntrack/common.go | 137 +++++++++++++++ .../uspfilter/conntrack/common_test.go | 115 +++++++++++++ client/firewall/uspfilter/conntrack/icmp.go | 93 +++++----- .../firewall/uspfilter/conntrack/icmp_test.go | 39 +++++ client/firewall/uspfilter/conntrack/tcp.go | 162 ++++++++++-------- .../firewall/uspfilter/conntrack/tcp_test.go | 103 +++++++++++ client/firewall/uspfilter/conntrack/udp.go | 111 ++++++------ .../firewall/uspfilter/conntrack/udp_test.go | 37 +++- client/firewall/uspfilter/uspfilter_test.go | 2 +- 9 files changed, 631 insertions(+), 168 deletions(-) create mode 100644 client/firewall/uspfilter/conntrack/common.go create mode 100644 client/firewall/uspfilter/conntrack/common_test.go create mode 100644 client/firewall/uspfilter/conntrack/icmp_test.go diff --git a/client/firewall/uspfilter/conntrack/common.go b/client/firewall/uspfilter/conntrack/common.go new file mode 100644 index 00000000000..079a0175f2f --- /dev/null +++ b/client/firewall/uspfilter/conntrack/common.go @@ -0,0 +1,137 @@ +// common.go +package conntrack + +import ( + "net" + "sync" + "sync/atomic" + "time" +) + +// BaseConnTrack provides common fields and locking for all connection types +type BaseConnTrack struct { + sync.RWMutex + SourceIP net.IP + DestIP net.IP + SourcePort uint16 + DestPort uint16 + lastSeen atomic.Int64 // Unix nano for atomic access + established atomic.Bool +} + +// these small methods will be inlined by the compiler + +// UpdateLastSeen safely updates the last seen timestamp +func (b *BaseConnTrack) UpdateLastSeen() { + b.lastSeen.Store(time.Now().UnixNano()) +} + +// IsEstablished safely checks if connection is established +func (b *BaseConnTrack) IsEstablished() bool { + return b.established.Load() +} + +// SetEstablished safely sets the established state +func (b *BaseConnTrack) SetEstablished(state bool) { + b.established.Store(state) +} + +// GetLastSeen safely gets the last seen timestamp +func (b *BaseConnTrack) GetLastSeen() time.Time { + return time.Unix(0, b.lastSeen.Load()) +} + +// timeoutExceeded checks if the connection has exceeded the given timeout +func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool { + lastSeen := time.Unix(0, b.lastSeen.Load()) + return time.Since(lastSeen) > timeout +} + +// IPAddr is a fixed-size IP address to avoid allocations +type IPAddr [16]byte + +// makeIPAddr creates an IPAddr from net.IP +func makeIPAddr(ip net.IP) (addr IPAddr) { + // Optimization: check for v4 first as it's more common + if ip4 := ip.To4(); ip4 != nil { + copy(addr[12:], ip4) + } else { + copy(addr[:], ip.To16()) + } + return addr +} + +// ConnKey uniquely identifies a connection +type ConnKey struct { + SrcIP IPAddr + DstIP IPAddr + SrcPort uint16 + DstPort uint16 +} + +// makeConnKey creates a connection key +func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey { + return ConnKey{ + SrcIP: makeIPAddr(srcIP), + DstIP: makeIPAddr(dstIP), + SrcPort: srcPort, + DstPort: dstPort, + } +} + +// ValidateIPs checks if IPs match without allocation +func ValidateIPs(connIP IPAddr, pktIP net.IP) bool { + if ip4 := pktIP.To4(); ip4 != nil { + // Compare IPv4 addresses (last 4 bytes) + for i := 0; i < 4; i++ { + if connIP[12+i] != ip4[i] { + return false + } + } + return true + } + // Compare full IPv6 addresses + ip6 := pktIP.To16() + for i := 0; i < 16; i++ { + if connIP[i] != ip6[i] { + return false + } + } + return true +} + +// PreallocatedIPs is a pool of IP byte slices to reduce allocations +type PreallocatedIPs struct { + sync.Pool +} + +// NewPreallocatedIPs creates a new IP pool +func NewPreallocatedIPs() *PreallocatedIPs { + return &PreallocatedIPs{ + Pool: sync.Pool{ + New: func() interface{} { + return make(net.IP, 16) + }, + }, + } +} + +// Get retrieves an IP from the pool +func (p *PreallocatedIPs) Get() net.IP { + return p.Pool.Get().(net.IP) +} + +// Put returns an IP to the pool +func (p *PreallocatedIPs) Put(ip net.IP) { + p.Pool.Put(ip) +} + +// copyIP copies an IP address efficiently +func copyIP(dst, src net.IP) { + if len(src) == 16 { + copy(dst, src) + } else { + // Handle IPv4 + copy(dst[12:], src.To4()) + } +} diff --git a/client/firewall/uspfilter/conntrack/common_test.go b/client/firewall/uspfilter/conntrack/common_test.go new file mode 100644 index 00000000000..a337f649b47 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/common_test.go @@ -0,0 +1,115 @@ +package conntrack + +import ( + "net" + "testing" +) + +func BenchmarkIPOperations(b *testing.B) { + b.Run("makeIPAddr", func(b *testing.B) { + ip := net.ParseIP("192.168.1.1") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = makeIPAddr(ip) + } + }) + + b.Run("ValidateIPs", func(b *testing.B) { + ip1 := net.ParseIP("192.168.1.1") + ip2 := net.ParseIP("192.168.1.1") + addr := makeIPAddr(ip1) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ValidateIPs(addr, ip2) + } + }) + + b.Run("IPPool", func(b *testing.B) { + pool := NewPreallocatedIPs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ip := pool.Get() + pool.Put(ip) + } + }) + +} +func BenchmarkAtomicOperations(b *testing.B) { + conn := &BaseConnTrack{} + b.Run("UpdateLastSeen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + conn.UpdateLastSeen() + } + }) + + b.Run("IsEstablished", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = conn.IsEstablished() + } + }) + + b.Run("SetEstablished", func(b *testing.B) { + for i := 0; i < b.N; i++ { + conn.SetEstablished(i%2 == 0) + } + }) + + b.Run("GetLastSeen", func(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = conn.GetLastSeen() + } + }) +} + +// Memory pressure tests +func BenchmarkMemoryPressure(b *testing.B) { + b.Run("TCPHighLoad", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + // Generate different IPs + srcIPs := make([]net.IP, 100) + dstIPs := make([]net.IP, 100) + for i := 0; i < 100; i++ { + srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) + dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + srcIdx := i % len(srcIPs) + dstIdx := (i + 1) % len(dstIPs) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn) + + // Simulate some valid inbound packets + if i%3 == 0 { + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck) + } + } + }) + + b.Run("UDPHighLoad", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + // Generate different IPs + srcIPs := make([]net.IP, 100) + dstIPs := make([]net.IP, 100) + for i := 0; i < 100; i++ { + srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256)) + dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + srcIdx := i % len(srcIPs) + dstIdx := (i + 1) % len(dstIPs) + tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80) + + // Simulate some valid inbound packets + if i%3 == 0 { + tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535)) + } + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/icmp.go b/client/firewall/uspfilter/conntrack/icmp.go index 9b76c370a46..4cab4cb0e72 100644 --- a/client/firewall/uspfilter/conntrack/icmp.go +++ b/client/firewall/uspfilter/conntrack/icmp.go @@ -2,7 +2,6 @@ package conntrack import ( "net" - "slices" "sync" "time" @@ -27,12 +26,9 @@ type ICMPConnKey struct { // ICMPConnTrack represents an ICMP connection state type ICMPConnTrack struct { - SourceIP net.IP - DestIP net.IP - Sequence uint16 - ID uint16 - LastSeen time.Time - established bool + BaseConnTrack + Sequence uint16 + ID uint16 } // ICMPTracker manages ICMP connection states @@ -42,6 +38,7 @@ type ICMPTracker struct { cleanupTicker *time.Ticker mutex sync.RWMutex done chan struct{} + ipPool *PreallocatedIPs } // NewICMPTracker creates a new ICMP connection tracker @@ -55,6 +52,7 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker { timeout: timeout, cleanupTicker: time.NewTicker(ICMPCleanupInterval), done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), } go tracker.cleanupRoutine() @@ -64,29 +62,41 @@ func NewICMPTracker(timeout time.Duration) *ICMPTracker { // TrackOutbound records an outbound ICMP Echo Request func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) { key := makeICMPKey(srcIP, dstIP, id, seq) + now := time.Now().UnixNano() t.mutex.Lock() - defer t.mutex.Unlock() - - t.connections[key] = &ICMPConnTrack{ - SourceIP: slices.Clone(srcIP), - DestIP: slices.Clone(dstIP), - ID: id, - Sequence: seq, - LastSeen: time.Now(), - established: true, + conn, exists := t.connections[key] + if !exists { + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &ICMPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + }, + ID: id, + Sequence: seq, + } + conn.lastSeen.Store(now) + conn.established.Store(true) + t.connections[key] = conn } + t.mutex.Unlock() + + conn.lastSeen.Store(now) } // IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool { switch icmpType { - // For Destination Unreachable and Time Exceeded, always allow - case uint8(layers.ICMPv4TypeDestinationUnreachable), uint8(layers.ICMPv4TypeTimeExceeded): + case uint8(layers.ICMPv4TypeDestinationUnreachable), + uint8(layers.ICMPv4TypeTimeExceeded): return true - // For Echo Reply, check if we have a matching request case uint8(layers.ICMPv4TypeEchoReply): - // continue further down + // continue processing default: return false } @@ -94,29 +104,22 @@ func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq key := makeICMPKey(dstIP, srcIP, id, seq) t.mutex.RLock() - defer t.mutex.RUnlock() - conn, exists := t.connections[key] + t.mutex.RUnlock() + if !exists { return false } - // Check if connection is still valid - if time.Since(conn.LastSeen) > t.timeout { + if conn.timeoutExceeded(t.timeout) { return false } - if conn.established && - conn.DestIP.Equal(srcIP) && - conn.SourceIP.Equal(dstIP) && + return conn.IsEstablished() && + ValidateIPs(makeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(makeIPAddr(dstIP), conn.SourceIP) && conn.ID == id && - conn.Sequence == seq { - - conn.LastSeen = time.Now() - return true - } - - return false + conn.Sequence == seq } func (t *ICMPTracker) cleanupRoutine() { @@ -129,14 +132,14 @@ func (t *ICMPTracker) cleanupRoutine() { } } } - func (t *ICMPTracker) cleanup() { t.mutex.Lock() defer t.mutex.Unlock() - now := time.Now() for key, conn := range t.connections { - if now.Sub(conn.LastSeen) > t.timeout { + if conn.timeoutExceeded(t.timeout) { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) delete(t.connections, key) } } @@ -146,15 +149,21 @@ func (t *ICMPTracker) cleanup() { func (t *ICMPTracker) Close() { t.cleanupTicker.Stop() close(t.done) + + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() } +// makeICMPKey creates an ICMP connection key func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey { - var srcAddr, dstAddr [16]byte - copy(srcAddr[:], srcIP.To16()) - copy(dstAddr[:], dstIP.To16()) return ICMPConnKey{ - SrcIP: srcAddr, - DstIP: dstAddr, + SrcIP: makeIPAddr(srcIP), + DstIP: makeIPAddr(dstIP), ID: id, Sequence: seq, } diff --git a/client/firewall/uspfilter/conntrack/icmp_test.go b/client/firewall/uspfilter/conntrack/icmp_test.go new file mode 100644 index 00000000000..21176e719d4 --- /dev/null +++ b/client/firewall/uspfilter/conntrack/icmp_test.go @@ -0,0 +1,39 @@ +package conntrack + +import ( + "net" + "testing" +) + +func BenchmarkICMPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewICMPTracker(DefaultICMPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535)) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewICMPTracker(DefaultICMPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0) + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/tcp.go b/client/firewall/uspfilter/conntrack/tcp.go index 22c37184c5f..e8d20f41c67 100644 --- a/client/firewall/uspfilter/conntrack/tcp.go +++ b/client/firewall/uspfilter/conntrack/tcp.go @@ -4,7 +4,6 @@ package conntrack import ( "net" - "slices" "sync" "time" ) @@ -61,31 +60,28 @@ type TCPConnKey struct { // TCPConnTrack represents a TCP connection state type TCPConnTrack struct { - SourceIP net.IP - DestIP net.IP - SourcePort uint16 - DestPort uint16 - State TCPState - LastSeen time.Time - established bool + BaseConnTrack + State TCPState } // TCPTracker manages TCP connection states type TCPTracker struct { - connections map[TCPConnKey]*TCPConnTrack + connections map[ConnKey]*TCPConnTrack mutex sync.RWMutex cleanupTicker *time.Ticker done chan struct{} timeout time.Duration + ipPool *PreallocatedIPs } // NewTCPTracker creates a new TCP connection tracker func NewTCPTracker(timeout time.Duration) *TCPTracker { tracker := &TCPTracker{ - connections: make(map[TCPConnKey]*TCPConnTrack), + connections: make(map[ConnKey]*TCPConnTrack), cleanupTicker: time.NewTicker(TCPCleanupInterval), done: make(chan struct{}), timeout: timeout, + ipPool: NewPreallocatedIPs(), } go tracker.cleanupRoutine() @@ -94,88 +90,108 @@ func NewTCPTracker(timeout time.Duration) *TCPTracker { // TrackOutbound processes an outbound TCP packet and updates connection state func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) { - - key := makeTCPKey(srcIP, dstIP, srcPort, dstPort) - now := time.Now() + // Create key before lock + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + now := time.Now().UnixNano() t.mutex.Lock() - defer t.mutex.Unlock() - conn, exists := t.connections[key] if !exists { + // Use preallocated IPs + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + conn = &TCPConnTrack{ - SourceIP: slices.Clone(srcIP), - DestIP: slices.Clone(dstIP), - SourcePort: srcPort, - DestPort: dstPort, - State: TCPStateNew, - LastSeen: now, - established: false, + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: srcPort, + DestPort: dstPort, + }, + State: TCPStateNew, } + conn.lastSeen.Store(now) + conn.established.Store(false) t.connections[key] = conn } + t.mutex.Unlock() - // Update connection state based on TCP flags + // Lock individual connection for state update + conn.Lock() t.updateState(conn, flags, true) - conn.LastSeen = now + conn.Unlock() + conn.lastSeen.Store(now) } // IsValidInbound checks if an inbound TCP packet matches a tracked connection func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool { - t.mutex.Lock() - defer t.mutex.Unlock() - - // Always validate flag combinations first if !isValidFlagCombination(flags) { return false } - // For SYN packets (new connection attempts), allow only pure SYN - if flags&TCPSyn != 0 { - if flags&TCPAck == 0 { - key := makeTCPKey(dstIP, srcIP, dstPort, srcPort) - t.connections[key] = &TCPConnTrack{ - SourceIP: slices.Clone(dstIP), - DestIP: slices.Clone(srcIP), - SourcePort: dstPort, - DestPort: srcPort, - State: TCPStateSynReceived, - LastSeen: time.Now(), - established: false, + // Handle new SYN packets + if flags&TCPSyn != 0 && flags&TCPAck == 0 { + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + t.mutex.Lock() + if _, exists := t.connections[key]; !exists { + // Use preallocated IPs + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, dstIP) + copyIP(dstIPCopy, srcIP) + + conn := &TCPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: dstPort, + DestPort: srcPort, + }, + State: TCPStateSynReceived, } - return true + conn.lastSeen.Store(time.Now().UnixNano()) + conn.established.Store(false) + t.connections[key] = conn } - // If it's SYN+ACK, let it fall through to normal processing + t.mutex.Unlock() + return true } - key := makeTCPKey(dstIP, srcIP, dstPort, srcPort) + // Look up existing connection + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) + t.mutex.RLock() conn, exists := t.connections[key] + t.mutex.RUnlock() + if !exists { return false } - // Handle RST packets - only allow for existing connections + // Handle RST packets if flags&TCPRst != 0 { - if conn.established || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { + conn.Lock() + isEstablished := conn.IsEstablished() + if isEstablished || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived { conn.State = TCPStateClosed - conn.established = false + conn.SetEstablished(false) + conn.Unlock() return true } + conn.Unlock() return false } - // Special handling for FIN state - if conn.State == TCPStateFinWait1 || conn.State == TCPStateFinWait2 { - t.updateState(conn, flags, false) - conn.LastSeen = time.Now() - return true - } - + // Update state + conn.Lock() t.updateState(conn, flags, false) - conn.LastSeen = time.Now() + conn.UpdateLastSeen() + isEstablished := conn.IsEstablished() + isValidState := t.isValidStateForFlags(conn.State, flags) + conn.Unlock() - // Allow if established or in a valid state for the flags - return conn.established || t.isValidStateForFlags(conn.State, flags) + return isEstablished || isValidState } // updateState updates the TCP connection state based on flags @@ -183,7 +199,7 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo // Handle RST flag specially - it always causes transition to closed if flags&TCPRst != 0 { conn.State = TCPStateClosed - conn.established = false + conn.SetEstablished(false) return } @@ -200,14 +216,14 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo } else { // Simultaneous open conn.State = TCPStateEstablished - conn.established = true + conn.SetEstablished(true) } } case TCPStateSynReceived: if flags&TCPAck != 0 && flags&TCPSyn == 0 { conn.State = TCPStateEstablished - conn.established = true + conn.SetEstablished(true) } case TCPStateEstablished: @@ -217,7 +233,7 @@ func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound boo } else { conn.State = TCPStateCloseWait } - conn.established = false + conn.SetEstablished(false) } case TCPStateFinWait1: @@ -309,19 +325,22 @@ func (t *TCPTracker) cleanup() { t.mutex.Lock() defer t.mutex.Unlock() - now := time.Now() for key, conn := range t.connections { var timeout time.Duration switch { case conn.State == TCPStateTimeWait: timeout = TimeWaitTimeout - case conn.established: + case conn.IsEstablished(): timeout = t.timeout default: timeout = TCPHandshakeTimeout } - if now.Sub(conn.LastSeen) > timeout { + lastSeen := conn.GetLastSeen() + if time.Since(lastSeen) > timeout { + // Return IPs to pool + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) delete(t.connections, key) } } @@ -331,18 +350,15 @@ func (t *TCPTracker) cleanup() { func (t *TCPTracker) Close() { t.cleanupTicker.Stop() close(t.done) -} -func makeTCPKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) TCPConnKey { - var srcAddr, dstAddr [16]byte - copy(srcAddr[:], srcIP.To16()) - copy(dstAddr[:], dstIP.To16()) - return TCPConnKey{ - SrcIP: srcAddr, - DstIP: dstAddr, - SrcPort: srcPort, - DstPort: dstPort, + // Clean up all remaining IPs + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) } + t.connections = nil + t.mutex.Unlock() } func isValidFlagCombination(flags uint8) bool { diff --git a/client/firewall/uspfilter/conntrack/tcp_test.go b/client/firewall/uspfilter/conntrack/tcp_test.go index 7bca8d9960e..42a2f708a47 100644 --- a/client/firewall/uspfilter/conntrack/tcp_test.go +++ b/client/firewall/uspfilter/conntrack/tcp_test.go @@ -3,6 +3,7 @@ package conntrack import ( "net" "testing" + "time" "github.com/stretchr/testify/require" ) @@ -160,3 +161,105 @@ func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck) } + +// Benchmarks for the optimized implementation +func (t *TCPTracker) benchmarkTrackOutbound(b *testing.B) { + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + t.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } +} + +func (t *TCPTracker) benchmarkIsValidInbound(b *testing.B) { + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + t.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + t.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) + } +} + +func BenchmarkTCPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck) + } + }) + + b.Run("ConcurrentAccess", func(b *testing.B) { + tracker := NewTCPTracker(DefaultTCPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.RunParallel(func(pb *testing.PB) { + i := 0 + for pb.Next() { + if i%2 == 0 { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn) + } else { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck) + } + i++ + } + }) + }) +} + +// Benchmark connection cleanup +func BenchmarkCleanup(b *testing.B) { + b.Run("TCPCleanup", func(b *testing.B) { + tracker := NewTCPTracker(100 * time.Millisecond) // Short timeout for testing + defer tracker.Close() + + // Pre-populate with expired connections + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + for i := 0; i < 10000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn) + } + + // Wait for connections to expire + time.Sleep(200 * time.Millisecond) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.cleanup() + } + }) +} diff --git a/client/firewall/uspfilter/conntrack/udp.go b/client/firewall/uspfilter/conntrack/udp.go index 94ab9e15273..4d55ec0df83 100644 --- a/client/firewall/uspfilter/conntrack/udp.go +++ b/client/firewall/uspfilter/conntrack/udp.go @@ -2,7 +2,6 @@ package conntrack import ( "net" - "slices" "sync" "time" ) @@ -14,22 +13,9 @@ const ( UDPCleanupInterval = 15 * time.Second ) -type ConnKey struct { - // Supports both IPv4 and IPv6 - SrcIP [16]byte - DstIP [16]byte - SrcPort uint16 - DstPort uint16 -} - // UDPConnTrack represents a UDP connection state type UDPConnTrack struct { - SourceIP net.IP - DestIP net.IP - SourcePort uint16 - DestPort uint16 - LastSeen time.Time - established bool + BaseConnTrack } // UDPTracker manages UDP connection states @@ -39,6 +25,7 @@ type UDPTracker struct { cleanupTicker *time.Ticker mutex sync.RWMutex done chan struct{} + ipPool *PreallocatedIPs } // NewUDPTracker creates a new UDP connection tracker @@ -52,6 +39,7 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker { timeout: timeout, cleanupTicker: time.NewTicker(UDPCleanupInterval), done: make(chan struct{}), + ipPool: NewPreallocatedIPs(), } go tracker.cleanupRoutine() @@ -60,49 +48,55 @@ func NewUDPTracker(timeout time.Duration) *UDPTracker { // TrackOutbound records an outbound UDP connection func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) { - key := makeKey(srcIP, srcPort, dstIP, dstPort) + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) + now := time.Now().UnixNano() t.mutex.Lock() - defer t.mutex.Unlock() - - t.connections[key] = &UDPConnTrack{ - SourceIP: slices.Clone(srcIP), - DestIP: slices.Clone(dstIP), - SourcePort: srcPort, - DestPort: dstPort, - LastSeen: time.Now(), - established: true, + conn, exists := t.connections[key] + if !exists { + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, srcIP) + copyIP(dstIPCopy, dstIP) + + conn = &UDPConnTrack{ + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: srcPort, + DestPort: dstPort, + }, + } + conn.lastSeen.Store(now) + conn.established.Store(true) + t.connections[key] = conn } + t.mutex.Unlock() + + conn.lastSeen.Store(now) } // IsValidInbound checks if an inbound packet matches a tracked connection func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool { - t.mutex.RLock() - defer t.mutex.RUnlock() + key := makeConnKey(dstIP, srcIP, dstPort, srcPort) - key := makeKey(dstIP, dstPort, srcIP, srcPort) + t.mutex.RLock() conn, exists := t.connections[key] + t.mutex.RUnlock() + if !exists { return false } - // Check if connection is still valid - if time.Since(conn.LastSeen) > t.timeout { + if conn.timeoutExceeded(t.timeout) { return false } - if conn.established && - conn.DestIP.Equal(srcIP) && - conn.SourceIP.Equal(dstIP) && + return conn.IsEstablished() && + ValidateIPs(makeIPAddr(srcIP), conn.DestIP) && + ValidateIPs(makeIPAddr(dstIP), conn.SourceIP) && conn.DestPort == srcPort && - conn.SourcePort == dstPort { - - conn.LastSeen = time.Now() - - return true - } - - return false + conn.SourcePort == dstPort } // cleanupRoutine periodically removes stale connections @@ -121,9 +115,10 @@ func (t *UDPTracker) cleanup() { t.mutex.Lock() defer t.mutex.Unlock() - now := time.Now() for key, conn := range t.connections { - if now.Sub(conn.LastSeen) > t.timeout { + if conn.timeoutExceeded(t.timeout) { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) delete(t.connections, key) } } @@ -133,6 +128,14 @@ func (t *UDPTracker) cleanup() { func (t *UDPTracker) Close() { t.cleanupTicker.Stop() close(t.done) + + t.mutex.Lock() + for _, conn := range t.connections { + t.ipPool.Put(conn.SourceIP) + t.ipPool.Put(conn.DestIP) + } + t.connections = nil + t.mutex.Unlock() } // GetConnection safely retrieves a connection state @@ -140,20 +143,28 @@ func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, d t.mutex.RLock() defer t.mutex.RUnlock() - key := makeKey(srcIP, srcPort, dstIP, dstPort) + key := makeConnKey(srcIP, dstIP, srcPort, dstPort) conn, exists := t.connections[key] if !exists { return nil, false } + // Create a copy with new IP allocations + srcIPCopy := t.ipPool.Get() + dstIPCopy := t.ipPool.Get() + copyIP(srcIPCopy, conn.SourceIP) + copyIP(dstIPCopy, conn.DestIP) + connCopy := &UDPConnTrack{ - SourceIP: slices.Clone(conn.SourceIP), - DestIP: slices.Clone(conn.DestIP), - SourcePort: conn.SourcePort, - DestPort: conn.DestPort, - LastSeen: conn.LastSeen, - established: conn.established, + BaseConnTrack: BaseConnTrack{ + SourceIP: srcIPCopy, + DestIP: dstIPCopy, + SourcePort: conn.SourcePort, + DestPort: conn.DestPort, + }, } + connCopy.lastSeen.Store(conn.lastSeen.Load()) + connCopy.established.Store(conn.IsEstablished()) return connCopy, true } diff --git a/client/firewall/uspfilter/conntrack/udp_test.go b/client/firewall/uspfilter/conntrack/udp_test.go index 938dc18ea59..1a8afc21a01 100644 --- a/client/firewall/uspfilter/conntrack/udp_test.go +++ b/client/firewall/uspfilter/conntrack/udp_test.go @@ -58,8 +58,8 @@ func TestUDPTracker_TrackOutbound(t *testing.T) { assert.True(t, conn.DestIP.Equal(dstIP)) assert.Equal(t, srcPort, conn.SourcePort) assert.Equal(t, dstPort, conn.DestPort) - assert.True(t, conn.established) - assert.WithinDuration(t, time.Now(), conn.LastSeen, 1*time.Second) + assert.True(t, conn.IsEstablished()) + assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second) } func TestUDPTracker_IsValidInbound(t *testing.T) { @@ -232,3 +232,36 @@ func TestUDPTracker_Close(t *testing.T) { _, ok := <-tracker.done assert.False(t, ok, "done channel should be closed") } + +func BenchmarkUDPTracker(b *testing.B) { + b.Run("TrackOutbound", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80) + } + }) + + b.Run("IsValidInbound", func(b *testing.B) { + tracker := NewUDPTracker(DefaultUDPTimeout) + defer tracker.Close() + + srcIP := net.ParseIP("192.168.1.1") + dstIP := net.ParseIP("192.168.1.2") + + // Pre-populate some connections + for i := 0; i < 1000; i++ { + tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000)) + } + }) +} diff --git a/client/firewall/uspfilter/uspfilter_test.go b/client/firewall/uspfilter/uspfilter_test.go index 23f575843a3..ea78a013abc 100644 --- a/client/firewall/uspfilter/uspfilter_test.go +++ b/client/firewall/uspfilter/uspfilter_test.go @@ -630,7 +630,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) { if cp.shouldAllow { conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort) require.True(t, exists, "Connection should still exist during valid window") - require.True(t, time.Since(conn.LastSeen) < manager.udpTracker.Timeout(), + require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(), "LastSeen should be updated for valid responses") } }