Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache and reuse generated tokens. #5

Merged
merged 2 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ jobs:
ruby:
- '3.1.4'
- '3.2.2'
- '3.3.0'
- '3.3.5'
steps:
- uses: actions/checkout@v3
- name: Set up Ruby
Expand Down
2 changes: 2 additions & 0 deletions .rubocop.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ Metrics/MethodLength:

Metrics/AbcSize:
Max: 20
Exclude:
- test/**/**.rb

Metrics/ClassLength:
Exclude:
Expand Down
2 changes: 1 addition & 1 deletion .ruby-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.3.0
3.3.5
9 changes: 7 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased](https://github.com/haines/pg-aws_rds_iam/compare/v0.1.0...HEAD)
## [Unreleased](https://github.com/floor114/mysql2-aws_rds_iam/compare/v0.1.0...HEAD)

No notable changes.

## [0.1.0](https://github.com/haines/pg-aws_rds_iam/compare/191a63e3c0222ac05bf06faaa496da954e352bbb...v0.1.0) - 2024-01-14
## [0.2.0](https://github.com/floor114/mysql2-aws_rds_iam/compare/v0.1.0...v0.2.0) - 2024-12-16

### Added
* Cache and reuse generated tokens ([#5](https://github.com/floor114/mysql2-aws_rds_iam/pull/5))

## [0.1.0](https://github.com/floor114/mysql2-aws_rds_iam/compare/f7035d3fea3ac90e6c1b8193f8befe797a425179...v0.1.0) - 2024-01-14

### Added
* `Mysql2::AwsRdsIam` is an extension of [mysql2](https://github.com/brianmario/mysql2) gem that adds support of [IAM authentication](https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/UsingWithRDS.IAMDBAuth.html) when connecting to MySQL in Amazon RDS.
2 changes: 1 addition & 1 deletion Gemfile.lock
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
PATH
remote: .
specs:
mysql2-aws_rds_iam (0.1.0)
mysql2-aws_rds_iam (0.2.0)
aws-sdk-rds (~> 1)
mysql2
zeitwerk (~> 2)
Expand Down
48 changes: 48 additions & 0 deletions lib/mysql2/aws_rds_iam/auth_token/expirable_token.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# frozen_string_literal: true

module Mysql2
module AwsRdsIam
module AuthToken
class ExpirableToken
# By default token is valid for up to 15 minutes, here we expire it after 14 minutes
DEFAULT_EXPIRE_AT = (15 * 60) # 15 minutes
EXPIRATION_THRESHOLD = (1 * 60) # 1 minute
EXPIRE_HEADER = 'x-amz-expires'

def initialize(token)
@token = token
@created_at = now
@expire_at = parse_expiration || DEFAULT_EXPIRE_AT
end

def value
token unless expired?
end

private

attr_reader :token, :created_at, :expire_at

def expired?
(now - created_at) > (expire_at - EXPIRATION_THRESHOLD)
end

def now
Process.clock_gettime(Process::CLOCK_MONOTONIC)
end

def parse_expiration
query = URI.parse("https://#{token}").query

return nil unless query

URI.decode_www_form(query)
.filter_map { |(key, value)| Integer(value) if key.downcase == EXPIRE_HEADER }
.first
rescue StandardError
nil
end
end
end
end
end
28 changes: 23 additions & 5 deletions lib/mysql2/aws_rds_iam/auth_token/generator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,32 @@ def initialize

@generator = Aws::RDS::AuthTokenGenerator.new(credentials: aws_config.credentials)
@region = aws_config.region

@cache = {}
@cache_mutex = Mutex.new
end

def call(host:, port:, username:)
generator.auth_token(
region: region,
endpoint: "#{host}:#{port}",
user_name: username.to_s
)
cache_key = "#{host}:#{port}:#{username}"

cached_token = @cache[cache_key]&.value
return cached_token if cached_token

@cache_mutex.synchronize do
# :nocov: Executed only when parallel thread just created token
cached_token = @cache[cache_key]&.value
return cached_token if cached_token

# :nocov:

generator.auth_token(
region: region,
endpoint: "#{host}:#{port}",
user_name: username.to_s
).tap do |token|
@cache[cache_key] = ExpirableToken.new(token)
end
end
end

private
Expand Down
2 changes: 1 addition & 1 deletion lib/mysql2/aws_rds_iam/version.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

module Mysql2
module AwsRdsIam
VERSION = '0.1.0'
VERSION = '0.2.0'
end
end
58 changes: 58 additions & 0 deletions test/mysql2/auth_token/test_expirable_token.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# frozen_string_literal: true

require 'test_helper'

module Mysql2
module AwsRdsIam
module AuthToken
class TestExpirableToken < Minitest::Test
def setup
@valid_token = 'https://example.com?x-amz-expires=900'
@no_expiration_token = 'https://example.com?other=test'
@malformed_token = 'https://example.com?x-amz-expires=test'
@no_query_token = 'https://example.com'
end

def test_that_token_is_valid_when_not_expired
token = ExpirableToken.new(@valid_token)

Process.stub(:clock_gettime, token.send(:created_at) + 60) do
assert_equal @valid_token, token.value
end
end

def test_that_tokenis_valid_when_expiry_is_missing
token = ExpirableToken.new(@no_expiration_token)

Process.stub(:clock_gettime, token.send(:created_at) + 840) do
assert_equal @no_expiration_token, token.value
end
end

def test_that_tokenis_valid_when_expiry_is_invalid
token = ExpirableToken.new(@malformed_token)

Process.stub(:clock_gettime, token.send(:created_at) + 840) do
assert_equal @malformed_token, token.value
end
end

def test_that_tokenis_valid_when_no_query
token = ExpirableToken.new(@no_query_token)

Process.stub(:clock_gettime, token.send(:created_at) + 840) do
assert_equal @no_query_token, token.value
end
end

def test_that_token_is_invalid_when_expired
token = ExpirableToken.new(@valid_token)

Process.stub(:clock_gettime, token.send(:created_at) + 900) do
assert_nil token.value
end
end
end
end
end
end
66 changes: 64 additions & 2 deletions test/mysql2/auth_token/test_generator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ def setup

def test_that_it_calls_aws_libraries_and_generates_token
aws_generator = mock('generator')
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host:port', user_name: 'username')
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host:port',
user_name: 'username').returns('aws_generated_token')

Aws::RDS::Client.expects(:new).once.returns(@aws_rds_client)
Aws::RDS::AuthTokenGenerator.expects(:new).with(credentials: { at: :at, st: :st }).once.returns(aws_generator)

Mysql2::AwsRdsIam::AuthToken::Generator.new.call(host: 'host', port: 'port', username: 'username')
token = Mysql2::AwsRdsIam::AuthToken::Generator.new.call(host: 'host', port: 'port', username: 'username')

assert_equal 'aws_generated_token', token
end

def test_that_when_username_passed_as_symbol
Expand All @@ -30,6 +33,65 @@ def test_that_when_username_passed_as_symbol

Mysql2::AwsRdsIam::AuthToken::Generator.new.call(host: 'host', port: 'port', username: :username)
end

def test_that_it_uses_cached_token
aws_generator = mock('generator')
aws_generator.expects(:auth_token).never

Aws::RDS::Client.expects(:new).once.returns(@aws_rds_client)
Aws::RDS::AuthTokenGenerator.expects(:new).with(credentials: { at: :at, st: :st }).once.returns(aws_generator)

generator = Mysql2::AwsRdsIam::AuthToken::Generator.new
cached_token = mock('ExpirableToken', value: 'cached-token')
generator.instance_variable_get(:@cache)['host:port:username'] = cached_token

token = generator.call(host: 'host', port: 'port', username: 'username')

assert_equal 'cached-token', token
end

def test_that_it_refreshes_token_when_cache_is_invalid
aws_generator = mock('generator')
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host:port',
user_name: 'username').returns('aws_generated_token')

Aws::RDS::Client.expects(:new).once.returns(@aws_rds_client)
Aws::RDS::AuthTokenGenerator.expects(:new).with(credentials: { at: :at, st: :st }).once.returns(aws_generator)

generator = Mysql2::AwsRdsIam::AuthToken::Generator.new
expired_token = mock('ExpirableToken')
expired_token.expects(:value).twice.returns(nil)
generator.instance_variable_get(:@cache)['host:port:username'] = expired_token

token = generator.call(host: 'host', port: 'port', username: 'username')

assert_equal 'aws_generated_token', token
end

def test_thread_safety_with_cache_access
token1 = mock('ExpirableToken', value: 'token1')
token2 = mock('ExpirableToken', value: 'token2')
aws_generator = mock('generator')
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host1:port1',
user_name: 'username1').returns('aws_generated_token1')
aws_generator.expects(:auth_token).with(region: 'region', endpoint: 'host2:port2',
user_name: 'username2').returns('aws_generated_token2')

Aws::RDS::Client.expects(:new).once.returns(@aws_rds_client)
Aws::RDS::AuthTokenGenerator.expects(:new).with(credentials: { at: :at, st: :st }).once.returns(aws_generator)

generator = Mysql2::AwsRdsIam::AuthToken::Generator.new
ExpirableToken.stubs(:new).returns(token1, token2)

threads = []
threads << Thread.new { generator.call(host: 'host1', port: 'port1', username: 'username1') }
threads << Thread.new { generator.call(host: 'host2', port: 'port2', username: 'username2') }

threads.each(&:join)

assert_equal 'token1', generator.instance_variable_get(:@cache)['host1:port1:username1'].value
assert_equal 'token2', generator.instance_variable_get(:@cache)['host2:port2:username2'].value
end
end
end
end
Expand Down
Loading