Skip to content

Commit

Permalink
Use http.NewRequestWithContext for outgoing HTTP requests
Browse files Browse the repository at this point in the history
This simple change mainly affects the distribution client. By respecting
the context the caller passes in, timeouts and cancellations will work
as expected. Also, transports which rely on the context (such as tracing
transports that retrieve a span from the context) will work properly.

Signed-off-by: Aaron Lehmann <alehmann@netflix.com>
  • Loading branch information
aaronlehmann committed Aug 10, 2022
1 parent 26163d8 commit fbdfd1a
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 20 deletions.
27 changes: 16 additions & 11 deletions registry/client/auth/session.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package auth

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -258,7 +259,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
}.String())
}

token, err := th.getToken(params, additionalScopes...)
token, err := th.getToken(req.Context(), params, additionalScopes...)
if err != nil {
return err
}
Expand All @@ -268,7 +269,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
return nil
}

func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...string) (string, error) {
func (th *tokenHandler) getToken(ctx context.Context, params map[string]string, additionalScopes ...string) (string, error) {
th.tokenLock.Lock()
defer th.tokenLock.Unlock()
scopes := make([]string, 0, len(th.scopes)+len(additionalScopes))
Expand All @@ -286,7 +287,7 @@ func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...s

now := th.clock.Now()
if now.After(th.tokenExpiration) || addedScopes {
token, expiration, err := th.fetchToken(params, scopes)
token, expiration, err := th.fetchToken(ctx, params, scopes)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -320,7 +321,7 @@ type postTokenResponse struct {
Scope string `json:"scope"`
}

func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
func (th *tokenHandler) fetchTokenWithOAuth(ctx context.Context, realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
form := url.Values{}
form.Set("scope", strings.Join(scopes, " "))
form.Set("service", service)
Expand Down Expand Up @@ -348,7 +349,12 @@ func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, servic
return "", time.Time{}, fmt.Errorf("no supported grant type")
}

resp, err := th.client().PostForm(realm.String(), form)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm.String(), strings.NewReader(form.Encode()))
if err != nil {
return "", time.Time{}, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := th.client().Do(req)
if err != nil {
return "", time.Time{}, err
}
Expand Down Expand Up @@ -396,9 +402,8 @@ type getTokenResponse struct {
RefreshToken string `json:"refresh_token"`
}

func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {

req, err := http.NewRequest("GET", realm.String(), nil)
func (th *tokenHandler) fetchTokenWithBasicAuth(ctx context.Context, realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm.String(), nil)
if err != nil {
return "", time.Time{}, err
}
Expand Down Expand Up @@ -479,7 +484,7 @@ func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string,
return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
}

func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
func (th *tokenHandler) fetchToken(ctx context.Context, params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
realm, ok := params["realm"]
if !ok {
return "", time.Time{}, errors.New("no realm specified for token auth challenge")
Expand All @@ -500,10 +505,10 @@ func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (t
}

if refreshToken != "" || th.forceOAuth {
return th.fetchTokenWithOAuth(realmURL, refreshToken, service, scopes)
return th.fetchTokenWithOAuth(ctx, realmURL, refreshToken, service, scopes)
}

return th.fetchTokenWithBasicAuth(realmURL, service, scopes)
return th.fetchTokenWithBasicAuth(ctx, realmURL, service, scopes)
}

type basicHandler struct {
Expand Down
10 changes: 6 additions & 4 deletions registry/client/blob_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
)

type httpBlobUpload struct {
ctx context.Context

statter distribution.BlobStatter
client *http.Client

Expand All @@ -36,7 +38,7 @@ func (hbu *httpBlobUpload) handleErrorResponse(resp *http.Response) error {
}

func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
req, err := http.NewRequest("PATCH", hbu.location, ioutil.NopCloser(r))
req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, ioutil.NopCloser(r))
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -69,7 +71,7 @@ func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
}

func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) {
req, err := http.NewRequest("PATCH", hbu.location, bytes.NewReader(p))
req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, bytes.NewReader(p))
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -117,7 +119,7 @@ func (hbu *httpBlobUpload) StartedAt() time.Time {

func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) {
// TODO(dmcgowan): Check if already finished, if so just fetch
req, err := http.NewRequest("PUT", hbu.location, nil)
req, err := http.NewRequestWithContext(hbu.ctx, "PUT", hbu.location, nil)
if err != nil {
return distribution.Descriptor{}, err
}
Expand All @@ -140,7 +142,7 @@ func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descrip
}

func (hbu *httpBlobUpload) Cancel(ctx context.Context) error {
req, err := http.NewRequest("DELETE", hbu.location, nil)
req, err := http.NewRequestWithContext(hbu.ctx, "DELETE", hbu.location, nil)
if err != nil {
return err
}
Expand Down
5 changes: 5 additions & 0 deletions registry/client/blob_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package client

import (
"bytes"
"context"
"fmt"
"net/http"
"testing"
Expand Down Expand Up @@ -126,6 +127,7 @@ func TestUploadReadFrom(t *testing.T) {
defer c()

blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
}

Expand Down Expand Up @@ -265,6 +267,7 @@ func TestUploadSize(t *testing.T) {

// Writing with ReadFrom
blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
location: e + readFromLocationPath,
}
Expand All @@ -284,6 +287,7 @@ func TestUploadSize(t *testing.T) {

// Writing with Write
blobUpload = &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
location: e + writeLocationPath,
}
Expand Down Expand Up @@ -409,6 +413,7 @@ func TestUploadWrite(t *testing.T) {
defer c()

blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
}

Expand Down
10 changes: 5 additions & 5 deletions registry/client/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ func (r *registry) Repositories(ctx context.Context, entries []string, last stri
return 0, err
}

for cnt := range ctlg.Repositories {
entries[cnt] = ctlg.Repositories[cnt]
}
copy(entries, ctlg.Repositories)
numFilled = len(ctlg.Repositories)

link := resp.Header.Get("Link")
Expand Down Expand Up @@ -373,7 +371,7 @@ func (t *tags) Untag(ctx context.Context, tag string) error {
return err
}

req, err := http.NewRequest("DELETE", u, nil)
req, err := http.NewRequestWithContext(ctx, "DELETE", u, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -792,7 +790,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
return nil, err
}

req, err := http.NewRequest("POST", u, nil)
req, err := http.NewRequestWithContext(ctx, "POST", u, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -827,6 +825,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
}

return &httpBlobUpload{
ctx: ctx,
statter: bs.statter,
client: bs.client,
uuid: uuid,
Expand All @@ -845,6 +844,7 @@ func (bs *blobs) Resume(ctx context.Context, id string) (distribution.BlobWriter
}

return &httpBlobUpload{
ctx: ctx,
statter: bs.statter,
client: bs.client,
uuid: id,
Expand Down

0 comments on commit fbdfd1a

Please sign in to comment.