Skip to content

Commit

Permalink
refactor: fix typo and clean up interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
karenc-bq committed Aug 16, 2024
1 parent 1e47d14 commit 5c14514
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 64 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/failover.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ jobs:
unzip -d C:/mysql-${{ vars.MYSQL_VERSION }}-winx64-debug mysql-debug.zip
mv -Force C:/mysql-${{ vars.MYSQL_VERSION }}-winx64-debug/mysql-${{ vars.MYSQL_VERSION }}-winx64/lib/debug/mysqlclient.lib C:/mysql-${{ vars.MYSQL_VERSION }}-winx64/lib/mysqlclient.lib
- name: Install OpenSSL 3
run: |
curl -L https://download.firedaemon.com/FireDaemon-OpenSSL/openssl-3.3.1.zip -o openssl3.zip
unzip -d C:/ openssl3.zip
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v2

Expand Down Expand Up @@ -71,6 +76,7 @@ jobs:
-DMYSQLCLIENT_STATIC_LINKING=TRUE
-DENABLE_UNIT_TESTS=TRUE
-DENABLE_INTEGRATION_TESTS=FALSE
-DOPENSSL_INCLUDE_DIR="C:/openssl-3/x64/include/"

# Configure test environment
- name: Build Driver
Expand Down
14 changes: 0 additions & 14 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,6 @@ name: Integration Tests

on:
workflow_dispatch:
push:
branches:
- main
pull_request:
branches:
- '*'
paths-ignore:
- '**/*.md'
- '**/*.jpg'
- '**/README.txt'
- '**/LICENSE.txt'
- 'docs/**'
- 'ISSUE_TEMPLATE/**'
- '**/remove-old-artifacts.yml'

env:
BUILD_TYPE: Release
Expand Down
14 changes: 10 additions & 4 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ jobs:
curl -L https://dev.mysql.com/get/Downloads/MySQL-8.3/mysql-${{ vars.MYSQL_VERSION }}-winx64.zip -o mysql.zip
unzip -d C:/ mysql.zip
- name: Install OpenSSL 3
run: |
curl -L https://download.firedaemon.com/FireDaemon-OpenSSL/openssl-3.3.1.zip -o openssl3.zip
unzip -d C:/ openssl3.zip
- name: Add msbuild to PATH
uses: microsoft/setup-msbuild@v2

Expand Down Expand Up @@ -66,6 +71,7 @@ jobs:
-DMYSQL_SQL="C:/mysql-${{ vars.MYSQL_VERSION }}-winx64"
-DCMAKE_BUILD_TYPE=$BUILD_TYPE
-DMYSQLCLIENT_STATIC_LINKING=TRUE
-DOPENSSL_INCLUDE_DIR="C:/openssl-3/x64/include/"

# Configure test environment
- name: Build Driver and Copy files
Expand Down Expand Up @@ -157,11 +163,11 @@ jobs:
brew update
brew unlink unixodbc
brew install libiodbc mysql@8.3 mysql-client@8.3
brew install libiodbc mysql@8.4 mysql-client@8.4
brew link --overwrite --force libiodbc
brew link --overwrite --force mysql@8.3
echo 'export PATH="/usr/local/opt/mysql@8.3/bin:$PATH"' >> /Users/runner/.bash_profile
echo 'export PATH="/usr/local/opt/mysql-client@8.3/bin:$PATH"' >> /Users/runner/.bash_profile
brew link --overwrite --force mysql@8.4
echo 'export PATH="/usr/local/opt/mysql@8.4/bin:$PATH"' >> /Users/runner/.bash_profile
echo 'export PATH="/usr/local/opt/mysql-client@8.4/bin:$PATH"' >> /Users/runner/.bash_profile
brew install openssl@3
rm -f /usr/local/lib/libssl.3.dylib
Expand Down
26 changes: 16 additions & 10 deletions driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})

SET(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} ${ODBC_DRIVER_LINK_FLAGS}")
TARGET_LINK_LIBRARIES(${DRIVER_NAME}
${MYSQL_CLIENT_LIBS} ${CMAKE_THREAD_LIBS_INIT} m myodbc-util otel_ap)
${MYSQL_CLIENT_LIBS} ${CMAKE_THREAD_LIBS_INIT} m myodbc-util otel_api)
TARGET_LINK_LIBRARIES(${DRIVER_NAME_STATIC}
${MYSQL_CLIENT_LIBS} ${CMAKE_THREAD_LIBS_INIT} m myodbc-util otel_api)

Expand Down Expand Up @@ -301,20 +301,26 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})

MATH(EXPR DRIVER_INDEX "${DRIVER_INDEX} + 1")

#------------DEPENDENCIES FOR FEDERATED AUTH---------
include(FetchContent)

FetchContent_Declare(
json
URL https://github.com/nlohmann/json/releases/download/v3.10.5/json.tar.xz
json
URL https://github.com/nlohmann/json/releases/download/v3.10.5/json.tar.xz
)

FetchContent_Declare(
httplib
URL https://github.com/yhirose/cpp-httplib/archive/refs/tags/v0.14.0.zip
httplib
URL https://github.com/yhirose/cpp-httplib/archive/refs/tags/v0.16.1.zip
)

FetchContent_MakeAvailable(httplib json)


TARGET_INCLUDE_DIRECTORIES(${DRIVER_NAME} PUBLIC "${httplib_SOURCE_DIR}" ${OPENSSL_INCLUDE_DIR})
TARGET_LINK_LIBRARIES(${DRIVER_NAME} nlohmann_json::nlohmann_json)
TARGET_INCLUDE_DIRECTORIES(${DRIVER_NAME_STATIC} PUBLIC "${httplib_SOURCE_DIR}" ${OPENSSL_INCLUDE_DIR})
TARGET_LINK_LIBRARIES(${DRIVER_NAME_STATIC} nlohmann_json::nlohmann_json)

#------------AWS SDK------------------
LIST(APPEND SERVICE_LIST rds secretsmanager sts)

Expand All @@ -324,10 +330,10 @@ WHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
AWSSDK_CPY_DYN_LIBS(SERVICE_LIST "" ${LIBRARY_OUTPUT_PATH}/${CMAKE_BUILD_TYPE})
ENDIF(MSVC)

TARGET_INCLUDE_DIRECTORIES(${DRIVER_NAME} PUBLIC "${httplib_SOURCE_DIR}" ${AWSSDK_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR})
TARGET_LINK_LIBRARIES(${DRIVER_NAME} nlohmann_json::nlohmann_json ${AWSSDK_LINK_LIBRARIES})
TARGET_INCLUDE_DIRECTORIES(${DRIVER_NAME_STATIC} PUBLIC "${httplib_SOURCE_DIR}" ${AWSSDK_INCLUDE_DIR} ${OPENSSL_INCLUDE_DIR})
TARGET_LINK_LIBRARIES(${DRIVER_NAME_STATIC} nlohmann_json::nlohmann_json ${AWSSDK_LINK_LIBRARIES})
TARGET_INCLUDE_DIRECTORIES(${DRIVER_NAME} PUBLIC ${AWSSDK_INCLUDE_DIR})
TARGET_LINK_LIBRARIES(${DRIVER_NAME} ${AWSSDK_LINK_LIBRARIES})
TARGET_INCLUDE_DIRECTORIES(${DRIVER_NAME_STATIC} PUBLIC ${AWSSDK_INCLUDE_DIR})
TARGET_LINK_LIBRARIES(${DRIVER_NAME_STATIC} ${AWSSDK_LINK_LIBRARIES})
#------------------------------

ENDWHILE(${DRIVER_INDEX} LESS ${DRIVERS_COUNT})
50 changes: 22 additions & 28 deletions driver/okta_proxy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,49 +48,45 @@ OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy) :

bool OKTA_PROXY::connect(const char* host, const char* user, const char* password, const char* database,
unsigned int port, const char* socket, unsigned long flags) {
auto f =
std::bind(&CONNECTION_PROXY::connect, next_proxy, host, user, std::placeholders::_1, database, port, socket, flags);
auto f = std::bind(&CONNECTION_PROXY::connect, next_proxy, host, user, std::placeholders::_1, database, port, socket,
flags);
return invoke_func_with_fed_credentials(f);
}

bool OKTA_PROXY::invoke_func_with_fed_credentials(std::function<bool(const char*)> func) {
const char* region = ds->opt_AUTH_REGION ? static_cast<const char*>(ds->opt_AUTH_REGION) : Aws::Region::US_EAST_1;
const std::string assertion = this->saml_util->get_saml_assertion(ds);
const char* idp_host = static_cast<const char*>(ds->opt_IDP_ENDPOINT);
const char* iam_role_arn = static_cast<const char*>(ds->opt_IAM_ROLE_ARN);
const char* idp_arn = static_cast<const char*>(ds->opt_IAM_IDP_ARN);
auto idp_host = static_cast<const char*>(ds->opt_IDP_ENDPOINT);
auto iam_role_arn = static_cast<const char*>(ds->opt_IAM_ROLE_ARN);
auto idp_arn = static_cast<const char*>(ds->opt_IAM_IDP_ARN);
const Aws::Auth::AWSCredentials credentials =
this->saml_util->get_aws_credentials(idp_host, region, iam_role_arn, idp_arn, assertion);
this->saml_util->get_aws_credentials(idp_host, region, iam_role_arn, idp_arn, assertion);
this->auth_util = std::make_shared<AUTH_UTIL>(region, credentials);

const char* AUTH_HOST =
ds->opt_AUTH_HOST ? static_cast<const char*>(ds->opt_AUTH_HOST) : static_cast<const char*>(ds->opt_SERVER);
ds->opt_AUTH_HOST ? static_cast<const char*>(ds->opt_AUTH_HOST) : static_cast<const char*>(ds->opt_SERVER);
int iam_port = ds->opt_AUTH_PORT;
if (iam_port == UNDEFINED_PORT) {
// Use regular port if user does not provide IAM port
iam_port = ds->opt_PORT;
}

std::string auth_token = this->auth_util->get_auth_token(AUTH_HOST, region, iam_port, (const char*)ds->opt_UID);
std::string auth_token = this->auth_util->get_auth_token(AUTH_HOST, region, iam_port, ds->opt_UID);

bool connect_result = func(auth_token.c_str());
if (!connect_result) {
if (using_cached_token) {
// Retry func with a fresh token
auth_token = this->auth_util->get_auth_token(AUTH_HOST, region, iam_port, (const char*)ds->opt_UID);
auth_token = this->auth_util->get_auth_token(AUTH_HOST, region, iam_port, ds->opt_UID);
if (func(auth_token.c_str())) {
return true;
}
}

Aws::Auth::DefaultAWSCredentialsProviderChain credentials_provider;
Aws::Auth::AWSCredentials credentials = credentials_provider.GetAWSCredentials();
if (credentials.IsEmpty()) {
this->set_custom_error_message(
"Could not find AWS Credentials for IAM Authentication. Please set up AWS credentials.");
} else if (credentials.IsExpired()) {
this->set_custom_error_message(
"AWS Credentials for IAM Authentication are expired. Please refresh AWS credentials.");
"Unable to generate temporary AWS credentials from the SAML assertion. Please ensure the Okta identity "
"provider is correctly configured with AWS.");
}
}

Expand All @@ -105,7 +101,7 @@ OKTA_PROXY::~OKTA_PROXY() {
#ifdef UNIT_TEST_BUILD
OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy,
const std::shared_ptr<AUTH_UTIL>& auth_util, const std::shared_ptr<OKTA_SAML_UTIL>& saml_util)
: CONNECTION_PROXY(dbc, ds) {
: CONNECTION_PROXY(dbc, ds) {
this->next_proxy = next_proxy;
this->auth_util = auth_util;
this->saml_util = saml_util;
Expand All @@ -122,14 +118,14 @@ OKTA_SAML_UTIL::OKTA_SAML_UTIL(std::string host) {
}

std::string OKTA_SAML_UTIL::get_saml_url(DataSource* ds) {
const std::string app_id{static_cast<const char*>(ds->opt_APP_ID)};
const std::string app_id{(ds->opt_APP_ID)};

return "/app/" + std::string(OKTA_AWS_APP_NAME) + "/" + app_id + "/sso/saml";
}

std::string OKTA_SAML_UTIL::get_session_token(DataSource* ds) const {
const std::string username = ds->opt_IDP_USERNAME;
const std::string password = ds->opt_IDP_PASSWORD;
const std::string username = static_cast<const char*>(ds->opt_IDP_USERNAME);
const std::string password = static_cast<const char*>(ds->opt_IDP_PASSWORD);

const std::string session_token_endpoint = "/api/v1/authn";
const nlohmann::json request_body = {{"username", username}, {"password", password}};
Expand All @@ -142,16 +138,14 @@ std::string OKTA_SAML_UTIL::get_session_token(DataSource* ds) const {

std::string OKTA_SAML_UTIL::get_saml_assertion(DataSource* ds) {
const std::string token = this->get_session_token(ds);
const nlohmann::json params = {{"onetimetoken", token}};
const nlohmann::json res = this->http_client->get(this->get_saml_url(ds) + "?onetimetoken=" + token);
const std::string body = std::string(res);
auto f = [body](const std::regex pattern) {
std::smatch m;
if (std::regex_search(body, m, pattern)) {
const auto body = std::string(res);
auto f = [body](const std::regex& pattern) {
if (std::smatch m; std::regex_search(body, m, pattern)) {
std::string saml = m.str(1);

saml = OKTA_SAML_UTIL::replace_all(saml, "&#x2b;", "+");
saml = OKTA_SAML_UTIL::replace_all(saml, "&#x3d;", "=");
saml = replace_all(saml, "&#x2b;", "+");
saml = replace_all(saml, "&#x3d;", "=");
return saml;
}
return std::string();
Expand All @@ -163,8 +157,8 @@ std::string OKTA_SAML_UTIL::get_saml_assertion(DataSource* ds) {
std::string OKTA_SAML_UTIL::replace_all(std::string str, const std::string& from, const std::string& to) {
size_t start_pos = 0;
while ((start_pos = str.find(from, start_pos)) != std::string::npos) {
str.replace(start_pos, from.length(), to);
str = str.replace(start_pos, from.length(), to);
start_pos += to.length();
}
return str;
}
}
4 changes: 2 additions & 2 deletions driver/okta_proxy.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ class OKTA_PROXY : public CONNECTION_PROXY {
OKTA_PROXY(DBC* dbc, DataSource* ds);
OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy);
#ifdef UNIT_TEST_BUILD
OKTA_PROXY::OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy,
const std::shared_ptr<AUTH_UTIL>& auth_util, const std::shared_ptr<OKTA_SAML_UTIL>& saml_util);
OKTA_PROXY(DBC* dbc, DataSource* ds, CONNECTION_PROXY* next_proxy,
const std::shared_ptr<AUTH_UTIL>& auth_util, const std::shared_ptr<OKTA_SAML_UTIL>& saml_util);
#endif
~OKTA_PROXY() override;

Expand Down
2 changes: 1 addition & 1 deletion driver/saml_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
// http://www.gnu.org/licenses/gpl-2.0.html.

#include "saml_util.h"
#include <aws/core/Auth/AWSCredentials.h>
#include <aws/core/auth/AWSCredentials.h>
#include <aws/sts/STSClient.h>
#include <aws/sts/model/AssumeRoleWithSAMLRequest.h>
#include <aws/sts/model/AssumeRoleWithSAMLResult.h>
Expand Down
2 changes: 1 addition & 1 deletion driver/saml_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
class SAML_UTIL {
public:
SAML_UTIL() = default;
~SAML_UTIL() = default;
virtual ~SAML_UTIL() = default;
Aws::Auth::AWSCredentials get_aws_credentials(const char* host, const char* region, const char* role_arn,
const char* idp_arn, const std::string& assertion);
virtual std::string get_saml_assertion(DataSource* ds) = 0;
Expand Down
4 changes: 0 additions & 4 deletions installer/myodbc-installer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,6 @@ int list_datasource_details(DataSource *ds)
OPTION_CUSTOM_OUTPUT(HOST_PATTERN, "Failover Instance Host Pattern");
OPTION_CUSTOM_OUTPUT(CLUSTER_ID, "Failover Cluster ID");
OPTION_CUSTOM_OUTPUT(FAILOVER_MODE, "Failover Mode");
OPTION_CUSTOM_OUTPUT(ENABLE_CLUSTER_FAILOVER, "Enable Cluster Failover");

/* AWS Authentication */


bool bool_mode = false;

Expand Down

0 comments on commit 5c14514

Please sign in to comment.