Skip to content

Commit

Permalink
add tests for step_unpack
Browse files Browse the repository at this point in the history
  • Loading branch information
pchila committed Feb 2, 2024
1 parent cd2e9de commit bf2e849
Show file tree
Hide file tree
Showing 2 changed files with 396 additions and 60 deletions.
181 changes: 140 additions & 41 deletions internal/pkg/agent/application/upgrade/step_unpack.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"io"
"io/fs"
"os"
"path"
"path/filepath"
"runtime"
"strings"
Expand All @@ -24,10 +25,12 @@ import (
"github.com/elastic/elastic-agent/pkg/core/logger"
)

// UnpackResult contains the location and hash of the unpacked agent files
type UnpackResult struct {
// Hash contains the unpacked agent commit hash, limited to a length of 6 for backward compatibility
Hash string `json:"hash" yaml:"hash"`
// TODO add mapped path of executable
// agentExecutable string
// VersionedHome indicates the path (forward slash separated) where to find the unpacked agent files
// The value depends on the mappings specified in manifest.yaml, if no manifest is found it assumes the legacy data/elastic-agent-<hash> format
VersionedHome string `json:"versioned-home" yaml:"versioned-home"`
}

Expand Down Expand Up @@ -64,7 +67,9 @@ func unzip(log *logger.Logger, archivePath, dataDir string) (UnpackResult, error

pm := pathMapper{}
versionedHome := ""
manifestFile, err := r.Open("manifest.yaml")

// Load manifest, the use of path.Join is intentional since in .zip file paths use slash ('/') as separator
manifestFile, err := r.Open(path.Join(fileNamePrefix, "manifest.yaml"))
if err != nil && !errors.Is(err, fs.ErrNotExist) {
// we got a real error looking up for the manifest
return UnpackResult{}, fmt.Errorf("looking up manifest in package: %w", err)
Expand All @@ -77,7 +82,28 @@ func unzip(log *logger.Logger, archivePath, dataDir string) (UnpackResult, error
return UnpackResult{}, fmt.Errorf("parsing package manifest: %w", err)
}
pm.mappings = manifest.Package.PathMappings
versionedHome = filepath.Clean(pm.Map(manifest.Package.VersionedHome))
versionedHome = path.Clean(pm.Map(manifest.Package.VersionedHome))
}

// Load hash, the use of path.Join is intentional since in .zip file paths use slash ('/') as separator
hashFile, err := r.Open(path.Join(fileNamePrefix, agentCommitFile))
if err != nil {
// we got a real error looking up for the manifest
return UnpackResult{}, fmt.Errorf("looking up %q in package: %w", agentCommitFile, err)
}
defer hashFile.Close()

hashBytes, err := io.ReadAll(hashFile)
if err != nil {
return UnpackResult{}, fmt.Errorf("reading elastic-agent hash file content: %w", err)
}
if len(hashBytes) < hashLen {
return UnpackResult{}, fmt.Errorf("elastic-agent hash %q is too short (minimum %d)", string(hashBytes), hashLen)
}
hash = string(hashBytes[:hashLen])
if versionedHome == "" {
// if at this point we didn't load the manifest et the versioned to the backup value
versionedHome = createVersionedHomeFromHash(hash)
}

unpackFile := func(f *zip.File) (err error) {
Expand All @@ -91,34 +117,46 @@ func unzip(log *logger.Logger, archivePath, dataDir string) (UnpackResult, error
}
}()

//get hash
fileName := strings.TrimPrefix(f.Name, fileNamePrefix)
if fileName == agentCommitFile {
hashBytes, err := io.ReadAll(rc)
if err != nil || len(hashBytes) < hashLen {
return err
}

hash = string(hashBytes[:hashLen])
// we already loaded the hash, skip this one
return nil
}

mappedPackagePath := pm.Map(fileName)

// skip everything outside data/
if !strings.HasPrefix(fileName, "data/") {
if !strings.HasPrefix(mappedPackagePath, "data/") {
return nil
}

path := filepath.Join(dataDir, strings.TrimPrefix(fileName, "data/"))
dstPath := strings.TrimPrefix(mappedPackagePath, "data/")
dstPath = filepath.Join(dataDir, dstPath)

if f.FileInfo().IsDir() {
log.Debugw("Unpacking directory", "archive", "zip", "file.path", path)
log.Debugw("Unpacking directory", "archive", "zip", "file.path", dstPath)
// remove any world permissions from the directory
_ = os.MkdirAll(path, f.Mode()&0770)
_, err = os.Stat(dstPath)
if errors.Is(err, fs.ErrNotExist) {
if err := os.MkdirAll(dstPath, f.Mode().Perm()&0770); err != nil {
return fmt.Errorf("creating directory %q: %w", dstPath, err)
}
} else if err != nil {
return fmt.Errorf("stat() directory %q: %w", dstPath, err)
} else {
// set the appropriate permissions
err = os.Chmod(dstPath, f.Mode().Perm()&0o770)
if err != nil {
return fmt.Errorf("setting permissions %O for directory %q: %w", f.Mode().Perm()&0o770, dstPath, err)
}
}

_ = os.MkdirAll(dstPath, f.Mode()&0770)
} else {
log.Debugw("Unpacking file", "archive", "zip", "file.path", path)
log.Debugw("Unpacking file", "archive", "zip", "file.path", dstPath)
// remove any world permissions from the directory/file
_ = os.MkdirAll(filepath.Dir(path), f.Mode()&0770)
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()&0770)
_ = os.MkdirAll(filepath.Dir(dstPath), 0770)
f, err := os.OpenFile(dstPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode()&0770)
if err != nil {
return err
}
Expand Down Expand Up @@ -158,17 +196,20 @@ func unzip(log *logger.Logger, archivePath, dataDir string) (UnpackResult, error

func untar(log *logger.Logger, version string, archivePath, dataDir string) (UnpackResult, error) {

var versionedHome string
var rootDir string
var hash string

// Look up manifest in the archive and prepare path mappings, if any
pm := pathMapper{}

// quickly open the archive and look up manifest.yaml file
manifestReader, err := getManifestFromTar(archivePath)

fileContents, err := getFilesContentFromTar(archivePath, "manifest.yaml", agentCommitFile)
if err != nil {
return UnpackResult{}, fmt.Errorf("looking for package manifest: %w", err)
return UnpackResult{}, fmt.Errorf("looking for package metadata files: %w", err)
}

versionedHome := ""
manifestReader := fileContents["manifest.yaml"]
if manifestReader != nil {
manifest, err := v1.ParseManifest(manifestReader)
if err != nil {
Expand All @@ -177,7 +218,24 @@ func untar(log *logger.Logger, version string, archivePath, dataDir string) (Unp

// set the path mappings
pm.mappings = manifest.Package.PathMappings
versionedHome = filepath.Clean(pm.Map(manifest.Package.VersionedHome))
versionedHome = path.Clean(pm.Map(manifest.Package.VersionedHome))
}

if agentCommitReader, ok := fileContents[agentCommitFile]; ok {
commitBytes, err := io.ReadAll(agentCommitReader)
if err != nil {
return UnpackResult{}, fmt.Errorf("reading agent commit hash file: %w", err)
}
if len(commitBytes) < hashLen {
return UnpackResult{}, fmt.Errorf("hash %q is shorter than minimum length %d", string(commitBytes), hashLen)
}

agentCommitHash := string(commitBytes)
hash = agentCommitHash[:hashLen]
if versionedHome == "" {
// set default value of versioned home if it wasn't set by reading the manifest
versionedHome = createVersionedHomeFromHash(agentCommitHash)
}
}

r, err := os.Open(archivePath)
Expand All @@ -192,8 +250,7 @@ func untar(log *logger.Logger, version string, archivePath, dataDir string) (Unp
}

tr := tar.NewReader(zr)
var rootDir string
var hash string

fileNamePrefix := getFileNamePrefix(archivePath)

// go through all the content of a tar archive
Expand All @@ -213,16 +270,9 @@ func untar(log *logger.Logger, version string, archivePath, dataDir string) (Unp
return UnpackResult{}, errors.New("tar contained invalid filename: %q", f.Name, errors.TypeFilesystem, errors.M(errors.MetaKeyPath, f.Name))
}

//get hash
fileName := strings.TrimPrefix(f.Name, fileNamePrefix)

if fileName == agentCommitFile {
hashBytes, err := io.ReadAll(tr)
if err != nil || len(hashBytes) < hashLen {
return UnpackResult{}, err
}

hash = string(hashBytes[:hashLen])
continue
}

Expand Down Expand Up @@ -324,23 +374,67 @@ func (pm pathMapper) Map(path string) string {
return path
}

func getManifestFromTar(archivePath string) (io.Reader, error) {
type tarCloser struct {
tarFile *os.File
gzipReader *gzip.Reader
}

func (tc *tarCloser) Close() error {
var err error
if tc.gzipReader != nil {
err = multierror.Append(err, tc.gzipReader.Close())
}
// prevent double Close() call to fzip reader
tc.gzipReader = nil
if tc.tarFile != nil {
err = multierror.Append(err, tc.tarFile.Close())
}
// prevent double Close() call the underlying file
tc.tarFile = nil
return err
}

// openTar is a convenience function to open a tar.gz file.
// It returns a *tar.Reader, an io.Closer implementation to be called to release resources and an error
// In case of errors the *tar.Reader will be nil, but the io.Closer is always returned and must be called also in case
// of errors to close the underlying readers.
func openTar(archivePath string) (*tar.Reader, io.Closer, error) {
tc := new(tarCloser)
r, err := os.Open(archivePath)
if err != nil {
return nil, fmt.Errorf("opening package %s: %w", archivePath, err)
return nil, tc, fmt.Errorf("opening package %s: %w", archivePath, err)
}
defer r.Close()
tc.tarFile = r

zr, err := gzip.NewReader(r)
if err != nil {
return nil, fmt.Errorf("package %s does not seem to have a valid gzip compression: %w", archivePath, err)
return nil, tc, fmt.Errorf("package %s does not seem to have a valid gzip compression: %w", archivePath, err)
}
tc.gzipReader = zr

return tar.NewReader(zr), tc, nil
}

// getFilesContentFromTar is a small utility function which will load in memory the contents of a list of files from the tar archive.
// It's meant to be used to load package information/metadata stored in small files within the .tar.gz archive
func getFilesContentFromTar(archivePath string, files ...string) (map[string]io.Reader, error) {
tr, tc, err := openTar(archivePath)
if err != nil {
return nil, fmt.Errorf("opening tar.gz package %s: %w", archivePath, err)
}
defer tc.Close()

tr := tar.NewReader(zr)
prefix := getFileNamePrefix(archivePath)

result := make(map[string]io.Reader, len(files))
fileset := make(map[string]struct{}, len(files))
// load the fileset with the names we are looking for
for _, fName := range files {
fileset[fName] = struct{}{}
}

// go through all the content of a tar archive
// if manifest.yaml is found, read the contents and return a bytereader, nil otherwise ,
// if one of the listed files is found, read the contents and set a byte reader into the result map
for {
f, err := tr.Next()
if errors.Is(err, io.EOF) {
Expand All @@ -352,17 +446,22 @@ func getManifestFromTar(archivePath string) (io.Reader, error) {
}

fileName := strings.TrimPrefix(f.Name, prefix)
if fileName == "manifest.yaml" {
if _, ok := fileset[fileName]; ok {
// it's one of the files we are looking for, retrieve the content and set a reader into the result map
manifestBytes, err := io.ReadAll(tr)
if err != nil {
return nil, fmt.Errorf("reading manifest bytes: %w", err)
}

reader := bytes.NewReader(manifestBytes)
return reader, nil
result[fileName] = reader
}

}

return nil, nil
return result, nil
}

func createVersionedHomeFromHash(hash string) string {
return fmt.Sprintf("data/elastic-agent-%s", hash[:hashLen])
}
Loading

0 comments on commit bf2e849

Please sign in to comment.