Skip to content

Commit

Permalink
Fix races, improve performance and add benchmarks
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Dec 22, 2024
1 parent 1a9a82b commit 3c158d4
Show file tree
Hide file tree
Showing 9 changed files with 631 additions and 168 deletions.
137 changes: 137 additions & 0 deletions client/firewall/uspfilter/conntrack/common.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// common.go
package conntrack

import (
"net"
"sync"
"sync/atomic"
"time"
)

// BaseConnTrack provides common fields and locking for all connection types
type BaseConnTrack struct {
sync.RWMutex
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
lastSeen atomic.Int64 // Unix nano for atomic access
established atomic.Bool
}

// these small methods will be inlined by the compiler

// UpdateLastSeen safely updates the last seen timestamp
func (b *BaseConnTrack) UpdateLastSeen() {
b.lastSeen.Store(time.Now().UnixNano())
}

// IsEstablished safely checks if connection is established
func (b *BaseConnTrack) IsEstablished() bool {
return b.established.Load()
}

// SetEstablished safely sets the established state
func (b *BaseConnTrack) SetEstablished(state bool) {
b.established.Store(state)
}

// GetLastSeen safely gets the last seen timestamp
func (b *BaseConnTrack) GetLastSeen() time.Time {
return time.Unix(0, b.lastSeen.Load())
}

// timeoutExceeded checks if the connection has exceeded the given timeout
func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
lastSeen := time.Unix(0, b.lastSeen.Load())
return time.Since(lastSeen) > timeout
}

// IPAddr is a fixed-size IP address to avoid allocations
type IPAddr [16]byte

// makeIPAddr creates an IPAddr from net.IP
func makeIPAddr(ip net.IP) (addr IPAddr) {
// Optimization: check for v4 first as it's more common
if ip4 := ip.To4(); ip4 != nil {
copy(addr[12:], ip4)
} else {
copy(addr[:], ip.To16())
}
return addr
}

// ConnKey uniquely identifies a connection
type ConnKey struct {
SrcIP IPAddr
DstIP IPAddr
SrcPort uint16
DstPort uint16
}

// makeConnKey creates a connection key
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
return ConnKey{
SrcIP: makeIPAddr(srcIP),
DstIP: makeIPAddr(dstIP),
SrcPort: srcPort,
DstPort: dstPort,
}
}

// ValidateIPs checks if IPs match without allocation
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
if ip4 := pktIP.To4(); ip4 != nil {
// Compare IPv4 addresses (last 4 bytes)
for i := 0; i < 4; i++ {
if connIP[12+i] != ip4[i] {
return false
}
}
return true
}
// Compare full IPv6 addresses
ip6 := pktIP.To16()
for i := 0; i < 16; i++ {
if connIP[i] != ip6[i] {
return false
}
}
return true
}

// PreallocatedIPs is a pool of IP byte slices to reduce allocations
type PreallocatedIPs struct {
sync.Pool
}

// NewPreallocatedIPs creates a new IP pool
func NewPreallocatedIPs() *PreallocatedIPs {
return &PreallocatedIPs{
Pool: sync.Pool{
New: func() interface{} {
return make(net.IP, 16)
},
},
}
}

// Get retrieves an IP from the pool
func (p *PreallocatedIPs) Get() net.IP {
return p.Pool.Get().(net.IP)
}

// Put returns an IP to the pool
func (p *PreallocatedIPs) Put(ip net.IP) {
p.Pool.Put(ip)
}

// copyIP copies an IP address efficiently
func copyIP(dst, src net.IP) {
if len(src) == 16 {
copy(dst, src)
} else {
// Handle IPv4
copy(dst[12:], src.To4())
}
}
115 changes: 115 additions & 0 deletions client/firewall/uspfilter/conntrack/common_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package conntrack

import (
"net"
"testing"
)

func BenchmarkIPOperations(b *testing.B) {
b.Run("makeIPAddr", func(b *testing.B) {
ip := net.ParseIP("192.168.1.1")
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = makeIPAddr(ip)
}
})

b.Run("ValidateIPs", func(b *testing.B) {
ip1 := net.ParseIP("192.168.1.1")
ip2 := net.ParseIP("192.168.1.1")
addr := makeIPAddr(ip1)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = ValidateIPs(addr, ip2)
}
})

b.Run("IPPool", func(b *testing.B) {
pool := NewPreallocatedIPs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
ip := pool.Get()
pool.Put(ip)
}
})

}
func BenchmarkAtomicOperations(b *testing.B) {
conn := &BaseConnTrack{}
b.Run("UpdateLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.UpdateLastSeen()
}
})

b.Run("IsEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.IsEstablished()
}
})

b.Run("SetEstablished", func(b *testing.B) {
for i := 0; i < b.N; i++ {
conn.SetEstablished(i%2 == 0)
}
})

b.Run("GetLastSeen", func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = conn.GetLastSeen()
}
})
}

// Memory pressure tests
func BenchmarkMemoryPressure(b *testing.B) {
b.Run("TCPHighLoad", func(b *testing.B) {
tracker := NewTCPTracker(DefaultTCPTimeout)
defer tracker.Close()

// Generate different IPs
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)

// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
}
}
})

b.Run("UDPHighLoad", func(b *testing.B) {
tracker := NewUDPTracker(DefaultUDPTimeout)
defer tracker.Close()

// Generate different IPs
srcIPs := make([]net.IP, 100)
dstIPs := make([]net.IP, 100)
for i := 0; i < 100; i++ {
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
}

b.ResetTimer()
for i := 0; i < b.N; i++ {
srcIdx := i % len(srcIPs)
dstIdx := (i + 1) % len(dstIPs)
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)

// Simulate some valid inbound packets
if i%3 == 0 {
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
}
}
})
}
Loading

0 comments on commit 3c158d4

Please sign in to comment.