From 4e8e2b5a743fb7180df3e470f871dc3ab3a76515 Mon Sep 17 00:00:00 2001 From: sujit Date: Wed, 18 Oct 2023 08:59:28 +0545 Subject: [PATCH] feat: add `Hijack()` --- pkg/websocket/config.go | 90 ------------- pkg/websocket/conn.go | 23 +--- pkg/websocket/envelope.go | 7 - pkg/websocket/hub.go | 91 ------------- pkg/websocket/melody.go | 272 -------------------------------------- pkg/websocket/session.go | 251 ----------------------------------- 6 files changed, 2 insertions(+), 732 deletions(-) delete mode 100644 pkg/websocket/config.go delete mode 100644 pkg/websocket/envelope.go delete mode 100644 pkg/websocket/hub.go delete mode 100644 pkg/websocket/melody.go delete mode 100644 pkg/websocket/session.go diff --git a/pkg/websocket/config.go b/pkg/websocket/config.go deleted file mode 100644 index b176fd1..0000000 --- a/pkg/websocket/config.go +++ /dev/null @@ -1,90 +0,0 @@ -package websocket - -import ( - "time" - - "github.com/oarkflow/frame" -) - -// Config hub configuration struct. -type Config struct { - WriteWait time.Duration // Milliseconds until write times out. - PongWait time.Duration // Timeout for waiting on pong. - PingPeriod time.Duration // Milliseconds between pings. - MaxMessageSize int64 // Maximum size in bytes of a message. - MessageBufferSize int // The max amount of messages that can be in a sessions buffer before it starts dropping them. - AutoCleanSession bool - CleanInterval time.Duration - Upgrader *Upgrader - Handlers *Handlers -} - -func defaultConfig(config *Config) { - if config.WriteWait == 0 { - config.WriteWait = 10 * time.Second - } - if config.PongWait == 0 { - config.PongWait = time.Minute - } - if config.PingPeriod == 0 { - config.PingPeriod = (time.Minute * 9) / 10 - } - if config.MaxMessageSize == 0 { - config.MaxMessageSize = 512 - } - if config.MessageBufferSize == 0 { - config.MessageBufferSize = 256 - } - if config.CleanInterval == 0 { - config.CleanInterval = time.Minute - } - if config.Upgrader == nil { - config.Upgrader = &Upgrader{} - } - if config.Handlers == nil { - config.Handlers = &Handlers{} - } - defaultUpgrader(config.Upgrader) - defaultHandlers(config.Handlers) -} - -func defaultUpgrader(upgrader *Upgrader) { - if upgrader.ReadBufferSize == 0 { - upgrader.ReadBufferSize = 1024 - } - if upgrader.WriteBufferSize == 0 { - upgrader.WriteBufferSize = 102400 - } - if upgrader.CheckOrigin == nil { - upgrader.CheckOrigin = func(r *frame.Context) bool { return true } - } -} - -func defaultHandlers(handlers *Handlers) { - if handlers.ConnectHandler == nil { - handlers.ConnectHandler = func(session *Session) {} - } - if handlers.DisconnectHandler == nil { - handlers.DisconnectHandler = func(session *Session) {} - } - if handlers.PongHandler == nil { - handlers.PongHandler = func(session *Session) {} - } - if handlers.BinaryMessageHandler == nil { - handlers.BinaryMessageHandler = func(session *Session, bytes []byte) {} - } - if handlers.MessageHandler == nil { - handlers.MessageHandler = func(session *Session, bytes []byte) {} - } - if handlers.MessageSentHandler == nil { - handlers.MessageSentHandler = func(session *Session, bytes []byte) {} - } - if handlers.BinaryMessageSentHandler == nil { - handlers.BinaryMessageSentHandler = func(session *Session, bytes []byte) {} - } - if handlers.CloseHandler == nil { - handlers.CloseHandler = func(session *Session, i int, s string) error { - return nil - } - } -} diff --git a/pkg/websocket/conn.go b/pkg/websocket/conn.go index 7ff8b94..d62b03e 100644 --- a/pkg/websocket/conn.go +++ b/pkg/websocket/conn.go @@ -252,7 +252,6 @@ type Conn struct { writeBufSize int writeDeadline time.Time writer io.WriteCloser // the current writer returned to the application - isWriting bool // for best-effort concurrent write detection writeErrMu sync.Mutex writeErr error @@ -613,18 +612,8 @@ func (w *messageWriter) flushFrame(final bool, extra []byte) error { // concurrent writes. See the concurrency section in the package // documentation for more info. - if c.isWriting { - panic("concurrent write to websocket connection") - } - c.isWriting = true - err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) - if !c.isWriting { - panic("concurrent write to websocket connection") - } - c.isWriting = false - if err != nil { return w.endMessage(err) } @@ -741,16 +730,8 @@ func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { if err != nil { return err } - if c.isWriting { - panic("concurrent write to websocket connection") - } - c.isWriting = true - err = c.write(frameType, c.writeDeadline, frameData, nil) - if !c.isWriting { - panic("concurrent write to websocket connection") - } - c.isWriting = false - return err + + return c.write(frameType, c.writeDeadline, frameData, nil) } // WriteMessage is a helper method for getting a writer using NextWriter, diff --git a/pkg/websocket/envelope.go b/pkg/websocket/envelope.go deleted file mode 100644 index c6a3f85..0000000 --- a/pkg/websocket/envelope.go +++ /dev/null @@ -1,7 +0,0 @@ -package websocket - -type envelope struct { - t int - msg []byte - filter FilterFunc -} diff --git a/pkg/websocket/hub.go b/pkg/websocket/hub.go deleted file mode 100644 index 551de8e..0000000 --- a/pkg/websocket/hub.go +++ /dev/null @@ -1,91 +0,0 @@ -package websocket - -import ( - "sync" -) - -type hub struct { - sessions map[*Session]bool - broadcast chan *envelope - register chan *Session - unregister chan *Session - exit chan *envelope - open bool - rwmutex *sync.RWMutex -} - -func newHub() *hub { - return &hub{ - sessions: make(map[*Session]bool), - broadcast: make(chan *envelope), - register: make(chan *Session), - unregister: make(chan *Session), - exit: make(chan *envelope), - open: true, - rwmutex: &sync.RWMutex{}, - } -} - -func (h *hub) run() { -loop: - for { - select { - case s := <-h.register: - h.rwmutex.Lock() - h.sessions[s] = true - h.rwmutex.Unlock() - case s := <-h.unregister: - if _, ok := h.sessions[s]; ok { - h.rwmutex.Lock() - delete(h.sessions, s) - h.rwmutex.Unlock() - } - case m := <-h.broadcast: - h.rwmutex.RLock() - for s := range h.sessions { - if m.filter != nil { - if m.filter(s) { - s.writeMessage(m) - } - } else { - s.writeMessage(m) - } - } - h.rwmutex.RUnlock() - case m := <-h.exit: - h.rwmutex.Lock() - for s := range h.sessions { - s.writeMessage(m) - delete(h.sessions, s) - s.Close() - } - h.open = false - h.rwmutex.Unlock() - break loop - } - } -} - -func (h *hub) closed() bool { - h.rwmutex.RLock() - defer h.rwmutex.RUnlock() - return !h.open -} - -func (h *hub) len() int { - h.rwmutex.RLock() - defer h.rwmutex.RUnlock() - - return len(h.sessions) -} - -func (h *hub) all() []*Session { - h.rwmutex.RLock() - defer h.rwmutex.RUnlock() - - s := make([]*Session, 0, len(h.sessions)) - for k := range h.sessions { - s = append(s, k) - } - return s -} diff --git a/pkg/websocket/melody.go b/pkg/websocket/melody.go deleted file mode 100644 index de1210c..0000000 --- a/pkg/websocket/melody.go +++ /dev/null @@ -1,272 +0,0 @@ -package websocket - -import ( - "github.com/oarkflow/frame" - "github.com/oarkflow/frame/pkg/common/xid" -) - -type HandleMessageFunc func(*Session, []byte) -type HandleErrorFunc func(*Session, error) -type HandleCloseFunc func(*Session, int, string) error -type HandleSessionFunc func(*Session) -type FilterFunc func(*Session) bool - -type Handlers struct { - MessageHandler HandleMessageFunc - BinaryMessageHandler HandleMessageFunc - MessageSentHandler HandleMessageFunc - BinaryMessageSentHandler HandleMessageFunc - ErrorHandler HandleErrorFunc - CloseHandler HandleCloseFunc - ConnectHandler HandleSessionFunc - DisconnectHandler HandleSessionFunc - PongHandler HandleSessionFunc -} - -// Hub implements a websocket manager. -type Hub struct { - config *Config - hub *hub -} - -// NewHub creates a new hub instance with default Upgrader and Config. -func NewHub(cfg ...*Config) *Hub { - var config *Config - if len(cfg) > 0 { - config = cfg[0] - } else { - config = &Config{} - } - defaultConfig(config) - hub := newHub() - go hub.run() - return &Hub{ - config: config, - hub: hub, - } -} - -// OnConnect fires fn when a session connects. -func (m *Hub) OnConnect(fn func(*Session)) { - m.config.Handlers.ConnectHandler = fn -} - -// OnDisconnect fires fn when a session disconnects. -func (m *Hub) OnDisconnect(fn func(*Session)) { - m.config.Handlers.DisconnectHandler = fn -} - -// OnPong fires fn when a pong is received from a session. -func (m *Hub) OnPong(fn func(*Session)) { - m.config.Handlers.PongHandler = fn -} - -// OnMessage fires fn when a text message comes in. -func (m *Hub) OnMessage(fn func(*Session, []byte)) { - m.config.Handlers.MessageHandler = fn -} - -// OnBinaryMessage fires fn when a binary message comes in. -func (m *Hub) OnBinaryMessage(fn func(*Session, []byte)) { - m.config.Handlers.BinaryMessageHandler = fn -} - -// OnMessageSent fires fn when a text message is successfully sent. -func (m *Hub) OnMessageSent(fn func(*Session, []byte)) { - m.config.Handlers.MessageSentHandler = fn -} - -// OnBinaryMessageSent fires fn when a binary message is successfully sent. -func (m *Hub) OnBinaryMessageSent(fn func(*Session, []byte)) { - m.config.Handlers.BinaryMessageSentHandler = fn -} - -// OnError fires fn when a session has an error. -func (m *Hub) OnError(fn func(*Session, error)) { - m.config.Handlers.ErrorHandler = fn -} - -// OnClose sets the handler for close messages received from the session. -// The code argument to h is the received close code or CloseNoStatusReceived -// if the close message is empty. The default close handler sends a close frame -// back to the session. -// -// The application must read the connection to process close messages as -// described in the section on Control Frames above. -// -// The connection read methods return a CloseError when a close frame is -// received. Most applications should handle close messages as part of their -// normal error handling. Applications should only set a close handler when the -// application must perform some action before sending a close frame back to -// the session. -func (m *Hub) OnClose(fn func(*Session, int, string) error) { - m.config.Handlers.CloseHandler = fn -} - -// OnRequest upgrades http requests to websocket connections and dispatches them to be handled by the hub instance. -func (m *Hub) OnRequest(ctx *frame.Context) (string, error) { - return m.OnRequestWithKeys(ctx, nil) -} - -// OnRequestWithKeys does the same as HandleRequest but populates session.Keys with keys. -func (m *Hub) OnRequestWithKeys(ctx *frame.Context, keys map[string]interface{}) (string, error) { - if m.hub.closed() { - return "", ErrClosed - } - m.config.Upgrader.Subprotocols = []string{string(ctx.GetHeader("Sec-WebSocket-Protocol"))} - id := xid.New().String() - return id, m.config.Upgrader.Upgrade(ctx, func(conn *Conn) { - session := NewSession(id, ctx, conn, keys, m.config.MessageBufferSize) - session.hub = m - m.hub.register <- session - m.config.Handlers.ConnectHandler(session) - go session.writePump() - session.readPump() - if !m.hub.closed() { - m.hub.unregister <- session - } - session.close() - m.config.Handlers.DisconnectHandler(session) - }) -} - -// Broadcast broadcasts a text message to all sessions. -func (m *Hub) Broadcast(msg []byte) error { - if m.hub.closed() { - return ErrClosed - } - - message := &envelope{t: TextMessage, msg: msg} - m.hub.broadcast <- message - - return nil -} - -// Notify broadcasts a text message to all sessions. -func (m *Hub) Notify(msg []byte, sessionID string) error { - if m.hub.closed() { - return ErrClosed - } - for _, session := range m.hub.all() { - if session.ID == sessionID { - return session.Write(msg) - } - } - return nil -} - -// BroadcastFilter broadcasts a text message to all sessions that fn returns true for. -func (m *Hub) BroadcastFilter(msg []byte, fn func(*Session) bool) error { - if m.hub.closed() { - return ErrClosed - } - - message := &envelope{t: TextMessage, msg: msg, filter: fn} - m.hub.broadcast <- message - - return nil -} - -// BroadcastExcept broadcasts a text message to all sessions except session s. -func (m *Hub) BroadcastExcept(msg []byte, s *Session) error { - return m.BroadcastFilter(msg, func(q *Session) bool { - return s != q - }) -} - -// BroadcastMultiple broadcasts a text message to multiple sessions given in the sessions slice. -func (m *Hub) BroadcastMultiple(msg []byte, sessions []*Session) error { - for _, sess := range sessions { - if writeErr := sess.Write(msg); writeErr != nil { - return writeErr - } - } - return nil -} - -// BroadcastBinary broadcasts a binary message to all sessions. -func (m *Hub) BroadcastBinary(msg []byte) error { - if m.hub.closed() { - return ErrClosed - } - - message := &envelope{t: BinaryMessage, msg: msg} - m.hub.broadcast <- message - - return nil -} - -// BroadcastBinaryFilter broadcasts a binary message to all sessions that fn returns true for. -func (m *Hub) BroadcastBinaryFilter(msg []byte, fn func(*Session) bool) error { - if m.hub.closed() { - return ErrClosed - } - - message := &envelope{t: BinaryMessage, msg: msg, filter: fn} - m.hub.broadcast <- message - - return nil -} - -// BroadcastBinaryOthers broadcasts a binary message to all sessions except session s. -func (m *Hub) BroadcastBinaryOthers(msg []byte, s *Session) error { - return m.BroadcastBinaryFilter(msg, func(q *Session) bool { - return s != q - }) -} - -// Sessions returns all sessions. An error is returned if the hub session is closed. -func (m *Hub) Sessions() ([]*Session, error) { - if m.hub.closed() { - return nil, ErrClosed - } - return m.hub.all(), nil -} - -// DeleteSession delete session -func (m *Hub) DeleteSession(sess *Session) { - delete(m.hub.sessions, sess) -} - -// SessionByID returns all sessions. An error is returned if the hub session is closed. -func (m *Hub) SessionByID(sessionID string) *Session { - for _, session := range m.hub.all() { - if session.ID == sessionID { - return session - } - } - return nil -} - -// Close closes the hub instance and all connected sessions. -func (m *Hub) Close() error { - if m.hub.closed() { - return ErrClosed - } - - m.hub.exit <- &envelope{t: CloseMessage, msg: []byte{}} - - return nil -} - -// CloseWithMsg closes the hub instance with the given close payload and all connected sessions. -// Use the FormatCloseMessage function to format a proper close message payload. -func (m *Hub) CloseWithMsg(msg []byte) error { - if m.hub.closed() { - return ErrClosed - } - - m.hub.exit <- &envelope{t: CloseMessage, msg: msg} - - return nil -} - -// Len return the number of connected sessions. -func (m *Hub) Len() int { - return m.hub.len() -} - -// IsClosed returns the status of the hub instance. -func (m *Hub) IsClosed() bool { - return m.hub.closed() -} diff --git a/pkg/websocket/session.go b/pkg/websocket/session.go deleted file mode 100644 index b9fecb5..0000000 --- a/pkg/websocket/session.go +++ /dev/null @@ -1,251 +0,0 @@ -package websocket - -import ( - "fmt" - "net" - "sync" - "time" - - "github.com/oarkflow/frame" -) - -// Session wrapper around websocket connections. -type Session struct { - ID string - Request *frame.Context - Channels []string - Keys map[string]interface{} - conn *Conn - output chan *envelope - outputDone chan struct{} - hub *Hub - open bool - rwmutex *sync.RWMutex -} - -func NewSession(id string, ctx *frame.Context, conn *Conn, keys map[string]any, messageBufferSize int) *Session { - return &Session{ - ID: id, - Request: ctx, - Keys: keys, - Channels: []string{}, - conn: conn, - output: make(chan *envelope, messageBufferSize), - outputDone: make(chan struct{}), - open: true, - rwmutex: &sync.RWMutex{}, - } -} - -func (s *Session) writeMessage(message *envelope) error { - if s.closed() { - s.hub.config.Handlers.ErrorHandler(s, ErrWriteClosed) - return ErrWriteClosed - } - - select { - case s.output <- message: - default: - s.hub.config.Handlers.ErrorHandler(s, ErrMessageBufferFull) - return ErrMessageBufferFull - } - return nil -} - -func (s *Session) writeRaw(message *envelope) error { - if s.closed() { - return ErrWriteClosed - } - - s.conn.SetWriteDeadline(time.Now().Add(s.hub.config.WriteWait)) - err := s.conn.WriteMessage(message.t, message.msg) - - if err != nil { - return err - } - - return nil -} - -func (s *Session) closed() bool { - s.rwmutex.RLock() - defer s.rwmutex.RUnlock() - - return !s.open -} - -func (s *Session) close() { - s.rwmutex.Lock() - open := s.open - s.open = false - s.rwmutex.Unlock() - if open { - s.conn.Close() - close(s.outputDone) - } -} - -func (s *Session) ping() { - s.writeRaw(&envelope{t: PingMessage, msg: []byte{}}) -} - -func (s *Session) writePump() { - ticker := time.NewTicker(s.hub.config.PingPeriod) - defer ticker.Stop() - -loop: - for { - select { - case msg := <-s.output: - err := s.writeRaw(msg) - if s.hub.config.Handlers.ErrorHandler == nil { - s.hub.config.Handlers.ErrorHandler = func(session *Session, err error) { - fmt.Println("Caught error on", session.ID, err.Error()) - } - } - if err != nil { - s.hub.config.Handlers.ErrorHandler(s, err) - break loop - } - - if msg.t == CloseMessage { - break loop - } - - if msg.t == TextMessage { - s.hub.config.Handlers.MessageSentHandler(s, msg.msg) - } - - if msg.t == BinaryMessage { - s.hub.config.Handlers.BinaryMessageSentHandler(s, msg.msg) - } - case <-ticker.C: - s.ping() - case _, ok := <-s.outputDone: - if !ok { - break loop - } - } - } -} - -func (s *Session) readPump() { - s.conn.SetReadLimit(s.hub.config.MaxMessageSize) - s.conn.SetReadDeadline(time.Now().Add(s.hub.config.PongWait)) - - s.conn.SetPongHandler(func(string) error { - s.conn.SetReadDeadline(time.Now().Add(s.hub.config.PongWait)) - s.hub.config.Handlers.PongHandler(s) - return nil - }) - - if s.hub.config.Handlers.CloseHandler != nil { - s.conn.SetCloseHandler(func(code int, text string) error { - return s.hub.config.Handlers.CloseHandler(s, code, text) - }) - } - - for { - t, message, err := s.conn.ReadMessage() - - if err != nil { - s.hub.config.Handlers.ErrorHandler(s, err) - break - } - - if t == TextMessage { - s.hub.config.Handlers.MessageHandler(s, message) - } - - if t == BinaryMessage { - s.hub.config.Handlers.BinaryMessageSentHandler(s, message) - } - } -} - -// Write writes message to session. -func (s *Session) Write(msg []byte) error { - if s.closed() { - return ErrSessionClosed - } - - return s.writeMessage(&envelope{t: TextMessage, msg: msg}) -} - -// WriteBinary writes a binary message to session. -func (s *Session) WriteBinary(msg []byte) error { - if s.closed() { - return ErrSessionClosed - } - - return s.writeMessage(&envelope{t: BinaryMessage, msg: msg}) -} - -// Close closes session. -func (s *Session) Close() error { - if s.closed() { - return ErrSessionClosed - } - - return s.writeMessage(&envelope{t: CloseMessage, msg: []byte{}}) -} - -// CloseWithMsg closes the session with the provided payload. -// Use the FormatCloseMessage function to format a proper close message payload. -func (s *Session) CloseWithMsg(msg []byte) error { - if s.closed() { - return ErrSessionClosed - } - - return s.writeMessage(&envelope{t: CloseMessage, msg: msg}) -} - -// Set is used to store a new key/value pair exclusively for this session. -// It also lazy initializes s.Keys if it was not used previously. -func (s *Session) Set(key string, value interface{}) { - s.rwmutex.Lock() - defer s.rwmutex.Unlock() - - if s.Keys == nil { - s.Keys = make(map[string]interface{}) - } - - s.Keys[key] = value -} - -// Get returns the value for the given key, ie: (value, true). -// If the value does not exists it returns (nil, false) -func (s *Session) Get(key string) (value interface{}, exists bool) { - s.rwmutex.RLock() - defer s.rwmutex.RUnlock() - - if s.Keys != nil { - value, exists = s.Keys[key] - } - - return -} - -// MustGet returns the value for the given key if it exists, otherwise it panics. -func (s *Session) MustGet(key string) interface{} { - if value, exists := s.Get(key); exists { - return value - } - - panic("Key \"" + key + "\" does not exist") -} - -// IsClosed returns the status of the connection. -func (s *Session) IsClosed() bool { - return s.closed() -} - -// LocalAddr returns the local addr of the connection. -func (s *Session) LocalAddr() net.Addr { - return s.conn.LocalAddr() -} - -// RemoteAddr returns the remote addr of the connection. -func (s *Session) RemoteAddr() net.Addr { - return s.conn.RemoteAddr() -}