Skip to content

Commit

Permalink
Limit duration format according to spec (#20)
Browse files Browse the repository at this point in the history
We used to accept anything `time.ParseDuration` would parse which is too
broad and violates the spec.
Also ensured that we use a custom `formatDuration` implementation to
ensure compatibility.
  • Loading branch information
bergundy authored Sep 13, 2024
1 parent 0fbe710 commit 71165c7
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 11 deletions.
32 changes: 31 additions & 1 deletion nexus/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"mime"
"net/http"
"net/url"
"regexp"
"strconv"
"strings"
"time"
)
Expand Down Expand Up @@ -202,7 +204,7 @@ func addContextTimeoutToHTTPHeader(ctx context.Context, httpHeader http.Header)
if !ok {
return httpHeader
}
httpHeader.Set(HeaderRequestTimeout, time.Until(deadline).String())
httpHeader.Set(HeaderRequestTimeout, formatDuration(time.Until(deadline)))
return httpHeader
}

Expand Down Expand Up @@ -333,3 +335,31 @@ func validateLinkType(value string) error {
}
return nil
}

var durationRegexp = regexp.MustCompile(`^(\d+(?:\.\d+)?)(ms|s|m)$`)

func parseDuration(value string) (time.Duration, error) {
m := durationRegexp.FindStringSubmatch(value)
if len(m) == 0 {
return 0, fmt.Errorf("invalid duration: %q", value)
}
v, err := strconv.ParseFloat(m[1], 64)
if err != nil {
return 0, err
}

switch m[2] {
case "ms":
return time.Millisecond * time.Duration(v), nil
case "s":
return time.Millisecond * time.Duration(v*1e3), nil
case "m":
return time.Millisecond * time.Duration(v*1e3*60), nil
}
panic("unreachable")
}

// formatDuration converts a duration into a string representation in millisecond resolution.
func formatDuration(d time.Duration) string {
return strconv.FormatInt(d.Milliseconds(), 10) + "ms"
}
21 changes: 21 additions & 0 deletions nexus/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"net/url"
"reflect"
"testing"
"time"

"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -508,3 +509,23 @@ func TestDecodeLink(t *testing.T) {
})
}
}

func TestParseDuration(t *testing.T) {
_, err := parseDuration("invalid")
require.ErrorContains(t, err, "invalid duration:")
d, err := parseDuration("10ms")
require.NoError(t, err)
require.Equal(t, 10*time.Millisecond, d)
d, err = parseDuration("10.1ms")
require.NoError(t, err)
require.Equal(t, 10*time.Millisecond, d)
d, err = parseDuration("1s")
require.NoError(t, err)
require.Equal(t, 1*time.Second, d)
d, err = parseDuration("999m")
require.NoError(t, err)
require.Equal(t, 999*time.Minute, d)
d, err = parseDuration("1.3s")
require.NoError(t, err)
require.Equal(t, 1300*time.Millisecond, d)
}
2 changes: 1 addition & 1 deletion nexus/cancel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func TestCancel_RequestTimeoutHeaderOverridesContextDeadline(t *testing.T) {

handle, err := client.NewHandle("foo", "timeout")
require.NoError(t, err)
err = handle.Cancel(ctx, CancelOperationOptions{Header: Header{HeaderRequestTimeout: timeout.String()}})
err = handle.Cancel(ctx, CancelOperationOptions{Header: Header{HeaderRequestTimeout: formatDuration(timeout)}})
require.NoError(t, err)
}

Expand Down
2 changes: 1 addition & 1 deletion nexus/get_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func TestGetInfo_RequestTimeoutHeaderOverridesContextDeadline(t *testing.T) {

handle, err := client.NewHandle("foo", "timeout")
require.NoError(t, err)
_, err = handle.GetInfo(ctx, GetOperationInfoOptions{Header: Header{HeaderRequestTimeout: timeout.String()}})
_, err = handle.GetInfo(ctx, GetOperationInfoOptions{Header: Header{HeaderRequestTimeout: formatDuration(timeout)}})
require.NoError(t, err)
}

Expand Down
2 changes: 1 addition & 1 deletion nexus/get_result_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ func TestWaitResult_RequestTimeout(t *testing.T) {

timeout := 200 * time.Millisecond
deadline := time.Now().Add(200 * time.Millisecond)
_, err = handle.GetResult(ctx, GetOperationResultOptions{Wait: time.Second, Header: Header{HeaderRequestTimeout: timeout.String()}})
_, err = handle.GetResult(ctx, GetOperationResultOptions{Wait: time.Second, Header: Header{HeaderRequestTimeout: formatDuration(timeout)}})
require.ErrorIs(t, err, ErrOperationStillRunning)
require.WithinDuration(t, deadline, handler.requests[0].deadline, 1*time.Millisecond)
}
Expand Down
3 changes: 1 addition & 2 deletions nexus/handle.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package nexus
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"time"
Expand Down Expand Up @@ -87,7 +86,7 @@ func (h *OperationHandle[T]) GetResult(ctx context.Context, options GetOperation
}

q := request.URL.Query()
q.Set(queryWait, fmt.Sprintf("%dms", wait.Milliseconds()))
q.Set(queryWait, formatDuration(wait))
request.URL.RawQuery = q.Encode()
} else {
// We may reuse the request object multiple times and will need to reset the query when wait becomes 0 or
Expand Down
4 changes: 2 additions & 2 deletions nexus/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func (h *httpHandler) getOperationResult(service, operation, operationID string,
}
waitStr := request.URL.Query().Get(queryWait)
if waitStr != "" {
waitDuration, err := time.ParseDuration(waitStr)
waitDuration, err := parseDuration(waitStr)
if err != nil {
h.logger.Warn("invalid wait duration query parameter", "wait", waitStr)
h.writeFailure(writer, HandlerErrorf(HandlerErrorTypeBadRequest, "invalid wait query parameter"))
Expand Down Expand Up @@ -401,7 +401,7 @@ func (h *httpHandler) cancelOperation(service, operation, operationID string, wr
func (h *httpHandler) parseRequestTimeoutHeader(writer http.ResponseWriter, request *http.Request) (time.Duration, bool) {
timeoutStr := request.Header.Get(HeaderRequestTimeout)
if timeoutStr != "" {
timeoutDuration, err := time.ParseDuration(timeoutStr)
timeoutDuration, err := parseDuration(timeoutStr)
if err != nil {
h.logger.Warn("invalid request timeout header", "timeout", timeoutStr)
h.writeFailure(writer, HandlerErrorf(HandlerErrorTypeBadRequest, "invalid request timeout header"))
Expand Down
6 changes: 3 additions & 3 deletions nexus/start_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ func (h *timeoutEchoHandler) StartOperation(ctx context.Context, service, operat
}, nil
}
return &HandlerStartOperationResultSync[any]{
Value: []byte(time.Until(deadline).String()),
Value: []byte(formatDuration(time.Until(deadline))),
}, nil
}

Expand All @@ -289,7 +289,7 @@ func TestStart_RequestTimeoutHeaderOverridesContextDeadline(t *testing.T) {
defer teardown()

timeout := 100 * time.Millisecond
result, err := client.StartOperation(ctx, "foo", nil, StartOperationOptions{Header: Header{HeaderRequestTimeout: timeout.String()}})
result, err := client.StartOperation(ctx, "foo", nil, StartOperationOptions{Header: Header{HeaderRequestTimeout: formatDuration(timeout)}})

require.NoError(t, err)
requireTimeoutPropagated(t, result, timeout)
Expand All @@ -301,7 +301,7 @@ func requireTimeoutPropagated(t *testing.T, result *ClientStartOperationResult[*
var responseBody []byte
err := response.Consume(&responseBody)
require.NoError(t, err)
parsedTimeout, err := time.ParseDuration(string(responseBody))
parsedTimeout, err := parseDuration(string(responseBody))
require.NoError(t, err)
require.NotZero(t, parsedTimeout)
require.LessOrEqual(t, parsedTimeout, expected)
Expand Down

0 comments on commit 71165c7

Please sign in to comment.