Skip to content

Commit

Permalink
Revert "Update configurePackage to use fixed download method"
Browse files Browse the repository at this point in the history
This reverts commit 3b05aa40e178ddefcfe4479d7a578a2564ff52c6.

cr: https://code.amazon.com/reviews/CR-160873742
  • Loading branch information
Yagnesh-Suribhatla authored and Chnwanze committed Nov 19, 2024
1 parent 6ae9397 commit c78facf
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 229 deletions.
104 changes: 3 additions & 101 deletions agent/fileutil/artifact/artifact.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ func s3Download(context context.T, amazonS3URL s3util.AmazonS3URL, destFile stri
params.ExpectedBucketOwner = aws.String(expectedBucketOwner)
}

if fileutil.Exists(destFile) && fileutil.Exists(eTagFile) {
if fileutil.Exists(destFile) == true && fileutil.Exists(eTagFile) == true {
var existingETag string
existingETag, err = fileutil.ReadAllText(eTagFile)
if err != nil {
Expand Down Expand Up @@ -404,7 +404,7 @@ func Download(context context.T, input DownloadInput) (output DownloadOutput, er
err = nil
}

if isLocalFile {
if isLocalFile == true {
err = fmt.Errorf("source is a local file, skipping download. %v", input.SourceURL)
output.LocalFilePath = input.SourceURL
output.IsUpdated = false
Expand Down Expand Up @@ -435,112 +435,14 @@ func Download(context context.T, input DownloadInput) (output DownloadOutput, er
}

isLocalFile, err = fileutil.LocalFileExist(output.LocalFilePath)
if isLocalFile {
if isLocalFile == true {
output.IsHashMatched, err = VerifyHash(log, input, output)
}
}

return
}

func setupDestinationDirectory(context context.T, input DownloadInput) (localFilePath string, err error) {
log := context.Log()

fileURL, err := url.Parse(input.SourceURL)
if err != nil {
log.Errorf("url parsing failed. %v", err)
return
}

// default destination directory is app config download root
destinationDir := input.DestinationDirectory
if destinationDir == "" {
destinationDir = appconfig.DownloadRoot
}

err = fileutil.MakeDirs(destinationDir)
if err != nil {
err = fmt.Errorf("failed to create directory=%v, err=%v", destinationDir, err)
}
urlHash := sha1.Sum([]byte(fileURL.String()))
localFilePath = filepath.Join(destinationDir, fmt.Sprintf("%x", urlHash))
return
}

func DownloadUsingHttp(context context.T, input DownloadInput) (*DownloadOutput, error) {
log := context.Log()
output := DownloadOutput{}
var err error

output.LocalFilePath, err = setupDestinationDirectory(context, input)
if err != nil {
return nil, err
}

output, err = httpDownload(context, input.SourceURL, output.LocalFilePath, "")
if err != nil {
err = fmt.Errorf("Download failed due to %v", err)
return nil, err
}

doesLocalFileExist, err := fileutil.LocalFileExist(output.LocalFilePath)
if err != nil {
err = fmt.Errorf("could not read output file %v", err)
return nil, err
}
if doesLocalFileExist {
output.IsHashMatched, err = VerifyHash(log, input, output)
if err != nil {
err = fmt.Errorf("could not verify hash - %v", err)
return nil, err
}
}
return &output, nil
}

func DownloadUsingS3(context context.T, input DownloadInput) (*DownloadOutput, error) {
log := context.Log()
output := DownloadOutput{}
var err error

output.LocalFilePath, err = setupDestinationDirectory(context, input)
if err != nil {
return nil, err
}

fileURL, err := url.Parse(input.SourceURL)
if err != nil {
err = fmt.Errorf("url parsing failed. %v", err)
return nil, err
}

amazonS3URL := s3util.ParseAmazonS3URL(log, fileURL)
if !amazonS3URL.IsBucketAndKeyPresent() {
err = fmt.Errorf("could not find bucket and key in the s3 url - %v", input.SourceURL)
return nil, err
}

output, err = s3Download(context, amazonS3URL, output.LocalFilePath, input.ExpectedBucketOwner)
if err != nil {
err = fmt.Errorf("an error occurred when attempting s3 download - %v", err)
return nil, err
}

doesLocalFileExist, err := fileutil.LocalFileExist(output.LocalFilePath)
if err != nil {
err = fmt.Errorf("could not read output file %v", err)
return nil, err
}
if doesLocalFileExist {
output.IsHashMatched, err = VerifyHash(log, input, output)
if err != nil {
err = fmt.Errorf("could not verify hash - %v", err)
return nil, err
}
}
return &output, nil
}

// VerifyHash verifies the hash of the url file as per specified hash algorithm type and its value
func VerifyHash(log log.T, input DownloadInput, output DownloadOutput) (bool, error) {
hasMatchingHash := false
Expand Down
103 changes: 4 additions & 99 deletions agent/fileutil/artifact/artifact_integ_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ package artifact

import (
"fmt"
"os"
"path/filepath"
"testing"

"github.com/aws/amazon-ssm-agent/agent/mocks/context"
"github.com/aws/amazon-ssm-agent/agent/mocks/log"
"github.com/stretchr/testify/assert"
"os"
"path/filepath"
"testing"
)

type DownloadTest struct {
Expand Down Expand Up @@ -284,101 +283,6 @@ func TestHttpHttpsDownloadArtifact(t *testing.T) {
assert.Equal(t, expectedOutput, output)
}

func TestDownloadUsingHttpInvalidUrl(t *testing.T) {
downloadInput := DownloadInput{
DestinationDirectory: ".",
SourceURL: "xyz@amazon.com",
SourceChecksums: map[string]string{
"sha256": "0c0f36c238e6c4c00f39d94dc6381930df2851db0ea2e2543d931474ddce1f8f",
},
}
output, err := DownloadUsingHttp(mockContext, downloadInput)
assert.ErrorContains(t, err, "unsupported protocol scheme")
assert.Nil(t, output)
}

func TestDownloadUsingHttp(t *testing.T) {
testFilePath := "https://amazon-ssm-us-east-1.s3.amazonaws.com/3.3.40.0/VERSION"
downloadInput := DownloadInput{
DestinationDirectory: ".",
SourceURL: testFilePath,
SourceChecksums: map[string]string{
"sha256": "0c0f36c238e6c4c00f39d94dc6381930df2851db0ea2e2543d931474ddce1f8f",
},
}
var expectedLocalPath = "b9f961391ec1ae061db3afcbed5571b2463139c8"
os.Remove(expectedLocalPath)
os.Remove(expectedLocalPath + ".etag")
expectedOutput := DownloadOutput{
expectedLocalPath,
true,
true}

output, err := DownloadUsingHttp(mockContext, downloadInput)
assert.NoError(t, err, "Failed to download %v", downloadInput)
mockLog.Infof("Download Result is %v and err:%v", output, err)

defer func() {
os.Remove(expectedLocalPath)
os.Remove(expectedLocalPath + ".etag")
}()
assert.Equal(t, expectedOutput, *output)

// now since we have downloaded the file, try to download again should result in cache hit!
expectedOutput = DownloadOutput{
expectedLocalPath,
false,
true}
output, err = DownloadUsingHttp(mockContext, downloadInput)
assert.NoError(t, err, "Failed to download %v", downloadInput)
mockLog.Infof("Download Result is %v and err:%v", output, err)
assert.Equal(t, expectedOutput, *output)
}

func TestDownloadUsingHttpMismatchingHash(t *testing.T) {
testFilePath := "https://amazon-ssm-us-east-1.s3.amazonaws.com/3.3.40.0/VERSION"
downloadInput := DownloadInput{
DestinationDirectory: ".",
SourceURL: testFilePath,
SourceChecksums: map[string]string{
"sha256": "invalidhash",
},
}
var expectedLocalPath = "b9f961391ec1ae061db3afcbed5571b2463139c8"
os.Remove(expectedLocalPath)
os.Remove(expectedLocalPath + ".etag")

output, err := DownloadUsingHttp(mockContext, downloadInput)
assert.ErrorContains(t, err, "failed to verify hash of downloadinput")
mockLog.Infof("Download Result is %v and err:%v", output, err)

defer func() {
os.Remove(expectedLocalPath)
os.Remove(expectedLocalPath + ".etag")
}()
assert.Nil(t, output)
}

func TestDownloadUsingS3InvalidUrl(t *testing.T) {
testFilePath := "https://not-an-s3-url/file.zip"
downloadInput := DownloadInput{
DestinationDirectory: ".",
SourceURL: testFilePath,
SourceChecksums: map[string]string{
"sha256": "0c0f36c238e6c4c00f39d94dc6381930df2851db0ea2e2543d931474ddce1f8f",
},
}
var expectedLocalPath = "b9f961391ec1ae061db3afcbed5571b2463139c8"
os.Remove(expectedLocalPath)
os.Remove(expectedLocalPath + ".etag")

// S3 download cannot be tesed with mock context credentials
output, err := DownloadUsingS3(mockContext, downloadInput)
assert.ErrorContains(t, err, "could not find bucket and key in the s3 url")
mockLog.Infof("Download Result is %v and err:%v", output, err)
assert.Nil(t, output)
}

func ExampleMd5HashValue() {
path := filepath.Join("testdata", "CheckMyHash.txt")
mockLog := log.NewMockLog()
Expand All @@ -391,4 +295,5 @@ func ExampleSha256HashValue() {
mockLog := log.NewMockLog()
content, _ := Sha256HashValue(mockLog, path)
fmt.Println(content)

}
11 changes: 3 additions & 8 deletions agent/plugins/configurepackage/birdwatcher/birdwatcher_dep.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,13 @@ import (

// dependency on S3 and downloaded artifacts
type networkDep interface {
DownloadUsingHttp(context context.T, input artifact.DownloadInput) (*artifact.DownloadOutput, error)
DownloadUsingS3(context context.T, input artifact.DownloadInput) (*artifact.DownloadOutput, error)
Download(context context.T, input artifact.DownloadInput) (artifact.DownloadOutput, error)
}

var Networkdep networkDep = &networkDepImp{}

type networkDepImp struct{}

func (networkDepImp) DownloadUsingHttp(context context.T, input artifact.DownloadInput) (*artifact.DownloadOutput, error) {
return artifact.DownloadUsingHttp(context, input)
}

func (networkDepImp) DownloadUsingS3(context context.T, input artifact.DownloadInput) (*artifact.DownloadOutput, error) {
return artifact.DownloadUsingS3(context, input)
func (networkDepImp) Download(context context.T, input artifact.DownloadInput) (artifact.DownloadOutput, error) {
return artifact.Download(context, input)
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ import (
// networkMock
type networkMock struct {
downloadInput artifact.DownloadInput
downloadOutput *artifact.DownloadOutput
downloadOutput artifact.DownloadOutput
downloadError error
}

func (p *networkMock) Download(log log.T, input artifact.DownloadInput) (*artifact.DownloadOutput, error) {
func (p *networkMock) Download(log log.T, input artifact.DownloadInput) (artifact.DownloadOutput, error) {
p.downloadInput = input
return p.downloadOutput, p.downloadError
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,11 @@ import (
// networkMock
type networkMock struct {
downloadInput artifact.DownloadInput
downloadOutput *artifact.DownloadOutput
downloadOutput artifact.DownloadOutput
downloadError error
}

func (p *networkMock) DownloadUsingHttp(context context.T, input artifact.DownloadInput) (*artifact.DownloadOutput, error) {
p.downloadInput = input
return p.downloadOutput, p.downloadError
}
func (p *networkMock) DownloadUsingS3(context context.T, input artifact.DownloadInput) (*artifact.DownloadOutput, error) {
func (p *networkMock) Download(context context.T, input artifact.DownloadInput) (artifact.DownloadOutput, error) {
p.downloadInput = input
return p.downloadOutput, p.downloadError
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,14 +261,8 @@ func downloadFile(ds *PackageService, tracer trace.Tracer, file *archive.File, p
}

log := tracer.CurrentTrace().Logger
var downloadOutput *artifact.DownloadOutput
var downloadErr error
if ds.packageArchive.Name() == archive.PackageArchiveBirdwatcher { // birdwatcher packages use public s3 buckets
downloadOutput, downloadErr = birdwatcher.Networkdep.DownloadUsingS3(ds.Context, downloadInput)
} else { // modern packages use presigned urls
downloadOutput, downloadErr = birdwatcher.Networkdep.DownloadUsingHttp(ds.Context, downloadInput)
}
if downloadErr != nil || downloadOutput == nil || downloadOutput.LocalFilePath == "" {
downloadOutput, downloadErr := birdwatcher.Networkdep.Download(ds.Context, downloadInput)
if downloadErr != nil || downloadOutput.LocalFilePath == "" {
errMessage := fmt.Sprintf("failed to download installation package reliably, %v", downloadInput.SourceURL)
if downloadErr != nil {
errMessage = fmt.Sprintf("%v, %v", errMessage, downloadErr.Error())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ func TestDownloadFile(t *testing.T) {
{
"working file download",
networkMock{
downloadOutput: &artifact.DownloadOutput{
downloadOutput: artifact.DownloadOutput{
LocalFilePath: "agent.zip",
},
},
Expand All @@ -1011,7 +1011,7 @@ func TestDownloadFile(t *testing.T) {
{
"empty local file location",
networkMock{
downloadOutput: &artifact.DownloadOutput{
downloadOutput: artifact.DownloadOutput{
LocalFilePath: "",
},
},
Expand Down Expand Up @@ -1102,7 +1102,7 @@ func TestDownloadFileFromDocumentArchive(t *testing.T) {
{
"working file download",
networkMock{
downloadOutput: &artifact.DownloadOutput{
downloadOutput: artifact.DownloadOutput{
LocalFilePath: "agent.zip",
},
},
Expand All @@ -1123,7 +1123,7 @@ func TestDownloadFileFromDocumentArchive(t *testing.T) {
{
"empty local file location",
networkMock{
downloadOutput: &artifact.DownloadOutput{
downloadOutput: artifact.DownloadOutput{
LocalFilePath: "",
},
},
Expand Down Expand Up @@ -1227,7 +1227,7 @@ func TestDownloadArtifact(t *testing.T) {
"packageName",
"1234",
networkMock{
downloadOutput: &artifact.DownloadOutput{
downloadOutput: artifact.DownloadOutput{
LocalFilePath: "agent.zip",
},
},
Expand Down

0 comments on commit c78facf

Please sign in to comment.