diff --git a/auth.go b/auth.go index 7579b4a..9c473a1 100644 --- a/auth.go +++ b/auth.go @@ -3,6 +3,7 @@ package reggie import ( "encoding/json" "errors" + "fmt" "regexp" "strings" @@ -36,6 +37,13 @@ func (client *Client) retryRequestWithAuth(originalRequest *Request, originalRes originalRequest.QueryParam.Del(k) } + if originalRequest.retryCallback != nil { + err := originalRequest.retryCallback(originalRequest) + if err != nil { + return nil, fmt.Errorf("retry callback returned error: %s", err) + } + } + authenticationType := authHeaderMatcher.ReplaceAllString(authHeaderRaw, "$1") if strings.EqualFold(authenticationType, "bearer") { h := parseAuthHeader(authHeaderRaw) diff --git a/client.go b/client.go index 85bf4de..4514bcd 100644 --- a/client.go +++ b/client.go @@ -133,7 +133,10 @@ func (client *Client) NewRequest(method string, path string, opts ...requestOpti restyRequest.URL = url restyRequest.SetHeader("User-Agent", client.Config.UserAgent) - return &Request{restyRequest} + return &Request{ + Request: restyRequest, + retryCallback: r.RetryCallback, + } } // Do executes a Request and returns a Response. diff --git a/client_test.go b/client_test.go index e51d9e2..4ac26df 100644 --- a/client_test.go +++ b/client_test.go @@ -3,6 +3,7 @@ package reggie import ( "bytes" "encoding/base64" + "errors" "fmt" "net/http" "net/http/httptest" @@ -260,6 +261,34 @@ func TestClient(t *testing.T) { t.Fatalf("Expected body to be \"abc\" but instead got %s", lastCapturedRequestBodyStr) } + // Test that the retry callback is invoked, if configured. + newBody := "not the original body" + req = client.NewRequest(PUT, "/a/b/c", WithRetryCallback(func(r *Request) error { + r.SetBody(newBody) + return nil + })).SetBody([]byte("original body")) + _, err = client.Do(req) + if err != nil { + t.Fatalf("Errors executing request: %s", err) + } + + // Ensure the request ended up with the new body. + if lastCapturedRequestBodyStr != newBody { + t.Fatalf("Expected body to be %q, but instead got %q", newBody, lastCapturedRequestBodyStr) + } + + // Test the case where the retry callback returns an error. + req = client.NewRequest(PUT, "/a/b/c", WithRetryCallback(func(r *Request) error { + return errors.New("uh oh") + })).SetBody([]byte("original body")) + _, err = client.Do(req) + if err == nil { + t.Fatalf("Expected error from callback function, but request returned no error") + } + if !strings.Contains(err.Error(), "uh oh") { + t.Fatalf("Expected error to contain callback error \"uh oh\", but instead got %q", err) + } + // test for access_token vs. token authUseAccessToken = true req = client.NewRequest(GET, "/v2//tags/list") diff --git a/request.go b/request.go index 45425da..6be5ae0 100644 --- a/request.go +++ b/request.go @@ -8,16 +8,22 @@ import ( ) type ( + // RetryCallbackFunc is a function that can mutate a request prior to it + // being retried. + RetryCallbackFunc func(*Request) error + // Request is an HTTP request to be sent to an OCI registry. Request struct { *resty.Request + retryCallback RetryCallbackFunc } requestConfig struct { - Name string - Reference string - Digest string - SessionID string + Name string + Reference string + Digest string + SessionID string + RetryCallback RetryCallbackFunc } requestOption func(c *requestConfig) @@ -51,6 +57,15 @@ func WithSessionID(id string) requestOption { } } +// WithRetryCallback specifies a callback that will be invoked before a request +// is retried. This is useful for, e.g., ensuring an io.Reader used for the body +// will produce the right content on retry. +func WithRetryCallback(cb RetryCallbackFunc) requestOption { + return func(c *requestConfig) { + c.RetryCallback = cb + } +} + // SetBody wraps the resty SetBody and returns the request, allowing method chaining func (req *Request) SetBody(body interface{}) *Request { req.Request.SetBody(body)