Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for websocket connection cancellation via Shutdown() method #24

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 85 additions & 40 deletions websocketproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
package websocketproxy

import (
"context"
"fmt"
"io"
"log"
Expand All @@ -13,6 +14,10 @@ import (
"github.com/gorilla/websocket"
)

const (
websocketProxyClosingMsg = "websocketproxy: closing connection"
)

var (
// DefaultUpgrader specifies the parameters for upgrading an HTTP
// connection to a WebSocket connection.
Expand All @@ -25,27 +30,39 @@ var (
DefaultDialer = websocket.DefaultDialer
)

// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket
// connection and proxies it to another server.
type WebsocketProxy struct {
// Director, if non-nil, is a function that may copy additional request
// headers from the incoming WebSocket connection into the output headers
// which will be forwarded to another server.
Director func(incoming *http.Request, out http.Header)

// Backend returns the backend URL which the proxy uses to reverse proxy
// the incoming WebSocket connection. Request is the initial incoming and
// unmodified request.
Backend func(*http.Request) *url.URL

// Upgrader specifies the parameters for upgrading a incoming HTTP
// connection to a WebSocket connection. If nil, DefaultUpgrader is used.
Upgrader *websocket.Upgrader

// Dialer contains options for connecting to the backend WebSocket server.
// If nil, DefaultDialer is used.
Dialer *websocket.Dialer
}
type (
// WebsocketProxy is an HTTP Handler that takes an incoming WebSocket
// connection and proxies it to another server.
WebsocketProxy struct {
// Director, if non-nil, is a function that may copy additional request
// headers from the incoming WebSocket connection into the output headers
// which will be forwarded to another server.
Director func(incoming *http.Request, out http.Header)

// Backend returns the backend URL which the proxy uses to reverse proxy
// the incoming WebSocket connection. Request is the initial incoming and
// unmodified request.
Backend func(*http.Request) *url.URL

// Upgrader specifies the parameters for upgrading a incoming HTTP
// connection to a WebSocket connection. If nil, DefaultUpgrader is used.
Upgrader *websocket.Upgrader

// Dialer contains options for connecting to the backend WebSocket server.
// If nil, DefaultDialer is used.
Dialer *websocket.Dialer

// done specifies a channel for which all proxied websocket connections
// can be closed on demand by closing the channel.
done chan struct{}
}

websocketMsg struct {
msgType int
msg []byte
err error
}
)

// ProxyHandler returns a new http.Handler interface that reverse proxies the
// request to the given target.
Expand Down Expand Up @@ -174,23 +191,38 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {

errClient := make(chan error, 1)
errBackend := make(chan error, 1)
if w.done == nil {
w.done = make(chan struct{})
}

replicateWebsocketConn := func(dst, src *websocket.Conn, errc chan error) {
for {
websocketMsgRcverC := make(chan websocketMsg, 1)
websocketMsgRcver := func() <-chan websocketMsg {
msgType, msg, err := src.ReadMessage()
if err != nil {
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", err))
if e, ok := err.(*websocket.CloseError); ok {
if e.Code != websocket.CloseNoStatusReceived {
m = websocket.FormatCloseMessage(e.Code, e.Text)
websocketMsgRcverC <- websocketMsg{msgType, msg, err}
return websocketMsgRcverC
}

for {
select {
case websocketMsgRcv := <-websocketMsgRcver():
if websocketMsgRcv.err != nil {
m := websocket.FormatCloseMessage(websocket.CloseNormalClosure, fmt.Sprintf("%v", websocketMsgRcv.err))
if e, ok := websocketMsgRcv.err.(*websocket.CloseError); ok {
if e.Code != websocket.CloseNoStatusReceived {
m = websocket.FormatCloseMessage(e.Code, e.Text)
}
}
errc <- websocketMsgRcv.err
dst.WriteMessage(websocket.CloseMessage, m)
break
}
errc <- err
dst.WriteMessage(websocket.CloseMessage, m)
break
}
err = dst.WriteMessage(msgType, msg)
if err != nil {
errc <- err
err = dst.WriteMessage(websocketMsgRcv.msgType, websocketMsgRcv.msg)
if err != nil {
errc <- err
break
}
case <-w.done:
break
}
}
Expand All @@ -199,17 +231,30 @@ func (w *WebsocketProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
go replicateWebsocketConn(connPub, connBackend, errClient)
go replicateWebsocketConn(connBackend, connPub, errBackend)

var message string
select {
case err = <-errClient:
message = "websocketproxy: Error when copying from backend to client: %v"
if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
log.Printf("websocketproxy: Error when copying from backend to client: %v", err)
}
case err = <-errBackend:
message = "websocketproxy: Error when copying from client to backend: %v"

if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
log.Printf("websocketproxy: Error when copying from client to backend: %v", err)
}
case <-w.done:
m := websocket.FormatCloseMessage(websocket.CloseGoingAway, websocketProxyClosingMsg)
connPub.WriteMessage(websocket.CloseMessage, m)
connBackend.WriteMessage(websocket.CloseMessage, m)
}
if e, ok := err.(*websocket.CloseError); !ok || e.Code == websocket.CloseAbnormalClosure {
log.Printf(message, err)
}

// Shutdown closes ws connections by closing the done channel they are subscribed to.
func (w *WebsocketProxy) Shutdown(ctx context.Context) error {
// TODO: support using context for control and return error when applicable
// Currently implemented such that the method signature matches http.Server.Shutdown()
if w.done != nil {
close(w.done)
}
return nil
}

func copyHeader(dst, src http.Header) {
Expand Down
60 changes: 47 additions & 13 deletions websocketproxy_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package websocketproxy

import (
"context"
"log"
"net/http"
"net/url"
Expand All @@ -17,6 +18,7 @@ var (

func TestProxy(t *testing.T) {
// websocket proxy
u, _ := url.Parse(backendURL)
supportedSubProtocols := []string{"test-protocol"}
upgrader := &websocket.Upgrader{
ReadBufferSize: 4096,
Expand All @@ -27,13 +29,12 @@ func TestProxy(t *testing.T) {
Subprotocols: supportedSubProtocols,
}

u, _ := url.Parse(backendURL)
proxy := NewProxy(u)
proxy.Upgrader = upgrader

mux := http.NewServeMux()
mux.Handle("/proxy", proxy)
go func() {
mux := http.NewServeMux()
mux.Handle("/proxy", proxy)
if err := http.ListenAndServe(":7777", mux); err != nil {
t.Fatal("ListenAndServe: ", err)
}
Expand All @@ -42,6 +43,7 @@ func TestProxy(t *testing.T) {
time.Sleep(time.Millisecond * 100)

// backend echo server
websocketMsgRcverCBackend := make(chan websocketMsg, 1)
go func() {
mux2 := http.NewServeMux()
mux2.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -51,13 +53,16 @@ func TestProxy(t *testing.T) {
return
}

messageType, p, err := conn.ReadMessage()
if err != nil {
return
}
for {
messageType, p, err := conn.ReadMessage()
if err != nil {
websocketMsgRcverCBackend <- websocketMsg{messageType, p, err}
return
}

if err = conn.WriteMessage(messageType, p); err != nil {
return
if err = conn.WriteMessage(messageType, p); err != nil {
return
}
}
})

Expand All @@ -69,8 +74,8 @@ func TestProxy(t *testing.T) {

time.Sleep(time.Millisecond * 100)

// let's us define two subprotocols, only one is supported by the server
clientSubProtocols := []string{"test-protocol", "test-notsupported"}
// define subprotocols for client, appending one not supported by the server
clientSubProtocols := append(supportedSubProtocols, []string{"test-notsupported"}...)
h := http.Header{}
for _, subprot := range clientSubProtocols {
h.Add("Sec-WebSocket-Protocol", subprot)
Expand Down Expand Up @@ -101,8 +106,7 @@ func TestProxy(t *testing.T) {
t.Error("test-notsupported should be not recevied from the server.")
}

// now write a message and send it to the backend server (which goes trough
// proxy..)
// send msg to the backend server which goes through proxy
msg := "hello kite"
err = conn.WriteMessage(websocket.TextMessage, []byte(msg))
if err != nil {
Expand All @@ -121,4 +125,34 @@ func TestProxy(t *testing.T) {
if msg != string(p) {
t.Errorf("expecting: %s, got: %s", msg, string(p))
}

// shutdown procedure
//
proxy.Shutdown(context.Background())

// check close msg received in client
messageType, p, err = conn.ReadMessage()
e, ok := err.(*websocket.CloseError)
if !ok {
t.Fatal("client error is not websocket.CloseError")
}
if e.Code != websocket.CloseGoingAway {
t.Error("client error code is not websocket.CloseGoingAway")
}
if e.Text != websocketProxyClosingMsg {
t.Errorf("client error test expecting: %s, got: %s", websocketProxyClosingMsg, e.Text)
}

// check close msg received in backend
wsErrBackend := <-websocketMsgRcverCBackend
e, ok = wsErrBackend.err.(*websocket.CloseError)
if !ok {
t.Fatal("backend error is not websocket.CloseError")
}
if e.Code != websocket.CloseGoingAway {
t.Error("backend error code is not websocket.CloseGoingAway")
}
if e.Text != websocketProxyClosingMsg {
t.Errorf("backend error test expecting: %s, got: %s", websocketProxyClosingMsg, e.Text)
}
}