1
1
use std:: { fs, net:: { SocketAddr , IpAddr , Ipv4Addr } } ;
2
+ use std:: collections:: { HashMap , HashSet } ;
3
+ use std:: future:: Future ;
4
+ use std:: sync:: { Arc , Mutex } ;
2
5
3
- use log:: info;
6
+ use log:: { error , info} ;
4
7
5
8
use openssl:: pkey:: Private ;
6
9
use openssl:: rsa:: Rsa ;
@@ -13,59 +16,69 @@ use osp_protocol::OSPUrl;
13
16
use crate :: connection:: inbound:: { InboundConnection , TransferState } ;
14
17
use crate :: connection:: outbound:: OutboundConnection ;
15
18
19
+ pub struct InitState {
20
+ private_key : Option < Rsa < Private > > ,
21
+ }
16
22
17
- pub struct OSProtocolNodeBuilder {
23
+ pub struct ConnectionState {
24
+ private_key : Rsa < Private > ,
25
+ }
26
+
27
+ #[ derive( Clone ) ]
28
+ pub struct OSProtocolNode < TState > {
18
29
bind_addr : SocketAddr ,
19
30
hostname : String ,
20
- private_key : Option < Rsa < Private > > ,
31
+ state : Arc < Mutex < TState > > ,
21
32
}
22
33
23
- impl OSProtocolNodeBuilder {
24
- pub fn bind_to ( mut self , addr : SocketAddr ) -> Self {
34
+ impl OSProtocolNode < InitState > {
35
+ pub fn new ( ) -> Self {
36
+ OSProtocolNode :: < InitState > {
37
+ bind_addr : SocketAddr :: new ( IpAddr :: from ( Ipv4Addr :: LOCALHOST ) , 57401 ) ,
38
+ hostname : "" . to_string ( ) ,
39
+ state : Arc :: new ( Mutex :: new ( InitState {
40
+ private_key : None ,
41
+ } ) ) ,
42
+ }
43
+ }
44
+
45
+ pub fn set_addr ( & mut self , addr : SocketAddr ) {
25
46
self . bind_addr = addr;
26
- self
27
47
}
28
48
29
- pub fn hostname ( mut self , hostname : String ) -> Self {
49
+ pub fn set_hostname ( & mut self , hostname : String ) {
30
50
self . hostname = hostname;
31
- self
32
51
}
33
52
34
- pub fn private_key_file ( mut self , path : String ) -> Self {
53
+ pub fn set_private_key_file ( & mut self , path : String ) {
35
54
let key_contents = fs:: read_to_string ( path. clone ( ) ) . expect ( format ! ( "Unable to open private key file {}" , path) . as_str ( ) ) ;
36
- self . private_key = Some ( Rsa :: private_key_from_pem ( key_contents. as_bytes ( ) ) . unwrap ( ) ) ;
37
- self
55
+ self . state . lock ( ) . unwrap ( ) . private_key = Some ( Rsa :: private_key_from_pem ( key_contents. as_bytes ( ) ) . unwrap ( ) ) ;
38
56
}
39
57
40
- pub fn build ( self ) -> OSProtocolNode {
41
- OSProtocolNode {
42
- bind_addr : self . bind_addr ,
43
- hostname : self . hostname ,
44
- private_key : self . private_key . unwrap ( ) ,
58
+ pub fn init ( & mut self ) -> OSProtocolNode < ConnectionState > {
59
+ let bind_addr = self . bind_addr . clone ( ) ;
60
+ let hostname = self . hostname . clone ( ) ;
61
+ let private_key = self . state . lock ( ) . unwrap ( ) . private_key . clone ( ) . unwrap ( ) ;
62
+ OSProtocolNode :: < ConnectionState > {
63
+ bind_addr,
64
+ hostname,
65
+ state : Arc :: new ( Mutex :: new ( ConnectionState {
66
+ private_key,
67
+ } ) ) ,
45
68
}
46
69
}
47
70
}
48
71
49
- #[ derive( Clone ) ]
50
- pub struct OSProtocolNode {
51
- bind_addr : SocketAddr ,
52
- hostname : String ,
53
- private_key : Rsa < Private > ,
54
- }
55
-
56
- impl OSProtocolNode {
57
- pub fn builder ( ) -> OSProtocolNodeBuilder {
58
- OSProtocolNodeBuilder {
59
- bind_addr : SocketAddr :: new ( IpAddr :: from ( Ipv4Addr :: LOCALHOST ) , 57401 ) ,
60
- hostname : "" . to_string ( ) ,
61
- private_key : None ,
62
- }
63
- }
64
-
65
- pub async fn listen ( & self ) -> io:: Result < ( ) > {
72
+ impl OSProtocolNode < ConnectionState > {
73
+ pub async fn listen < ' a , F , Fut > ( & ' a mut self , conn_handler : F ) -> io:: Result < ( ) >
74
+ where
75
+ F : Fn ( InboundConnection < TransferState > , & Arc < Mutex < ConnectionState > > ) -> Fut + Send + Copy + ' static ,
76
+ Fut : Future < Output = Result < ( ) , ( ) > > + Send + ' static ,
77
+ {
66
78
let port = self . bind_addr . port ( ) ;
67
79
let listener = TcpListener :: bind ( self . bind_addr ) . await ?;
68
80
info ! ( "Listening started on port {port}, ready to accept connections" ) ;
81
+
69
82
loop {
70
83
// The second item contains the IP and port of the new connection.
71
84
let ( stream, _) = listener. accept ( ) . await ?;
@@ -78,21 +91,30 @@ impl OSProtocolNode {
78
91
. unwrap_or( "unknown address" . to_string( ) )
79
92
) ;
80
93
81
- self . start_connection ( stream) ;
94
+ let state_rc = self . state . clone ( ) ;
95
+ tokio:: spawn ( async move {
96
+ let mut connection_handshake = InboundConnection :: with_stream ( stream) . unwrap ( ) ;
97
+ match connection_handshake. begin ( ) . await {
98
+ Ok ( _) => {
99
+ let connection_transfer = InboundConnection :: < TransferState > :: from ( connection_handshake) ;
100
+
101
+ let _ = conn_handler ( connection_transfer, & state_rc) . await ?;
102
+ }
103
+ Err ( e) => {
104
+ error ! ( "Handshake failed: {e}" ) ;
105
+ }
106
+ }
107
+ } ) ;
82
108
}
83
109
}
84
110
85
- fn start_connection ( & self , stream : TcpStream ) {
86
- tokio:: spawn ( async move {
87
- let mut connection_handshake = InboundConnection :: with_stream ( stream) . unwrap ( ) ;
88
- connection_handshake. begin ( ) . await ?;
89
- let mut connection_transfer = InboundConnection :: < TransferState > :: from ( connection_handshake) ;
90
- } ) ;
91
- }
92
-
93
111
pub async fn create_outbound ( & self , url : OSPUrl ) -> io:: Result < ( ) > {
94
112
info ! ( "Starting outbound connection to {url}" ) ;
95
- let mut conn = OutboundConnection :: create ( url, self . private_key . clone ( ) , self . hostname . clone ( ) ) . await ?;
113
+ let mut conn = OutboundConnection :: create (
114
+ url,
115
+ self . state . lock ( ) . unwrap ( ) . private_key . clone ( ) ,
116
+ self . hostname . clone ( )
117
+ ) . await ?;
96
118
let mut conn_in_handshake = conn. begin ( ) . await ?;
97
119
conn_in_handshake. handshake ( ) . await
98
120
}
0 commit comments