Skip to content

Commit

Permalink
Cache and reuse generated tokens.
Browse files Browse the repository at this point in the history
  • Loading branch information
floor114 committed Dec 16, 2024
1 parent 20061e5 commit b9dfdef
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
ruby:
- '3.1.4'
- '3.2.2'
- '3.3.0'
- '3.3.5'
steps:
- uses: actions/checkout@v3
- name: Set up Ruby
Expand Down
2 changes: 1 addition & 1 deletion .ruby-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.3.0
3.3.5
48 changes: 48 additions & 0 deletions lib/mysql2/aws_rds_iam/auth_token/expirable_token.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# frozen_string_literal: true

module Mysql2
module AwsRdsIam
module AuthToken
class ExpirableToken
# By default token is valid for up to 15 minutes, here we expire it after 14 minutes
DEFAULT_EXPIRE_AT = (15 * 60) # 15 minutes
EXPIRATION_THRESHOLD = (1 * 60) # 1 minute
EXPIRE_HEADER = 'x-amz-expires'

def initialize(token)
@token = token
@created_at = now
@expire_at = parse_expiration || DEFAULT_EXPIRE_AT
end

def value
token unless expired?
end

private

attr_reader :token, :created_at, :expire_at

def expired?
(now - created_at) > (expire_at - EXPIRATION_THRESHOLD)
end

def now
Process.clock_gettime(Process::CLOCK_MONOTONIC)
end

def parse_expiration
query = URI.parse("https://#{token}").query

return nil unless query

URI.decode_www_form(query)
.filter_map { |(key, value)| Integer(value) if key.downcase == EXPIRE_HEADER }
.first
rescue StandardError
nil
end
end
end
end
end
27 changes: 22 additions & 5 deletions lib/mysql2/aws_rds_iam/auth_token/generator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,31 @@ def initialize

@generator = Aws::RDS::AuthTokenGenerator.new(credentials: aws_config.credentials)
@region = aws_config.region

@cache = {}
@cache_mutex = Mutex.new
end

def call(host:, port:, username:)
generator.auth_token(
region: region,
endpoint: "#{host}:#{port}",
user_name: username.to_s
)
cache_key = "#{host}:#{port}:#{username}"

cached_token = @cache[cache_key]&.value
return cached_token if cached_token

@cache_mutex.synchronize do
# :nocov: Executed only when parallel thread just created token
cached_token = @cache[cache_key]&.value
return cached_token if cached_token
# :nocov:

generator.auth_token(
region: region,
endpoint: "#{host}:#{port}",
user_name: username.to_s
).tap do |token|
@cache[cache_key] = ExpirableToken.new(token)
end
end
end

private
Expand Down
51 changes: 51 additions & 0 deletions test/mysql2/auth_token/test_expirable_token.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
require 'test_helper'

module Mysql2
module AwsRdsIam
module AuthToken
class TestExpirableToken < Minitest::Test
def setup
@valid_token = "https://example.com?x-amz-expires=900"
@no_expiration_token = "https://example.com?other=test"
@malformed_token = "https://example.com?x-amz-expires=test"
@no_query_token = "https://example.com"
end

def test_that_token_is_valid_when_not_expired
token = ExpirableToken.new(@valid_token)
Process.stub(:clock_gettime, token.send(:created_at) + 60) do
assert_equal @valid_token, token.value
end
end

def test_that_tokenis_valid_when_expiry_is_missing
token = ExpirableToken.new(@no_expiration_token)
Process.stub(:clock_gettime, token.send(:created_at) + 840) do
assert_equal @no_expiration_token, token.value
end
end

def test_that_tokenis_valid_when_expiry_is_invalid
token = ExpirableToken.new(@malformed_token)
Process.stub(:clock_gettime, token.send(:created_at) + 840) do
assert_equal @malformed_token, token.value
end
end

def test_that_tokenis_valid_when_no_query
token = ExpirableToken.new(@no_query_token)
Process.stub(:clock_gettime, token.send(:created_at) + 840) do
assert_equal @no_query_token, token.value
end
end

def test_that_token_is_invalid_when_expired
token = ExpirableToken.new(@valid_token)
Process.stub(:clock_gettime, token.send(:created_at) + 900) do
assert_nil token.value
end
end
end
end
end
end
61 changes: 59 additions & 2 deletions test/mysql2/auth_token/test_generator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ def setup

def test_that_it_calls_aws_libraries_and_generates_token
aws_generator = mock('generator')
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host:port', user_name: 'username')
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host:port', user_name: 'username').returns('aws_generated_token')

Aws::RDS::Client.expects(:new).once.returns(@aws_rds_client)
Aws::RDS::AuthTokenGenerator.expects(:new).with(credentials: { at: :at, st: :st }).once.returns(aws_generator)

Mysql2::AwsRdsIam::AuthToken::Generator.new.call(host: 'host', port: 'port', username: 'username')
token = Mysql2::AwsRdsIam::AuthToken::Generator.new.call(host: 'host', port: 'port', username: 'username')
assert_equal 'aws_generated_token', token
end

def test_that_when_username_passed_as_symbol
Expand All @@ -30,6 +31,62 @@ def test_that_when_username_passed_as_symbol

Mysql2::AwsRdsIam::AuthToken::Generator.new.call(host: 'host', port: 'port', username: :username)
end

def test_that_it_uses_cached_token
aws_generator = mock('generator')
aws_generator.expects(:auth_token).never

Aws::RDS::Client.expects(:new).once.returns(@aws_rds_client)
Aws::RDS::AuthTokenGenerator.expects(:new).with(credentials: { at: :at, st: :st }).once.returns(aws_generator)

generator = Mysql2::AwsRdsIam::AuthToken::Generator.new
cached_token = mock('ExpirableToken', value: 'cached-token')
generator.instance_variable_get(:@cache)["host:port:username"] = cached_token

token = generator.call(host: 'host', port: 'port', username: 'username')

assert_equal 'cached-token', token
end

def test_that_it_refreshes_token_when_cache_is_invalid
aws_generator = mock('generator')
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host:port', user_name: 'username').returns('aws_generated_token')

Aws::RDS::Client.expects(:new).once.returns(@aws_rds_client)
Aws::RDS::AuthTokenGenerator.expects(:new).with(credentials: { at: :at, st: :st }).once.returns(aws_generator)

generator = Mysql2::AwsRdsIam::AuthToken::Generator.new
expired_token = mock('ExpirableToken')
expired_token.expects(:value).twice.returns(nil)
generator.instance_variable_get(:@cache)["host:port:username"] = expired_token

token = generator.call(host: 'host', port: 'port', username: 'username')

assert_equal 'aws_generated_token', token
end

def test_thread_safety_with_cache_access
token1 = mock('ExpirableToken', value: 'token1')
token2 = mock('ExpirableToken', value: 'token2')
aws_generator = mock('generator')
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host1:port1', user_name: 'username1').returns('aws_generated_token1')
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host2:port2', user_name: 'username2').returns('aws_generated_token2')

Aws::RDS::Client.expects(:new).once.returns(@aws_rds_client)
Aws::RDS::AuthTokenGenerator.expects(:new).with(credentials: { at: :at, st: :st }).once.returns(aws_generator)

generator = Mysql2::AwsRdsIam::AuthToken::Generator.new
ExpirableToken.stubs(:new).returns(token1, token2)

threads = []
threads << Thread.new { generator.call(host: 'host1', port: 'port1', username: 'username1') }
threads << Thread.new { generator.call(host: 'host2', port: 'port2', username: 'username2') }

threads.each(&:join)

assert_equal 'token1', generator.instance_variable_get(:@cache)['host1:port1:username1'].value
assert_equal 'token2', generator.instance_variable_get(:@cache)['host2:port2:username2'].value
end
end
end
end
Expand Down

0 comments on commit b9dfdef

Please sign in to comment.