From b07f21e13b7cd7e292f549d2a0073d266d9bea0e Mon Sep 17 00:00:00 2001 From: cavivie Date: Wed, 10 Apr 2024 18:08:07 +0800 Subject: [PATCH] chore!: refactor ip filters and also fix some docs --- README.md | 4 +-- examples/forward.rs | 6 ++-- src/filter.rs | 83 ++++++++++++++++++++------------------------- src/lib.rs | 5 ++- src/stack.rs | 61 +++++++++++++-------------------- 5 files changed, 69 insertions(+), 90 deletions(-) diff --git a/README.md b/README.md index f1bca7f..5825d46 100644 --- a/README.md +++ b/README.md @@ -73,5 +73,5 @@ at your option. ### Contribution Unless you explicitly state otherwise, any contribution intentionally submitted -for inclusion in cc-rs by you, as defined in the Apache-2.0 license, shall be -dual licensed as above, without any additional terms or conditions. +for inclusion in netstack-smoltcp by you, as defined in the Apache-2.0 license, +shall be dual licensed as above, without any additional terms or conditions. diff --git a/examples/forward.rs b/examples/forward.rs index 5ebd9f5..4097ac3 100644 --- a/examples/forward.rs +++ b/examples/forward.rs @@ -118,8 +118,8 @@ async fn main_exec(opt: Opt) { let mut builder = StackBuilder::default(); if let Some(device_broadcast) = get_device_broadcast(&device) { builder = builder - .add_src_v4_filter(move |v4| *v4 != device_broadcast) - .add_dst_v4_filter(move |v4| *v4 != device_broadcast); + // .add_ip_filter(Box::new(move |src, dst| *src != device_broadcast && *dst != device_broadcast)); + .add_ip_filter_fn(move |src, dst| *src != device_broadcast && *dst != device_broadcast); } let (runner, udp_socket, tcp_listener, stack) = builder.build(); @@ -147,7 +147,7 @@ async fn main_exec(opt: Opt) { futs.push(tokio_spawn!(async move { while let Some(pkt) = tun_stream.next().await { if let Ok(pkt) = pkt { - match stack_sink.send(pkt.to_vec()).await { + match stack_sink.send(pkt).await { Ok(_) => {} Err(e) => warn!("failed to send packet to stack, err: {:?}", e), }; diff --git a/src/filter.rs b/src/filter.rs index d47c144..0e475bb 100644 --- a/src/filter.rs +++ b/src/filter.rs @@ -1,64 +1,55 @@ -use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; +use std::net::IpAddr; -pub type Ipv4Filter<'a> = Box bool + Send + Sync + 'a>; -pub type Ipv6Filter<'a> = Box bool + Send + Sync + 'a>; +pub type IpFilter<'a> = Box bool + Send + Sync + 'a>; -pub struct Filters<'a> { - ipv4_filters: Vec>, - ipv6_filters: Vec>, +pub struct IpFilters<'a> { + filters: Vec>, } -impl<'a> Default for Filters<'a> { +impl<'a> Default for IpFilters<'a> { fn default() -> Self { - Self::new( - vec![ - Box::new(|v4| !v4.is_broadcast()), - Box::new(|v4| !v4.is_multicast()), - Box::new(|v4| !v4.is_unspecified()), - ], - vec![ - Box::new(|v6| !v6.is_multicast()), - Box::new(|v6| !v6.is_unspecified()), - ], - ) + Self::new(vec![]) } } -impl<'a> Filters<'a> { - pub fn new(ipv4_filters: Vec>, ipv6_filters: Vec>) -> Self { - Self { - ipv4_filters, - ipv6_filters, - } +impl<'a> IpFilters<'a> { + pub fn new(filters: Vec>) -> Self { + Self { filters } } - pub fn is_allowed(&self, addr: &IpAddr) -> bool { - match addr { - IpAddr::V4(v4) => { - for filter in &self.ipv4_filters { - if !filter(v4) { - return false; - } - } - } - IpAddr::V6(v6) => { - for filter in &self.ipv6_filters { - if !filter(v6) { - return false; + pub fn with_non_broadcast() -> Self { + Self::new(vec![Box::new(|src, dst| { + macro_rules! non_broadcast { + ($addr:expr) => { + match $addr { + IpAddr::V4(v4) => { + !(v4.is_broadcast() || v4.is_multicast() || v4.is_multicast()) + } + IpAddr::V6(v6) => !(v6.is_multicast() || v6.is_unspecified()), } - } + }; } - } - true + non_broadcast!(src) && non_broadcast!(dst) + })]) } - #[allow(unused)] - pub fn add_v4(&mut self, filter: Ipv4Filter<'a>) { - self.ipv4_filters.push(filter); + pub fn add(&mut self, filter: IpFilter<'a>) { + self.filters.push(filter); } - #[allow(unused)] - pub fn add_v6(&mut self, filter: Ipv6Filter<'a>) { - self.ipv6_filters.push(filter); + pub fn add_fn(&mut self, filter: F) + where + F: Fn(&IpAddr, &IpAddr) -> bool + Send + Sync + 'a, + { + self.filters.push(Box::new(filter)); + } + + pub fn is_allowed(&self, src: &IpAddr, dst: &IpAddr) -> bool { + for filter in &self.filters { + if !filter(src, dst) { + return false; + } + } + true } } diff --git a/src/lib.rs b/src/lib.rs index 24bb7c3..ab72413 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,7 +7,7 @@ mod packet; pub use packet::AnyIpPktFrame; mod filter; -pub use filter::{Filters, Ipv4Filter, Ipv6Filter}; +pub use filter::{IpFilter, IpFilters}; pub mod udp; pub use udp::UdpSocket; @@ -17,3 +17,6 @@ pub use tcp::{TcpListener, TcpStream}; pub mod stack; pub use stack::{Stack, StackBuilder}; + +/// Re-export +pub use smoltcp; diff --git a/src/stack.rs b/src/stack.rs index fc720ea..4ffaf13 100644 --- a/src/stack.rs +++ b/src/stack.rs @@ -1,6 +1,6 @@ use std::{ io, - net::{Ipv4Addr, Ipv6Addr}, + net::IpAddr, pin::Pin, task::{Context, Poll}, }; @@ -10,14 +10,19 @@ use smoltcp::wire::IpProtocol; use tokio::sync::mpsc::{channel, Receiver, Sender}; use tracing::{debug, trace}; -use super::{packet::IpPacket, AnyIpPktFrame, Filters, Runner, TcpListener, UdpSocket}; +use crate::{ + filter::{IpFilter, IpFilters}, + packet::{AnyIpPktFrame, IpPacket}, + runner::Runner, + tcp::TcpListener, + udp::UdpSocket, +}; pub struct StackBuilder { stack_buffer_size: usize, udp_buffer_size: usize, tcp_buffer_size: usize, - src_filters: Filters<'static>, - dst_filters: Filters<'static>, + ip_filters: IpFilters<'static>, } impl Default for StackBuilder { @@ -26,8 +31,7 @@ impl Default for StackBuilder { stack_buffer_size: 1024, udp_buffer_size: 256, tcp_buffer_size: 512, - src_filters: Default::default(), - dst_filters: Default::default(), + ip_filters: IpFilters::with_non_broadcast(), } } } @@ -49,35 +53,21 @@ impl StackBuilder { self } - pub fn add_src_v4_filter(mut self, filter: F) -> Self - where - F: Fn(&Ipv4Addr) -> bool + Send + Sync + 'static, - { - self.src_filters.add_v4(Box::new(filter)); + pub fn set_ip_filters(mut self, filters: IpFilters<'static>) -> Self { + self.ip_filters = filters; self } - pub fn add_dst_v4_filter(mut self, filter: F) -> Self - where - F: Fn(&Ipv4Addr) -> bool + Send + Sync + 'static, - { - self.dst_filters.add_v4(Box::new(filter)); + pub fn add_ip_filter(mut self, filter: IpFilter<'static>) -> Self { + self.ip_filters.add(filter); self } - pub fn add_src_v6_filter(mut self, filter: F) -> Self + pub fn add_ip_filter_fn(mut self, filter: F) -> Self where - F: Fn(&Ipv6Addr) -> bool + Send + Sync + 'static, + F: Fn(&IpAddr, &IpAddr) -> bool + Send + Sync + 'static, { - self.src_filters.add_v6(Box::new(filter)); - self - } - - pub fn add_dst_v6_filter(mut self, filter: F) -> Self - where - F: Fn(&Ipv6Addr) -> bool + Send + Sync + 'static, - { - self.dst_filters.add_v6(Box::new(filter)); + self.ip_filters.add_fn(filter); self } @@ -89,8 +79,7 @@ impl StackBuilder { let udp_socket = UdpSocket::new(udp_rx, stack_tx.clone()); let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx); let stack = Stack { - src_filters: self.src_filters, - dst_filters: self.dst_filters, + ip_filters: self.ip_filters, sink_buf: None, stack_rx, udp_tx, @@ -114,8 +103,7 @@ impl StackBuilder { } pub struct Stack { - src_filters: Filters<'static>, - dst_filters: Filters<'static>, + ip_filters: IpFilters<'static>, sink_buf: Option, udp_tx: Sender, tcp_tx: Sender, @@ -174,16 +162,13 @@ impl Sink for Stack { let src_ip = packet.src_addr(); let dst_ip = packet.dst_addr(); - let src_allowed = self.src_filters.is_allowed(&src_ip); - let dst_allowed = self.dst_filters.is_allowed(&dst_ip); - - if !(src_allowed && dst_allowed) { + let addr_allowed = self.ip_filters.is_allowed(&src_ip, &dst_ip); + if !addr_allowed { trace!( - "IP packet {} (allowed? {}) -> {} (allowed? {}) throwing away", + "IP packet {} -> {} (allowed? {}) throwing away", src_ip, - src_allowed, dst_ip, - dst_allowed, + addr_allowed, ); return Poll::Ready(Ok(())); }