Skip to content

Commit

Permalink
Reduce complexity
Browse files Browse the repository at this point in the history
  • Loading branch information
lixmal committed Jan 3, 2025
1 parent d711172 commit 9490e90
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 45 deletions.
2 changes: 2 additions & 0 deletions client/firewall/uspfilter/forwarder/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
}

func (e *endpoint) Wait() {
// not required
}

func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}

func (e *endpoint) AddHeader(*stack.PacketBuffer) {
// not required
}

func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
Expand Down
27 changes: 13 additions & 14 deletions client/firewall/uspfilter/forwarder/icmp.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
// For Echo Requests, send and handle response
switch icmpHdr.Type() {
case header.ICMPv4Echo:
_, err = conn.WriteTo(payload, dst)
if err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}

f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())

return f.handleEchoResponse(conn, id)
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
case header.ICMPv4EchoReply:
// dont process our own replies
return true
Expand All @@ -70,10 +61,18 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
return true
}

func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) bool {
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
if _, err := conn.WriteTo(payload, dst); err != nil {
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
return true
}

f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
id, icmpHdr.Type(), icmpHdr.Code())

if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
return false
return true
}

response := make([]byte, f.endpoint.mtu)
Expand All @@ -82,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn
if !isTimeout(err) {
f.logger.Error("Failed to read ICMP response: %v", err)
}
return false
return true
}

ipHdr := make([]byte, header.IPv4MinimumSize)
Expand All @@ -102,7 +101,7 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn

if err := f.InjectIncomingPacket(fullPacket); err != nil {
f.logger.Error("Failed to inject ICMP response: %v", err)
return false
return true
}

f.logger.Trace("Forwarded ICMP echo reply for %v", id)
Expand Down
62 changes: 31 additions & 31 deletions client/firewall/uspfilter/uspfilter.go
Original file line number Diff line number Diff line change
Expand Up @@ -732,47 +732,47 @@ func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, sr
srcAddr, _ := netip.AddrFromSlice(srcIP)
dstAddr, _ := netip.AddrFromSlice(dstIP)

// Default deny if no rules match
matched := false

for _, rule := range m.routeRules {
// Check destination
if !rule.destination.Contains(dstAddr) {
continue
}

// Check if source matches any source prefix
sourceMatched := false
for _, src := range rule.sources {
if src.Contains(srcAddr) {
sourceMatched = true
break
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
matched = true
if rule.action == firewall.ActionDrop {
return false
}
}
if !sourceMatched {
continue
}
}

// Check protocol
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
continue
}
return matched
}

// Check ports if specified
if rule.srcPort != nil && rule.srcPort.Values[0] != int(srcPort) {
continue
}
if rule.dstPort != nil && rule.dstPort.Values[0] != int(dstPort) {
continue
}
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
if !rule.destination.Contains(dstAddr) {
return false
}

matched = true
if rule.action == firewall.ActionDrop {
return false
sourceMatched := false
for _, src := range rule.sources {
if src.Contains(srcAddr) {
sourceMatched = true
break
}
}
if !sourceMatched {
return false
}

return matched
if rule.proto != firewall.ProtocolALL && rule.proto != proto {
return false
}

if rule.srcPort != nil && rule.srcPort.Values[0] != int(srcPort) {
return false
}
if rule.dstPort != nil && rule.dstPort.Values[0] != int(dstPort) {
return false
}

return true
}

// SetNetwork of the wireguard interface to which filtering applied
Expand Down

0 comments on commit 9490e90

Please sign in to comment.