Skip to content

Commit

Permalink
feat: okta authentication support
Browse files Browse the repository at this point in the history
  • Loading branch information
karenc-bq committed Aug 13, 2024
1 parent c114d75 commit 1e47d14
Show file tree
Hide file tree
Showing 19 changed files with 603 additions and 33 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ ENDIF(WIN32)
#------------ find the AWS SDK for C++ package---------
LIST(APPEND CMAKE_PREFIX_PATH "${CMAKE_SOURCE_DIR}/aws_sdk/install")

FIND_PACKAGE(AWSSDK REQUIRED COMPONENTS rds secretsmanager)
FIND_PACKAGE(AWSSDK REQUIRED COMPONENTS rds secretsmanager sts)

#------------------------------------------------------

Expand Down Expand Up @@ -807,7 +807,7 @@ if(APPLE)
set(BUNDLED_LIBS
libssl libcrypto ssleay libeay
libaws-c-auth libaws-c-cal libaws-c-common libaws-c-compression libaws-c-event-stream libaws-c-http libaws-c-io
libaws-c-mqtt libaws-c-s3 libaws-c-sdkutils libaws-checksums libaws-cpp-sdk-core libaws-cpp-sdk-rds libaws-cpp-sdk-secretsmanager libaws-crt-cpp
libaws-c-mqtt libaws-c-s3 libaws-c-sdkutils libaws-checksums libaws-cpp-sdk-core libaws-cpp-sdk-rds libaws-cpp-sdk-secretsmanager libaws-cpp-sdk-sts libaws-crt-cpp
)
else(APPLE)
set(BUNDLED_LIBS
Expand Down
32 changes: 26 additions & 6 deletions driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,14 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
my_stmt.cc
mylog.cc
mysql_proxy.cc
okta_proxy.cc
options.cc
parse.cc
prepare.cc
query_parsing.cc
results.cc
saml_http_client.cc
saml_util.cc
secrets_manager_proxy.cc
topology_service.cc
transact.cc
Expand Down Expand Up @@ -149,8 +152,11 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
mylog.h
mysql_proxy.h
myutil.h
okta_proxy.h
parse.h
query_parsing.h
saml_http_client.h
saml_util.h
secrets_manager_proxy.h
topology_service.h
../MYODBC_MYSQL.h ../MYODBC_CONF.h ../MYODBC_ODBC.h)
Expand Down Expand Up @@ -253,7 +259,7 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})

SET(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${ODBC_DRIVER_LINK_FLAGS}")
TARGET_LINK_LIBRARIES(${DRIVER_NAME}
${MYSQL_CLIENT_LIBS} ${CMAKE_THREAD_LIBS_INIT} m myodbc-util otel_api)
${MYSQL_CLIENT_LIBS} ${CMAKE_THREAD_LIBS_INIT} m myodbc-util otel_ap)
TARGET_LINK_LIBRARIES(${DRIVER_NAME_STATIC}
${MYSQL_CLIENT_LIBS} ${CMAKE_THREAD_LIBS_INIT} m myodbc-util otel_api)

Expand Down Expand Up @@ -295,19 +301,33 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})

MATH(EXPR DRIVER_INDEX "${DRIVER_INDEX} + 1")

include(FetchContent)

FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.10.5/json.tar.xz
)

FetchContent_Declare(
httplib
URL https://github.com/yhirose/cpp-httplib/archive/refs/tags/v0.14.0.zip
)

FetchContent_MakeAvailable(httplib json)

#------------AWS SDK------------------
LIST(APPEND SERVICE_LIST rds secretsmanager)
LIST(APPEND SERVICE_LIST rds secretsmanager sts)

MESSAGE(STATUS "CMAKE_BUILD_TYPE is ${CMAKE_BUILD_TYPE}")
IF(MSVC)
MESSAGE(STATUS "Copying AWS SDK libraries to ${LIBRARY_OUTPUT_PATH}/${CMAKE_BUILD_TYPE}")
AWSSDK_CPY_DYN_LIBS(SERVICE_LIST "" ${LIBRARY_OUTPUT_PATH}/${CMAKE_BUILD_TYPE})
ENDIF(MSVC)

TARGET_INCLUDE_DIRECTORIES(${DRIVER_NAME} PUBLIC ${AWSSDK_INCLUDE_DIR})
TARGET_LINK_LIBRARIES(${DRIVER_NAME} ${AWSSDK_LINK_LIBRARIES})
TARGET_INCLUDE_DIRECTORIES(${DRIVER_NAME_STATIC} PUBLIC ${AWSSDK_INCLUDE_DIR})
TARGET_LINK_LIBRARIES(${DRIVER_NAME_STATIC} ${AWSSDK_LINK_LIBRARIES})
TARGET_INCLUDE_DIRECTORIES(${DRIVER_NAME} PUBLIC "${httplib_SOURCE_DIR}" ${AWSSDK_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR})
TARGET_LINK_LIBRARIES(${DRIVER_NAME} nlohmann_json::nlohmann_json ${AWSSDK_LINK_LIBRARIES})
TARGET_INCLUDE_DIRECTORIES(${DRIVER_NAME_STATIC} PUBLIC "${httplib_SOURCE_DIR}" ${AWSSDK_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR})
TARGET_LINK_LIBRARIES(${DRIVER_NAME_STATIC} nlohmann_json::nlohmann_json ${AWSSDK_LINK_LIBRARIES})
#------------------------------

ENDWHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
10 changes: 5 additions & 5 deletions driver/auth_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,19 @@ namespace {
AWS_SDK_HELPER SDK_HELPER;
}

AUTH_UTIL::AUTH_UTIL(const char* region) {
++SDK_HELPER;
AUTH_UTIL::AUTH_UTIL(const char* region)
: AUTH_UTIL(region, Aws::Auth::DefaultAWSCredentialsProviderChain().GetAWSCredentials()) {};

Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider;
Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials();
AUTH_UTIL::AUTH_UTIL(const char* region, Aws::Auth::AWSCredentials credentials) {
++SDK_HELPER;

Aws::RDS::RDSClientConfiguration client_config;
if (region) {
client_config.region = region;
}

this->rds_client = std::make_shared<Aws::RDS::RDSClient>(credentials, client_config);
};
}

std::string AUTH_UTIL::get_auth_token(const char* host, const char* region, unsigned int port, const char* user) {
return this->rds_client->GenerateConnectAuthToken(host, region, port, user);
Expand Down
1 change: 1 addition & 0 deletions driver/auth_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class AUTH_UTIL {
public:
AUTH_UTIL() {};
AUTH_UTIL(const char* region);
AUTH_UTIL(const char* region, Aws::Auth::AWSCredentials credentials);
~AUTH_UTIL();

virtual std::string get_auth_token(const char* host, const char* region, unsigned int port, const char* user);
Expand Down
3 changes: 2 additions & 1 deletion driver/connect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,8 @@ SQLRETURN DBC::connect(DataSource *dsrc, bool failover_enabled, bool is_monitor_
#if (MYSQL_VERSION_ID >= 50527 && MYSQL_VERSION_ID < 50600) || MYSQL_VERSION_ID >= 50607
// IAM authentication requires the plugin to be set.
if (dsrc->opt_ENABLE_CLEARTEXT_PLUGIN ||
(dsrc->opt_AUTH_MODE && !myodbc_strcasecmp(AUTH_MODE_IAM, (const char*)dsrc->opt_AUTH_MODE)))
(dsrc->opt_AUTH_MODE && !myodbc_strcasecmp(AUTH_MODE_IAM, (const char*)dsrc->opt_AUTH_MODE))
|| dsrc->opt_FED_AUTH_MODE)
{
connection_proxy->options(MYSQL_ENABLE_CLEARTEXT_PLUGIN, (char *)&on);
}
Expand Down
16 changes: 16 additions & 0 deletions driver/handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@
* *
****************************************************************************/

#include "adfs_proxy.h"
#include "driver.h"
#include "efm_proxy.h"
#include "iam_proxy.h"
#include "mysql_proxy.h"
#include "okta_proxy.h"
#include "secrets_manager_proxy.h"

#include <mutex>
Expand Down Expand Up @@ -141,6 +143,20 @@ void DBC::init_proxy_chain(DataSource* dsrc)
}
}

if (dsrc->opt_FED_AUTH_MODE) {
const char* fed_auth_mode = (const char*)dsrc->opt_FED_AUTH_MODE;
if (!myodbc_strcasecmp(FED_AUTH_MODE_ADFS, fed_auth_mode)) {
CONNECTION_PROXY* adfs_proxy = new ADFS_PROXY(this, dsrc);
adfs_proxy->set_next_proxy(head);
head = adfs_proxy;
}
else if (!myodbc_strcasecmp(FED_AUTH_MODE_OKTA, fed_auth_mode)) {
CONNECTION_PROXY* okta_proxy = new OKTA_PROXY(this, dsrc);
okta_proxy->set_next_proxy(head);
head = okta_proxy;
}
}

this->connection_proxy = head;
}

Expand Down
170 changes: 170 additions & 0 deletions driver/okta_proxy.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
//
// This program is free software; you can redistribute it and/or modify
// it under the terms of the GNU General Public License, version 2.0
// (GPLv2), as published by the Free Software Foundation, with the
// following additional permissions:
//
// This program is distributed with certain software that is licensed
// under separate terms, as designated in a particular file or component
// or in the license documentation. Without limiting your rights under
// the GPLv2, the authors of this program hereby grant you an additional
// permission to link the program and your derivative works with the
// separately licensed software that they have included with the program.
//
// Without limiting the foregoing grant of rights under the GPLv2 and
// additional permission as to separately licensed software, this
// program is also subject to the Universal FOSS Exception, version 1.0,
// a copy of which can be found along with its FAQ at
// http://oss.oracle.com/licenses/universal-foss-exception.
//
// This program is distributed in the hope that it will be useful, but
// WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
// See the GNU General Public License, version 2.0, for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see
// http://www.gnu.org/licenses/gpl-2.0.html.

#include <functional>

#include "driver.h"
#include "okta_proxy.h"
#include "saml_http_client.h"

#define OKTA_AWS_APP_NAME "amazon_aws"

std::unordered_map<std::string, TOKEN_INFO> OKTA_PROXY::token_cache;
std::mutex OKTA_PROXY::token_cache_mutex;

OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds) : OKTA_PROXY(dbc, ds, nullptr) {};

OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) : CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
const std::string idp_host{static_cast<const char*>(ds->opt_IDP_ENDPOINT)};
this->saml_util = std::make_shared<OKTA_SAML_UTIL>(idp_host);
}

bool OKTA_PROXY::connect(const char* host, const char* user, const char* password, const char* database,
unsigned int port, const char* socket, unsigned long flags) {
auto f =
std::bind(&CONNECTION_PROXY::connect, next_proxy, host, user, std::placeholders::_1, database, port, socket, flags);
return invoke_func_with_fed_credentials(f);
}

bool OKTA_PROXY::invoke_func_with_fed_credentials(std::function<bool(const char*)> func) {
const char* region = ds->opt_AUTH_REGION ? static_cast<const char*>(ds->opt_AUTH_REGION) : Aws::Region::US_EAST_1;
const std::string assertion = this->saml_util->get_saml_assertion(ds);
const char* idp_host = static_cast<const char*>(ds->opt_IDP_ENDPOINT);
const char* iam_role_arn = static_cast<const char*>(ds->opt_IAM_ROLE_ARN);
const char* idp_arn = static_cast<const char*>(ds->opt_IAM_IDP_ARN);
const Aws::Auth::AWSCredentials credentials =
this->saml_util->get_aws_credentials(idp_host, region, iam_role_arn, idp_arn, assertion);
this->auth_util = std::make_shared<AUTH_UTIL>(region, credentials);

const char* AUTH_HOST =
ds->opt_AUTH_HOST ? static_cast<const char*>(ds->opt_AUTH_HOST) : static_cast<const char*>(ds->opt_SERVER);
int iam_port = ds->opt_AUTH_PORT;
if (iam_port == UNDEFINED_PORT) {
// Use regular port if user does not provide IAM port
iam_port = ds->opt_PORT;
}

std::string auth_token = this->auth_util->get_auth_token(AUTH_HOST, region, iam_port, (const char*)ds->opt_UID);

bool connect_result = func(auth_token.c_str());
if (!connect_result) {
if (using_cached_token) {
// Retry func with a fresh token
auth_token = this->auth_util->get_auth_token(AUTH_HOST, region, iam_port, (const char*)ds->opt_UID);
if (func(auth_token.c_str())) {
return true;
}
}

Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider;
Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials();
if (credentials.IsEmpty()) {
this->set_custom_error_message(
"Could not find AWS Credentials for IAM Authentication. Please set up AWS credentials.");
} else if (credentials.IsExpired()) {
this->set_custom_error_message(
"AWS Credentials for IAM Authentication are expired. Please refresh AWS credentials.");
}
}

return connect_result;
}

OKTA_PROXY::~OKTA_PROXY() {
this->auth_util.reset();
this->saml_util.reset();
}

#ifdef UNIT_TEST_BUILD
OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy,
const std::shared_ptr<AUTH_UTIL>& auth_util, const std::shared_ptr<OKTA_SAML_UTIL>& saml_util)
: CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
this->auth_util = auth_util;
this->saml_util = saml_util;
}
#endif

void OKTA_PROXY::clear_token_cache() {
std::unique_lock<std::mutex> lock(token_cache_mutex);
token_cache.clear();
}

OKTA_SAML_UTIL::OKTA_SAML_UTIL(std::string host) {
this->http_client = std::make_shared<SAML_HTTP_CLIENT>("https://" + host);
}

std::string OKTA_SAML_UTIL::get_saml_url(DataSource* ds) {
const std::string app_id{static_cast<const char*>(ds->opt_APP_ID)};

return "/app/" + std::string(OKTA_AWS_APP_NAME) + "/" + app_id + "/sso/saml";
}

std::string OKTA_SAML_UTIL::get_session_token(DataSource* ds) const {
const std::string username = ds->opt_IDP_USERNAME;
const std::string password = ds->opt_IDP_PASSWORD;

const std::string session_token_endpoint = "/api/v1/authn";
const nlohmann::json request_body = {{"username", username}, {"password", password}};
const nlohmann::json res = this->http_client->post(session_token_endpoint, request_body);
if (res.empty()) {
return "";
}
return res["sessionToken"];
}

std::string OKTA_SAML_UTIL::get_saml_assertion(DataSource* ds) {
const std::string token = this->get_session_token(ds);
const nlohmann::json params = {{"onetimetoken", token}};
const nlohmann::json res = this->http_client->get(this->get_saml_url(ds) + "?onetimetoken=" + token);
const std::string body = std::string(res);
auto f = [body](const std::regex pattern) {
std::smatch m;
if (std::regex_search(body, m, pattern)) {
std::string saml = m.str(1);

saml = OKTA_SAML_UTIL::replace_all(saml, "&#x2b;", "+");
saml = OKTA_SAML_UTIL::replace_all(saml, "&#x3d;", "=");
return saml;
}
return std::string();
};

return f(SAML_RESPONSE_PATTERN);
}

std::string OKTA_SAML_UTIL::replace_all(std::string str, const std::string& from, const std::string& to) {
size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
str.replace(start_pos, from.length(), to);
start_pos += to.length();
}
return str;
}
Loading

0 comments on commit 1e47d14

Please sign in to comment.