From 3f4a25e77e5be91d9dedec0be356d87fa33c0224 Mon Sep 17 00:00:00 2001 From: Tong Jia Date: Tue, 24 Jul 2018 08:52:39 -0400 Subject: [PATCH] make done channel private and create public Shutdown method --- websocketproxy.go | 23 +++++++++++++++++++---- websocketproxy_test.go | 4 ++-- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/websocketproxy.go b/websocketproxy.go index b6c341c..cfc0014 100644 --- a/websocketproxy.go +++ b/websocketproxy.go @@ -2,6 +2,7 @@ package websocketproxy import ( + "context" "fmt" "io" "log" @@ -47,9 +48,9 @@ type ( // If nil, DefaultDialer is used. Dialer *websocket.Dialer - // Done specifies a channel for which all proxied websocket connections + // done specifies a channel for which all proxied websocket connections // can be closed on demand by closing the channel. - Done chan struct{} + done chan struct{} } websocketMsg struct { @@ -186,6 +187,9 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { errClient := make(chan error, 1) errBackend := make(chan error, 1) + if w.done == nil { + w.done = make(chan struct{}) + } replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) { websocketMsgRcverC := make(chan websocketMsg, 1) @@ -214,7 +218,7 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { errc <- err break } - case <-w.Done: + case <-w.done: m := websocket.FormatCloseMessage(websocket.CloseGoingAway, "websocketproxy: closing connection") dst.WriteMessage(websocket.CloseMessage, m) break @@ -234,8 +238,19 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure { log.Printf("websocketproxy: Error when copying from client to backend: %v", err) } - case <-w.Done: + case <-w.done: + } +} + +// Shutdown gracefully closes proxied websocket connections by closing the +// done channel they are subscribed to. +func (w *WebsocketProxy) Shutdown(ctx context.Context) error { + // TODO: support using context for control and return error when applicable + // Currently implemented such that the method signature matches http.Server.Shutdown() + if w.done != nil { + close(w.done) } + return nil } func copyHeader(dst, src http.Header) { diff --git a/websocketproxy_test.go b/websocketproxy_test.go index 3f31646..68cbc7f 100644 --- a/websocketproxy_test.go +++ b/websocketproxy_test.go @@ -1,6 +1,7 @@ package websocketproxy import ( + "context" "log" "net/http" "net/url" @@ -30,7 +31,6 @@ func TestProxy(t *testing.T) { u, _ := url.Parse(backendURL) proxy := NewProxy(u) proxy.Upgrader = upgrader - proxy.Done = make(chan struct{}) mux := http.NewServeMux() mux.Handle("/proxy", proxy) @@ -123,5 +123,5 @@ func TestProxy(t *testing.T) { t.Errorf("expecting: %s, got: %s", msg, string(p)) } - close(proxy.Done) + proxy.Shutdown(context.Background()) }