diff --git a/client/firewall/uspfilter/forwarder/endpoint.go b/client/firewall/uspfilter/forwarder/endpoint.go index c234ca24136..7ff6101d050 100644 --- a/client/firewall/uspfilter/forwarder/endpoint.go +++ b/client/firewall/uspfilter/forwarder/endpoint.go @@ -67,6 +67,7 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) } func (e *endpoint) Wait() { + // not required } func (e *endpoint) ARPHardwareType() header.ARPHardwareType { @@ -74,6 +75,7 @@ func (e *endpoint) ARPHardwareType() header.ARPHardwareType { } func (e *endpoint) AddHeader(*stack.PacketBuffer) { + // not required } func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool { diff --git a/client/firewall/uspfilter/forwarder/icmp.go b/client/firewall/uspfilter/forwarder/icmp.go index e04464dd97f..14cdc37be85 100644 --- a/client/firewall/uspfilter/forwarder/icmp.go +++ b/client/firewall/uspfilter/forwarder/icmp.go @@ -41,16 +41,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf // For Echo Requests, send and handle response switch icmpHdr.Type() { case header.ICMPv4Echo: - _, err = conn.WriteTo(payload, dst) - if err != nil { - f.logger.Error("Failed to write ICMP packet for %v: %v", id, err) - return true - } - - f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v", - id, icmpHdr.Type(), icmpHdr.Code()) - - return f.handleEchoResponse(conn, id) + return f.handleEchoResponse(icmpHdr, payload, dst, conn, id) case header.ICMPv4EchoReply: // dont process our own replies return true @@ -70,10 +61,18 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf return true } -func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEndpointID) bool { +func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool { + if _, err := conn.WriteTo(payload, dst); err != nil { + f.logger.Error("Failed to write ICMP packet for %v: %v", id, err) + return true + } + + f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v", + id, icmpHdr.Type(), icmpHdr.Code()) + if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil { f.logger.Error("Failed to set read deadline for ICMP response: %v", err) - return false + return true } response := make([]byte, f.endpoint.mtu) @@ -82,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn if !isTimeout(err) { f.logger.Error("Failed to read ICMP response: %v", err) } - return false + return true } ipHdr := make([]byte, header.IPv4MinimumSize) @@ -102,7 +101,7 @@ func (f *Forwarder) handleEchoResponse(conn net.PacketConn, id stack.TransportEn if err := f.InjectIncomingPacket(fullPacket); err != nil { f.logger.Error("Failed to inject ICMP response: %v", err) - return false + return true } f.logger.Trace("Forwarded ICMP echo reply for %v", id) diff --git a/client/firewall/uspfilter/uspfilter.go b/client/firewall/uspfilter/uspfilter.go index 11ef68a4d3a..e869b0f4b84 100644 --- a/client/firewall/uspfilter/uspfilter.go +++ b/client/firewall/uspfilter/uspfilter.go @@ -732,47 +732,47 @@ func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, sr srcAddr, _ := netip.AddrFromSlice(srcIP) dstAddr, _ := netip.AddrFromSlice(dstIP) - // Default deny if no rules match matched := false - for _, rule := range m.routeRules { - // Check destination - if !rule.destination.Contains(dstAddr) { - continue - } - - // Check if source matches any source prefix - sourceMatched := false - for _, src := range rule.sources { - if src.Contains(srcAddr) { - sourceMatched = true - break + if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) { + matched = true + if rule.action == firewall.ActionDrop { + return false } } - if !sourceMatched { - continue - } + } - // Check protocol - if rule.proto != firewall.ProtocolALL && rule.proto != proto { - continue - } + return matched +} - // Check ports if specified - if rule.srcPort != nil && rule.srcPort.Values[0] != int(srcPort) { - continue - } - if rule.dstPort != nil && rule.dstPort.Values[0] != int(dstPort) { - continue - } +func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool { + if !rule.destination.Contains(dstAddr) { + return false + } - matched = true - if rule.action == firewall.ActionDrop { - return false + sourceMatched := false + for _, src := range rule.sources { + if src.Contains(srcAddr) { + sourceMatched = true + break } } + if !sourceMatched { + return false + } - return matched + if rule.proto != firewall.ProtocolALL && rule.proto != proto { + return false + } + + if rule.srcPort != nil && rule.srcPort.Values[0] != int(srcPort) { + return false + } + if rule.dstPort != nil && rule.dstPort.Values[0] != int(dstPort) { + return false + } + + return true } // SetNetwork of the wireguard interface to which filtering applied