Skip to content

Commit cc10b7f

Browse files
committed
chore(data): restructure node impl
1 parent 174a392 commit cc10b7f

File tree

2 files changed

+77
-49
lines changed

2 files changed

+77
-49
lines changed

crates/server/src/node.rs

+65-43
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
use std::{fs, net::{SocketAddr, IpAddr, Ipv4Addr}};
2+
use std::collections::{HashMap, HashSet};
3+
use std::future::Future;
4+
use std::sync::{Arc, Mutex};
25

3-
use log::info;
6+
use log::{error, info};
47

58
use openssl::pkey::Private;
69
use openssl::rsa::Rsa;
@@ -13,59 +16,69 @@ use osp_protocol::OSPUrl;
1316
use crate::connection::inbound::{InboundConnection, TransferState};
1417
use crate::connection::outbound::OutboundConnection;
1518

19+
pub struct InitState {
20+
private_key: Option<Rsa<Private>>,
21+
}
1622

17-
pub struct OSProtocolNodeBuilder {
23+
pub struct ConnectionState {
24+
private_key: Rsa<Private>,
25+
}
26+
27+
#[derive(Clone)]
28+
pub struct OSProtocolNode<TState> {
1829
bind_addr: SocketAddr,
1930
hostname: String,
20-
private_key: Option<Rsa<Private>>,
31+
state: Arc<Mutex<TState>>,
2132
}
2233

23-
impl OSProtocolNodeBuilder {
24-
pub fn bind_to(mut self, addr: SocketAddr) -> Self {
34+
impl OSProtocolNode<InitState> {
35+
pub fn new() -> Self {
36+
OSProtocolNode::<InitState> {
37+
bind_addr: SocketAddr::new(IpAddr::from(Ipv4Addr::LOCALHOST), 57401),
38+
hostname: "".to_string(),
39+
state: Arc::new(Mutex::new(InitState {
40+
private_key: None,
41+
})),
42+
}
43+
}
44+
45+
pub fn set_addr(&mut self, addr: SocketAddr) {
2546
self.bind_addr = addr;
26-
self
2747
}
2848

29-
pub fn hostname(mut self, hostname: String) -> Self {
49+
pub fn set_hostname(&mut self, hostname: String) {
3050
self.hostname = hostname;
31-
self
3251
}
3352

34-
pub fn private_key_file(mut self, path: String) -> Self {
53+
pub fn set_private_key_file(&mut self, path: String) {
3554
let key_contents = fs::read_to_string(path.clone()).expect(format!("Unable to open private key file {}", path).as_str());
36-
self.private_key = Some(Rsa::private_key_from_pem(key_contents.as_bytes()).unwrap());
37-
self
55+
self.state.lock().unwrap().private_key = Some(Rsa::private_key_from_pem(key_contents.as_bytes()).unwrap());
3856
}
3957

40-
pub fn build(self) -> OSProtocolNode {
41-
OSProtocolNode {
42-
bind_addr: self.bind_addr,
43-
hostname: self.hostname,
44-
private_key: self.private_key.unwrap(),
58+
pub fn init(&mut self) -> OSProtocolNode<ConnectionState> {
59+
let bind_addr = self.bind_addr.clone();
60+
let hostname = self.hostname.clone();
61+
let private_key = self.state.lock().unwrap().private_key.clone().unwrap();
62+
OSProtocolNode::<ConnectionState> {
63+
bind_addr,
64+
hostname,
65+
state: Arc::new(Mutex::new(ConnectionState {
66+
private_key,
67+
})),
4568
}
4669
}
4770
}
4871

49-
#[derive(Clone)]
50-
pub struct OSProtocolNode {
51-
bind_addr: SocketAddr,
52-
hostname: String,
53-
private_key: Rsa<Private>,
54-
}
55-
56-
impl OSProtocolNode {
57-
pub fn builder() -> OSProtocolNodeBuilder {
58-
OSProtocolNodeBuilder {
59-
bind_addr: SocketAddr::new(IpAddr::from(Ipv4Addr::LOCALHOST), 57401),
60-
hostname: "".to_string(),
61-
private_key: None,
62-
}
63-
}
64-
65-
pub async fn listen(&self) -> io::Result<()> {
72+
impl OSProtocolNode<ConnectionState> {
73+
pub async fn listen<'a, F, Fut>(&'a mut self, conn_handler: F) -> io::Result<()>
74+
where
75+
F: Fn(InboundConnection<TransferState>, &Arc<Mutex<ConnectionState>>) -> Fut + Send + Copy + 'static,
76+
Fut: Future<Output = Result<(), ()>> + Send + 'static,
77+
{
6678
let port = self.bind_addr.port();
6779
let listener = TcpListener::bind(self.bind_addr).await?;
6880
info!("Listening started on port {port}, ready to accept connections");
81+
6982
loop {
7083
// The second item contains the IP and port of the new connection.
7184
let (stream, _) = listener.accept().await?;
@@ -78,21 +91,30 @@ impl OSProtocolNode {
7891
.unwrap_or("unknown address".to_string())
7992
);
8093

81-
self.start_connection(stream);
94+
let state_rc = self.state.clone();
95+
tokio::spawn(async move {
96+
let mut connection_handshake = InboundConnection::with_stream(stream).unwrap();
97+
match connection_handshake.begin().await {
98+
Ok(_) => {
99+
let connection_transfer = InboundConnection::<TransferState>::from(connection_handshake);
100+
101+
let _ = conn_handler(connection_transfer, &state_rc).await?;
102+
}
103+
Err(e) => {
104+
error!("Handshake failed: {e}");
105+
}
106+
}
107+
});
82108
}
83109
}
84110

85-
fn start_connection(&self, stream: TcpStream) {
86-
tokio::spawn(async move {
87-
let mut connection_handshake = InboundConnection::with_stream(stream).unwrap();
88-
connection_handshake.begin().await?;
89-
let mut connection_transfer = InboundConnection::<TransferState>::from(connection_handshake);
90-
});
91-
}
92-
93111
pub async fn create_outbound(&self, url: OSPUrl) -> io::Result<()> {
94112
info!("Starting outbound connection to {url}");
95-
let mut conn = OutboundConnection::create(url, self.private_key.clone(), self.hostname.clone()).await?;
113+
let mut conn = OutboundConnection::create(
114+
url,
115+
self.state.lock().unwrap().private_key.clone(),
116+
self.hostname.clone()
117+
).await?;
96118
let mut conn_in_handshake = conn.begin().await?;
97119
conn_in_handshake.handshake().await
98120
}

examples/test_server_implementation/src/bin/server.rs

+12-6
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,19 @@ async fn main() -> io::Result<()> {
3535

3636
let args = Args::parse();
3737
let addr = SocketAddrV4::new(args.bind.parse().expect("Invalid bind address"), args.port);
38-
let node = OSProtocolNode::builder()
39-
.bind_to(SocketAddr::from(addr))
40-
.private_key_file(args.private_key)
41-
.hostname(args.hostname)
42-
.build();
38+
let mut node = OSProtocolNode::new();
39+
node.set_addr(SocketAddr::from(addr));
40+
node.set_private_key_file(args.private_key);
41+
node.set_hostname(args.hostname);
4342

44-
node.listen().await
43+
44+
let mut connection_node = node.init();
45+
connection_node.listen(|connection, state| async move {
46+
47+
Ok(())
48+
}).await?;
49+
50+
Ok(())
4551

4652
// for uri in args.push_to {
4753
// let osp_url = OSPUrl::from(Url::parse(uri.as_str()).unwrap());

0 commit comments

Comments
 (0)