diff --git a/websocketproxy.go b/websocketproxy.go index 29f5bda..9a2d312 100644 --- a/websocketproxy.go +++ b/websocketproxy.go @@ -2,6 +2,7 @@ package websocketproxy import ( + "context" "fmt" "io" "log" @@ -13,6 +14,10 @@ import ( "github.com/gorilla/websocket" ) +const ( + websocketProxyClosingMsg = "websocketproxy: closing connection" +) + var ( // DefaultUpgrader specifies the parameters for upgrading an HTTP // connection to a WebSocket connection. @@ -25,27 +30,39 @@ var ( DefaultDialer = websocket.DefaultDialer ) -// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket -// connection and proxies it to another server. -type WebsocketProxy struct { - // Director, if non-nil, is a function that may copy additional request - // headers from the incoming WebSocket connection into the output headers - // which will be forwarded to another server. - Director func(incoming *http.Request, out http.Header) - - // Backend returns the backend URL which the proxy uses to reverse proxy - // the incoming WebSocket connection. Request is the initial incoming and - // unmodified request. - Backend func(*http.Request) *url.URL - - // Upgrader specifies the parameters for upgrading a incoming HTTP - // connection to a WebSocket connection. If nil, DefaultUpgrader is used. - Upgrader *websocket.Upgrader - - // Dialer contains options for connecting to the backend WebSocket server. - // If nil, DefaultDialer is used. - Dialer *websocket.Dialer -} +type ( + // WebsocketProxy is an HTTP Handler that takes an incoming WebSocket + // connection and proxies it to another server. + WebsocketProxy struct { + // Director, if non-nil, is a function that may copy additional request + // headers from the incoming WebSocket connection into the output headers + // which will be forwarded to another server. + Director func(incoming *http.Request, out http.Header) + + // Backend returns the backend URL which the proxy uses to reverse proxy + // the incoming WebSocket connection. Request is the initial incoming and + // unmodified request. + Backend func(*http.Request) *url.URL + + // Upgrader specifies the parameters for upgrading a incoming HTTP + // connection to a WebSocket connection. If nil, DefaultUpgrader is used. + Upgrader *websocket.Upgrader + + // Dialer contains options for connecting to the backend WebSocket server. + // If nil, DefaultDialer is used. + Dialer *websocket.Dialer + + // done specifies a channel for which all proxied websocket connections + // can be closed on demand by closing the channel. + done chan struct{} + } + + websocketMsg struct { + msgType int + msg []byte + err error + } +) // ProxyHandler returns a new http.Handler interface that reverse proxies the // request to the given target. @@ -174,23 +191,38 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { errClient := make(chan error, 1) errBackend := make(chan error, 1) + if w.done == nil { + w.done = make(chan struct{}) + } + replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { - for { + websocketMsgRcverC := make(chan websocketMsg, 1) + websocketMsgRcver := func() <-chan websocketMsg { msgType, msg, err := src.ReadMessage() - if err != nil { - m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err)) - if e, ok := err.(*websocket.CloseError); ok { - if e.Code != websocket.CloseNoStatusReceived { - m = websocket.FormatCloseMessage(e.Code, e.Text) + websocketMsgRcverC <- websocketMsg{msgType, msg, err} + return websocketMsgRcverC + } + + for { + select { + case websocketMsgRcv := <-websocketMsgRcver(): + if websocketMsgRcv.err != nil { + m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", websocketMsgRcv.err)) + if e, ok := websocketMsgRcv.err.(*websocket.CloseError); ok { + if e.Code != websocket.CloseNoStatusReceived { + m = websocket.FormatCloseMessage(e.Code, e.Text) + } } + errc <- websocketMsgRcv.err + dst.WriteMessage(websocket.CloseMessage, m) + break } - errc <- err - dst.WriteMessage(websocket.CloseMessage, m) - break - } - err = dst.WriteMessage(msgType, msg) - if err != nil { - errc <- err + err = dst.WriteMessage(websocketMsgRcv.msgType, websocketMsgRcv.msg) + if err != nil { + errc <- err + break + } + case <-w.done: break } } @@ -199,17 +231,30 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { go replicateWebsocketConn(connPub, connBackend, errClient) go replicateWebsocketConn(connBackend, connPub, errBackend) - var message string select { case err = <-errClient: - message = "websocketproxy: Error when copying from backend to client: %v" + if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure { + log.Printf("websocketproxy: Error when copying from backend to client: %v", err) + } case err = <-errBackend: - message = "websocketproxy: Error when copying from client to backend: %v" - + if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure { + log.Printf("websocketproxy: Error when copying from client to backend: %v", err) + } + case <-w.done: + m := websocket.FormatCloseMessage(websocket.CloseGoingAway, websocketProxyClosingMsg) + connPub.WriteMessage(websocket.CloseMessage, m) + connBackend.WriteMessage(websocket.CloseMessage, m) } - if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure { - log.Printf(message, err) +} + +// Shutdown closes ws connections by closing the done channel they are subscribed to. +func (w *WebsocketProxy) Shutdown(ctx context.Context) error { + // TODO: support using context for control and return error when applicable + // Currently implemented such that the method signature matches http.Server.Shutdown() + if w.done != nil { + close(w.done) } + return nil } func copyHeader(dst, src http.Header) { diff --git a/websocketproxy_test.go b/websocketproxy_test.go index 63f7861..88cedf8 100644 --- a/websocketproxy_test.go +++ b/websocketproxy_test.go @@ -1,6 +1,7 @@ package websocketproxy import ( + "context" "log" "net/http" "net/url" @@ -17,6 +18,7 @@ var ( func TestProxy(t *testing.T) { // websocket proxy + u, _ := url.Parse(backendURL) supportedSubProtocols := []string{"test-protocol"} upgrader := &websocket.Upgrader{ ReadBufferSize: 4096, @@ -27,13 +29,12 @@ func TestProxy(t *testing.T) { Subprotocols: supportedSubProtocols, } - u, _ := url.Parse(backendURL) proxy := NewProxy(u) proxy.Upgrader = upgrader - mux := http.NewServeMux() - mux.Handle("/proxy", proxy) go func() { + mux := http.NewServeMux() + mux.Handle("/proxy", proxy) if err := http.ListenAndServe(":7777", mux); err != nil { t.Fatal("ListenAndServe: ", err) } @@ -42,6 +43,7 @@ func TestProxy(t *testing.T) { time.Sleep(time.Millisecond * 100) // backend echo server + websocketMsgRcverCBackend := make(chan websocketMsg, 1) go func() { mux2 := http.NewServeMux() mux2.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { @@ -51,13 +53,16 @@ func TestProxy(t *testing.T) { return } - messageType, p, err := conn.ReadMessage() - if err != nil { - return - } + for { + messageType, p, err := conn.ReadMessage() + if err != nil { + websocketMsgRcverCBackend <- websocketMsg{messageType, p, err} + return + } - if err = conn.WriteMessage(messageType, p); err != nil { - return + if err = conn.WriteMessage(messageType, p); err != nil { + return + } } }) @@ -69,8 +74,8 @@ func TestProxy(t *testing.T) { time.Sleep(time.Millisecond * 100) - // let's us define two subprotocols, only one is supported by the server - clientSubProtocols := []string{"test-protocol", "test-notsupported"} + // define subprotocols for client, appending one not supported by the server + clientSubProtocols := append(supportedSubProtocols, []string{"test-notsupported"}...) h := http.Header{} for _, subprot := range clientSubProtocols { h.Add("Sec-WebSocket-Protocol", subprot) @@ -101,8 +106,7 @@ func TestProxy(t *testing.T) { t.Error("test-notsupported should be not recevied from the server.") } - // now write a message and send it to the backend server (which goes trough - // proxy..) + // send msg to the backend server which goes through proxy msg := "hello kite" err = conn.WriteMessage(websocket.TextMessage, []byte(msg)) if err != nil { @@ -121,4 +125,34 @@ func TestProxy(t *testing.T) { if msg != string(p) { t.Errorf("expecting: %s, got: %s", msg, string(p)) } + + // shutdown procedure + // + proxy.Shutdown(context.Background()) + + // check close msg received in client + messageType, p, err = conn.ReadMessage() + e, ok := err.(*websocket.CloseError) + if !ok { + t.Fatal("client error is not websocket.CloseError") + } + if e.Code != websocket.CloseGoingAway { + t.Error("client error code is not websocket.CloseGoingAway") + } + if e.Text != websocketProxyClosingMsg { + t.Errorf("client error test expecting: %s, got: %s", websocketProxyClosingMsg, e.Text) + } + + // check close msg received in backend + wsErrBackend := <-websocketMsgRcverCBackend + e, ok = wsErrBackend.err.(*websocket.CloseError) + if !ok { + t.Fatal("backend error is not websocket.CloseError") + } + if e.Code != websocket.CloseGoingAway { + t.Error("backend error code is not websocket.CloseGoingAway") + } + if e.Text != websocketProxyClosingMsg { + t.Errorf("backend error test expecting: %s, got: %s", websocketProxyClosingMsg, e.Text) + } }