Skip to content

Commit

Permalink
refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
jurassix committed Feb 5, 2024
1 parent ea49ebe commit 5364b93
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 68 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func (f contentBlockerPluginFactory) New(configSection *config.Section) (traffic
}

if regexp, err := regexp.Compile(pattern); err != nil {
return fmt.Errorf(`Could not compile regular expression "%v": %v`, pattern, err)
return fmt.Errorf(`could not compile regular expression "%v": %v`, pattern, err)
} else {
logger.Printf("Added rule: %s %s content matching \"%s\"", mode, contentKind, regexp)
blockers = append(blockers, &contentBlocker{
Expand All @@ -94,7 +94,7 @@ func (f contentBlockerPluginFactory) New(configSection *config.Section) (traffic
case "header":
plugin.headerBlockers = append(plugin.headerBlockers, blockers...)
default:
return fmt.Errorf(`Unexpected content kind %s`, contentKind)
return fmt.Errorf(`unexpected content kind %s`, contentKind)
}

return nil
Expand Down Expand Up @@ -218,7 +218,13 @@ func (plug contentBlockerPlugin) blockBodyContent(response http.ResponseWriter,
return true
}

if reader, err := traffic.GetBodyReader(request); err != nil {
encoding, err := traffic.GetContentEncoding(request)
if err != nil {
logger.Printf("URL %v error getting request content encoding: %v", request.URL, err)
return false
}

if reader, err := traffic.WrapReader(request, encoding); err != nil {
logger.Printf("URL %v error setting up request body reader: %v", request.URL, err)
return false
} else if reader == nil || reader == http.NoBody {
Expand All @@ -238,6 +244,20 @@ func (plug contentBlockerPlugin) blockBodyContent(response http.ResponseWriter,
processedBody = blocker.Block(processedBody)
}

// Note we encode the data in memory before setting it back on the request so we can update
// the Content-Length header to reflect the new length of the body.
// FIXME: if additional plugins are added that modify the request body, we should consider
// decoding the stream before handling off to plugins and encoding it again after all plugins complete.
// This would allow us to avoid encoding and decoding the body multiple times. We don't do this
// today since this is the only plugin that modifies the request body.
if encodedData, err := traffic.EncodeData(processedBody, encoding); err != nil {
http.Error(response, fmt.Sprintf("Error encoding request body: %s", err), 500)
request.Body = http.NoBody
return true
} else {
processedBody = encodedData
}

// If the length of the body has changed, we should update the
// Content-Length header too.
contentLength := int64(len(processedBody))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package content_blocker_plugin_test

import (
"bytes"
"compress/gzip"
"fmt"
"net/http"
"strconv"
Expand Down Expand Up @@ -192,6 +191,10 @@ type contentBlockerTestCase struct {

func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, withCompression bool) {
desc := fmt.Sprintf("%v (compression=%t)", testCase.desc, withCompression)
encoding := "indentity"
if withCompression {
encoding = "gzip"
}

plugins := []traffic.PluginFactory{
content_blocker_plugin.Factory,
Expand All @@ -210,17 +213,12 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, withCo
expectedHeaders[content_blocker_plugin.PluginVersionHeaderName] = version.RelayRelease

test.WithCatcherAndRelay(t, testCase.config, plugins, func(catcherService *catcher.Service, relayService *relay.Service) {
var body *bytes.Buffer
if withCompression {
b, err := compressString(testCase.originalBody)
if err != nil {
t.Errorf("Test '%v': Error compressing body: %v", desc, err)
return
}
body = &b
} else {
body = bytes.NewBufferString(testCase.originalBody)
encodedBody, err := traffic.EncodeData([]byte(testCase.originalBody), encoding)
if err != nil {
t.Errorf("Test '%v': Error compressing body: %v", desc, err)
return
}
body := bytes.NewBuffer(encodedBody)

request, err := http.NewRequest(
"POST",
Expand All @@ -232,10 +230,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, withCo
return
}

if withCompression {
request.Header.Set("Content-Encoding", "gzip")
}

request.Header.Set("Content-Encoding", encoding)
request.Header.Set("Content-Type", "application/json")
for header, headerValue := range originalHeaders {
request.Header.Set(header, headerValue)
Expand Down Expand Up @@ -272,62 +267,55 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, withCo
}
}

lastRequestBody, err := catcherService.LastRequestBody()
if err != nil {
t.Errorf("Test '%v': Error reading last request body from catcher: %v", desc, err)
return
}

lastRequestBodyStr := string(lastRequestBody)
if testCase.expectedBody != lastRequestBodyStr {
// ensure content encoding is preserved
if lastRequest.Header.Get("Content-Encoding") != encoding {
t.Errorf(
"Test '%v': Expected body '%v' but got: %v",
"Test '%v': Expected Content-Encoding '%v' but got: %v",
desc,
testCase.expectedBody,
lastRequestBodyStr,
encoding,
lastRequest.Header.Get("Content-Encoding"),
)
}

// Note "raw" because these are the raw bytes of the request body without decoding,
// which allows us to assert the Content-Length header against the actual body length.
lastRequestBodyRaw, err := catcherService.LastRequestBody()
if err != nil {
t.Errorf("Test '%v': Error reading last request body from catcher: %v", desc, err)
return
}

contentLength, err := strconv.Atoi(lastRequest.Header.Get("Content-Length"))
if err != nil {
t.Errorf("Test '%v': Error parsing Content-Length: %v", desc, err)
return
}

if contentLength != len(lastRequestBody) {
// so the issue is that content length keeps changing we should do this before we decode the body
// actually seems like the data should be sent here compressed then uncompressed
if contentLength != len(lastRequestBodyRaw) {
t.Errorf(
"Test '%v': Content-Length is %v but actual body length is %v",
desc,
contentLength,
len(lastRequestBody),
len(lastRequestBodyRaw),
)
}
})
}

// compressString takes a string, compresses it using gzip, and returns the compressed data as a base64-encoded string.
func compressString(input string) (bytes.Buffer, error) {
if input == "" {
return bytes.Buffer{}, nil
}

// Create a buffer to hold the compressed data
var b bytes.Buffer

// Create a new gzip writer
gz := gzip.NewWriter(&b)

// Write the input string to the gzip writer
_, err := gz.Write([]byte(input))
if err != nil {
return b, err
}

// Close the gzip writer to flush the compression
err = gz.Close()
if err != nil {
return b, err
}
// Try to decode the body, if it was encoded, and then to compare it to the expected value.
lastRequestBodyDecoded, err := traffic.DecodeData(lastRequestBodyRaw, encoding)
if err != nil {
t.Errorf("Test '%v': Error decompressing body: %v", desc, err)
}

return b, nil
lastRequestBodyStr := string(lastRequestBodyDecoded)
if testCase.expectedBody != lastRequestBodyStr {
t.Errorf(
"Test '%v': Expected body '%v' but got: %v",
desc,
testCase.expectedBody,
lastRequestBodyStr,
)
}
})
}
2 changes: 1 addition & 1 deletion relay/plugins/traffic/paths-plugin/paths-plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ func runPathsPluginTest(t *testing.T, testCase pathsPluginTestCase) {
lastRequest, err = altCatcherService.LastRequest()
}
if err != nil {
t.Errorf("Error reading last request from catcher: %v", err)
t.Errorf("Text '%v': Error reading last request from catcher: %v", testCase.desc, err)
return
}

Expand Down
74 changes: 64 additions & 10 deletions relay/traffic/util.go
Original file line number Diff line number Diff line change
@@ -1,29 +1,83 @@
package traffic

import (
"bytes"
"compress/gzip"
"io"
"net/http"
"net/url"
)

// GetBodyReader checks if the request Content-Encoding or request query parameter indicates gzip compression.
func GetContentEncoding(request *http.Request) (string, error) {
queryParams, err := url.ParseQuery(request.URL.RawQuery)
if err != nil {
return "", err
}

encoding := queryParams.Get("ContentEncoding")
if encoding == "" {
encoding = request.Header.Get("Content-Encoding")
}
return encoding, nil
}

// WrapReader checks if the request Content-Encoding or request query parameter indicates gzip compression.
// If so, it returns a gzip.Reader that decompresses the content.
func GetBodyReader(request *http.Request) (io.ReadCloser, error) {
func WrapReader(request *http.Request, encoding string) (io.ReadCloser, error) {
if request.Body == nil || request.Body == http.NoBody {
return nil, nil
}

queryParams, err := url.ParseQuery(request.URL.RawQuery)
if err != nil {
return nil, err
}

if request.Header.Get("Content-Encoding") == "gzip" || queryParams.Get("ContentEncoding") == "gzip" {
switch encoding {
case "gzip":
// Create a new gzip.Reader to decompress the request body
return gzip.NewReader(request.Body)
default:
// If the content is not gzip-compressed, return the original request body
return request.Body, nil
}
}

func EncodeData(data []byte, encoding string) ([]byte, error) {
switch encoding {
case "gzip":
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)

_, err := gz.Write(data)
if err != nil {
return nil, err
}

err = gz.Close()
if err != nil {
return nil, err
}

compressedData := buf.Bytes()
return compressedData, nil
default:
// identity encoding
return data, nil
}
}

func DecodeData(data []byte, encoding string) ([]byte, error) {
switch encoding {
case "gzip":
reader, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, err
}

// If the content is not gzip-compressed, return the original request body
return request.Body, nil
decodedData, err := io.ReadAll(reader)
if err != nil {
return nil, err
}

return decodedData, nil
default:
// identity encoding
return data, nil
}
}

0 comments on commit 5364b93

Please sign in to comment.