diff --git a/peers/e2e_test.go b/peers/e2e_test.go index 9ad808b..f8f2f02 100644 --- a/peers/e2e_test.go +++ b/peers/e2e_test.go @@ -3,6 +3,7 @@ package peers import ( + "context" "fmt" "log" "net/http" @@ -15,7 +16,7 @@ import ( func TestE2E(t *testing.T) { success := make(chan bool) - a := Peer{Handler: HandlerFunc(func(u *sticktable.EntryUpdate) { + a := Peer{Handler: HandlerFunc(func(_ context.Context, u *sticktable.EntryUpdate) { log.Println(u) success <- true })} diff --git a/peers/example/dump/main.go b/peers/example/dump/main.go index ad6c986..89d2301 100644 --- a/peers/example/dump/main.go +++ b/peers/example/dump/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "log" "github.com/dropmorepackets/haproxy-go/peers" @@ -10,7 +11,7 @@ import ( func main() { log.SetFlags(log.LstdFlags | log.Lshortfile) - err := peers.ListenAndServe(":21000", peers.HandlerFunc(func(u *sticktable.EntryUpdate) { + err := peers.ListenAndServe(":21000", peers.HandlerFunc(func(_ context.Context, u *sticktable.EntryUpdate) { log.Println(u.String()) })) if err != nil { diff --git a/peers/example/prometheus-exporter/main.go b/peers/example/prometheus-exporter/main.go index 2391579..9921872 100644 --- a/peers/example/prometheus-exporter/main.go +++ b/peers/example/prometheus-exporter/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "log" "net/http" @@ -24,7 +25,7 @@ func main() { go http.ListenAndServe(":8081", promhttp.Handler()) - err := peers.ListenAndServe(":21000", peers.HandlerFunc(func(update *sticktable.EntryUpdate) { + err := peers.ListenAndServe(":21000", peers.HandlerFunc(func(_ context.Context, update *sticktable.EntryUpdate) { for i, d := range update.Data { dt := update.StickTable.DataTypes[i].DataType switch d := d.(type) { diff --git a/peers/handler.go b/peers/handler.go index 107d638..b2b3ba2 100644 --- a/peers/handler.go +++ b/peers/handler.go @@ -1,13 +1,27 @@ package peers -import "github.com/dropmorepackets/haproxy-go/peers/sticktable" +import ( + "context" + + "github.com/dropmorepackets/haproxy-go/peers/sticktable" +) type Handler interface { - Update(*sticktable.EntryUpdate) + HandleUpdate(context.Context, *sticktable.EntryUpdate) + HandleHandshake(context.Context, *Handshake) + Close() error } -type HandlerFunc func(*sticktable.EntryUpdate) +type HandlerFunc func(context.Context, *sticktable.EntryUpdate) + +func (HandlerFunc) Close() error { return nil } -func (h HandlerFunc) Update(u *sticktable.EntryUpdate) { - h(u) +func (HandlerFunc) HandleHandshake(context.Context, *Handshake) {} + +func (h HandlerFunc) HandleUpdate(ctx context.Context, u *sticktable.EntryUpdate) { + h(ctx, u) } + +var ( + _ Handler = (HandlerFunc)(nil) +) diff --git a/peers/peers.go b/peers/peers.go index d5b1a6a..b1d13f2 100644 --- a/peers/peers.go +++ b/peers/peers.go @@ -1,7 +1,6 @@ package peers import ( - "bufio" "context" "fmt" "log" @@ -9,9 +8,10 @@ import ( ) type Peer struct { - Addr string - Handler Handler - BaseContext context.Context + Addr string + Handler Handler + HandlerSource func() Handler + BaseContext context.Context } func ListenAndServe(addr string, handler Handler) error { @@ -40,23 +40,28 @@ func (a *Peer) Serve(l net.Listener) error { l.Close() }() + if a.Handler != nil && a.HandlerSource != nil { + return fmt.Errorf("cannot set Handler and HandlerSource at the same time") + } + + if a.Handler != nil { + a.HandlerSource = func() Handler { + return a.Handler + } + } + for { nc, err := l.Accept() if err != nil { return fmt.Errorf("accepting conn: %w", err) } - conn := &Conn{ - ctx: a.BaseContext, - conn: nc, - r: bufio.NewReader(nc), - handler: a.Handler, - } - + p := newProtocolClient(a.BaseContext, nc, a.HandlerSource()) go func() { defer nc.Close() + defer p.Close() - if err := conn.Serve(); err != nil && err != conn.ctx.Err() { + if err := p.Serve(); err != nil && err != p.ctx.Err() { log.Println(err) } }() diff --git a/peers/proto.go b/peers/proto.go deleted file mode 100644 index 0185751..0000000 --- a/peers/proto.go +++ /dev/null @@ -1,99 +0,0 @@ -package peers - -import ( - "fmt" - "log" - - "github.com/dropmorepackets/haproxy-go/peers/sticktable" -) - -func (t ErrorMessageType) OnMessage(m *rawMessage, c *Conn) error { - switch t { - case ErrorMessageProtocol: - return fmt.Errorf("protocol error") - case ErrorMessageSizeLimit: - return fmt.Errorf("message size limit") - default: - return fmt.Errorf("unknown error message type: %s", t) - } -} - -func (t ControlMessageType) OnMessage(m *rawMessage, c *Conn) error { - switch t { - case ControlMessageSyncRequest: - _, _ = c.conn.Write([]byte{byte(MessageClassControl), byte(ControlMessageSyncPartial)}) - return nil - case ControlMessageSyncFinished: - return nil - case ControlMessageSyncPartial: - return nil - case ControlMessageSyncConfirmed: - return nil - case ControlMessageHeartbeat: - return nil - default: - return fmt.Errorf("unknown control message type: %s", t) - } -} - -func (t StickTableUpdateMessageType) OnMessage(m *rawMessage, c *Conn) error { - switch t { - case StickTableUpdateMessageTypeStickTableDefinition: - var std sticktable.Definition - if _, err := std.Unmarshal(m.Data); err != nil { - return err - } - c.lastTableDefinition = &std - - return nil - case StickTableUpdateMessageTypeStickTableSwitch: - log.Printf("not implemented: %s", t) - return nil - case StickTableUpdateMessageTypeUpdateAcknowledge: - log.Printf("not implemented: %s", t) - return nil - case StickTableUpdateMessageTypeEntryUpdate, - StickTableUpdateMessageTypeUpdateTimed, - StickTableUpdateMessageTypeIncrementalEntryUpdate, - StickTableUpdateMessageTypeIncrementalEntryUpdateTimed: - // All entry update messages are handled in a separate switch case - // following this one. - break - default: - return fmt.Errorf("unknown stick-table update message type: %s", t) - } - - if c.lastTableDefinition == nil { - return fmt.Errorf("cannot process entry update without table definition") - } - - e := sticktable.EntryUpdate{ - StickTable: c.lastTableDefinition, - } - log.Printf("%+v", e) - - if c.lastEntryUpdate != nil { - e.LocalUpdateID = c.lastEntryUpdate.LocalUpdateID + 1 - } - - switch t { - case StickTableUpdateMessageTypeEntryUpdate: - e.WithLocalUpdateID = true - case StickTableUpdateMessageTypeUpdateTimed: - e.WithLocalUpdateID = true - e.WithExpiry = true - case StickTableUpdateMessageTypeIncrementalEntryUpdate: - case StickTableUpdateMessageTypeIncrementalEntryUpdateTimed: - e.WithExpiry = true - } - - if _, err := e.Unmarshal(m.Data); err != nil { - return err - } - - c.lastEntryUpdate = &e - - c.handler.Update(&e) - - return nil -} diff --git a/peers/conn.go b/peers/protocol.go similarity index 56% rename from peers/conn.go rename to peers/protocol.go index 71fe6fa..22e69bf 100644 --- a/peers/conn.go +++ b/peers/protocol.go @@ -14,10 +14,11 @@ import ( "github.com/dropmorepackets/haproxy-go/pkg/encoding" ) -type Conn struct { - ctx context.Context - conn net.Conn - r *bufio.Reader +type protocolClient struct { + ctx context.Context + ctxCancel context.CancelFunc + rw io.ReadWriter + br *bufio.Reader nextHeartbeat *time.Ticker lastMessageTimer *time.Timer @@ -27,24 +28,37 @@ type Conn struct { handler Handler } -func (c *Conn) Close() error { - return c.conn.Close() +func newProtocolClient(ctx context.Context, rw io.ReadWriter, handler Handler) *protocolClient { + var c protocolClient + c.rw = rw + c.br = bufio.NewReader(rw) + c.handler = handler + c.ctx, c.ctxCancel = context.WithCancel(ctx) + return &c } -func (c *Conn) peerHandshake() error { +func (c *protocolClient) Close() error { + defer c.ctxCancel() + + return c.handler.Close() +} + +func (c *protocolClient) peerHandshake() error { var h Handshake - if _, err := h.ReadFrom(c.r); err != nil { + if _, err := h.ReadFrom(c.br); err != nil { return err } - if _, err := c.conn.Write([]byte(fmt.Sprintf("%d\n", HandshakeStatusHandshakeSucceeded))); err != nil { + c.handler.HandleHandshake(c.ctx, &h) + + if _, err := c.rw.Write([]byte(fmt.Sprintf("%d\n", HandshakeStatusHandshakeSucceeded))); err != nil { return fmt.Errorf("handshake failed: %v", err) } return nil } -func (c *Conn) resetHeartbeat() { +func (c *protocolClient) resetHeartbeat() { // a peer sends heartbeat messages to peers it is // connected to after periods of 3s of inactivity (i.e. when there is no // stick-table to synchronize for 3s). @@ -56,7 +70,7 @@ func (c *Conn) resetHeartbeat() { c.nextHeartbeat.Reset(time.Second * 3) } -func (c *Conn) resetLastMessage() { +func (c *protocolClient) resetLastMessage() { // After a successful peer protocol handshake between two peers, // if one of them does not send any other peer // protocol messages (i.e. no heartbeat and no stick-table update messages) @@ -71,9 +85,9 @@ func (c *Conn) resetLastMessage() { c.lastMessageTimer.Reset(time.Second * 5) } -func (c *Conn) heartbeat() { +func (c *protocolClient) heartbeat() { for range c.nextHeartbeat.C { - _, err := c.conn.Write([]byte{byte(MessageClassControl), byte(ControlMessageHeartbeat)}) + _, err := c.rw.Write([]byte{byte(MessageClassControl), byte(ControlMessageHeartbeat)}) if err != nil { _ = c.Close() return @@ -81,13 +95,13 @@ func (c *Conn) heartbeat() { } } -func (c *Conn) lastMessage() { +func (c *protocolClient) lastMessage() { <-c.lastMessageTimer.C log.Println("last message timer expired: closing connection") _ = c.Close() } -func (c *Conn) Serve() error { +func (c *protocolClient) Serve() error { if err := c.peerHandshake(); err != nil { return fmt.Errorf("handshake: %v", err) } @@ -100,7 +114,7 @@ func (c *Conn) Serve() error { for { var m rawMessage - if _, err := m.ReadFrom(c.r); err != nil { + if _, err := m.ReadFrom(c.br); err != nil { if c.ctx.Err() != nil { return c.ctx.Err() } @@ -119,7 +133,7 @@ func (c *Conn) Serve() error { } } -func (c *Conn) messageHandler(m *rawMessage) error { +func (c *protocolClient) messageHandler(m *rawMessage) error { switch m.MessageClass { case MessageClassControl: return ControlMessageType(m.MessageType).OnMessage(m, c) @@ -214,3 +228,93 @@ func (h *Handshake) ReadFrom(r io.Reader) (n int64, err error) { //TODO: find out how many bytes where read. return -1, scanner.Err() } + +func (t ErrorMessageType) OnMessage(m *rawMessage, c *protocolClient) error { + switch t { + case ErrorMessageProtocol: + return fmt.Errorf("protocol error") + case ErrorMessageSizeLimit: + return fmt.Errorf("message size limit") + default: + return fmt.Errorf("unknown error message type: %s", t) + } +} + +func (t ControlMessageType) OnMessage(m *rawMessage, c *protocolClient) error { + switch t { + case ControlMessageSyncRequest: + _, _ = c.rw.Write([]byte{byte(MessageClassControl), byte(ControlMessageSyncPartial)}) + return nil + case ControlMessageSyncFinished: + return nil + case ControlMessageSyncPartial: + return nil + case ControlMessageSyncConfirmed: + return nil + case ControlMessageHeartbeat: + return nil + default: + return fmt.Errorf("unknown control message type: %s", t) + } +} + +func (t StickTableUpdateMessageType) OnMessage(m *rawMessage, c *protocolClient) error { + switch t { + case StickTableUpdateMessageTypeStickTableDefinition: + var std sticktable.Definition + if _, err := std.Unmarshal(m.Data); err != nil { + return err + } + c.lastTableDefinition = &std + + return nil + case StickTableUpdateMessageTypeStickTableSwitch: + log.Printf("not implemented: %s", t) + return nil + case StickTableUpdateMessageTypeUpdateAcknowledge: + log.Printf("not implemented: %s", t) + return nil + case StickTableUpdateMessageTypeEntryUpdate, + StickTableUpdateMessageTypeUpdateTimed, + StickTableUpdateMessageTypeIncrementalEntryUpdate, + StickTableUpdateMessageTypeIncrementalEntryUpdateTimed: + // All entry update messages are handled in a separate switch case + // following this one. + break + default: + return fmt.Errorf("unknown stick-table update message type: %s", t) + } + + if c.lastTableDefinition == nil { + return fmt.Errorf("cannot process entry update without table definition") + } + + e := sticktable.EntryUpdate{ + StickTable: c.lastTableDefinition, + } + + if c.lastEntryUpdate != nil { + e.LocalUpdateID = c.lastEntryUpdate.LocalUpdateID + 1 + } + + switch t { + case StickTableUpdateMessageTypeEntryUpdate: + e.WithLocalUpdateID = true + case StickTableUpdateMessageTypeUpdateTimed: + e.WithLocalUpdateID = true + e.WithExpiry = true + case StickTableUpdateMessageTypeIncrementalEntryUpdate: + case StickTableUpdateMessageTypeIncrementalEntryUpdateTimed: + e.WithExpiry = true + } + + if _, err := e.Unmarshal(m.Data); err != nil { + return err + } + + c.lastEntryUpdate = &e + + c.handler.HandleUpdate(c.ctx, &e) + + return nil +}