Skip to content

Commit

Permalink
Merge pull request #36 from fullstorydev/clint/WEB-14177
Browse files Browse the repository at this point in the history
Support scrubbing gzip encoded bundles
  • Loading branch information
jurassix authored Feb 12, 2024
2 parents 5a34026 + a52dacd commit a272202
Show file tree
Hide file tree
Showing 9 changed files with 349 additions and 44 deletions.
3 changes: 1 addition & 2 deletions catcher/catcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -82,7 +81,7 @@ func (service *Service) LastRequestBody() ([]byte, error) {
}

defer request.Body.Close()
body, err := ioutil.ReadAll(request.Body)
body, err := io.ReadAll(request.Body)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions relay/main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@ package main

import (
"flag"
"io/ioutil"
"io"
"log"
"os"
"time"

"github.com/fullstorydev/relay-core/relay"
"github.com/fullstorydev/relay-core/relay/config"
"github.com/fullstorydev/relay-core/relay/environment"
"github.com/fullstorydev/relay-core/relay/traffic/plugin-loader"
plugin_loader "github.com/fullstorydev/relay-core/relay/traffic/plugin-loader"
)

var logger = log.New(os.Stdout, "[relay] ", 0)

func readConfigFile(path string) (rawConfigFileBytes []byte, err error) {
if path == "-" {
rawConfigFileBytes, err = ioutil.ReadAll(os.Stdin)
rawConfigFileBytes, err = io.ReadAll(os.Stdin)
return
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ package content_blocker_plugin
import (
"bytes"
"fmt"
"io/ioutil"
"io"
"log"
"net/http"
"os"
Expand Down Expand Up @@ -78,7 +78,7 @@ func (f contentBlockerPluginFactory) New(configSection *config.Section) (traffic
}

if regexp, err := regexp.Compile(pattern); err != nil {
return fmt.Errorf(`Could not compile regular expression "%v": %v`, pattern, err)
return fmt.Errorf(`could not compile regular expression "%v": %v`, pattern, err)
} else {
logger.Printf("Added rule: %s %s content matching \"%s\"", mode, contentKind, regexp)
blockers = append(blockers, &contentBlocker{
Expand All @@ -94,7 +94,7 @@ func (f contentBlockerPluginFactory) New(configSection *config.Section) (traffic
case "header":
plugin.headerBlockers = append(plugin.headerBlockers, blockers...)
default:
return fmt.Errorf(`Unexpected content kind %s`, contentKind)
return fmt.Errorf(`unexpected content kind %s`, contentKind)
}

return nil
Expand Down Expand Up @@ -222,28 +222,26 @@ func (plug contentBlockerPlugin) blockBodyContent(response http.ResponseWriter,
return false
}

processedBody, err := ioutil.ReadAll(request.Body)
processedBody, err := io.ReadAll(request.Body)
if err != nil {
http.Error(response, fmt.Sprintf("Error reading request body: %s", err), 500)
request.Body = http.NoBody
return true
}
initialLength := len(processedBody)

for _, blocker := range plug.bodyBlockers {
processedBody = blocker.Block(processedBody)
}

// If the length of the body has changed, we should update the
// Content-Length header too.
finalLength := len(processedBody)
if finalLength != initialLength {
contentLength := int64(finalLength)
contentLength := int64(len(processedBody))
if contentLength != request.ContentLength {
request.ContentLength = contentLength
request.Header.Set("Content-Length", strconv.FormatInt(contentLength, 10))
}

request.Body = ioutil.NopCloser(bytes.NewBuffer(processedBody))
request.Body = io.NopCloser(bytes.NewBuffer(processedBody))
return false
}

Expand Down Expand Up @@ -283,7 +281,7 @@ func (b *contentBlocker) Block(content []byte) []byte {
case excludeMode:
return b.regexp.ReplaceAllLiteral(content, []byte{})
default:
panic(fmt.Errorf("Invalid content blocking mode: %v", b.mode))
panic(fmt.Errorf("invalid content blocking mode: %v", b.mode))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package content_blocker_plugin_test

import (
"bytes"
"fmt"
"net/http"
"strconv"
"testing"

"github.com/fullstorydev/relay-core/catcher"
"github.com/fullstorydev/relay-core/relay"
"github.com/fullstorydev/relay-core/relay/plugins/traffic/content-blocker-plugin"
content_blocker_plugin "github.com/fullstorydev/relay-core/relay/plugins/traffic/content-blocker-plugin"
"github.com/fullstorydev/relay-core/relay/test"
"github.com/fullstorydev/relay-core/relay/traffic"
"github.com/fullstorydev/relay-core/relay/version"
Expand Down Expand Up @@ -133,7 +134,8 @@ func TestContentBlocking(t *testing.T) {
}

for _, testCase := range testCases {
runContentBlockerTest(t, testCase)
runContentBlockerTest(t, testCase, traffic.Identity)
runContentBlockerTest(t, testCase, traffic.Gzip)
}
}

Expand Down Expand Up @@ -185,7 +187,18 @@ type contentBlockerTestCase struct {
expectedHeaders map[string]string
}

func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase) {
func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encoding traffic.Encoding) {
var encodingStr string
switch encoding {
case traffic.Gzip:
encodingStr = "gzip"
case traffic.Identity:
encodingStr = ""
}

// Add encoding to the test description
desc := fmt.Sprintf("%s (encoding: %v)", testCase.desc, encodingStr)

plugins := []traffic.PluginFactory{
content_blocker_plugin.Factory,
}
Expand All @@ -203,36 +216,46 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase) {
expectedHeaders[content_blocker_plugin.PluginVersionHeaderName] = version.RelayRelease

test.WithCatcherAndRelay(t, testCase.config, plugins, func(catcherService *catcher.Service, relayService *relay.Service) {
b, err := traffic.EncodeData([]byte(testCase.originalBody), encoding)
if err != nil {
t.Errorf("Test '%v': Error encoding data: %v", desc, err)
return
}

request, err := http.NewRequest(
"POST",
relayService.HttpUrl(),
bytes.NewBufferString(testCase.originalBody),
bytes.NewBuffer(b),
)
if err != nil {
t.Errorf("Test '%v': Error creating request: %v", testCase.desc, err)
t.Errorf("Test '%v': Error creating request: %v", desc, err)
return
}

if encoding == traffic.Gzip {
request.Header.Set("Content-Encoding", "gzip")
}

request.Header.Set("Content-Type", "application/json")
for header, headerValue := range originalHeaders {
request.Header.Set(header, headerValue)
}

response, err := http.DefaultClient.Do(request)
if err != nil {
t.Errorf("Test '%v': Error POSTing: %v", testCase.desc, err)
t.Errorf("Test '%v': Error POSTing: %v", desc, err)
return
}
defer response.Body.Close()

if response.StatusCode != 200 {
t.Errorf("Test '%v': Expected 200 response: %v", testCase.desc, response)
t.Errorf("Test '%v': Expected 200 response: %v", desc, response)
return
}

lastRequest, err := catcherService.LastRequest()
if err != nil {
t.Errorf("Test '%v': Error reading last request from catcher: %v", testCase.desc, err)
t.Errorf("Test '%v': Error reading last request from catcher: %v", desc, err)
return
}

Expand All @@ -241,43 +264,58 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase) {
if expectedHeaderValue != actualHeaderValue {
t.Errorf(
"Test '%v': Expected header '%v' with value '%v' but got: %v",
testCase.desc,
desc,
expectedHeader,
expectedHeaderValue,
actualHeaderValue,
)
}
}

if lastRequest.Header.Get("Content-Encoding") != encodingStr {
t.Errorf(
"Test '%v': Expected Content-Encoding '%v' but got: %v",
desc,
encodingStr,
lastRequest.Header.Get("Content-Encoding"),
)
}

lastRequestBody, err := catcherService.LastRequestBody()
if err != nil {
t.Errorf("Test '%v': Error reading last request body from catcher: %v", testCase.desc, err)
t.Errorf("Test '%v': Error reading last request body from catcher: %v", desc, err)
return
}

lastRequestBodyStr := string(lastRequestBody)
if testCase.expectedBody != lastRequestBodyStr {
t.Errorf(
"Test '%v': Expected body '%v' but got: %v",
testCase.desc,
testCase.expectedBody,
lastRequestBodyStr,
)
}

contentLength, err := strconv.Atoi(lastRequest.Header.Get("Content-Length"))
if err != nil {
t.Errorf("Test '%v': Error parsing Content-Length: %v", testCase.desc, err)
t.Errorf("Test '%v': Error parsing Content-Length: %v", desc, err)
return
}

if contentLength != len(lastRequestBody) {
t.Errorf(
"Test '%v': Content-Length is %v but actual body length is %v",
testCase.desc,
desc,
contentLength,
len(lastRequestBody),
)
}

decodedRequestBody, err := traffic.DecodeData(lastRequestBody, encoding)
if err != nil {
t.Errorf("Test '%v': Error decoding data: %v", desc, err)
return
}

lastRequestBodyStr := string(decodedRequestBody)
if testCase.expectedBody != lastRequestBodyStr {
t.Errorf(
"Test '%v': Expected body '%v' but got: %v",
desc,
testCase.expectedBody,
lastRequestBodyStr,
)
}
})
}
2 changes: 1 addition & 1 deletion relay/plugins/traffic/paths-plugin/paths-plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ func runPathsPluginTest(t *testing.T, testCase pathsPluginTestCase) {
lastRequest, err = altCatcherService.LastRequest()
}
if err != nil {
t.Errorf("Error reading last request from catcher: %v", err)
t.Errorf("Text '%v': Error reading last request from catcher: %v", testCase.desc, err)
return
}

Expand Down
Loading

0 comments on commit a272202

Please sign in to comment.