1
1
use std:: { borrow:: Cow , str:: from_utf8, sync:: Arc , task:: Poll } ;
2
2
3
- use crate :: { error:: Result , Error , Packet , PacketId } ;
3
+ use crate :: { error:: Result , Error , Packet , PacketId , PacketSerializer } ;
4
4
use bytes:: { BufMut , Bytes , BytesMut } ;
5
5
use futures_util:: {
6
6
ready,
@@ -22,16 +22,19 @@ type AsyncWebsocketReceiver = SplitStream<WebSocketStream<MaybeTlsStream<TcpStre
22
22
pub ( crate ) struct AsyncWebsocketGeneralTransport {
23
23
sender : Arc < Mutex < AsyncWebsocketSender > > ,
24
24
receiver : Arc < Mutex < AsyncWebsocketReceiver > > ,
25
+ serializer : Arc < PacketSerializer > ,
25
26
}
26
27
27
28
impl AsyncWebsocketGeneralTransport {
28
29
pub ( crate ) async fn new (
29
30
sender : SplitSink < WebSocketStream < MaybeTlsStream < TcpStream > > , Message > ,
30
31
receiver : SplitStream < WebSocketStream < MaybeTlsStream < TcpStream > > > ,
32
+ serializer : Arc < PacketSerializer > ,
31
33
) -> Self {
32
34
AsyncWebsocketGeneralTransport {
33
35
sender : Arc :: new ( Mutex :: new ( sender) ) ,
34
36
receiver : Arc :: new ( Mutex :: new ( receiver) ) ,
37
+ serializer,
35
38
}
36
39
}
37
40
@@ -41,25 +44,30 @@ impl AsyncWebsocketGeneralTransport {
41
44
let mut receiver = self . receiver . lock ( ) . await ;
42
45
let mut sender = self . sender . lock ( ) . await ;
43
46
47
+ let ping_packet = Packet :: new ( PacketId :: Ping , Bytes :: from ( "probe" ) ) ;
48
+ let ping_packet = self . serializer . encode ( ping_packet) ;
49
+
44
50
sender
45
- . send ( Message :: text ( Cow :: Borrowed ( from_utf8 ( & Bytes :: from (
46
- Packet :: new ( PacketId :: Ping , Bytes :: from ( "probe" ) ) ,
47
- ) ) ?) ) )
51
+ . send ( Message :: text ( Cow :: Borrowed ( from_utf8 ( & ping_packet) ?) ) )
48
52
. await ?;
49
53
50
54
let msg = receiver
51
55
. next ( )
52
56
. await
53
57
. ok_or ( Error :: IllegalWebsocketUpgrade ( ) ) ??;
54
58
55
- if msg. into_data ( ) != Bytes :: from ( Packet :: new ( PacketId :: Pong , Bytes :: from ( "probe" ) ) ) {
59
+ let pong_packet = Packet :: new ( PacketId :: Pong , Bytes :: from ( "probe" ) ) ;
60
+ let pong_packet = self . serializer . encode ( pong_packet) ;
61
+
62
+ if msg. into_data ( ) != pong_packet {
56
63
return Err ( Error :: InvalidPacket ( ) ) ;
57
64
}
58
65
66
+ let upgrade_packet = Packet :: new ( PacketId :: Upgrade , Bytes :: from ( "" ) ) ;
67
+ let upgrade_packet = self . serializer . encode ( upgrade_packet) ;
68
+
59
69
sender
60
- . send ( Message :: text ( Cow :: Borrowed ( from_utf8 ( & Bytes :: from (
61
- Packet :: new ( PacketId :: Upgrade , Bytes :: from ( "" ) ) ,
62
- ) ) ?) ) )
70
+ . send ( Message :: text ( Cow :: Borrowed ( from_utf8 ( & upgrade_packet) ?) ) )
63
71
. await ?;
64
72
65
73
Ok ( ( ) )
0 commit comments