Skip to content

Commit

Permalink
Use RoleArn for RefreshableProvider requests (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
patricksanders authored Mar 25, 2021
1 parent 1d4fb4b commit 27dcdd3
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 53 deletions.
10 changes: 5 additions & 5 deletions cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ func getCacheSlug(role string, assume []string) string {
return strings.Join(elements, "/")
}

func (cc *CredentialCache) Get(role string, assumeChain []string) (*creds.RefreshableProvider, error) {
func (cc *CredentialCache) Get(searchString string, assumeChain []string) (*creds.RefreshableProvider, error) {
log.WithFields(logrus.Fields{
"role": role,
"assumeChain": assumeChain,
"searchString": searchString,
"assumeChain": assumeChain,
}).Info("retrieving credentials")
c, ok := cc.get(getCacheSlug(role, assumeChain))
c, ok := cc.get(getCacheSlug(searchString, assumeChain))
if ok {
log.Debugf("found credentials for %s in cache", role)
log.Debugf("found credentials for %s in cache", searchString)
return c, nil
}
return nil, errors.NoCredentialsFoundInCache
Expand Down
72 changes: 36 additions & 36 deletions cache/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,44 +44,44 @@ func TestCredentialCache_Get(t *testing.T) {
{
Description: "role in cache",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a": {RoleName: "a"},
},
Role: "a",
AssumeChain: []string{},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a"},
},
{
Description: "role in cache with assume",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a/b/c": {Role: "a/b/c"},
"a": {RoleName: "a"},
"a/b/c": {RoleName: "a/b/c"},
},
Role: "a",
AssumeChain: []string{},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a"},
},
{
Description: "assume role in cache",
CacheContents: map[string]*creds.RefreshableProvider{
"a/b/c": {Role: "a/b/c"},
"a/b/c": {RoleName: "a/b/c"},
},
Role: "a",
AssumeChain: []string{"b", "c"},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a/b/c"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a/b/c"},
},
{
Description: "assume role in cache with non-assume",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a/b/c": {Role: "a/b/c"},
"a": {RoleName: "a"},
"a/b/c": {RoleName: "a/b/c"},
},
Role: "a",
AssumeChain: []string{"b", "c"},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a/b/c"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a/b/c"},
},
}

Expand All @@ -95,7 +95,7 @@ func TestCredentialCache_Get(t *testing.T) {
t.Errorf("%s failed: expected %v error, got %v", tc.Description, tc.ExpectedError, actualError)
continue
}
if actualResult != nil && actualResult.Role != tc.ExpectedResult.Role {
if actualResult != nil && actualResult.RoleArn != tc.ExpectedResult.RoleArn {
t.Errorf("%s failed: expected %v result, got %v", tc.Description, tc.ExpectedResult, actualResult)
}
}
Expand All @@ -120,16 +120,16 @@ func TestCredentialCache_GetDefault(t *testing.T) {
Description: "default role in cache",
DefaultRole: "a",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a": {RoleName: "a"},
},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a"},
},
{
Description: "no default role set",
DefaultRole: "",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a": {RoleName: "a"},
},
ExpectedError: errors.NoDefaultRoleSet,
ExpectedResult: nil,
Expand All @@ -138,30 +138,30 @@ func TestCredentialCache_GetDefault(t *testing.T) {
Description: "default role in cache with assume",
DefaultRole: "a",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a/b/c": {Role: "a/b/c"},
"a": {RoleName: "a"},
"a/b/c": {RoleName: "a/b/c"},
},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a"},
},
{
Description: "default assume role in cache",
DefaultRole: "a/b/c",
CacheContents: map[string]*creds.RefreshableProvider{
"a/b/c": {Role: "a/b/c"},
"a/b/c": {RoleName: "a/b/c"},
},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a/b/c"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a/b/c"},
},
{
Description: "default assume role in cache with non-assume",
DefaultRole: "a/b/c",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a/b/c": {Role: "a/b/c"},
"a": {RoleName: "a"},
"a/b/c": {RoleName: "a/b/c"},
},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a/b/c"},
ExpectedResult: &creds.RefreshableProvider{RoleName: "a/b/c"},
},
}

Expand All @@ -176,7 +176,7 @@ func TestCredentialCache_GetDefault(t *testing.T) {
t.Errorf("%s failed: expected %v error, got %v", tc.Description, tc.ExpectedError, actualError)
continue
}
if actualResult != nil && actualResult.Role != tc.ExpectedResult.Role {
if actualResult != nil && actualResult.RoleArn != tc.ExpectedResult.RoleArn {
t.Errorf("%s failed: expected %v result, got %v", tc.Description, tc.ExpectedResult, actualResult)
}
}
Expand Down Expand Up @@ -296,7 +296,7 @@ func TestCredentialCache_GetOrSet(t *testing.T) {
cases := []struct {
CacheContents map[string]*creds.RefreshableProvider
ClientResponse interface{}
Role string
SearchString string
AssumeChain []string
Region string
Description string
Expand All @@ -306,30 +306,30 @@ func TestCredentialCache_GetOrSet(t *testing.T) {
{
Description: "role not in cache",
CacheContents: make(map[string]*creds.RefreshableProvider),
Role: "a",
SearchString: "a",
AssumeChain: []string{},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleArn: "arn:aws:iam::012345678901:role/coolRole1"},
},
{
Description: "role not in cache with assume",
CacheContents: map[string]*creds.RefreshableProvider{
"a/b/c": {Role: "a/b/c"},
"a/b/c": {RoleName: "a/b/c"},
},
Role: "a",
SearchString: "a",
AssumeChain: []string{},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleArn: "arn:aws:iam::012345678901:role/coolRole2"},
},
{
Description: "role already in cache",
CacheContents: map[string]*creds.RefreshableProvider{
"a": {Role: "a"},
"a": {RoleArn: "arn:aws:iam::012345678901:role/coolRole3"},
},
Role: "a",
SearchString: "a",
AssumeChain: []string{},
ExpectedError: nil,
ExpectedResult: &creds.RefreshableProvider{Role: "a"},
ExpectedResult: &creds.RefreshableProvider{RoleArn: "arn:aws:iam::012345678901:role/coolRole3"},
},
}

Expand All @@ -344,14 +344,14 @@ func TestCredentialCache_GetOrSet(t *testing.T) {
SecretAccessKey: "b",
SessionToken: "c",
Expiration: creds.Time(time.Unix(1, 0)),
RoleArn: "e",
RoleArn: tc.ExpectedResult.RoleArn,
},
})
if err != nil {
t.Errorf("test setup failure: %e", err)
continue
}
result, actualError := testCache.GetOrSet(client, tc.Role, tc.Region, tc.AssumeChain)
result, actualError := testCache.GetOrSet(client, tc.SearchString, tc.Region, tc.AssumeChain)
if actualError != tc.ExpectedError {
t.Errorf("%s failed: expected %v error, got %v", tc.Description, tc.ExpectedError, actualError)
continue
Expand All @@ -360,8 +360,8 @@ func TestCredentialCache_GetOrSet(t *testing.T) {
t.Errorf("%s failed: got nil result, expected %v", tc.Description, tc.ExpectedResult)
continue
}
if result != nil && result.Role != tc.ExpectedResult.Role {
t.Errorf("%s failed: expected role %v, got %v", tc.Description, tc.ExpectedResult.Role, result.Role)
if result != nil && result.RoleArn != tc.ExpectedResult.RoleArn {
t.Errorf("%s failed: expected role %v, got %v", tc.Description, tc.ExpectedResult.RoleArn, result.RoleArn)
continue
}
}
Expand Down
11 changes: 6 additions & 5 deletions creds/refreshable.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func NewRefreshableProvider(client HTTPClient, role, region string, assumeChain
splitRole := strings.Split(role, "/")
roleName := splitRole[len(splitRole)-1]
rp := &RefreshableProvider{
Role: roleName,
RoleName: roleName,
RoleArn: role,
Region: region,
NoIpRestrict: noIpRestrict,
Expand Down Expand Up @@ -66,7 +66,7 @@ func (rp *RefreshableProvider) AutoRefresh() {
}

func (rp *RefreshableProvider) checkAndRefresh(threshold int) (bool, error) {
log.Debugf("checking credentials for %s", rp.Role)
log.Debugf("checking credentials for %s", rp.RoleName)
// refresh creds if we're within 10 minutes of them expiring
diff := time.Duration(threshold*-1) * time.Minute
thresh := rp.Expiration.Add(diff)
Expand All @@ -80,14 +80,14 @@ func (rp *RefreshableProvider) checkAndRefresh(threshold int) (bool, error) {
}

func (rp *RefreshableProvider) refresh() error {
log.Debugf("refreshing credentials for %s", rp.Role)
log.Debugf("refreshing credentials for %s", rp.RoleArn)
var err error
var newCreds *AwsCredentials

rp.Lock()
defer rp.Unlock()

newCreds, err = GetCredentialsC(rp.client, rp.Role, rp.NoIpRestrict, rp.AssumeChain)
newCreds, err = GetCredentialsC(rp.client, rp.RoleArn, rp.NoIpRestrict, rp.AssumeChain)
if err != nil {
if err == errors.MutualTLSCertNeedsRefreshError {
log.Error(err)
Expand All @@ -106,11 +106,12 @@ func (rp *RefreshableProvider) refresh() error {
rp.value.SecretAccessKey = newCreds.SecretAccessKey
rp.value.AccessKeyID = newCreds.AccessKeyId
rp.LastRefreshed = Time(time.Now())
// We favor the role ARN from ConsoleMe over the one from the user, which could just be a search string.
rp.RoleArn = newCreds.RoleArn
if rp.value.ProviderName == "" {
rp.value.ProviderName = "WeepRefreshableProvider"
}
log.Debugf("successfully refreshed credentials for %s", rp.Role)
log.Debugf("successfully refreshed credentials for %s", rp.RoleArn)
return nil
}

Expand Down
10 changes: 5 additions & 5 deletions creds/refreshable_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func TestNewRefreshableProvider(t *testing.T) {
Expiration: testExpiration,
LastRefreshed: Time{},
Region: testRegion,
Role: testRole,
RoleName: testRole,
RoleArn: testRoleArn,
NoIpRestrict: false,
AssumeChain: make([]string, 0),
Expand All @@ -89,7 +89,7 @@ func TestNewRefreshableProvider(t *testing.T) {
Expiration: testExpiration,
LastRefreshed: Time{},
Region: testRegion,
Role: testRole,
RoleName: testRole,
RoleArn: testRoleArn,
NoIpRestrict: true,
AssumeChain: make([]string, 0),
Expand Down Expand Up @@ -120,8 +120,8 @@ func TestNewRefreshableProvider(t *testing.T) {
t.Errorf("%s failed: got %v region, expected %v", tc.Description, actualResult.Region, tc.ExpectedResult.Region)
continue
}
if actualResult != nil && actualResult.Role != tc.ExpectedResult.Role {
t.Errorf("%s failed: got %v role, expected %v", tc.Description, actualResult.Role, tc.ExpectedResult.Role)
if actualResult != nil && actualResult.RoleName != tc.ExpectedResult.RoleName {
t.Errorf("%s failed: got %v role, expected %v", tc.Description, actualResult.RoleName, tc.ExpectedResult.RoleName)
continue
}
if actualResult != nil && actualResult.RoleArn != tc.ExpectedResult.RoleArn {
Expand Down Expand Up @@ -194,7 +194,7 @@ func TestRefreshableProvider_refresh(t *testing.T) {
retries: tc.Retries,
retryDelay: tc.RetryDelay,
Region: tc.Region,
Role: tc.Role,
RoleName: tc.Role,
RoleArn: tc.RoleArn,
NoIpRestrict: tc.NoIpRestrict,
AssumeChain: tc.AssumeChain,
Expand Down
2 changes: 1 addition & 1 deletion creds/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ type RefreshableProvider struct {
Expiration Time
LastRefreshed Time
Region string
Role string
RoleName string
RoleArn string
NoIpRestrict bool
AssumeChain []string
Expand Down
2 changes: 1 addition & 1 deletion server/credentialsHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func RoleHandler(w http.ResponseWriter, r *http.Request) {
util.WriteError(w, "error", 500)
return
}
if _, err := w.Write([]byte(defaultRole.Role)); err != nil {
if _, err := w.Write([]byte(defaultRole.RoleName)); err != nil {
log.Errorf("failed to write response: %v", err)
}
}
Expand Down

0 comments on commit 27dcdd3

Please sign in to comment.