From 10ab726e7459c17aa9961db1e5c63218941706a6 Mon Sep 17 00:00:00 2001 From: Tong Jia Date: Fri, 20 Jul 2018 10:15:40 -0400 Subject: [PATCH 1/5] add websocket connection cancellation via done channel --- websocketproxy.go | 106 +++++++++++++++++++++++++---------------- websocketproxy_test.go | 3 ++ 2 files changed, 69 insertions(+), 40 deletions(-) diff --git a/websocketproxy.go b/websocketproxy.go index 29f5bda..b6c341c 100644 --- a/websocketproxy.go +++ b/websocketproxy.go @@ -25,27 +25,39 @@ var ( DefaultDialer = websocket.DefaultDialer ) -// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket -// connection and proxies it to another server. -type WebsocketProxy struct { - // Director, if non-nil, is a function that may copy additional request - // headers from the incoming WebSocket connection into the output headers - // which will be forwarded to another server. - Director func(incoming *http.Request, out http.Header) - - // Backend returns the backend URL which the proxy uses to reverse proxy - // the incoming WebSocket connection. Request is the initial incoming and - // unmodified request. - Backend func(*http.Request) *url.URL - - // Upgrader specifies the parameters for upgrading a incoming HTTP - // connection to a WebSocket connection. If nil, DefaultUpgrader is used. - Upgrader *websocket.Upgrader - - // Dialer contains options for connecting to the backend WebSocket server. - // If nil, DefaultDialer is used. - Dialer *websocket.Dialer -} +type ( + // WebsocketProxy is an HTTP Handler that takes an incoming WebSocket + // connection and proxies it to another server. + WebsocketProxy struct { + // Director, if non-nil, is a function that may copy additional request + // headers from the incoming WebSocket connection into the output headers + // which will be forwarded to another server. + Director func(incoming *http.Request, out http.Header) + + // Backend returns the backend URL which the proxy uses to reverse proxy + // the incoming WebSocket connection. Request is the initial incoming and + // unmodified request. + Backend func(*http.Request) *url.URL + + // Upgrader specifies the parameters for upgrading a incoming HTTP + // connection to a WebSocket connection. If nil, DefaultUpgrader is used. + Upgrader *websocket.Upgrader + + // Dialer contains options for connecting to the backend WebSocket server. + // If nil, DefaultDialer is used. + Dialer *websocket.Dialer + + // Done specifies a channel for which all proxied websocket connections + // can be closed on demand by closing the channel. + Done chan struct{} + } + + websocketMsg struct { + msgType int + msg []byte + err error + } +) // ProxyHandler returns a new http.Handler interface that reverse proxies the // request to the given target. @@ -174,41 +186,55 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { errClient := make(chan error, 1) errBackend := make(chan error, 1) + replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { - for { + websocketMsgRcverC := make(chan websocketMsg, 1) + websocketMsgRcver := func() <-chan websocketMsg { msgType, msg, err := src.ReadMessage() - if err != nil { - m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err)) - if e, ok := err.(*websocket.CloseError); ok { - if e.Code != websocket.CloseNoStatusReceived { - m = websocket.FormatCloseMessage(e.Code, e.Text) + websocketMsgRcverC <- websocketMsg{msgType, msg, err} + return websocketMsgRcverC + } + + for { + select { + case websocketMsgRcv := <-websocketMsgRcver(): + if websocketMsgRcv.err != nil { + m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", websocketMsgRcv.err)) + if e, ok := websocketMsgRcv.err.(*websocket.CloseError); ok { + if e.Code != websocket.CloseNoStatusReceived { + m = websocket.FormatCloseMessage(e.Code, e.Text) + } } + errc <- websocketMsgRcv.err + dst.WriteMessage(websocket.CloseMessage, m) + break } - errc <- err + err = dst.WriteMessage(websocketMsgRcv.msgType, websocketMsgRcv.msg) + if err != nil { + errc <- err + break + } + case <-w.Done: + m := websocket.FormatCloseMessage(websocket.CloseGoingAway, "websocketproxy: closing connection") dst.WriteMessage(websocket.CloseMessage, m) break } - err = dst.WriteMessage(msgType, msg) - if err != nil { - errc <- err - break - } } } go replicateWebsocketConn(connPub, connBackend, errClient) go replicateWebsocketConn(connBackend, connPub, errBackend) - var message string select { case err = <-errClient: - message = "websocketproxy: Error when copying from backend to client: %v" + if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure { + log.Printf("websocketproxy: Error when copying from backend to client: %v", err) + } case err = <-errBackend: - message = "websocketproxy: Error when copying from client to backend: %v" - - } - if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure { - log.Printf(message, err) + if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure { + log.Printf("websocketproxy: Error when copying from client to backend: %v", err) + } + case <-w.Done: } } diff --git a/websocketproxy_test.go b/websocketproxy_test.go index 63f7861..3f31646 100644 --- a/websocketproxy_test.go +++ b/websocketproxy_test.go @@ -30,6 +30,7 @@ func TestProxy(t *testing.T) { u, _ := url.Parse(backendURL) proxy := NewProxy(u) proxy.Upgrader = upgrader + proxy.Done = make(chan struct{}) mux := http.NewServeMux() mux.Handle("/proxy", proxy) @@ -121,4 +122,6 @@ func TestProxy(t *testing.T) { if msg != string(p) { t.Errorf("expecting: %s, got: %s", msg, string(p)) } + + close(proxy.Done) } From 624f1943d3f130ce618dabc0a2bad2635c61d41e Mon Sep 17 00:00:00 2001 From: Tong Jia Date: Tue, 24 Jul 2018 08:52:39 -0400 Subject: [PATCH 2/5] make done channel private and create public Shutdown method --- websocketproxy.go | 27 +++++++++++++++++++++------ websocketproxy_test.go | 4 ++-- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/websocketproxy.go b/websocketproxy.go index b6c341c..2354f87 100644 --- a/websocketproxy.go +++ b/websocketproxy.go @@ -2,6 +2,7 @@ package websocketproxy import ( + "context" "fmt" "io" "log" @@ -47,9 +48,9 @@ type ( // If nil, DefaultDialer is used. Dialer *websocket.Dialer - // Done specifies a channel for which all proxied websocket connections + // done specifies a channel for which all proxied websocket connections // can be closed on demand by closing the channel. - Done chan struct{} + done chan struct{} } websocketMsg struct { @@ -186,6 +187,9 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { errClient := make(chan error, 1) errBackend := make(chan error, 1) + if w.done == nil { + w.done = make(chan struct{}) + } replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { websocketMsgRcverC := make(chan websocketMsg, 1) @@ -214,9 +218,7 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { errc <- err break } - case <-w.Done: - m := websocket.FormatCloseMessage(websocket.CloseGoingAway, "websocketproxy: closing connection") - dst.WriteMessage(websocket.CloseMessage, m) + case <-w.done: break } } @@ -234,8 +236,21 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure { log.Printf("websocketproxy: Error when copying from client to backend: %v", err) } - case <-w.Done: + case <-w.done: + m := websocket.FormatCloseMessage(websocket.CloseGoingAway, "websocketproxy: closing connection") + connPub.WriteMessage(websocket.CloseMessage, m) + connBackend.WriteMessage(websocket.CloseMessage, m) + } +} + +// Shutdown closes ws connections by closing the done channel they are subscribed to. +func (w *WebsocketProxy) Shutdown(ctx context.Context) error { + // TODO: support using context for control and return error when applicable + // Currently implemented such that the method signature matches http.Server.Shutdown() + if w.done != nil { + close(w.done) } + return nil } func copyHeader(dst, src http.Header) { diff --git a/websocketproxy_test.go b/websocketproxy_test.go index 3f31646..68cbc7f 100644 --- a/websocketproxy_test.go +++ b/websocketproxy_test.go @@ -1,6 +1,7 @@ package websocketproxy import ( + "context" "log" "net/http" "net/url" @@ -30,7 +31,6 @@ func TestProxy(t *testing.T) { u, _ := url.Parse(backendURL) proxy := NewProxy(u) proxy.Upgrader = upgrader - proxy.Done = make(chan struct{}) mux := http.NewServeMux() mux.Handle("/proxy", proxy) @@ -123,5 +123,5 @@ func TestProxy(t *testing.T) { t.Errorf("expecting: %s, got: %s", msg, string(p)) } - close(proxy.Done) + proxy.Shutdown(context.Background()) } From 0a477f53f5e77e6c50b54a81ef1da31d27ed88ed Mon Sep 17 00:00:00 2001 From: Tong Jia Date: Thu, 26 Jul 2018 10:33:32 -0400 Subject: [PATCH 3/5] add test for shutdown procedure; minor refactoring --- websocketproxy_test.go | 44 +++++++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/websocketproxy_test.go b/websocketproxy_test.go index 68cbc7f..e727c15 100644 --- a/websocketproxy_test.go +++ b/websocketproxy_test.go @@ -18,6 +18,7 @@ var ( func TestProxy(t *testing.T) { // websocket proxy + u, _ := url.Parse(backendURL) supportedSubProtocols := []string{"test-protocol"} upgrader := &websocket.Upgrader{ ReadBufferSize: 4096, @@ -28,13 +29,12 @@ func TestProxy(t *testing.T) { Subprotocols: supportedSubProtocols, } - u, _ := url.Parse(backendURL) proxy := NewProxy(u) proxy.Upgrader = upgrader - mux := http.NewServeMux() - mux.Handle("/proxy", proxy) go func() { + mux := http.NewServeMux() + mux.Handle("/proxy", proxy) if err := http.ListenAndServe(":7777", mux); err != nil { t.Fatal("ListenAndServe: ", err) } @@ -43,6 +43,7 @@ func TestProxy(t *testing.T) { time.Sleep(time.Millisecond * 100) // backend echo server + websocketMsgRcverCBackend := make(chan websocketMsg, 1) go func() { mux2 := http.NewServeMux() mux2.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { @@ -52,13 +53,16 @@ func TestProxy(t *testing.T) { return } - messageType, p, err := conn.ReadMessage() - if err != nil { - return - } + for { + messageType, p, err := conn.ReadMessage() + if err != nil { + websocketMsgRcverCBackend <- websocketMsg{messageType, p, err} + return + } - if err = conn.WriteMessage(messageType, p); err != nil { - return + if err = conn.WriteMessage(messageType, p); err != nil { + return + } } }) @@ -70,8 +74,8 @@ func TestProxy(t *testing.T) { time.Sleep(time.Millisecond * 100) - // let's us define two subprotocols, only one is supported by the server - clientSubProtocols := []string{"test-protocol", "test-notsupported"} + // define subprotocols for client, appending one not supported by the server + clientSubProtocols := append(supportedSubProtocols, []string{"test-notsupported"}...) h := http.Header{} for _, subprot := range clientSubProtocols { h.Add("Sec-WebSocket-Protocol", subprot) @@ -102,8 +106,7 @@ func TestProxy(t *testing.T) { t.Error("test-notsupported should be not recevied from the server.") } - // now write a message and send it to the backend server (which goes trough - // proxy..) + // send msg to the backend server which goes through proxy msg := "hello kite" err = conn.WriteMessage(websocket.TextMessage, []byte(msg)) if err != nil { @@ -123,5 +126,20 @@ func TestProxy(t *testing.T) { t.Errorf("expecting: %s, got: %s", msg, string(p)) } + // shutdown procedure + // + backendErrMsg := "websocketproxy: closing connection" proxy.Shutdown(context.Background()) + + wsErrBackend := <-websocketMsgRcverCBackend + e, ok := wsErrBackend.err.(*websocket.CloseError) + if !ok { + t.Fatal("backend error is not websocket.CloseError") + } + if e.Code != websocket.CloseGoingAway { + t.Error("backend error code is not websocket.CloseGoingAway") + } + if e.Text != backendErrMsg { + t.Errorf("backend error test expecting: %s, got: %s", backendErrMsg, e.Text) + } } From 9b9f7ab5ab79faf350b9a79293228fc5d864ee76 Mon Sep 17 00:00:00 2001 From: Tong Jia Date: Thu, 26 Jul 2018 10:48:20 -0400 Subject: [PATCH 4/5] make websocketProxyClosingMsg const --- websocketproxy.go | 6 +++++- websocketproxy_test.go | 5 ++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/websocketproxy.go b/websocketproxy.go index 2354f87..9a2d312 100644 --- a/websocketproxy.go +++ b/websocketproxy.go @@ -14,6 +14,10 @@ import ( "github.com/gorilla/websocket" ) +const ( + websocketProxyClosingMsg = "websocketproxy: closing connection" +) + var ( // DefaultUpgrader specifies the parameters for upgrading an HTTP // connection to a WebSocket connection. @@ -237,7 +241,7 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { log.Printf("websocketproxy: Error when copying from client to backend: %v", err) } case <-w.done: - m := websocket.FormatCloseMessage(websocket.CloseGoingAway, "websocketproxy: closing connection") + m := websocket.FormatCloseMessage(websocket.CloseGoingAway, websocketProxyClosingMsg) connPub.WriteMessage(websocket.CloseMessage, m) connBackend.WriteMessage(websocket.CloseMessage, m) } diff --git a/websocketproxy_test.go b/websocketproxy_test.go index e727c15..2f6a9e3 100644 --- a/websocketproxy_test.go +++ b/websocketproxy_test.go @@ -128,7 +128,6 @@ func TestProxy(t *testing.T) { // shutdown procedure // - backendErrMsg := "websocketproxy: closing connection" proxy.Shutdown(context.Background()) wsErrBackend := <-websocketMsgRcverCBackend @@ -139,7 +138,7 @@ func TestProxy(t *testing.T) { if e.Code != websocket.CloseGoingAway { t.Error("backend error code is not websocket.CloseGoingAway") } - if e.Text != backendErrMsg { - t.Errorf("backend error test expecting: %s, got: %s", backendErrMsg, e.Text) + if e.Text != websocketProxyClosingMsg { + t.Errorf("backend error test expecting: %s, got: %s", websocketProxyClosingMsg, e.Text) } } From 80a05c01b6003aab27116c13915bc065354308b3 Mon Sep 17 00:00:00 2001 From: Tong Jia Date: Thu, 26 Jul 2018 11:09:03 -0400 Subject: [PATCH 5/5] check close msg received on client side as well --- websocketproxy_test.go | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/websocketproxy_test.go b/websocketproxy_test.go index 2f6a9e3..88cedf8 100644 --- a/websocketproxy_test.go +++ b/websocketproxy_test.go @@ -130,8 +130,22 @@ func TestProxy(t *testing.T) { // proxy.Shutdown(context.Background()) + // check close msg received in client + messageType, p, err = conn.ReadMessage() + e, ok := err.(*websocket.CloseError) + if !ok { + t.Fatal("client error is not websocket.CloseError") + } + if e.Code != websocket.CloseGoingAway { + t.Error("client error code is not websocket.CloseGoingAway") + } + if e.Text != websocketProxyClosingMsg { + t.Errorf("client error test expecting: %s, got: %s", websocketProxyClosingMsg, e.Text) + } + + // check close msg received in backend wsErrBackend := <-websocketMsgRcverCBackend - e, ok := wsErrBackend.err.(*websocket.CloseError) + e, ok = wsErrBackend.err.(*websocket.CloseError) if !ok { t.Fatal("backend error is not websocket.CloseError") }