Skip to content

Commit

Permalink
Fix mTLS reload on failed credential requests (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
patricksanders authored Mar 12, 2021
1 parent 1afbcd0 commit 4753ff5
Show file tree
Hide file tree
Showing 10 changed files with 160 additions and 221 deletions.
6 changes: 3 additions & 3 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ func (cc *CredentialCache) Get(role string, assumeChain []string) (*creds.Refres
return nil, errors.NoCredentialsFoundInCache
}

func (cc *CredentialCache) GetOrSet(client *creds.Client, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) {
func (cc *CredentialCache) GetOrSet(client creds.HTTPClient, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) {
c, err := cc.Get(role, assumeChain)
if err == nil {
return c, nil
Expand All @@ -74,7 +74,7 @@ func (cc *CredentialCache) GetOrSet(client *creds.Client, role, region string, a
return c, nil
}

func (cc *CredentialCache) SetDefault(client *creds.Client, role, region string, assumeChain []string) error {
func (cc *CredentialCache) SetDefault(client creds.HTTPClient, role, region string, assumeChain []string) error {
_, err := cc.set(client, role, region, assumeChain)
if err != nil {
return err
Expand Down Expand Up @@ -119,7 +119,7 @@ func (cc *CredentialCache) get(slug string) (*creds.RefreshableProvider, bool) {
return c, ok
}

func (cc *CredentialCache) set(client *creds.Client, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) {
func (cc *CredentialCache) set(client creds.HTTPClient, role, region string, assumeChain []string) (*creds.RefreshableProvider, error) {
c, err := creds.NewRefreshableProvider(client, role, region, assumeChain, false)
if err != nil {
return nil, fmt.Errorf("could not generate creds: %w", err)
Expand Down
90 changes: 0 additions & 90 deletions cmd/metadata.go

This file was deleted.

8 changes: 2 additions & 6 deletions cmd/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,10 @@ func (p *program) run() {
exitCode := 0
args := viper.GetStringSlice("service.args")
switch command := viper.GetString("service.command"); command {
case "serve":
case "ecs_credential_provider":
err := runWeepServer(nil, args)
if err != nil {
log.Error(err)
exitCode = 1
}
case "metadata":
err := runMetadata(nil, args)
err := runWeepServer(nil, args)
if err != nil {
log.Error(err)
exitCode = 1
Expand Down
25 changes: 4 additions & 21 deletions creds/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,6 @@ import (
"github.com/aws/aws-sdk-go/service/sts"
)

// getAwsCredentials uses the provided Client to request credentials from ConsoleMe.
func getAwsCredentials(client *Client, role string, ipRestrict bool) (string, string, string, string, Time, error) {
tempCreds, err := client.GetRoleCredentials(role, ipRestrict)
if err != nil {
return "", "", "", "", Time{}, err
}

return tempCreds.AccessKeyId, tempCreds.SecretAccessKey, tempCreds.SessionToken, tempCreds.RoleArn, tempCreds.Expiration, nil
}

// getSessionName returns the AWS session name, or defaults to weep if we can't find one.
func getSessionName(session *sts.STS) string {
identity, err := session.GetCallerIdentity(&sts.GetCallerIdentityInput{})
Expand Down Expand Up @@ -82,27 +72,20 @@ func getAssumeRoleCredentials(id, secret, token, roleArn string) (string, string
// GetCredentialsC uses the provided Client to request credentials from ConsoleMe then
// follows the provided chain of roles to assume. Roles are assumed in the order in which
// they appear in the assumeRole slice.
func GetCredentialsC(client *Client, role string, ipRestrict bool, assumeRole []string) (*AwsCredentials, error) {
id, secret, token, roleArn, expiration, err := getAwsCredentials(client, role, ipRestrict)
func GetCredentialsC(client HTTPClient, role string, ipRestrict bool, assumeRole []string) (*AwsCredentials, error) {
resp, err := client.GetRoleCredentials(role, ipRestrict)
if err != nil {
return nil, err
}

for _, assumeRoleArn := range assumeRole {
id, secret, token, err = getAssumeRoleCredentials(id, secret, token, assumeRoleArn)
resp.AccessKeyId, resp.SecretAccessKey, resp.SessionToken, err = getAssumeRoleCredentials(resp.AccessKeyId, resp.SecretAccessKey, resp.SessionToken, assumeRoleArn)
if err != nil {
return nil, fmt.Errorf("role assumption failed for %s: %s", assumeRoleArn, err)
}
}

finalCreds := &AwsCredentials{
AccessKeyId: id,
SecretAccessKey: secret,
SessionToken: token,
Expiration: expiration,
RoleArn: roleArn,
}
return finalCreds, nil
return resp, nil
}

// GetCredentials requests credentials from ConsoleMe then follows the provided chain of roles to
Expand Down
Loading

0 comments on commit 4753ff5

Please sign in to comment.