Skip to content

Commit

Permalink
Merge branch '8-checksum-failure-when-downloading-using-edl-token'
Browse files Browse the repository at this point in the history
  • Loading branch information
bmflynn committed Jan 2, 2025
2 parents 98fe38b + eddd434 commit 03dc6ee
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 36 deletions.
47 changes: 17 additions & 30 deletions internal/fetch_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import (
)

var defaultNetrcFinder = findNetrc
var edlToken = ""

type FailedDownload struct {
RequestID string
Expand Down Expand Up @@ -64,14 +63,6 @@ func ResolveEDLToken(token string) string {
return resolvedToken
}

// Sets token auth header
func newRedirectWithToken(bearer string) (func(*http.Request, []*http.Request) error, error) {
return func(req *http.Request, via []*http.Request) error {
req.Header.Add("Authorization", "Bearer "+bearer)
return nil
}, nil
}

// Sets basic auth on redirect if the host is in the netrc file.
func newRedirectWithNetrcCredentials() (func(*http.Request, []*http.Request) error, error) {
fpath, err := defaultNetrcFinder()
Expand Down Expand Up @@ -101,6 +92,8 @@ func newRedirectWithNetrcCredentials() (func(*http.Request, []*http.Request) err
type HTTPFetcher struct {
client *http.Client
readSize int64
// If provided an authorization header is added to every request
bearerToken string
}

func NewHTTPFetcher(netrc bool, edlToken string) (*HTTPFetcher, error) {
Expand All @@ -118,17 +111,28 @@ func NewHTTPFetcher(netrc bool, edlToken string) (*HTTPFetcher, error) {
client.Jar = jar
client.CheckRedirect, err = newRedirectWithNetrcCredentials()
if err != nil {
return nil, err
return nil, fmt.Errorf("configuring netrc token redirect: %w", err)
}
}
return &HTTPFetcher{
client: client,
readSize: 2 << 19,
client: client,
readSize: 2 << 19,
bearerToken: edlToken,
}, nil
}

func (f *HTTPFetcher) newRequest(ctx context.Context, url string) (*http.Request, error) {
return http.NewRequestWithContext(ctx, "GET", url, nil)
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, err
}
if f.bearerToken != "" {
if req.URL.Scheme != "https" {
return nil, fmt.Errorf("refusing to add bearer token to non-https url %s", req.URL)
}
req.Header.Add("Authorization", "Bearer "+f.bearerToken)
}
return req, nil
}

// Fetch url to destdir using url's basename as the filename and update hash with the file
Expand All @@ -139,10 +143,6 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string, w io.Writer) (int64
return 0, err
}

if edlToken != "" {
req.Header.Add("Authorization", "Bearer "+edlToken)
}

resp, err := f.client.Do(req)
if err != nil {
return 0, err
Expand All @@ -154,10 +154,6 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string, w io.Writer) (int64
}
defer resp.Body.Close()

// if err := validateResponse(req, resp); err != nil {
// return 0, err
// }

var size int64
buf := make([]byte, f.readSize)
r := bufio.NewReader(resp.Body)
Expand All @@ -179,12 +175,3 @@ func (f *HTTPFetcher) Fetch(ctx context.Context, url string, w io.Writer) (int64
}
return size, nil
}

func validateResponse(req *http.Request, resp *http.Response) error {
wanted := req.URL.Hostname()
found := resp.Request.URL.Hostname()
if wanted != found {
return fmt.Errorf("probable auth redirect error; expected host %s, found %s", wanted, found)
}
return nil
}
22 changes: 16 additions & 6 deletions internal/fetch_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,21 @@ func TestHTTPFetcher(t *testing.T) {
})
}

func Test_validateResponse(t *testing.T) {
req := httptest.NewRequest("GET", "http://localhost/", nil)
resp := &http.Response{
Request: httptest.NewRequest("GET", "http://nope/", nil),
}
func TestHTTPFetcherRequest(t *testing.T) {
fetcher, err := NewHTTPFetcher(false, "XXX")
require.NoError(t, err)

t.Run("token with non-https is error", func(t *testing.T) {
req, err := fetcher.newRequest(context.TODO(), "http://server/path/file.ext")
require.Nil(t, req)
require.Error(t, err, "expected error with http scheme")
})

t.Run("token header present", func(t *testing.T) {
req, err := fetcher.newRequest(context.TODO(), "https://server/path/file.ext")
require.NotNil(t, req)
require.NoError(t, err)

require.Error(t, validateResponse(req, resp))
require.Equal(t, "Bearer XXX", req.Header.Get("Authorization"))
})
}

0 comments on commit 03dc6ee

Please sign in to comment.