Skip to content

Commit

Permalink
add websocket connection cancellation via done channel
Browse files Browse the repository at this point in the history
  • Loading branch information
ttjiaa committed Jul 20, 2018
1 parent 0fa3f99 commit 10ab726
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 40 deletions.
106 changes: 66 additions & 40 deletions websocketproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,39 @@ var (
DefaultDialer = websocket.DefaultDialer
)

// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket
// connection and proxies it to another server.
type WebsocketProxy struct {
// Director, if non-nil, is a function that may copy additional request
// headers from the incoming WebSocket connection into the output headers
// which will be forwarded to another server.
Director func(incoming *http.Request, out http.Header)

// Backend returns the backend URL which the proxy uses to reverse proxy
// the incoming WebSocket connection. Request is the initial incoming and
// unmodified request.
Backend func(*http.Request) *url.URL

// Upgrader specifies the parameters for upgrading a incoming HTTP
// connection to a WebSocket connection. If nil, DefaultUpgrader is used.
Upgrader *websocket.Upgrader

// Dialer contains options for connecting to the backend WebSocket server.
// If nil, DefaultDialer is used.
Dialer *websocket.Dialer
}
type (
// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket
// connection and proxies it to another server.
WebsocketProxy struct {
// Director, if non-nil, is a function that may copy additional request
// headers from the incoming WebSocket connection into the output headers
// which will be forwarded to another server.
Director func(incoming *http.Request, out http.Header)

// Backend returns the backend URL which the proxy uses to reverse proxy
// the incoming WebSocket connection. Request is the initial incoming and
// unmodified request.
Backend func(*http.Request) *url.URL

// Upgrader specifies the parameters for upgrading a incoming HTTP
// connection to a WebSocket connection. If nil, DefaultUpgrader is used.
Upgrader *websocket.Upgrader

// Dialer contains options for connecting to the backend WebSocket server.
// If nil, DefaultDialer is used.
Dialer *websocket.Dialer

// Done specifies a channel for which all proxied websocket connections
// can be closed on demand by closing the channel.
Done chan struct{}
}

websocketMsg struct {
msgType int
msg []byte
err error
}
)

// ProxyHandler returns a new http.Handler interface that reverse proxies the
// request to the given target.
Expand Down Expand Up @@ -174,41 +186,55 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {

errClient := make(chan error, 1)
errBackend := make(chan error, 1)

replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) {
for {
websocketMsgRcverC := make(chan websocketMsg, 1)
websocketMsgRcver := func() <-chan websocketMsg {
msgType, msg, err := src.ReadMessage()
if err != nil {
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
if e, ok := err.(*websocket.CloseError); ok {
if e.Code != websocket.CloseNoStatusReceived {
m = websocket.FormatCloseMessage(e.Code, e.Text)
websocketMsgRcverC <- websocketMsg{msgType, msg, err}
return websocketMsgRcverC
}

for {
select {
case websocketMsgRcv := <-websocketMsgRcver():
if websocketMsgRcv.err != nil {
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", websocketMsgRcv.err))
if e, ok := websocketMsgRcv.err.(*websocket.CloseError); ok {
if e.Code != websocket.CloseNoStatusReceived {
m = websocket.FormatCloseMessage(e.Code, e.Text)
}
}
errc <- websocketMsgRcv.err
dst.WriteMessage(websocket.CloseMessage, m)
break
}
errc <- err
err = dst.WriteMessage(websocketMsgRcv.msgType, websocketMsgRcv.msg)
if err != nil {
errc <- err
break
}
case <-w.Done:
m := websocket.FormatCloseMessage(websocket.CloseGoingAway, "websocketproxy: closing connection")
dst.WriteMessage(websocket.CloseMessage, m)
break
}
err = dst.WriteMessage(msgType, msg)
if err != nil {
errc <- err
break
}
}
}

go replicateWebsocketConn(connPub, connBackend, errClient)
go replicateWebsocketConn(connBackend, connPub, errBackend)

var message string
select {
case err = <-errClient:
message = "websocketproxy: Error when copying from backend to client: %v"
if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
log.Printf("websocketproxy: Error when copying from backend to client: %v", err)
}
case err = <-errBackend:
message = "websocketproxy: Error when copying from client to backend: %v"

}
if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
log.Printf(message, err)
if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
log.Printf("websocketproxy: Error when copying from client to backend: %v", err)
}
case <-w.Done:
}
}

Expand Down
3 changes: 3 additions & 0 deletions websocketproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ func TestProxy(t *testing.T) {
u, _ := url.Parse(backendURL)
proxy := NewProxy(u)
proxy.Upgrader = upgrader
proxy.Done = make(chan struct{})

mux := http.NewServeMux()
mux.Handle("/proxy", proxy)
Expand Down Expand Up @@ -121,4 +122,6 @@ func TestProxy(t *testing.T) {
if msg != string(p) {
t.Errorf("expecting: %s, got: %s", msg, string(p))
}

close(proxy.Done)
}

0 comments on commit 10ab726

Please sign in to comment.