diff --git a/cache/cache.go b/cache/cache.go index 3abf836..709a66d 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -48,14 +48,14 @@ func getCacheSlug(role string, assume []string) string { return strings.Join(elements, "/") } -func (cc *CredentialCache) Get(role string, assumeChain []string) (*creds.RefreshableProvider, error) { +func (cc *CredentialCache) Get(searchString string, assumeChain []string) (*creds.RefreshableProvider, error) { log.WithFields(logrus.Fields{ - "role": role, - "assumeChain": assumeChain, + "searchString": searchString, + "assumeChain": assumeChain, }).Info("retrieving credentials") - c, ok := cc.get(getCacheSlug(role, assumeChain)) + c, ok := cc.get(getCacheSlug(searchString, assumeChain)) if ok { - log.Debugf("found credentials for %s in cache", role) + log.Debugf("found credentials for %s in cache", searchString) return c, nil } return nil, errors.NoCredentialsFoundInCache diff --git a/cache/cache_test.go b/cache/cache_test.go index a880d9b..827433e 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -44,44 +44,44 @@ func TestCredentialCache_Get(t *testing.T) { { Description: "role in cache", CacheContents: map[string]*creds.RefreshableProvider{ - "a": {Role: "a"}, + "a": {RoleName: "a"}, }, Role: "a", AssumeChain: []string{}, ExpectedError: nil, - ExpectedResult: &creds.RefreshableProvider{Role: "a"}, + ExpectedResult: &creds.RefreshableProvider{RoleName: "a"}, }, { Description: "role in cache with assume", CacheContents: map[string]*creds.RefreshableProvider{ - "a": {Role: "a"}, - "a/b/c": {Role: "a/b/c"}, + "a": {RoleName: "a"}, + "a/b/c": {RoleName: "a/b/c"}, }, Role: "a", AssumeChain: []string{}, ExpectedError: nil, - ExpectedResult: &creds.RefreshableProvider{Role: "a"}, + ExpectedResult: &creds.RefreshableProvider{RoleName: "a"}, }, { Description: "assume role in cache", CacheContents: map[string]*creds.RefreshableProvider{ - "a/b/c": {Role: "a/b/c"}, + "a/b/c": {RoleName: "a/b/c"}, }, Role: "a", AssumeChain: []string{"b", "c"}, ExpectedError: nil, - ExpectedResult: &creds.RefreshableProvider{Role: "a/b/c"}, + ExpectedResult: &creds.RefreshableProvider{RoleName: "a/b/c"}, }, { Description: "assume role in cache with non-assume", CacheContents: map[string]*creds.RefreshableProvider{ - "a": {Role: "a"}, - "a/b/c": {Role: "a/b/c"}, + "a": {RoleName: "a"}, + "a/b/c": {RoleName: "a/b/c"}, }, Role: "a", AssumeChain: []string{"b", "c"}, ExpectedError: nil, - ExpectedResult: &creds.RefreshableProvider{Role: "a/b/c"}, + ExpectedResult: &creds.RefreshableProvider{RoleName: "a/b/c"}, }, } @@ -95,7 +95,7 @@ func TestCredentialCache_Get(t *testing.T) { t.Errorf("%s failed: expected %v error, got %v", tc.Description, tc.ExpectedError, actualError) continue } - if actualResult != nil && actualResult.Role != tc.ExpectedResult.Role { + if actualResult != nil && actualResult.RoleArn != tc.ExpectedResult.RoleArn { t.Errorf("%s failed: expected %v result, got %v", tc.Description, tc.ExpectedResult, actualResult) } } @@ -120,16 +120,16 @@ func TestCredentialCache_GetDefault(t *testing.T) { Description: "default role in cache", DefaultRole: "a", CacheContents: map[string]*creds.RefreshableProvider{ - "a": {Role: "a"}, + "a": {RoleName: "a"}, }, ExpectedError: nil, - ExpectedResult: &creds.RefreshableProvider{Role: "a"}, + ExpectedResult: &creds.RefreshableProvider{RoleName: "a"}, }, { Description: "no default role set", DefaultRole: "", CacheContents: map[string]*creds.RefreshableProvider{ - "a": {Role: "a"}, + "a": {RoleName: "a"}, }, ExpectedError: errors.NoDefaultRoleSet, ExpectedResult: nil, @@ -138,30 +138,30 @@ func TestCredentialCache_GetDefault(t *testing.T) { Description: "default role in cache with assume", DefaultRole: "a", CacheContents: map[string]*creds.RefreshableProvider{ - "a": {Role: "a"}, - "a/b/c": {Role: "a/b/c"}, + "a": {RoleName: "a"}, + "a/b/c": {RoleName: "a/b/c"}, }, ExpectedError: nil, - ExpectedResult: &creds.RefreshableProvider{Role: "a"}, + ExpectedResult: &creds.RefreshableProvider{RoleName: "a"}, }, { Description: "default assume role in cache", DefaultRole: "a/b/c", CacheContents: map[string]*creds.RefreshableProvider{ - "a/b/c": {Role: "a/b/c"}, + "a/b/c": {RoleName: "a/b/c"}, }, ExpectedError: nil, - ExpectedResult: &creds.RefreshableProvider{Role: "a/b/c"}, + ExpectedResult: &creds.RefreshableProvider{RoleName: "a/b/c"}, }, { Description: "default assume role in cache with non-assume", DefaultRole: "a/b/c", CacheContents: map[string]*creds.RefreshableProvider{ - "a": {Role: "a"}, - "a/b/c": {Role: "a/b/c"}, + "a": {RoleName: "a"}, + "a/b/c": {RoleName: "a/b/c"}, }, ExpectedError: nil, - ExpectedResult: &creds.RefreshableProvider{Role: "a/b/c"}, + ExpectedResult: &creds.RefreshableProvider{RoleName: "a/b/c"}, }, } @@ -176,7 +176,7 @@ func TestCredentialCache_GetDefault(t *testing.T) { t.Errorf("%s failed: expected %v error, got %v", tc.Description, tc.ExpectedError, actualError) continue } - if actualResult != nil && actualResult.Role != tc.ExpectedResult.Role { + if actualResult != nil && actualResult.RoleArn != tc.ExpectedResult.RoleArn { t.Errorf("%s failed: expected %v result, got %v", tc.Description, tc.ExpectedResult, actualResult) } } @@ -296,7 +296,7 @@ func TestCredentialCache_GetOrSet(t *testing.T) { cases := []struct { CacheContents map[string]*creds.RefreshableProvider ClientResponse interface{} - Role string + SearchString string AssumeChain []string Region string Description string @@ -306,30 +306,30 @@ func TestCredentialCache_GetOrSet(t *testing.T) { { Description: "role not in cache", CacheContents: make(map[string]*creds.RefreshableProvider), - Role: "a", + SearchString: "a", AssumeChain: []string{}, ExpectedError: nil, - ExpectedResult: &creds.RefreshableProvider{Role: "a"}, + ExpectedResult: &creds.RefreshableProvider{RoleArn: "arn:aws:iam::012345678901:role/coolRole1"}, }, { Description: "role not in cache with assume", CacheContents: map[string]*creds.RefreshableProvider{ - "a/b/c": {Role: "a/b/c"}, + "a/b/c": {RoleName: "a/b/c"}, }, - Role: "a", + SearchString: "a", AssumeChain: []string{}, ExpectedError: nil, - ExpectedResult: &creds.RefreshableProvider{Role: "a"}, + ExpectedResult: &creds.RefreshableProvider{RoleArn: "arn:aws:iam::012345678901:role/coolRole2"}, }, { Description: "role already in cache", CacheContents: map[string]*creds.RefreshableProvider{ - "a": {Role: "a"}, + "a": {RoleArn: "arn:aws:iam::012345678901:role/coolRole3"}, }, - Role: "a", + SearchString: "a", AssumeChain: []string{}, ExpectedError: nil, - ExpectedResult: &creds.RefreshableProvider{Role: "a"}, + ExpectedResult: &creds.RefreshableProvider{RoleArn: "arn:aws:iam::012345678901:role/coolRole3"}, }, } @@ -344,14 +344,14 @@ func TestCredentialCache_GetOrSet(t *testing.T) { SecretAccessKey: "b", SessionToken: "c", Expiration: creds.Time(time.Unix(1, 0)), - RoleArn: "e", + RoleArn: tc.ExpectedResult.RoleArn, }, }) if err != nil { t.Errorf("test setup failure: %e", err) continue } - result, actualError := testCache.GetOrSet(client, tc.Role, tc.Region, tc.AssumeChain) + result, actualError := testCache.GetOrSet(client, tc.SearchString, tc.Region, tc.AssumeChain) if actualError != tc.ExpectedError { t.Errorf("%s failed: expected %v error, got %v", tc.Description, tc.ExpectedError, actualError) continue @@ -360,8 +360,8 @@ func TestCredentialCache_GetOrSet(t *testing.T) { t.Errorf("%s failed: got nil result, expected %v", tc.Description, tc.ExpectedResult) continue } - if result != nil && result.Role != tc.ExpectedResult.Role { - t.Errorf("%s failed: expected role %v, got %v", tc.Description, tc.ExpectedResult.Role, result.Role) + if result != nil && result.RoleArn != tc.ExpectedResult.RoleArn { + t.Errorf("%s failed: expected role %v, got %v", tc.Description, tc.ExpectedResult.RoleArn, result.RoleArn) continue } } diff --git a/creds/refreshable.go b/creds/refreshable.go index af6b29a..03e9488 100644 --- a/creds/refreshable.go +++ b/creds/refreshable.go @@ -34,7 +34,7 @@ func NewRefreshableProvider(client HTTPClient, role, region string, assumeChain splitRole := strings.Split(role, "/") roleName := splitRole[len(splitRole)-1] rp := &RefreshableProvider{ - Role: roleName, + RoleName: roleName, RoleArn: role, Region: region, NoIpRestrict: noIpRestrict, @@ -66,7 +66,7 @@ func (rp *RefreshableProvider) AutoRefresh() { } func (rp *RefreshableProvider) checkAndRefresh(threshold int) (bool, error) { - log.Debugf("checking credentials for %s", rp.Role) + log.Debugf("checking credentials for %s", rp.RoleName) // refresh creds if we're within 10 minutes of them expiring diff := time.Duration(threshold*-1) * time.Minute thresh := rp.Expiration.Add(diff) @@ -80,14 +80,14 @@ func (rp *RefreshableProvider) checkAndRefresh(threshold int) (bool, error) { } func (rp *RefreshableProvider) refresh() error { - log.Debugf("refreshing credentials for %s", rp.Role) + log.Debugf("refreshing credentials for %s", rp.RoleArn) var err error var newCreds *AwsCredentials rp.Lock() defer rp.Unlock() - newCreds, err = GetCredentialsC(rp.client, rp.Role, rp.NoIpRestrict, rp.AssumeChain) + newCreds, err = GetCredentialsC(rp.client, rp.RoleArn, rp.NoIpRestrict, rp.AssumeChain) if err != nil { if err == errors.MutualTLSCertNeedsRefreshError { log.Error(err) @@ -106,11 +106,12 @@ func (rp *RefreshableProvider) refresh() error { rp.value.SecretAccessKey = newCreds.SecretAccessKey rp.value.AccessKeyID = newCreds.AccessKeyId rp.LastRefreshed = Time(time.Now()) + // We favor the role ARN from ConsoleMe over the one from the user, which could just be a search string. rp.RoleArn = newCreds.RoleArn if rp.value.ProviderName == "" { rp.value.ProviderName = "WeepRefreshableProvider" } - log.Debugf("successfully refreshed credentials for %s", rp.Role) + log.Debugf("successfully refreshed credentials for %s", rp.RoleArn) return nil } diff --git a/creds/refreshable_test.go b/creds/refreshable_test.go index fb22f3c..b94d660 100644 --- a/creds/refreshable_test.go +++ b/creds/refreshable_test.go @@ -71,7 +71,7 @@ func TestNewRefreshableProvider(t *testing.T) { Expiration: testExpiration, LastRefreshed: Time{}, Region: testRegion, - Role: testRole, + RoleName: testRole, RoleArn: testRoleArn, NoIpRestrict: false, AssumeChain: make([]string, 0), @@ -89,7 +89,7 @@ func TestNewRefreshableProvider(t *testing.T) { Expiration: testExpiration, LastRefreshed: Time{}, Region: testRegion, - Role: testRole, + RoleName: testRole, RoleArn: testRoleArn, NoIpRestrict: true, AssumeChain: make([]string, 0), @@ -120,8 +120,8 @@ func TestNewRefreshableProvider(t *testing.T) { t.Errorf("%s failed: got %v region, expected %v", tc.Description, actualResult.Region, tc.ExpectedResult.Region) continue } - if actualResult != nil && actualResult.Role != tc.ExpectedResult.Role { - t.Errorf("%s failed: got %v role, expected %v", tc.Description, actualResult.Role, tc.ExpectedResult.Role) + if actualResult != nil && actualResult.RoleName != tc.ExpectedResult.RoleName { + t.Errorf("%s failed: got %v role, expected %v", tc.Description, actualResult.RoleName, tc.ExpectedResult.RoleName) continue } if actualResult != nil && actualResult.RoleArn != tc.ExpectedResult.RoleArn { @@ -194,7 +194,7 @@ func TestRefreshableProvider_refresh(t *testing.T) { retries: tc.Retries, retryDelay: tc.RetryDelay, Region: tc.Region, - Role: tc.Role, + RoleName: tc.Role, RoleArn: tc.RoleArn, NoIpRestrict: tc.NoIpRestrict, AssumeChain: tc.AssumeChain, diff --git a/creds/types.go b/creds/types.go index f73dfa4..29e99f9 100644 --- a/creds/types.go +++ b/creds/types.go @@ -43,7 +43,7 @@ type RefreshableProvider struct { Expiration Time LastRefreshed Time Region string - Role string + RoleName string RoleArn string NoIpRestrict bool AssumeChain []string diff --git a/server/credentialsHandler.go b/server/credentialsHandler.go index e49cd2a..40408dd 100644 --- a/server/credentialsHandler.go +++ b/server/credentialsHandler.go @@ -32,7 +32,7 @@ func RoleHandler(w http.ResponseWriter, r *http.Request) { util.WriteError(w, "error", 500) return } - if _, err := w.Write([]byte(defaultRole.Role)); err != nil { + if _, err := w.Write([]byte(defaultRole.RoleName)); err != nil { log.Errorf("failed to write response: %v", err) } }