Skip to content

Commit

Permalink
chore!: refactor ip filters and also fix some docs
Browse files Browse the repository at this point in the history
  • Loading branch information
cavivie committed Apr 10, 2024
1 parent 1a92a79 commit b07f21e
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 90 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ at your option.
### Contribution

Unless you explicitly state otherwise, any contribution intentionally submitted
for inclusion in cc-rs by you, as defined in the Apache-2.0 license, shall be
dual licensed as above, without any additional terms or conditions.
for inclusion in netstack-smoltcp by you, as defined in the Apache-2.0 license,
shall be dual licensed as above, without any additional terms or conditions.
6 changes: 3 additions & 3 deletions examples/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ async fn main_exec(opt: Opt) {
let mut builder = StackBuilder::default();
if let Some(device_broadcast) = get_device_broadcast(&device) {
builder = builder
.add_src_v4_filter(move |v4| *v4 != device_broadcast)
.add_dst_v4_filter(move |v4| *v4 != device_broadcast);
// .add_ip_filter(Box::new(move |src, dst| *src != device_broadcast && *dst != device_broadcast));
.add_ip_filter_fn(move |src, dst| *src != device_broadcast && *dst != device_broadcast);
}

let (runner, udp_socket, tcp_listener, stack) = builder.build();
Expand Down Expand Up @@ -147,7 +147,7 @@ async fn main_exec(opt: Opt) {
futs.push(tokio_spawn!(async move {
while let Some(pkt) = tun_stream.next().await {
if let Ok(pkt) = pkt {
match stack_sink.send(pkt.to_vec()).await {
match stack_sink.send(pkt).await {
Ok(_) => {}
Err(e) => warn!("failed to send packet to stack, err: {:?}", e),
};
Expand Down
83 changes: 37 additions & 46 deletions src/filter.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,55 @@
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::net::IpAddr;

pub type Ipv4Filter<'a> = Box<dyn Fn(&Ipv4Addr) -> bool + Send + Sync + 'a>;
pub type Ipv6Filter<'a> = Box<dyn Fn(&Ipv6Addr) -> bool + Send + Sync + 'a>;
pub type IpFilter<'a> = Box<dyn Fn(&IpAddr, &IpAddr) -> bool + Send + Sync + 'a>;

pub struct Filters<'a> {
ipv4_filters: Vec<Ipv4Filter<'a>>,
ipv6_filters: Vec<Ipv6Filter<'a>>,
pub struct IpFilters<'a> {
filters: Vec<IpFilter<'a>>,
}

impl<'a> Default for Filters<'a> {
impl<'a> Default for IpFilters<'a> {
fn default() -> Self {
Self::new(
vec![
Box::new(|v4| !v4.is_broadcast()),
Box::new(|v4| !v4.is_multicast()),
Box::new(|v4| !v4.is_unspecified()),
],
vec![
Box::new(|v6| !v6.is_multicast()),
Box::new(|v6| !v6.is_unspecified()),
],
)
Self::new(vec![])
}
}

impl<'a> Filters<'a> {
pub fn new(ipv4_filters: Vec<Ipv4Filter<'a>>, ipv6_filters: Vec<Ipv6Filter<'a>>) -> Self {
Self {
ipv4_filters,
ipv6_filters,
}
impl<'a> IpFilters<'a> {
pub fn new(filters: Vec<IpFilter<'a>>) -> Self {
Self { filters }
}

pub fn is_allowed(&self, addr: &IpAddr) -> bool {
match addr {
IpAddr::V4(v4) => {
for filter in &self.ipv4_filters {
if !filter(v4) {
return false;
}
}
}
IpAddr::V6(v6) => {
for filter in &self.ipv6_filters {
if !filter(v6) {
return false;
pub fn with_non_broadcast() -> Self {
Self::new(vec![Box::new(|src, dst| {
macro_rules! non_broadcast {
($addr:expr) => {
match $addr {
IpAddr::V4(v4) => {
!(v4.is_broadcast() || v4.is_multicast() || v4.is_multicast())
}
IpAddr::V6(v6) => !(v6.is_multicast() || v6.is_unspecified()),
}
}
};
}
}
true
non_broadcast!(src) && non_broadcast!(dst)
})])
}

#[allow(unused)]
pub fn add_v4(&mut self, filter: Ipv4Filter<'a>) {
self.ipv4_filters.push(filter);
pub fn add(&mut self, filter: IpFilter<'a>) {
self.filters.push(filter);
}

#[allow(unused)]
pub fn add_v6(&mut self, filter: Ipv6Filter<'a>) {
self.ipv6_filters.push(filter);
pub fn add_fn<F>(&mut self, filter: F)
where
F: Fn(&IpAddr, &IpAddr) -> bool + Send + Sync + 'a,
{
self.filters.push(Box::new(filter));
}

pub fn is_allowed(&self, src: &IpAddr, dst: &IpAddr) -> bool {
for filter in &self.filters {
if !filter(src, dst) {
return false;
}
}
true
}
}
5 changes: 4 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mod packet;
pub use packet::AnyIpPktFrame;

mod filter;
pub use filter::{Filters, Ipv4Filter, Ipv6Filter};
pub use filter::{IpFilter, IpFilters};

pub mod udp;
pub use udp::UdpSocket;
Expand All @@ -17,3 +17,6 @@ pub use tcp::{TcpListener, TcpStream};

pub mod stack;
pub use stack::{Stack, StackBuilder};

/// Re-export
pub use smoltcp;
61 changes: 23 additions & 38 deletions src/stack.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{
io,
net::{Ipv4Addr, Ipv6Addr},
net::IpAddr,
pin::Pin,
task::{Context, Poll},
};
Expand All @@ -10,14 +10,19 @@ use smoltcp::wire::IpProtocol;
use tokio::sync::mpsc::{channel, Receiver, Sender};
use tracing::{debug, trace};

use super::{packet::IpPacket, AnyIpPktFrame, Filters, Runner, TcpListener, UdpSocket};
use crate::{
filter::{IpFilter, IpFilters},
packet::{AnyIpPktFrame, IpPacket},
runner::Runner,
tcp::TcpListener,
udp::UdpSocket,
};

pub struct StackBuilder {
stack_buffer_size: usize,
udp_buffer_size: usize,
tcp_buffer_size: usize,
src_filters: Filters<'static>,
dst_filters: Filters<'static>,
ip_filters: IpFilters<'static>,
}

impl Default for StackBuilder {
Expand All @@ -26,8 +31,7 @@ impl Default for StackBuilder {
stack_buffer_size: 1024,
udp_buffer_size: 256,
tcp_buffer_size: 512,
src_filters: Default::default(),
dst_filters: Default::default(),
ip_filters: IpFilters::with_non_broadcast(),
}
}
}
Expand All @@ -49,35 +53,21 @@ impl StackBuilder {
self
}

pub fn add_src_v4_filter<F>(mut self, filter: F) -> Self
where
F: Fn(&Ipv4Addr) -> bool + Send + Sync + 'static,
{
self.src_filters.add_v4(Box::new(filter));
pub fn set_ip_filters(mut self, filters: IpFilters<'static>) -> Self {
self.ip_filters = filters;
self
}

pub fn add_dst_v4_filter<F>(mut self, filter: F) -> Self
where
F: Fn(&Ipv4Addr) -> bool + Send + Sync + 'static,
{
self.dst_filters.add_v4(Box::new(filter));
pub fn add_ip_filter(mut self, filter: IpFilter<'static>) -> Self {
self.ip_filters.add(filter);
self
}

pub fn add_src_v6_filter<F>(mut self, filter: F) -> Self
pub fn add_ip_filter_fn<F>(mut self, filter: F) -> Self
where
F: Fn(&Ipv6Addr) -> bool + Send + Sync + 'static,
F: Fn(&IpAddr, &IpAddr) -> bool + Send + Sync + 'static,
{
self.src_filters.add_v6(Box::new(filter));
self
}

pub fn add_dst_v6_filter<F>(mut self, filter: F) -> Self
where
F: Fn(&Ipv6Addr) -> bool + Send + Sync + 'static,
{
self.dst_filters.add_v6(Box::new(filter));
self.ip_filters.add_fn(filter);
self
}

Expand All @@ -89,8 +79,7 @@ impl StackBuilder {
let udp_socket = UdpSocket::new(udp_rx, stack_tx.clone());
let (tcp_runner, tcp_listener) = TcpListener::new(tcp_rx, stack_tx);
let stack = Stack {
src_filters: self.src_filters,
dst_filters: self.dst_filters,
ip_filters: self.ip_filters,
sink_buf: None,
stack_rx,
udp_tx,
Expand All @@ -114,8 +103,7 @@ impl StackBuilder {
}

pub struct Stack {
src_filters: Filters<'static>,
dst_filters: Filters<'static>,
ip_filters: IpFilters<'static>,
sink_buf: Option<AnyIpPktFrame>,
udp_tx: Sender<AnyIpPktFrame>,
tcp_tx: Sender<AnyIpPktFrame>,
Expand Down Expand Up @@ -174,16 +162,13 @@ impl Sink<AnyIpPktFrame> for Stack {
let src_ip = packet.src_addr();
let dst_ip = packet.dst_addr();

let src_allowed = self.src_filters.is_allowed(&src_ip);
let dst_allowed = self.dst_filters.is_allowed(&dst_ip);

if !(src_allowed && dst_allowed) {
let addr_allowed = self.ip_filters.is_allowed(&src_ip, &dst_ip);
if !addr_allowed {
trace!(
"IP packet {} (allowed? {}) -> {} (allowed? {}) throwing away",
"IP packet {} -> {} (allowed? {}) throwing away",
src_ip,
src_allowed,
dst_ip,
dst_allowed,
addr_allowed,
);
return Poll::Ready(Ok(()));
}
Expand Down

0 comments on commit b07f21e

Please sign in to comment.