Skip to content

Commit

Permalink
pr feedback (sethf)
Browse files Browse the repository at this point in the history
  • Loading branch information
jurassix committed Feb 9, 2024
1 parent e21c919 commit 4f9a8c3
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,6 @@ import (
"github.com/fullstorydev/relay-core/relay/version"
)

type Encoding int

const (
Identity Encoding = iota
Gzip
)

func TestContentBlocking(t *testing.T) {
testCases := []contentBlockerTestCase{
{
Expand Down Expand Up @@ -141,8 +134,8 @@ func TestContentBlocking(t *testing.T) {
}

for _, testCase := range testCases {
runContentBlockerTest(t, testCase, Identity)
runContentBlockerTest(t, testCase, Gzip)
runContentBlockerTest(t, testCase, traffic.Identity)
runContentBlockerTest(t, testCase, traffic.Gzip)
}
}

Expand Down Expand Up @@ -194,12 +187,12 @@ type contentBlockerTestCase struct {
expectedHeaders map[string]string
}

func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encoding Encoding) {
func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encoding traffic.Encoding) {
var encodingStr string
switch encoding {
case Gzip:
case traffic.Gzip:
encodingStr = "gzip"
case Identity:
case traffic.Identity:
encodingStr = ""
}

Expand All @@ -223,7 +216,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encodi
expectedHeaders[content_blocker_plugin.PluginVersionHeaderName] = version.RelayRelease

test.WithCatcherAndRelay(t, testCase.config, plugins, func(catcherService *catcher.Service, relayService *relay.Service) {
b, err := traffic.EncodeData([]byte(testCase.originalBody), encodingStr)
b, err := traffic.EncodeData([]byte(testCase.originalBody), encoding)
if err != nil {
t.Errorf("Test '%v': Error encoding data: %v", desc, err)
return
Expand All @@ -239,7 +232,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encodi
return
}

if encoding == Gzip {
if encoding == traffic.Gzip {
request.Header.Set("Content-Encoding", "gzip")
}

Expand Down Expand Up @@ -309,7 +302,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encodi
)
}

decodedRequestBody, err := traffic.DecodeData(lastRequestBody, encodingStr)
decodedRequestBody, err := traffic.DecodeData(lastRequestBody, encoding)
if err != nil {
t.Errorf("Test '%v': Error decoding data: %v", desc, err)
return
Expand Down
49 changes: 35 additions & 14 deletions relay/traffic/encoding.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,22 @@ package traffic
import (
"bytes"
"compress/gzip"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)

func GetContentEncoding(request *http.Request) (string, error) {
type Encoding int

const (
Unsupported Encoding = iota
Identity
Gzip
)

func GetContentEncoding(request *http.Request) (Encoding, error) {
// NOTE: This is a workaround for a bug in post-Go 1.17. See golang.org/issue/25192.
// Our algorithm differs from the logic of AllowQuerySemicolons by replacing semicolons with encoded semicolons instead
// of with ampersands. This is because we want to preserve the original query string as much as possible.
Expand All @@ -19,37 +28,47 @@ func GetContentEncoding(request *http.Request) (string, error) {

queryParams, err := url.ParseQuery(request.URL.RawQuery)
if err != nil {
return "", err
return Unsupported, err
}

// request query parameter takes precedence over request header
encoding := queryParams.Get("ContentEncoding")
if encoding == "" {
encoding = request.Header.Get("Content-Encoding")
}
return encoding, nil

switch encoding {
case "gzip":
return Gzip, nil
case "":
return Identity, nil
default:
return Unsupported, fmt.Errorf("unsupported encoding: %v", encoding)
}
}

// WrapReader checks if the request Content-Encoding or request query parameter indicates gzip compression.
// If so, it returns a gzip.Reader that decompresses the content.
func WrapReader(request *http.Request, encoding string) (io.ReadCloser, error) {
func WrapReader(request *http.Request, encoding Encoding) (io.ReadCloser, error) {
if request.Body == nil {
return nil, nil
}

switch encoding {
case "gzip":
case Gzip:
// Create a new gzip.Reader to decompress the request body
return gzip.NewReader(request.Body)
default:
case Identity:
// If the content is not gzip-compressed, return the original request body
return request.Body, nil
default:
return nil, fmt.Errorf("unsupported encoding: %v", encoding)
}
}

func EncodeData(data []byte, encoding string) ([]byte, error) {
func EncodeData(data []byte, encoding Encoding) ([]byte, error) {
switch encoding {
case "gzip":
case Gzip:
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)

Expand All @@ -65,15 +84,16 @@ func EncodeData(data []byte, encoding string) ([]byte, error) {

compressedData := buf.Bytes()
return compressedData, nil
default:
// identity encoding
case Identity:
return data, nil
default:
return nil, fmt.Errorf("unsupported encoding: %v", encoding)
}
}

func DecodeData(data []byte, encoding string) ([]byte, error) {
func DecodeData(data []byte, encoding Encoding) ([]byte, error) {
switch encoding {
case "gzip":
case Gzip:
reader, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, err
Expand All @@ -85,8 +105,9 @@ func DecodeData(data []byte, encoding string) ([]byte, error) {
}

return decodedData, nil
default:
// identity encoding
case Identity:
return data, nil
default:
return nil, fmt.Errorf("unsupported encoding: %v", encoding)
}
}
58 changes: 31 additions & 27 deletions relay/traffic/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (handler *Handler) ServeHTTP(response http.ResponseWriter, request *http.Re

encoding, err := GetContentEncoding(request)
if err != nil {
logger.Printf("URL %v error getting request content encoding: %v", request.URL, err)
logger.Printf("URL %v error in request content encoding: %v", request.URL, err)
request.Body = http.NoBody
return
}
Expand Down Expand Up @@ -95,7 +95,7 @@ func (handler *Handler) ServeHTTP(response http.ResponseWriter, request *http.Re
}

// prepareRequestBody wraps the request Body with a reader that will decode the content if necessary.
func (handler *Handler) prepareRequestBody(clientRequest *http.Request, encoding string) error {
func (handler *Handler) prepareRequestBody(clientRequest *http.Request, encoding Encoding) error {
if reader, err := WrapReader(clientRequest, encoding); err != nil {
return err
} else if reader != nil && reader != http.NoBody {
Expand All @@ -104,7 +104,7 @@ func (handler *Handler) prepareRequestBody(clientRequest *http.Request, encoding
return nil
}

func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, clientRequest *http.Request, serviced bool, encoding string) bool {
func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, clientRequest *http.Request, serviced bool, encoding Encoding) bool {
if serviced {
return false
}
Expand All @@ -124,35 +124,39 @@ func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, client
}
}

func (handler *Handler) ensureBodyContentEncoding(clientRequest *http.Request, encoding string) {
if encoding == "" || encoding == "identity" {
func (handler *Handler) ensureBodyContentEncoding(clientRequest *http.Request, encoding Encoding) {
switch encoding {
case Unsupported:
logger.Println("Error unsupported content-encoding")
return
}

servicedBody, err := io.ReadAll(clientRequest.Body)
if err != nil {
logger.Printf("Error reading request body: %s", err)
clientRequest.Body = http.NoBody
case Identity:
return
}
case Gzip:
servicedBody, err := io.ReadAll(clientRequest.Body)
if err != nil {
logger.Printf("Error reading request body: %s", err)
clientRequest.Body = http.NoBody
return
}

if encodedData, err := EncodeData(servicedBody, encoding); err != nil {
logger.Printf("Error encoding request body: %s", err)
clientRequest.Body = http.NoBody
return
} else {
servicedBody = encodedData
}
if encodedData, err := EncodeData(servicedBody, encoding); err != nil {
logger.Printf("Error encoding request body: %s", err)
clientRequest.Body = http.NoBody
return
} else {
servicedBody = encodedData
}

// If the length of the body has changed, we should update the
// Content-Length header too.
contentLength := int64(len(servicedBody))
if contentLength != clientRequest.ContentLength {
clientRequest.ContentLength = contentLength
clientRequest.Header.Set("Content-Length", strconv.FormatInt(contentLength, 10))
}
// If the length of the body has changed, we should update the
// Content-Length header too.
contentLength := int64(len(servicedBody))
if contentLength != clientRequest.ContentLength {
clientRequest.ContentLength = contentLength
clientRequest.Header.Set("Content-Length", strconv.FormatInt(contentLength, 10))
}

clientRequest.Body = io.NopCloser(bytes.NewBuffer(servicedBody))
clientRequest.Body = io.NopCloser(bytes.NewBuffer(servicedBody))
}

}

Expand Down
27 changes: 10 additions & 17 deletions relay/traffic/traffic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,33 +151,26 @@ func TestMaxBodySize(t *testing.T) {
})
}

type Encoding int

const (
Identity Encoding = iota
Gzip
)

func TestRelaySupportsContentEncoding(t *testing.T) {
testCases := map[string]struct {
encoding Encoding
encoding traffic.Encoding
bodyContentStr string
headers map[string]string
customUrl func(relayServiceURL string) string
}{
"identity": {
encoding: Identity,
encoding: traffic.Identity,
bodyContentStr: "Hello, world!",
},
"gzip - with header": {
encoding: Gzip,
encoding: traffic.Gzip,
bodyContentStr: "Hello, world!",
headers: map[string]string{
"Content-Encoding": "gzip",
},
},
"gzip - with query param": {
encoding: Gzip,
encoding: traffic.Gzip,
bodyContentStr: "Hello, world!",
customUrl: func(relayServiceURL string) string {
return fmt.Sprintf("%v?ContentEncoding=gzip", relayServiceURL)
Expand All @@ -190,14 +183,14 @@ func TestRelaySupportsContentEncoding(t *testing.T) {
// convert the body content to a reader with the proper content encoding applied
var body io.Reader
switch testCase.encoding {
case Gzip:
b, err := traffic.EncodeData([]byte(testCase.bodyContentStr), "gzip")
case traffic.Gzip:
b, err := traffic.EncodeData([]byte(testCase.bodyContentStr), traffic.Gzip)
if err != nil {
t.Errorf("Test %s - Error encoding data: %v", desc, err)
return
}
body = bytes.NewReader(b)
case Identity:
case traffic.Identity:
body = strings.NewReader(testCase.bodyContentStr)
}

Expand Down Expand Up @@ -235,16 +228,16 @@ func TestRelaySupportsContentEncoding(t *testing.T) {
}

switch testCase.encoding {
case Gzip:
decodedData, err := traffic.DecodeData(lastRequest, "gzip")
case traffic.Gzip:
decodedData, err := traffic.DecodeData(lastRequest, traffic.Gzip)
if err != nil {
t.Errorf("Test %s - Error decoding data: %v", desc, err)
return
}
if string(decodedData) != testCase.bodyContentStr {
t.Errorf("Test %s - Expected body '%v' but got: %v", desc, testCase.bodyContentStr, string(decodedData))
}
case Identity:
case traffic.Identity:
if string(lastRequest) != testCase.bodyContentStr {
t.Errorf("Test %s - Expected body '%v' but got: %v", desc, testCase.bodyContentStr, string(lastRequest))
}
Expand Down

0 comments on commit 4f9a8c3

Please sign in to comment.