Skip to content

Commit

Permalink
Merge pull request distribution#3711 from aaronlehmann/request-with-c…
Browse files Browse the repository at this point in the history
…ontext

Use http.NewRequestWithContext for outgoing HTTP requests
  • Loading branch information
milosgajdos authored Aug 16, 2022
2 parents 1db54ec + fbdfd1a commit 6c23795
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 20 deletions.
27 changes: 16 additions & 11 deletions registry/client/auth/session.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package auth

import (
"context"
"encoding/json"
"errors"
"fmt"
Expand Down Expand Up @@ -258,7 +259,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
}.String())
}

token, err := th.getToken(params, additionalScopes...)
token, err := th.getToken(req.Context(), params, additionalScopes...)
if err != nil {
return err
}
Expand All @@ -268,7 +269,7 @@ func (th *tokenHandler) AuthorizeRequest(req *http.Request, params map[string]st
return nil
}

func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...string) (string, error) {
func (th *tokenHandler) getToken(ctx context.Context, params map[string]string, additionalScopes ...string) (string, error) {
th.tokenLock.Lock()
defer th.tokenLock.Unlock()
scopes := make([]string, 0, len(th.scopes)+len(additionalScopes))
Expand All @@ -286,7 +287,7 @@ func (th *tokenHandler) getToken(params map[string]string, additionalScopes ...s

now := th.clock.Now()
if now.After(th.tokenExpiration) || addedScopes {
token, expiration, err := th.fetchToken(params, scopes)
token, expiration, err := th.fetchToken(ctx, params, scopes)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -320,7 +321,7 @@ type postTokenResponse struct {
Scope string `json:"scope"`
}

func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
func (th *tokenHandler) fetchTokenWithOAuth(ctx context.Context, realm *url.URL, refreshToken, service string, scopes []string) (token string, expiration time.Time, err error) {
form := url.Values{}
form.Set("scope", strings.Join(scopes, " "))
form.Set("service", service)
Expand Down Expand Up @@ -348,7 +349,12 @@ func (th *tokenHandler) fetchTokenWithOAuth(realm *url.URL, refreshToken, servic
return "", time.Time{}, fmt.Errorf("no supported grant type")
}

resp, err := th.client().PostForm(realm.String(), form)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, realm.String(), strings.NewReader(form.Encode()))
if err != nil {
return "", time.Time{}, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, err := th.client().Do(req)
if err != nil {
return "", time.Time{}, err
}
Expand Down Expand Up @@ -396,9 +402,8 @@ type getTokenResponse struct {
RefreshToken string `json:"refresh_token"`
}

func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {

req, err := http.NewRequest("GET", realm.String(), nil)
func (th *tokenHandler) fetchTokenWithBasicAuth(ctx context.Context, realm *url.URL, service string, scopes []string) (token string, expiration time.Time, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, realm.String(), nil)
if err != nil {
return "", time.Time{}, err
}
Expand Down Expand Up @@ -479,7 +484,7 @@ func (th *tokenHandler) fetchTokenWithBasicAuth(realm *url.URL, service string,
return tr.Token, tr.IssuedAt.Add(time.Duration(tr.ExpiresIn) * time.Second), nil
}

func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
func (th *tokenHandler) fetchToken(ctx context.Context, params map[string]string, scopes []string) (token string, expiration time.Time, err error) {
realm, ok := params["realm"]
if !ok {
return "", time.Time{}, errors.New("no realm specified for token auth challenge")
Expand All @@ -500,10 +505,10 @@ func (th *tokenHandler) fetchToken(params map[string]string, scopes []string) (t
}

if refreshToken != "" || th.forceOAuth {
return th.fetchTokenWithOAuth(realmURL, refreshToken, service, scopes)
return th.fetchTokenWithOAuth(ctx, realmURL, refreshToken, service, scopes)
}

return th.fetchTokenWithBasicAuth(realmURL, service, scopes)
return th.fetchTokenWithBasicAuth(ctx, realmURL, service, scopes)
}

type basicHandler struct {
Expand Down
10 changes: 6 additions & 4 deletions registry/client/blob_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
)

type httpBlobUpload struct {
ctx context.Context

statter distribution.BlobStatter
client *http.Client

Expand All @@ -36,7 +38,7 @@ func (hbu *httpBlobUpload) handleErrorResponse(resp *http.Response) error {
}

func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
req, err := http.NewRequest("PATCH", hbu.location, ioutil.NopCloser(r))
req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, ioutil.NopCloser(r))
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -69,7 +71,7 @@ func (hbu *httpBlobUpload) ReadFrom(r io.Reader) (n int64, err error) {
}

func (hbu *httpBlobUpload) Write(p []byte) (n int, err error) {
req, err := http.NewRequest("PATCH", hbu.location, bytes.NewReader(p))
req, err := http.NewRequestWithContext(hbu.ctx, "PATCH", hbu.location, bytes.NewReader(p))
if err != nil {
return 0, err
}
Expand Down Expand Up @@ -117,7 +119,7 @@ func (hbu *httpBlobUpload) StartedAt() time.Time {

func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descriptor) (distribution.Descriptor, error) {
// TODO(dmcgowan): Check if already finished, if so just fetch
req, err := http.NewRequest("PUT", hbu.location, nil)
req, err := http.NewRequestWithContext(hbu.ctx, "PUT", hbu.location, nil)
if err != nil {
return distribution.Descriptor{}, err
}
Expand All @@ -140,7 +142,7 @@ func (hbu *httpBlobUpload) Commit(ctx context.Context, desc distribution.Descrip
}

func (hbu *httpBlobUpload) Cancel(ctx context.Context) error {
req, err := http.NewRequest("DELETE", hbu.location, nil)
req, err := http.NewRequestWithContext(hbu.ctx, "DELETE", hbu.location, nil)
if err != nil {
return err
}
Expand Down
5 changes: 5 additions & 0 deletions registry/client/blob_writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package client

import (
"bytes"
"context"
"fmt"
"net/http"
"testing"
Expand Down Expand Up @@ -126,6 +127,7 @@ func TestUploadReadFrom(t *testing.T) {
defer c()

blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
}

Expand Down Expand Up @@ -265,6 +267,7 @@ func TestUploadSize(t *testing.T) {

// Writing with ReadFrom
blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
location: e + readFromLocationPath,
}
Expand All @@ -284,6 +287,7 @@ func TestUploadSize(t *testing.T) {

// Writing with Write
blobUpload = &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
location: e + writeLocationPath,
}
Expand Down Expand Up @@ -409,6 +413,7 @@ func TestUploadWrite(t *testing.T) {
defer c()

blobUpload := &httpBlobUpload{
ctx: context.Background(),
client: &http.Client{},
}

Expand Down
10 changes: 5 additions & 5 deletions registry/client/repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ func (r *registry) Repositories(ctx context.Context, entries []string, last stri
return 0, err
}

for cnt := range ctlg.Repositories {
entries[cnt] = ctlg.Repositories[cnt]
}
copy(entries, ctlg.Repositories)
numFilled = len(ctlg.Repositories)

link := resp.Header.Get("Link")
Expand Down Expand Up @@ -373,7 +371,7 @@ func (t *tags) Untag(ctx context.Context, tag string) error {
return err
}

req, err := http.NewRequest("DELETE", u, nil)
req, err := http.NewRequestWithContext(ctx, "DELETE", u, nil)
if err != nil {
return err
}
Expand Down Expand Up @@ -792,7 +790,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
return nil, err
}

req, err := http.NewRequest("POST", u, nil)
req, err := http.NewRequestWithContext(ctx, "POST", u, nil)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -827,6 +825,7 @@ func (bs *blobs) Create(ctx context.Context, options ...distribution.BlobCreateO
}

return &httpBlobUpload{
ctx: ctx,
statter: bs.statter,
client: bs.client,
uuid: uuid,
Expand All @@ -845,6 +844,7 @@ func (bs *blobs) Resume(ctx context.Context, id string) (distribution.BlobWriter
}

return &httpBlobUpload{
ctx: ctx,
statter: bs.statter,
client: bs.client,
uuid: id,
Expand Down

0 comments on commit 6c23795

Please sign in to comment.