From 4f9a8c302ab59be0170c6e30f205f0b428b9bcf7 Mon Sep 17 00:00:00 2001 From: Clint Ayres Date: Fri, 9 Feb 2024 11:13:18 +0000 Subject: [PATCH] pr feedback (sethf) --- .../content-blocker-plugin_test.go | 23 +++----- relay/traffic/encoding.go | 49 +++++++++++----- relay/traffic/handler.go | 58 ++++++++++--------- relay/traffic/traffic_test.go | 27 ++++----- 4 files changed, 84 insertions(+), 73 deletions(-) diff --git a/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin_test.go b/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin_test.go index fe4c344..501a026 100644 --- a/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin_test.go +++ b/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin_test.go @@ -15,13 +15,6 @@ import ( "github.com/fullstorydev/relay-core/relay/version" ) -type Encoding int - -const ( - Identity Encoding = iota - Gzip -) - func TestContentBlocking(t *testing.T) { testCases := []contentBlockerTestCase{ { @@ -141,8 +134,8 @@ func TestContentBlocking(t *testing.T) { } for _, testCase := range testCases { - runContentBlockerTest(t, testCase, Identity) - runContentBlockerTest(t, testCase, Gzip) + runContentBlockerTest(t, testCase, traffic.Identity) + runContentBlockerTest(t, testCase, traffic.Gzip) } } @@ -194,12 +187,12 @@ type contentBlockerTestCase struct { expectedHeaders map[string]string } -func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encoding Encoding) { +func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encoding traffic.Encoding) { var encodingStr string switch encoding { - case Gzip: + case traffic.Gzip: encodingStr = "gzip" - case Identity: + case traffic.Identity: encodingStr = "" } @@ -223,7 +216,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encodi expectedHeaders[content_blocker_plugin.PluginVersionHeaderName] = version.RelayRelease test.WithCatcherAndRelay(t, testCase.config, plugins, func(catcherService *catcher.Service, relayService *relay.Service) { - b, err := traffic.EncodeData([]byte(testCase.originalBody), encodingStr) + b, err := traffic.EncodeData([]byte(testCase.originalBody), encoding) if err != nil { t.Errorf("Test '%v': Error encoding data: %v", desc, err) return @@ -239,7 +232,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encodi return } - if encoding == Gzip { + if encoding == traffic.Gzip { request.Header.Set("Content-Encoding", "gzip") } @@ -309,7 +302,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encodi ) } - decodedRequestBody, err := traffic.DecodeData(lastRequestBody, encodingStr) + decodedRequestBody, err := traffic.DecodeData(lastRequestBody, encoding) if err != nil { t.Errorf("Test '%v': Error decoding data: %v", desc, err) return diff --git a/relay/traffic/encoding.go b/relay/traffic/encoding.go index ed98a57..b8e0e06 100644 --- a/relay/traffic/encoding.go +++ b/relay/traffic/encoding.go @@ -3,13 +3,22 @@ package traffic import ( "bytes" "compress/gzip" + "fmt" "io" "net/http" "net/url" "strings" ) -func GetContentEncoding(request *http.Request) (string, error) { +type Encoding int + +const ( + Unsupported Encoding = iota + Identity + Gzip +) + +func GetContentEncoding(request *http.Request) (Encoding, error) { // NOTE: This is a workaround for a bug in post-Go 1.17. See golang.org/issue/25192. // Our algorithm differs from the logic of AllowQuerySemicolons by replacing semicolons with encoded semicolons instead // of with ampersands. This is because we want to preserve the original query string as much as possible. @@ -19,7 +28,7 @@ func GetContentEncoding(request *http.Request) (string, error) { queryParams, err := url.ParseQuery(request.URL.RawQuery) if err != nil { - return "", err + return Unsupported, err } // request query parameter takes precedence over request header @@ -27,29 +36,39 @@ func GetContentEncoding(request *http.Request) (string, error) { if encoding == "" { encoding = request.Header.Get("Content-Encoding") } - return encoding, nil + + switch encoding { + case "gzip": + return Gzip, nil + case "": + return Identity, nil + default: + return Unsupported, fmt.Errorf("unsupported encoding: %v", encoding) + } } // WrapReader checks if the request Content-Encoding or request query parameter indicates gzip compression. // If so, it returns a gzip.Reader that decompresses the content. -func WrapReader(request *http.Request, encoding string) (io.ReadCloser, error) { +func WrapReader(request *http.Request, encoding Encoding) (io.ReadCloser, error) { if request.Body == nil { return nil, nil } switch encoding { - case "gzip": + case Gzip: // Create a new gzip.Reader to decompress the request body return gzip.NewReader(request.Body) - default: + case Identity: // If the content is not gzip-compressed, return the original request body return request.Body, nil + default: + return nil, fmt.Errorf("unsupported encoding: %v", encoding) } } -func EncodeData(data []byte, encoding string) ([]byte, error) { +func EncodeData(data []byte, encoding Encoding) ([]byte, error) { switch encoding { - case "gzip": + case Gzip: var buf bytes.Buffer gz := gzip.NewWriter(&buf) @@ -65,15 +84,16 @@ func EncodeData(data []byte, encoding string) ([]byte, error) { compressedData := buf.Bytes() return compressedData, nil - default: - // identity encoding + case Identity: return data, nil + default: + return nil, fmt.Errorf("unsupported encoding: %v", encoding) } } -func DecodeData(data []byte, encoding string) ([]byte, error) { +func DecodeData(data []byte, encoding Encoding) ([]byte, error) { switch encoding { - case "gzip": + case Gzip: reader, err := gzip.NewReader(bytes.NewReader(data)) if err != nil { return nil, err @@ -85,8 +105,9 @@ func DecodeData(data []byte, encoding string) ([]byte, error) { } return decodedData, nil - default: - // identity encoding + case Identity: return data, nil + default: + return nil, fmt.Errorf("unsupported encoding: %v", encoding) } } diff --git a/relay/traffic/handler.go b/relay/traffic/handler.go index cdf6b66..ce15a8a 100644 --- a/relay/traffic/handler.go +++ b/relay/traffic/handler.go @@ -60,7 +60,7 @@ func (handler *Handler) ServeHTTP(response http.ResponseWriter, request *http.Re encoding, err := GetContentEncoding(request) if err != nil { - logger.Printf("URL %v error getting request content encoding: %v", request.URL, err) + logger.Printf("URL %v error in request content encoding: %v", request.URL, err) request.Body = http.NoBody return } @@ -95,7 +95,7 @@ func (handler *Handler) ServeHTTP(response http.ResponseWriter, request *http.Re } // prepareRequestBody wraps the request Body with a reader that will decode the content if necessary. -func (handler *Handler) prepareRequestBody(clientRequest *http.Request, encoding string) error { +func (handler *Handler) prepareRequestBody(clientRequest *http.Request, encoding Encoding) error { if reader, err := WrapReader(clientRequest, encoding); err != nil { return err } else if reader != nil && reader != http.NoBody { @@ -104,7 +104,7 @@ func (handler *Handler) prepareRequestBody(clientRequest *http.Request, encoding return nil } -func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, clientRequest *http.Request, serviced bool, encoding string) bool { +func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, clientRequest *http.Request, serviced bool, encoding Encoding) bool { if serviced { return false } @@ -124,35 +124,39 @@ func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, client } } -func (handler *Handler) ensureBodyContentEncoding(clientRequest *http.Request, encoding string) { - if encoding == "" || encoding == "identity" { +func (handler *Handler) ensureBodyContentEncoding(clientRequest *http.Request, encoding Encoding) { + switch encoding { + case Unsupported: + logger.Println("Error unsupported content-encoding") return - } - - servicedBody, err := io.ReadAll(clientRequest.Body) - if err != nil { - logger.Printf("Error reading request body: %s", err) - clientRequest.Body = http.NoBody + case Identity: return - } + case Gzip: + servicedBody, err := io.ReadAll(clientRequest.Body) + if err != nil { + logger.Printf("Error reading request body: %s", err) + clientRequest.Body = http.NoBody + return + } - if encodedData, err := EncodeData(servicedBody, encoding); err != nil { - logger.Printf("Error encoding request body: %s", err) - clientRequest.Body = http.NoBody - return - } else { - servicedBody = encodedData - } + if encodedData, err := EncodeData(servicedBody, encoding); err != nil { + logger.Printf("Error encoding request body: %s", err) + clientRequest.Body = http.NoBody + return + } else { + servicedBody = encodedData + } - // If the length of the body has changed, we should update the - // Content-Length header too. - contentLength := int64(len(servicedBody)) - if contentLength != clientRequest.ContentLength { - clientRequest.ContentLength = contentLength - clientRequest.Header.Set("Content-Length", strconv.FormatInt(contentLength, 10)) - } + // If the length of the body has changed, we should update the + // Content-Length header too. + contentLength := int64(len(servicedBody)) + if contentLength != clientRequest.ContentLength { + clientRequest.ContentLength = contentLength + clientRequest.Header.Set("Content-Length", strconv.FormatInt(contentLength, 10)) + } - clientRequest.Body = io.NopCloser(bytes.NewBuffer(servicedBody)) + clientRequest.Body = io.NopCloser(bytes.NewBuffer(servicedBody)) + } } diff --git a/relay/traffic/traffic_test.go b/relay/traffic/traffic_test.go index a0788ef..a1597ec 100644 --- a/relay/traffic/traffic_test.go +++ b/relay/traffic/traffic_test.go @@ -151,33 +151,26 @@ func TestMaxBodySize(t *testing.T) { }) } -type Encoding int - -const ( - Identity Encoding = iota - Gzip -) - func TestRelaySupportsContentEncoding(t *testing.T) { testCases := map[string]struct { - encoding Encoding + encoding traffic.Encoding bodyContentStr string headers map[string]string customUrl func(relayServiceURL string) string }{ "identity": { - encoding: Identity, + encoding: traffic.Identity, bodyContentStr: "Hello, world!", }, "gzip - with header": { - encoding: Gzip, + encoding: traffic.Gzip, bodyContentStr: "Hello, world!", headers: map[string]string{ "Content-Encoding": "gzip", }, }, "gzip - with query param": { - encoding: Gzip, + encoding: traffic.Gzip, bodyContentStr: "Hello, world!", customUrl: func(relayServiceURL string) string { return fmt.Sprintf("%v?ContentEncoding=gzip", relayServiceURL) @@ -190,14 +183,14 @@ func TestRelaySupportsContentEncoding(t *testing.T) { // convert the body content to a reader with the proper content encoding applied var body io.Reader switch testCase.encoding { - case Gzip: - b, err := traffic.EncodeData([]byte(testCase.bodyContentStr), "gzip") + case traffic.Gzip: + b, err := traffic.EncodeData([]byte(testCase.bodyContentStr), traffic.Gzip) if err != nil { t.Errorf("Test %s - Error encoding data: %v", desc, err) return } body = bytes.NewReader(b) - case Identity: + case traffic.Identity: body = strings.NewReader(testCase.bodyContentStr) } @@ -235,8 +228,8 @@ func TestRelaySupportsContentEncoding(t *testing.T) { } switch testCase.encoding { - case Gzip: - decodedData, err := traffic.DecodeData(lastRequest, "gzip") + case traffic.Gzip: + decodedData, err := traffic.DecodeData(lastRequest, traffic.Gzip) if err != nil { t.Errorf("Test %s - Error decoding data: %v", desc, err) return @@ -244,7 +237,7 @@ func TestRelaySupportsContentEncoding(t *testing.T) { if string(decodedData) != testCase.bodyContentStr { t.Errorf("Test %s - Expected body '%v' but got: %v", desc, testCase.bodyContentStr, string(decodedData)) } - case Identity: + case traffic.Identity: if string(lastRequest) != testCase.bodyContentStr { t.Errorf("Test %s - Expected body '%v' but got: %v", desc, testCase.bodyContentStr, string(lastRequest)) }