Skip to content

Commit

Permalink
Add TCP tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Dec 20, 2024
1 parent dadf64e commit fa38d8e
Show file tree
Hide file tree
Showing 4 changed files with 459 additions and 75 deletions.
5 changes: 5 additions & 0 deletions client/firewall/uspfilter/allow_netbird.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ func (m *Manager) Reset(stateManager *statemanager.Manager) error {
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}

if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}

if m.nativeFirewall != nil {
return m.nativeFirewall.Reset(stateManager)
}
Expand Down
5 changes: 5 additions & 0 deletions client/firewall/uspfilter/allow_netbird_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ func (m *Manager) Reset(*statemanager.Manager) error {
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout)
}

if m.tcpTracker != nil {
m.tcpTracker.Close()
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout)
}

if !isWindowsFirewallReachable() {
return nil
}
Expand Down
310 changes: 310 additions & 0 deletions client/firewall/uspfilter/conntrack/tcp.go
Original file line number Diff line number Diff line change
@@ -1 +1,311 @@
package conntrack

// TODO: Send RST packets for invalid/timed-out connections

import (
"net"
"slices"
"sync"
"time"
)

const (
// MSL (Maximum Segment Lifetime) is typically 2 minutes
MSL = 2 * time.Minute
// TimeWaitTimeout (TIME-WAIT) should last 2*MSL
TimeWaitTimeout = 2 * MSL
)

const (
TCPSyn uint8 = 0x02
TCPAck uint8 = 0x10
TCPFin uint8 = 0x01
TCPRst uint8 = 0x04
TCPPush uint8 = 0x08
TCPUrg uint8 = 0x20
)

const (
// DefaultTCPTimeout is the default timeout for established TCP connections
DefaultTCPTimeout = 3 * time.Hour
// TCPHandshakeTimeout is timeout for TCP handshake completion
TCPHandshakeTimeout = 60 * time.Second
// TCPCleanupInterval is how often we check for stale connections
TCPCleanupInterval = 5 * time.Minute
)

// TCPState represents the state of a TCP connection
type TCPState int

const (
TCPStateNew TCPState = iota
TCPStateSynSent
TCPStateSynReceived
TCPStateEstablished
TCPStateFinWait1
TCPStateFinWait2
TCPStateClosing
TCPStateTimeWait
TCPStateCloseWait
TCPStateLastAck
TCPStateClosed
)

// TCPConnKey uniquely identifies a TCP connection
type TCPConnKey struct {
SrcIP [16]byte
DstIP [16]byte
SrcPort uint16
DstPort uint16
}

// TCPConnTrack represents a TCP connection state
type TCPConnTrack struct {
SourceIP net.IP
DestIP net.IP
SourcePort uint16
DestPort uint16
State TCPState
LastSeen time.Time
established bool
}

// TCPTracker manages TCP connection states
type TCPTracker struct {
connections map[TCPConnKey]*TCPConnTrack
mutex sync.RWMutex
cleanupTicker *time.Ticker
done chan struct{}
timeout time.Duration
}

// NewTCPTracker creates a new TCP connection tracker
func NewTCPTracker(timeout time.Duration) *TCPTracker {
tracker := &TCPTracker{
connections: make(map[TCPConnKey]*TCPConnTrack),
cleanupTicker: time.NewTicker(TCPCleanupInterval),
done: make(chan struct{}),
timeout: timeout,
}

go tracker.cleanupRoutine()
return tracker
}

// TrackOutbound processes an outbound TCP packet and updates connection state
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
t.mutex.Lock()
defer t.mutex.Unlock()

key := makeTCPKey(srcIP, dstIP, srcPort, dstPort)
now := time.Now()

conn, exists := t.connections[key]
if !exists {
conn = &TCPConnTrack{
SourceIP: slices.Clone(srcIP),
DestIP: slices.Clone(dstIP),
SourcePort: srcPort,
DestPort: dstPort,
State: TCPStateNew,
LastSeen: now,
established: false,
}
t.connections[key] = conn
}

// Update connection state based on TCP flags
t.updateState(conn, flags, true)
conn.LastSeen = now
}

// IsValidInbound checks if an inbound TCP packet matches a tracked connection
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
t.mutex.Lock()
defer t.mutex.Unlock()

// For SYN packets (new connection attempts), always allow
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
key := makeTCPKey(dstIP, srcIP, dstPort, srcPort)
t.connections[key] = &TCPConnTrack{
SourceIP: slices.Clone(dstIP),
DestIP: slices.Clone(srcIP),
SourcePort: dstPort,
DestPort: srcPort,
State: TCPStateSynReceived,
LastSeen: time.Now(),
established: false,
}
return true
}

key := makeTCPKey(dstIP, srcIP, dstPort, srcPort)
conn, exists := t.connections[key]
if !exists {
return false
}

// Update state and check validity
if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.established = false
return true
}

// Special handling for FIN state
if conn.State == TCPStateFinWait1 || conn.State == TCPStateFinWait2 {
t.updateState(conn, flags, false)
conn.LastSeen = time.Now()
return true
}

t.updateState(conn, flags, false)
conn.LastSeen = time.Now()

// Allow if established or in a valid state for the flags
return conn.established || t.isValidStateForFlags(conn.State, flags)
}

// updateState updates the TCP connection state based on flags
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
// Handle RST flag specially - it always causes transition to closed
if flags&TCPRst != 0 {
conn.State = TCPStateClosed
conn.established = false
return
}

switch conn.State {
case TCPStateNew:
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
conn.State = TCPStateSynSent
}

case TCPStateSynSent:
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
if isOutbound {
conn.State = TCPStateSynReceived
} else {
// Simultaneous open
conn.State = TCPStateEstablished
conn.established = true
}
}

case TCPStateSynReceived:
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
conn.State = TCPStateEstablished
conn.established = true
}

case TCPStateEstablished:
if flags&TCPFin != 0 {
if isOutbound {
conn.State = TCPStateFinWait1
} else {
conn.State = TCPStateCloseWait
}
conn.established = false
}

case TCPStateFinWait1:
if flags&TCPFin != 0 && flags&TCPAck != 0 {

Check failure on line 212 in client/firewall/uspfilter/conntrack/tcp.go

View workflow job for this annotation

GitHub Actions / lint (ubuntu-latest)

ifElseChain: rewrite if-else to switch statement (gocritic)

Check failure on line 212 in client/firewall/uspfilter/conntrack/tcp.go

View workflow job for this annotation

GitHub Actions / lint (macos-latest)

ifElseChain: rewrite if-else to switch statement (gocritic)
// Simultaneous close
conn.State = TCPStateClosing
} else if flags&TCPFin != 0 {
conn.State = TCPStateFinWait2
} else if flags&TCPAck != 0 {
conn.State = TCPStateFinWait2
}

case TCPStateFinWait2:
if flags&TCPFin != 0 {
conn.State = TCPStateTimeWait
}

case TCPStateClosing:
if flags&TCPAck != 0 {
conn.State = TCPStateTimeWait
}

case TCPStateCloseWait:
if flags&TCPFin != 0 {
conn.State = TCPStateLastAck
}

case TCPStateLastAck:
if flags&TCPAck != 0 {
conn.State = TCPStateClosed
}

case TCPStateTimeWait:
// Stay in TIME-WAIT for 2MSL before transitioning to closed
// This is handled by the cleanup routine
}
}

// isValidStateForFlags checks if the TCP flags are valid for the current connection state
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
switch state {
case TCPStateSynSent:
return flags&TCPSyn != 0 && flags&TCPAck != 0
case TCPStateSynReceived:
return flags&TCPAck != 0
case TCPStateEstablished:
return true // Allow all flags in established state
case TCPStateFinWait1, TCPStateFinWait2:
return flags&TCPFin != 0 || flags&TCPAck != 0
}
return false
}

func (t *TCPTracker) cleanupRoutine() {
for {
select {
case <-t.cleanupTicker.C:
t.cleanup()
case <-t.done:
return
}
}
}

func (t *TCPTracker) cleanup() {
t.mutex.Lock()
defer t.mutex.Unlock()

now := time.Now()
for key, conn := range t.connections {
var timeout time.Duration
switch {
case conn.State == TCPStateTimeWait:
timeout = TimeWaitTimeout
case conn.established:
timeout = t.timeout
default:
timeout = TCPHandshakeTimeout
}

if now.Sub(conn.LastSeen) > timeout {
delete(t.connections, key)
}
}
}

// Close stops the cleanup routine and releases resources
func (t *TCPTracker) Close() {
t.cleanupTicker.Stop()
close(t.done)
}

func makeTCPKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) TCPConnKey {
var srcAddr, dstAddr [16]byte
copy(srcAddr[:], srcIP.To16())
copy(dstAddr[:], dstIP.To16())
return TCPConnKey{
SrcIP: srcAddr,
DstIP: dstAddr,
SrcPort: srcPort,
DstPort: dstPort,
}
}
Loading

0 comments on commit fa38d8e

Please sign in to comment.