Skip to content

Commit

Permalink
Fix for keyring errors when initializing Flyte for_sandbox config cli…
Browse files Browse the repository at this point in the history
…ent (#2962)

Signed-off-by: taieeuu <taieeuu@gmail.com>
  • Loading branch information
taieeuu authored Feb 7, 2025
1 parent c9bfbd6 commit ef6d7d4
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 11 deletions.
14 changes: 10 additions & 4 deletions flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,12 @@ def upgrade_channel_to_proxy_authenticated(cfg: PlatformConfig, in_channel: grpc
:param in_channel: grpc.Channel Precreated channel
:return: grpc.Channel. New composite channel
"""

def authenticator_factory():
return get_proxy_authenticator(cfg)

if cfg.proxy_command:
proxy_authenticator = get_proxy_authenticator(cfg)
return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(proxy_authenticator))
return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator_factory))
else:
return in_channel

Expand All @@ -137,8 +140,11 @@ def upgrade_channel_to_authenticated(cfg: PlatformConfig, in_channel: grpc.Chann
:param in_channel: grpc.Channel Precreated channel
:return: grpc.Channel. New composite channel
"""
authenticator = get_authenticator(cfg, RemoteClientConfigStore(in_channel))
return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator))

def authenticator_factory():
return get_authenticator(cfg, RemoteClientConfigStore(in_channel))

return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator_factory))


def get_authenticated_channel(cfg: PlatformConfig) -> grpc.Channel:
Expand Down
15 changes: 12 additions & 3 deletions flytekit/clients/grpc_utils/auth_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,22 @@ class AuthUnaryInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamCli
is needed.
"""

def __init__(self, authenticator: Authenticator):
self._authenticator = authenticator
def __init__(self, get_authenticator: typing.Callable[[], Authenticator]):
self._get_authenticator = get_authenticator
self._authenticator = None

@property
def authenticator(self) -> Authenticator:
if self._authenticator is None:
self._authenticator = self._get_authenticator()
return self._authenticator

def _call_details_with_auth_metadata(self, client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails:
"""
Returns new ClientCallDetails with metadata added.
"""
metadata = client_call_details.metadata
auth_metadata = self._authenticator.fetch_grpc_call_auth_metadata()
auth_metadata = self.authenticator.fetch_grpc_call_auth_metadata()
if auth_metadata:
metadata = []
if client_call_details.metadata:
Expand Down Expand Up @@ -65,6 +72,7 @@ def intercept_unary_unary(
raise e
if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
self._authenticator.refresh_credentials()
self.authenticator.refresh_credentials()
updated_call_details = self._call_details_with_auth_metadata(client_call_details)
return continuation(updated_call_details, request)
return fut
Expand All @@ -77,6 +85,7 @@ def intercept_unary_stream(self, continuation, client_call_details, request):
c: grpc.Call = continuation(updated_call_details, request)
if c.code() == grpc.StatusCode.UNAUTHENTICATED:
self._authenticator.refresh_credentials()
self.authenticator.refresh_credentials()
updated_call_details = self._call_details_with_auth_metadata(client_call_details)
return continuation(updated_call_details, request)
return c
40 changes: 40 additions & 0 deletions tests/flytekit/unit/clients/auth/test_keyring_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@

from flytekit.clients.auth.keyring import Credentials, KeyringStore

from flytekit.clients.auth_helper import upgrade_channel_to_authenticated, upgrade_channel_to_proxy_authenticated

from flytekit.configuration import PlatformConfig

import pytest

from flytekit.clients.auth.authenticator import CommandAuthenticator

from flytekit.clients.grpc_utils.auth_interceptor import AuthUnaryInterceptor

@patch("keyring.get_password")
def test_keyring_store_get(kr_get_password: MagicMock):
Expand All @@ -30,3 +39,34 @@ def test_keyring_store_set(kr_set_password: MagicMock):

kr_set_password.side_effect = NoKeyringError()
assert KeyringStore.retrieve("example2.com") is None

@patch("flytekit.clients.auth.authenticator.KeyringStore")
def test_upgrade_channel_to_authenticated_with_keyring_exception(mock_keyring_store):
mock_keyring_store.retrieve.side_effect = Exception("mock exception")

mock_channel = MagicMock()

platform_config = PlatformConfig()

try:
out_ch = upgrade_channel_to_authenticated(platform_config, mock_channel)
except Exception as e:
pytest.fail(f"upgrade_channel_to_authenticated Exception: {e}")

assert isinstance(out_ch._interceptor, AuthUnaryInterceptor)

@patch("flytekit.clients.auth.authenticator.KeyringStore")
def test_upgrade_channel_to_proxy_authenticated_with_keyring_exception(mock_keyring_store):
mock_keyring_store.retrieve.side_effect = Exception("mock exception")

mock_channel = MagicMock()

platform_config = PlatformConfig(auth_mode="Pkce", proxy_command=["echo", "foo-bar"])

try:
out_ch = upgrade_channel_to_proxy_authenticated(platform_config, mock_channel)
except Exception as e:
pytest.fail(f"upgrade_channel_to_proxy_authenticated Exception: {e}")

assert isinstance(out_ch._interceptor, AuthUnaryInterceptor)
assert isinstance(out_ch._interceptor.authenticator, CommandAuthenticator)
2 changes: 1 addition & 1 deletion tests/flytekit/unit/clients/test_auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def test_upgrade_channel_to_proxy_auth():
ch,
)
assert isinstance(out_ch._interceptor, AuthUnaryInterceptor)
assert isinstance(out_ch._interceptor._authenticator, CommandAuthenticator)
assert isinstance(out_ch._interceptor.authenticator, CommandAuthenticator)


def test_get_proxy_authenticated_session():
Expand Down
1 change: 0 additions & 1 deletion tests/flytekit/unit/clients/test_friendly.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from flytekit.configuration import PlatformConfig
from flytekit.models.project import Project as _Project


@mock.patch("flytekit.clients.friendly._RawSynchronousFlyteClient.update_project")
def test_update_project(mock_raw_update_project):
client = _SynchronousFlyteClient(PlatformConfig.for_endpoint("a.b.com", True))
Expand Down
2 changes: 0 additions & 2 deletions tests/flytekit/unit/clients/test_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from flytekit.clients.raw import RawSynchronousFlyteClient
from flytekit.configuration import PlatformConfig


@mock.patch("flytekit.clients.raw._admin_service")
@mock.patch("flytekit.clients.raw.grpc.insecure_channel")
def test_update_project(mock_channel, mock_admin):
Expand All @@ -14,7 +13,6 @@ def test_update_project(mock_channel, mock_admin):
client.update_project(project)
mock_admin.AdminServiceStub().UpdateProject.assert_called_with(project, metadata=None)


@mock.patch("flytekit.clients.raw._admin_service")
@mock.patch("flytekit.clients.raw.grpc.insecure_channel")
def test_list_projects_paginated(mock_channel, mock_admin):
Expand Down

0 comments on commit ef6d7d4

Please sign in to comment.