Skip to content

Commit

Permalink
fix refresh interval for STS profile credentials provider
Browse files Browse the repository at this point in the history
  • Loading branch information
sbiscigl committed Feb 18, 2025
1 parent b667af2 commit f721e3b
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 7 deletions.
8 changes: 8 additions & 0 deletions src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentials.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ namespace Aws

inline bool IsExpired() const { return m_expiration <= Aws::Utils::DateTime::Now(); }

/**
* Checks to see if the credentials will expire in a threshold of time
*
* @param millisecondThreshold the milliseconds of threshold we will check for expiry.
* @return true if the credentials will expire before the threshold
*/
inline bool ExpiresSoon(int64_t millisecondThreshold = 5000) const { return (m_expiration - Aws::Utils::DateTime::Now()).count() < millisecondThreshold; }

inline bool IsExpiredOrEmpty() const { return IsEmpty() || IsExpired(); }

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ bool InstanceProfileCredentialsProvider::ExpiresSoon() const
credentials = profileIter->second.GetCredentials();
}

return ((credentials.GetExpiration() - Aws::Utils::DateTime::Now()).count() < AWS_CREDENTIAL_PROVIDER_EXPIRATION_GRACE_PERIOD);
return credentials.ExpiresSoon(AWS_CREDENTIAL_PROVIDER_EXPIRATION_GRACE_PERIOD);
}

void InstanceProfileCredentialsProvider::Reload()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ AWSCredentials STSProfileCredentialsProvider::GetAWSCredentials()
void STSProfileCredentialsProvider::RefreshIfExpired()
{
Utils::Threading::ReaderLockGuard guard(m_reloadLock);
if (!IsTimeToRefresh(static_cast<long>(m_reloadFrequency.count())) || !m_credentials.IsExpiredOrEmpty())
if (!IsTimeToRefresh(static_cast<long>(m_reloadFrequency.count())) && !m_credentials.IsEmpty() && !m_credentials.ExpiresSoon(m_reloadFrequency.count()))
{
return;
}

guard.UpgradeToWriterLock();
if (!IsTimeToRefresh(static_cast<long>(m_reloadFrequency.count())) || !m_credentials.IsExpiredOrEmpty()) // double-checked lock to avoid refreshing twice
if (!IsTimeToRefresh(static_cast<long>(m_reloadFrequency.count())) && !m_credentials.IsEmpty() && !m_credentials.ExpiresSoon(m_reloadFrequency.count())) // double-checked lock to avoid refreshing twice
{
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,17 @@ class MockSTSClient : public STSClient
Model::AssumeRoleOutcome AssumeRole(const Model::AssumeRoleRequest& request) const override
{
m_capturedRequest = request;
return m_mockedOutcome;
if (!m_mockedOutcomes.empty()) {
auto outcome = m_mockedOutcomes.front();
m_mockedOutcomes.pop();
return outcome;
}
return STSError{};
}

void MockAssumeRole(const Model::AssumeRoleOutcome& outcome)
{
m_mockedOutcome = outcome;
m_mockedOutcomes.push(outcome);
}

const Model::AssumeRoleRequest& CapturedRequest() const
Expand All @@ -54,7 +59,7 @@ class MockSTSClient : public STSClient

private:
mutable Model::AssumeRoleRequest m_capturedRequest;
Model::AssumeRoleOutcome m_mockedOutcome;
mutable Aws::Queue<Model::AssumeRoleOutcome> m_mockedOutcomes;
AWSCredentials m_credentials;
};

Expand Down Expand Up @@ -483,7 +488,7 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleRecursively)
/**
* Test that profile that sources itself.
*/
TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReferencing)
TEST_F(STSProfileCredentialsProviderTest, AssumeRoleSelfReerencing)
{
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};

Expand Down Expand Up @@ -621,4 +626,72 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleRecursivelyCircularReference

ASSERT_TRUE(actualCredentials.IsExpiredOrEmpty());
}

TEST_F(STSProfileCredentialsProviderTest, ShouldRefreshCredentialsNearExpiry)
{
Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc};

configFile << std::endl;
configFile << "[default]" << std::endl;
configFile << "source_profile = default" << std::endl;
configFile << "role_arn = " << ROLE_ARN_1 << std::endl;
configFile << "aws_access_key_id = " << ACCESS_KEY_ID_1 << std::endl;
configFile << "aws_secret_access_key = " << SECRET_ACCESS_KEY_ID_1 << std::endl;
configFile.close();
Aws::Config::ReloadCachedConfigFile();

constexpr auto roleSessionDuration = std::chrono::seconds(5);
const DateTime expiryTime{DateTime::Now() + roleSessionDuration};

Model::Credentials stsCredentials;
stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_2)
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_2)
.WithSessionToken(SESSION_TOKEN)
.WithExpiration(expiryTime);

Model::Credentials refreshedStsCredentials;
refreshedStsCredentials.WithAccessKeyId(ACCESS_KEY_ID_3)
.WithSecretAccessKey(SECRET_ACCESS_KEY_ID_3)
.WithSessionToken(SESSION_TOKEN)
.WithExpiration(expiryTime);

Model::AssumeRoleResult mockResult;
mockResult.SetCredentials(stsCredentials);
Model::AssumeRoleResult refreshedMockResult;
refreshedMockResult.SetCredentials(refreshedStsCredentials);
Aws::UniquePtr<MockSTSClient> stsClient;
std::once_flag stsClientInitialized;

int stsCallCounter = 0;
STSProfileCredentialsProvider credsProvider("default", std::chrono::minutes(60), [&](const AWSCredentials& creds)
{
++stsCallCounter;
std::call_once(stsClientInitialized, [&] {
stsClient = Aws::MakeUnique<MockSTSClient>(CLASS_TAG, creds);
stsClient->MockAssumeRole(mockResult);
stsClient->MockAssumeRole(refreshedMockResult);
});
return stsClient.get();
});

auto actualCredentials = credsProvider.GetAWSCredentials();

ASSERT_STREQ(ACCESS_KEY_ID_2, actualCredentials.GetAWSAccessKeyId().c_str());
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_2, actualCredentials.GetAWSSecretKey().c_str());
ASSERT_STREQ(SESSION_TOKEN, actualCredentials.GetSessionToken().c_str());
ASSERT_EQ(expiryTime, actualCredentials.GetExpiration());

ASSERT_EQ(1, stsCallCounter);
ASSERT_TRUE(stsClient);
ASSERT_STREQ(ACCESS_KEY_ID_1, stsClient->Credentials().GetAWSAccessKeyId().c_str());
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_1, stsClient->Credentials().GetAWSSecretKey().c_str());

actualCredentials = credsProvider.GetAWSCredentials();
ASSERT_STREQ(ACCESS_KEY_ID_3, actualCredentials.GetAWSAccessKeyId().c_str());
ASSERT_STREQ(SECRET_ACCESS_KEY_ID_3, actualCredentials.GetAWSSecretKey().c_str());
ASSERT_STREQ(SESSION_TOKEN, actualCredentials.GetSessionToken().c_str());
ASSERT_EQ(expiryTime, actualCredentials.GetExpiration());
//should have called refresh
ASSERT_EQ(2, stsCallCounter);
}
} // namespace

0 comments on commit f721e3b

Please sign in to comment.