From 521a7b9c6ac64547fb9ffc83ae5f2c30b247b1e1 Mon Sep 17 00:00:00 2001 From: sbiscigl Date: Tue, 18 Feb 2025 13:07:06 -0500 Subject: [PATCH] fix refresh interval for STS profile credentials provider --- .../include/aws/core/auth/AWSCredentials.h | 8 ++ .../source/auth/AWSCredentialsProvider.cpp | 2 +- .../auth/STSProfileCredentialsProvider.cpp | 4 +- .../STSProfileCredentialsProviderTest.cpp | 79 ++++++++++++++++++- 4 files changed, 87 insertions(+), 6 deletions(-) diff --git a/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentials.h b/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentials.h index ca73b5e5c70..850664637eb 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentials.h +++ b/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentials.h @@ -91,6 +91,14 @@ namespace Aws inline bool IsExpired() const { return m_expiration <= Aws::Utils::DateTime::Now(); } + /** + * Checks to see if the credentials will expire in a threshold of time + * + * @param millisecondThreshold the milliseconds of threshold we will check for expiry. + * @return true if the credentials will expire before the threshold + */ + inline bool ExpiresSoon(int64_t millisecondThreshold = 5000) const { return (m_expiration - Aws::Utils::DateTime::Now()).count() < millisecondThreshold; } + inline bool IsExpiredOrEmpty() const { return IsEmpty() || IsExpired(); } /** diff --git a/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp b/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp index b1e1471a632..69fc27e4514 100644 --- a/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp +++ b/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProvider.cpp @@ -274,7 +274,7 @@ bool InstanceProfileCredentialsProvider::ExpiresSoon() const credentials = profileIter->second.GetCredentials(); } - return ((credentials.GetExpiration() - Aws::Utils::DateTime::Now()).count() < AWS_CREDENTIAL_PROVIDER_EXPIRATION_GRACE_PERIOD); + return credentials.ExpiresSoon(AWS_CREDENTIAL_PROVIDER_EXPIRATION_GRACE_PERIOD); } void InstanceProfileCredentialsProvider::Reload() diff --git a/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp b/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp index fd82b678fba..904932ee57d 100644 --- a/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp +++ b/src/aws-cpp-sdk-identity-management/source/auth/STSProfileCredentialsProvider.cpp @@ -45,13 +45,13 @@ AWSCredentials STSProfileCredentialsProvider::GetAWSCredentials() void STSProfileCredentialsProvider::RefreshIfExpired() { Utils::Threading::ReaderLockGuard guard(m_reloadLock); - if (!IsTimeToRefresh(static_cast(m_reloadFrequency.count())) || !m_credentials.IsExpiredOrEmpty()) + if (!IsTimeToRefresh(static_cast(m_reloadFrequency.count())) && !m_credentials.IsEmpty() && !m_credentials.ExpiresSoon(m_reloadFrequency.count())) { return; } guard.UpgradeToWriterLock(); - if (!IsTimeToRefresh(static_cast(m_reloadFrequency.count())) || !m_credentials.IsExpiredOrEmpty()) // double-checked lock to avoid refreshing twice + if (!IsTimeToRefresh(static_cast(m_reloadFrequency.count())) && !m_credentials.IsEmpty() && !m_credentials.ExpiresSoon(m_reloadFrequency.count())) // double-checked lock to avoid refreshing twice { return; } diff --git a/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp b/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp index 197535a6a2e..e096f098bcc 100644 --- a/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp +++ b/tests/aws-cpp-sdk-identity-management-tests/auth/STSProfileCredentialsProviderTest.cpp @@ -34,12 +34,17 @@ class MockSTSClient : public STSClient Model::AssumeRoleOutcome AssumeRole(const Model::AssumeRoleRequest& request) const override { m_capturedRequest = request; - return m_mockedOutcome; + if (!m_mockedOutcomes.empty()) { + auto outcome = m_mockedOutcomes.front(); + m_mockedOutcomes.pop(); + return outcome; + } + return STSError{}; } void MockAssumeRole(const Model::AssumeRoleOutcome& outcome) { - m_mockedOutcome = outcome; + m_mockedOutcomes.push(outcome); } const Model::AssumeRoleRequest& CapturedRequest() const @@ -54,7 +59,7 @@ class MockSTSClient : public STSClient private: mutable Model::AssumeRoleRequest m_capturedRequest; - Model::AssumeRoleOutcome m_mockedOutcome; + mutable Aws::Queue m_mockedOutcomes; AWSCredentials m_credentials; }; @@ -621,4 +626,72 @@ TEST_F(STSProfileCredentialsProviderTest, AssumeRoleRecursivelyCircularReference ASSERT_TRUE(actualCredentials.IsExpiredOrEmpty()); } + +TEST_F(STSProfileCredentialsProviderTest, ShouldRefreshCredentialsNearExpiry) +{ + Aws::OFStream configFile {m_configFilename.c_str(), Aws::OFStream::out | Aws::OFStream::trunc}; + + configFile << std::endl; + configFile << "[default]" << std::endl; + configFile << "source_profile = default" << std::endl; + configFile << "role_arn = " << ROLE_ARN_1 << std::endl; + configFile << "aws_access_key_id = " << ACCESS_KEY_ID_1 << std::endl; + configFile << "aws_secret_access_key = " << SECRET_ACCESS_KEY_ID_1 << std::endl; + configFile.close(); + Aws::Config::ReloadCachedConfigFile(); + + constexpr auto roleSessionDuration = std::chrono::seconds(5); + const DateTime expiryTime{DateTime::Now() + roleSessionDuration}; + + Model::Credentials stsCredentials; + stsCredentials.WithAccessKeyId(ACCESS_KEY_ID_2) + .WithSecretAccessKey(SECRET_ACCESS_KEY_ID_2) + .WithSessionToken(SESSION_TOKEN) + .WithExpiration(expiryTime); + + Model::Credentials refreshedStsCredentials; + refreshedStsCredentials.WithAccessKeyId(ACCESS_KEY_ID_3) + .WithSecretAccessKey(SECRET_ACCESS_KEY_ID_3) + .WithSessionToken(SESSION_TOKEN) + .WithExpiration(expiryTime); + + Model::AssumeRoleResult mockResult; + mockResult.SetCredentials(stsCredentials); + Model::AssumeRoleResult refreshedMockResult; + refreshedMockResult.SetCredentials(refreshedStsCredentials); + Aws::UniquePtr stsClient; + std::once_flag stsClientInitialized; + + int stsCallCounter = 0; + STSProfileCredentialsProvider credsProvider("default", std::chrono::minutes(60), [&](const AWSCredentials& creds) + { + ++stsCallCounter; + std::call_once(stsClientInitialized, [&] { + stsClient = Aws::MakeUnique(CLASS_TAG, creds); + stsClient->MockAssumeRole(mockResult); + stsClient->MockAssumeRole(refreshedMockResult); + }); + return stsClient.get(); + }); + + auto actualCredentials = credsProvider.GetAWSCredentials(); + + ASSERT_STREQ(ACCESS_KEY_ID_2, actualCredentials.GetAWSAccessKeyId().c_str()); + ASSERT_STREQ(SECRET_ACCESS_KEY_ID_2, actualCredentials.GetAWSSecretKey().c_str()); + ASSERT_STREQ(SESSION_TOKEN, actualCredentials.GetSessionToken().c_str()); + ASSERT_EQ(expiryTime, actualCredentials.GetExpiration()); + + ASSERT_EQ(1, stsCallCounter); + ASSERT_TRUE(stsClient); + ASSERT_STREQ(ACCESS_KEY_ID_1, stsClient->Credentials().GetAWSAccessKeyId().c_str()); + ASSERT_STREQ(SECRET_ACCESS_KEY_ID_1, stsClient->Credentials().GetAWSSecretKey().c_str()); + + actualCredentials = credsProvider.GetAWSCredentials(); + ASSERT_STREQ(ACCESS_KEY_ID_3, actualCredentials.GetAWSAccessKeyId().c_str()); + ASSERT_STREQ(SECRET_ACCESS_KEY_ID_3, actualCredentials.GetAWSSecretKey().c_str()); + ASSERT_STREQ(SESSION_TOKEN, actualCredentials.GetSessionToken().c_str()); + ASSERT_EQ(expiryTime, actualCredentials.GetExpiration()); + //should have called refresh + ASSERT_EQ(2, stsCallCounter); +} } // namespace