Skip to content

Commit 5e96948

Browse files
Add serilazer to where it needs
1 parent 97cffe8 commit 5e96948

File tree

11 files changed

+77
-31
lines changed

11 files changed

+77
-31
lines changed

engineio/src/asynchronous/async_socket.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,10 @@ impl Socket {
120120
}
121121

122122
/// Helper method that parses bytes and returns an iterator over the elements.
123-
fn parse_payload(bytes: Bytes, serializer: Arc<PacketSerializer>) -> impl Stream<Item = Result<Packet>> {
123+
fn parse_payload(
124+
bytes: Bytes,
125+
serializer: Arc<PacketSerializer>,
126+
) -> impl Stream<Item = Result<Packet>> {
124127
try_stream! {
125128
// let payload = Payload::try_from(bytes);
126129
let payload = serializer.decode_payload(bytes);

engineio/src/asynchronous/async_transports/websocket.rs

+13-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::sync::Arc;
44

55
use crate::asynchronous::transport::AsyncTransport;
66
use crate::error::Result;
7+
use crate::PacketSerializer;
78
use async_trait::async_trait;
89
use bytes::Bytes;
910
use futures_util::stream::StreamExt;
@@ -27,7 +28,11 @@ pub struct WebsocketTransport {
2728

2829
impl WebsocketTransport {
2930
/// Creates a new instance over a request that might hold additional headers and an URL.
30-
pub async fn new(base_url: Url, headers: Option<HeaderMap>) -> Result<Self> {
31+
pub async fn new(
32+
base_url: Url,
33+
headers: Option<HeaderMap>,
34+
serializer: Arc<PacketSerializer>,
35+
) -> Result<Self> {
3136
let mut url = base_url;
3237
url.query_pairs_mut().append_pair("transport", "websocket");
3338
url.set_scheme("ws").unwrap();
@@ -41,7 +46,7 @@ impl WebsocketTransport {
4146
let (ws_stream, _) = connect_async(req).await?;
4247
let (sen, rec) = ws_stream.split();
4348

44-
let inner = AsyncWebsocketGeneralTransport::new(sen, rec).await;
49+
let inner = AsyncWebsocketGeneralTransport::new(sen, rec, serializer).await;
4550
Ok(WebsocketTransport {
4651
inner,
4752
base_url: Arc::new(RwLock::new(url)),
@@ -118,7 +123,12 @@ mod test {
118123
let url = crate::test::engine_io_server()?.to_string()
119124
+ "engine.io/?EIO="
120125
+ &ENGINE_IO_VERSION.to_string();
121-
WebsocketTransport::new(Url::from_str(&url[..])?, None).await
126+
WebsocketTransport::new(
127+
Url::from_str(&url[..])?,
128+
None,
129+
PacketSerializer::default_arc(),
130+
)
131+
.await
122132
}
123133

124134
#[tokio::test]

engineio/src/asynchronous/async_transports/websocket_general.rs

+16-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::{borrow::Cow, str::from_utf8, sync::Arc, task::Poll};
22

3-
use crate::{error::Result, Error, Packet, PacketId};
3+
use crate::{error::Result, Error, Packet, PacketId, PacketSerializer};
44
use bytes::{BufMut, Bytes, BytesMut};
55
use futures_util::{
66
ready,
@@ -22,16 +22,19 @@ type AsyncWebsocketReceiver = SplitStream<WebSocketStream<MaybeTlsStream<TcpStre
2222
pub(crate) struct AsyncWebsocketGeneralTransport {
2323
sender: Arc<Mutex<AsyncWebsocketSender>>,
2424
receiver: Arc<Mutex<AsyncWebsocketReceiver>>,
25+
serializer: Arc<PacketSerializer>,
2526
}
2627

2728
impl AsyncWebsocketGeneralTransport {
2829
pub(crate) async fn new(
2930
sender: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
3031
receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
32+
serializer: Arc<PacketSerializer>,
3133
) -> Self {
3234
AsyncWebsocketGeneralTransport {
3335
sender: Arc::new(Mutex::new(sender)),
3436
receiver: Arc::new(Mutex::new(receiver)),
37+
serializer,
3538
}
3639
}
3740

@@ -41,25 +44,30 @@ impl AsyncWebsocketGeneralTransport {
4144
let mut receiver = self.receiver.lock().await;
4245
let mut sender = self.sender.lock().await;
4346

47+
let ping_packet = Packet::new(PacketId::Ping, Bytes::from("probe"));
48+
let ping_packet = self.serializer.encode(ping_packet);
49+
4450
sender
45-
.send(Message::text(Cow::Borrowed(from_utf8(&Bytes::from(
46-
Packet::new(PacketId::Ping, Bytes::from("probe")),
47-
))?)))
51+
.send(Message::text(Cow::Borrowed(from_utf8(&ping_packet)?)))
4852
.await?;
4953

5054
let msg = receiver
5155
.next()
5256
.await
5357
.ok_or(Error::IllegalWebsocketUpgrade())??;
5458

55-
if msg.into_data() != Bytes::from(Packet::new(PacketId::Pong, Bytes::from("probe"))) {
59+
let pong_packet = Packet::new(PacketId::Pong, Bytes::from("probe"));
60+
let pong_packet = self.serializer.encode(pong_packet);
61+
62+
if msg.into_data() != pong_packet {
5663
return Err(Error::InvalidPacket());
5764
}
5865

66+
let upgrade_packet = Packet::new(PacketId::Upgrade, Bytes::from(""));
67+
let upgrade_packet = self.serializer.encode(upgrade_packet);
68+
5969
sender
60-
.send(Message::text(Cow::Borrowed(from_utf8(&Bytes::from(
61-
Packet::new(PacketId::Upgrade, Bytes::from("")),
62-
))?)))
70+
.send(Message::text(Cow::Borrowed(from_utf8(&upgrade_packet)?)))
6371
.await?;
6472

6573
Ok(())

engineio/src/asynchronous/async_transports/websocket_secure.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::sync::Arc;
44

55
use crate::asynchronous::transport::AsyncTransport;
66
use crate::error::Result;
7+
use crate::PacketSerializer;
78
use async_trait::async_trait;
89
use bytes::Bytes;
910
use futures_util::Stream;
@@ -34,6 +35,7 @@ impl WebsocketSecureTransport {
3435
base_url: Url,
3536
tls_config: Option<TlsConnector>,
3637
headers: Option<HeaderMap>,
38+
serializer: Arc<PacketSerializer>,
3739
) -> Result<Self> {
3840
let mut url = base_url;
3941
url.query_pairs_mut().append_pair("transport", "websocket");
@@ -61,7 +63,7 @@ impl WebsocketSecureTransport {
6163
.await?;
6264

6365
let (sen, rec) = ws_stream.split();
64-
let inner = AsyncWebsocketGeneralTransport::new(sen, rec).await;
66+
let inner = AsyncWebsocketGeneralTransport::new(sen, rec, serializer).await;
6567

6668
Ok(WebsocketSecureTransport {
6769
inner,
@@ -143,6 +145,7 @@ mod test {
143145
Url::from_str(&url[..])?,
144146
Some(crate::test::tls_connector()?),
145147
None,
148+
PacketSerializer::default_arc(),
146149
)
147150
.await
148151
}

engineio/src/asynchronous/client/builder.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ impl ClientBuilder {
227227

228228
match self.url.scheme() {
229229
"http" | "ws" => {
230-
let mut transport = WebsocketTransport::new(self.url.clone(), headers).await?;
230+
let mut transport =
231+
WebsocketTransport::new(self.url.clone(), headers, self.serializer.clone())
232+
.await?;
231233

232234
if self.handshake.is_some() {
233235
transport.upgrade().await?;
@@ -252,6 +254,7 @@ impl ClientBuilder {
252254
self.url.clone(),
253255
self.tls_config.clone(),
254256
headers,
257+
self.serializer.clone(),
255258
)
256259
.await?;
257260

engineio/src/client/client.rs

+9-5
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ impl ClientBuilder {
6565
}
6666

6767
/// Specify Packet Serializer
68-
pub fn packet_serializer(mut self, packet_serializer: PacketSerializer) -> Self {
69-
self.serializer = Arc::new(packet_serializer);
68+
pub fn packet_serializer(mut self, packet_serializer: Arc<PacketSerializer>) -> Self {
69+
self.serializer = packet_serializer;
7070

7171
self
7272
}
@@ -228,7 +228,7 @@ impl ClientBuilder {
228228

229229
match url.scheme() {
230230
"http" | "ws" => {
231-
let transport = WebsocketTransport::new(url, headers)?;
231+
let transport = WebsocketTransport::new(url, headers, self.serializer.clone())?;
232232
if self.handshake.is_some() {
233233
transport.upgrade()?;
234234
} else {
@@ -250,8 +250,12 @@ impl ClientBuilder {
250250
})
251251
}
252252
"https" | "wss" => {
253-
let transport =
254-
WebsocketSecureTransport::new(url, self.tls_config.clone(), headers)?;
253+
let transport = WebsocketSecureTransport::new(
254+
url,
255+
self.tls_config.clone(),
256+
headers,
257+
self.serializer.clone(),
258+
)?;
255259
if self.handshake.is_some() {
256260
transport.upgrade()?;
257261
} else {

engineio/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ pub mod error;
9797

9898
pub use client::{Client, ClientBuilder};
9999
pub use error::Error;
100-
pub use packet::{Packet, PacketId};
100+
pub use packet::{Packet, PacketId, PacketSerializer};
101101

102102
#[cfg(test)]
103103
pub(crate) mod test {

engineio/src/packet.rs

+4
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ impl PacketSerializer {
9191
let _ = buf.split_off(buf.len() - 1);
9292
buf.freeze()
9393
}
94+
95+
pub fn default_arc() -> std::sync::Arc<Self> {
96+
std::sync::Arc::new(Self::default())
97+
}
9498
}
9599

96100
impl Default for PacketSerializer {

engineio/src/transports/websocket.rs

+13-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::{
44
},
55
error::Result,
66
transport::Transport,
7-
Error,
7+
Error, PacketSerializer,
88
};
99
use bytes::Bytes;
1010
use http::HeaderMap;
@@ -20,12 +20,17 @@ pub struct WebsocketTransport {
2020

2121
impl WebsocketTransport {
2222
/// Creates an instance of `WebsocketTransport`.
23-
pub fn new(base_url: Url, headers: Option<HeaderMap>) -> Result<Self> {
23+
pub fn new(
24+
base_url: Url,
25+
headers: Option<HeaderMap>,
26+
serializer: Arc<PacketSerializer>,
27+
) -> Result<Self> {
2428
let runtime = tokio::runtime::Builder::new_current_thread()
2529
.enable_all()
2630
.build()?;
2731

28-
let inner = runtime.block_on(AsyncWebsocketTransport::new(base_url, headers))?;
32+
let inner =
33+
runtime.block_on(AsyncWebsocketTransport::new(base_url, headers, serializer))?;
2934

3035
Ok(WebsocketTransport {
3136
runtime: Arc::new(runtime),
@@ -90,7 +95,11 @@ mod test {
9095
let url = crate::test::engine_io_server()?.to_string()
9196
+ "engine.io/?EIO="
9297
+ &ENGINE_IO_VERSION.to_string();
93-
WebsocketTransport::new(Url::from_str(&url[..])?, None)
98+
WebsocketTransport::new(
99+
Url::from_str(&url[..])?,
100+
None,
101+
PacketSerializer::default_arc(),
102+
)
94103
}
95104

96105
#[test]

engineio/src/transports/websocket_secure.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::{
55
},
66
error::Result,
77
transport::Transport,
8-
Error,
8+
Error, PacketSerializer,
99
};
1010
use bytes::Bytes;
1111
use http::HeaderMap;
@@ -26,13 +26,14 @@ impl WebsocketSecureTransport {
2626
base_url: Url,
2727
tls_config: Option<TlsConnector>,
2828
headers: Option<HeaderMap>,
29+
serializer: Arc<PacketSerializer>,
2930
) -> Result<Self> {
3031
let runtime = tokio::runtime::Builder::new_current_thread()
3132
.enable_all()
3233
.build()?;
3334

3435
let inner = runtime.block_on(AsyncWebsocketSecureTransport::new(
35-
base_url, tls_config, headers,
36+
base_url, tls_config, headers, serializer,
3637
))?;
3738

3839
Ok(WebsocketSecureTransport {
@@ -99,6 +100,7 @@ mod test {
99100
Url::from_str(&url[..])?,
100101
Some(crate::test::tls_connector()?),
101102
None,
103+
PacketSerializer::default_arc(),
102104
)
103105
}
104106

socketio/src/client/builder.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ impl Default for TransportType {
3333
}
3434
}
3535

36-
pub use rust_engineio::Packet::PacketSerializer;
36+
pub use rust_engineio::PacketSerializer;
3737

3838
/// A builder class for a `socket.io` socket. This handles setting up the client and
3939
/// configuring the callback, the namespace and metadata of the socket. If no
@@ -48,7 +48,7 @@ pub struct ClientBuilder {
4848
tls_config: Option<TlsConnector>,
4949
opening_headers: Option<HeaderMap>,
5050
transport_type: TransportType,
51-
packet_serializer: PacketSerializer,
51+
packet_serializer: Arc<PacketSerializer>,
5252
auth: Option<serde_json::Value>,
5353
pub(crate) reconnect: bool,
5454
pub(crate) reconnect_on_disconnect: bool,
@@ -100,7 +100,7 @@ impl ClientBuilder {
100100
tls_config: None,
101101
opening_headers: None,
102102
transport_type: TransportType::default(),
103-
packet_serializer: PacketSerializer::default(),
103+
packet_serializer: PacketSerializer::default_arc(),
104104
auth: None,
105105
reconnect: true,
106106
reconnect_on_disconnect: false,
@@ -335,7 +335,7 @@ impl ClientBuilder {
335335
/// }
336336
/// ```
337337
pub fn packet_serializer(mut self, packet_serializer: PacketSerializer) -> Self {
338-
self.packet_serializer = packet_serializer;
338+
self.packet_serializer = Arc::new(packet_serializer);
339339

340340
self
341341
}
@@ -376,7 +376,7 @@ impl ClientBuilder {
376376
}
377377

378378
let mut builder =
379-
EngineIoClientBuilder::new(url).packet_serializer(self.packet_serializer.into());
379+
EngineIoClientBuilder::new(url).packet_serializer(self.packet_serializer.clone());
380380

381381
if let Some(tls_config) = self.tls_config {
382382
builder = builder.tls_config(tls_config);

0 commit comments

Comments
 (0)