-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver_protocol.py
66 lines (52 loc) · 2.42 KB
/
server_protocol.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from dataclasses import dataclass, field
from struct import unpack
from typing import Callable, Any
from signatures import Verifier
from util import *
@dataclass
class ServerProtocol:
clients: dict[IP, ColorPoint]
ongoing_strokes: dict[IP, Stroke]
finished_strokes: list[Stroke]
change_flag: Flag
_verifiers: dict[IP, Verifier] = field(default_factory=dict)
_transport: Any = None # used for returning messages
def connection_made(self, transport) -> None:
self._transport = transport
def datagram_received(self, data: bytes, full_address: ClientAddress) -> None:
client_ip = full_address[0]
# Ping requests
if len(data) != 500:
try:
self._verifiers[client_ip] = Verifier(data.strip())
self.ongoing_strokes[client_ip] = Stroke()
print(f"Replying to ping packet from new client at: {client_ip}")
self._transport.sendto(b"Freddie nice to meet\r\r", full_address) # add useful info
finally:
return
# Verify signature
client_verifier: Verifier = self._verifiers.get(client_ip)
*current_position, pressed, signature_length = unpack(">ff iii", data[:20])
if (client_verifier is None) or (not client_verifier.verify(data[:16], data[20:20 + signature_length])):
print(f"Got a non-valid request from {client_ip}")
return
# Update positions
self.clients[client_ip] = nudge_point(self.clients.get(client_ip), current_position)
if pressed:
self.ongoing_strokes[client_ip].append(self.clients[client_ip][:2])
else:
current_stroke = self.ongoing_strokes[client_ip]
if current_stroke.points:
current_stroke.color = current_position[2]
self.finished_strokes.append(current_stroke)
self.ongoing_strokes[client_ip] = Stroke()
self.change_flag.triggered = True
@staticmethod
def error_received(exc) -> None:
print("Error received: ", exc)
@staticmethod
def connection_lost(exc) -> None:
print("Error received: ", exc)
def protocol_factory(clients: dict[IP, ColorPoint], ongoing_strokes: dict[IP, Stroke],
finished_strokes: list[Stroke], change_flag: Flag) -> Callable[[], ServerProtocol]:
return lambda: ServerProtocol(clients, ongoing_strokes, finished_strokes, change_flag)