diff --git a/asyncssh/connection.py b/asyncssh/connection.py index 157a4ae..65abc68 100644 --- a/asyncssh/connection.py +++ b/asyncssh/connection.py @@ -6977,7 +6977,7 @@ async def construct(cls, options: Optional['_OptionsSelf'] = None, loop = asyncio.get_event_loop() return cast(_OptionsSelf, await loop.run_in_executor( - None, functools.partial(cls, options, **kwargs))) + None, functools.partial(cls, options, loop=loop, **kwargs))) # pylint: disable=arguments-differ def prepare(self, config: SSHConfig, # type: ignore @@ -7255,10 +7255,12 @@ class SSHClientConnectionOptions(SSHConnectionOptions): A list of optional certificates which can be paired with the provided client keys. :param passphrase: (optional) - The passphrase to use to decrypt client keys when loading them, - if they are encrypted. If this is not specified, only unencrypted - client keys can be loaded. If the keys passed into client_keys - are already loaded, this argument is ignored. + The passphrase to use to decrypt client keys if they are + encrypted, or a `callable` or coroutine which takes a filename + as a parameter and returns the passphrase to use to decrypt + that file. If not specified, only unencrypted client keys can + be loaded. If the keys passed into client_keys are already + loaded, this argument is ignored. :param ignore_encrypted: (optional) Whether or not to ignore encrypted keys when no passphrase is specified. This defaults to `True` when keys are specified via @@ -7605,7 +7607,9 @@ class SSHClientConnectionOptions(SSHConnectionOptions): max_pktsize: int # pylint: disable=arguments-differ - def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore + def prepare(self, # type: ignore + loop: Optional[asyncio.AbstractEventLoop] = None, + last_config: Optional[SSHConfig] = None, config: DefTuple[ConfigPaths] = None, reload: bool = False, client_factory: Optional[_ClientFactory] = None, client_version: _VersionArg = (), host: str = '', @@ -7761,7 +7765,7 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore self.client_host_keypairs = \ load_keypairs(cast(KeyPairListArg, client_host_keys), - passphrase, client_host_certs) + passphrase, client_host_certs, loop=loop) self.client_host_keysign = client_host_keysign self.client_host = client_host @@ -7839,7 +7843,8 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore if client_keys: self.client_keys = \ load_keypairs(cast(KeyPairListArg, client_keys), passphrase, - client_certs, identities_only, ignore_encrypted) + client_certs, identities_only, ignore_encrypted, + loop=loop) elif client_keys is not None: self.client_keys = load_default_keypairs(passphrase, client_certs) else: @@ -7914,11 +7919,12 @@ class SSHServerConnectionOptions(SSHConnectionOptions): A list of optional certificates which can be paired with the provided server host keys. :param passphrase: (optional) - The passphrase to use to decrypt server host keys when loading - them, if they are encrypted. If this is not specified, only - unencrypted server host keys can be loaded. If the keys passed - into server_host_keys are already loaded, this argument is - ignored. + The passphrase to use to decrypt server host keys if they are + encrypted, or a `callable` or coroutine which takes a filename + as a parameter and returns the passphrase to use to decrypt + that file. If not specified, only unencrypted server host keys + can be loaded. If the keys passed into server_host_keys are + already loaded, this argument is ignored. :param known_client_hosts: (optional) A list of client hosts which should be trusted to perform host-based client authentication. If this is not specified, @@ -8224,7 +8230,9 @@ class SSHServerConnectionOptions(SSHConnectionOptions): max_pktsize: int # pylint: disable=arguments-differ - def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore + def prepare(self, # type: ignore + loop: Optional[asyncio.AbstractEventLoop] = None, + last_config: Optional[SSHConfig] = None, config: DefTuple[ConfigPaths] = None, reload: bool = False, accept_addr: str = '', accept_port: int = 0, username: str = '', client_host: str = '', @@ -8320,7 +8328,7 @@ def prepare(self, last_config: Optional[SSHConfig] = None, # type: ignore config.get('HostCertificate', ())) server_keys = load_keypairs(server_host_keys, passphrase, - server_host_certs) + server_host_certs, loop=loop) self.server_host_keys = OrderedDict() diff --git a/asyncssh/public_key.py b/asyncssh/public_key.py index 335a1c3..8de2048 100644 --- a/asyncssh/public_key.py +++ b/asyncssh/public_key.py @@ -1,4 +1,4 @@ -# Copyright (c) 2013-2023 by Ron Frederick and others. +# Copyright (c) 2013-2024 by Ron Frederick and others. # # This program and the accompanying materials are made available under # the terms of the Eclipse Public License v2.0 which accompanies this @@ -20,7 +20,9 @@ """SSH asymmetric encryption handlers""" +import asyncio import binascii +import inspect import os import re import time @@ -3472,7 +3474,8 @@ def load_keypairs( keylist: KeyPairListArg, passphrase: Optional[BytesOrStr] = None, certlist: CertListArg = (), skip_public: bool = False, ignore_encrypted: bool = False, - unsafe_skip_rsa_key_validation: Optional[bool] = None) -> \ + unsafe_skip_rsa_key_validation: Optional[bool] = None, + loop: Optional[asyncio.AbstractEventLoop] = None) -> \ Sequence[SSHKeyPair]: """Load SSH private keys and optional matching certificates @@ -3521,6 +3524,10 @@ def load_keypairs( else: resolved_passphrase = passphrase + if loop and inspect.isawaitable(resolved_passphrase): + resolved_passphrase = asyncio.run_coroutine_threadsafe( + resolved_passphrase, loop).result() + priv_keys = read_private_key_list(keylist, resolved_passphrase, unsafe_skip_rsa_key_validation) @@ -3559,6 +3566,10 @@ def load_keypairs( else: resolved_passphrase = passphrase + if loop and inspect.isawaitable(resolved_passphrase): + resolved_passphrase = asyncio.run_coroutine_threadsafe( + resolved_passphrase, loop).result() + if allow_certs: key, certs_to_load = read_private_key_and_certs( key_to_load, resolved_passphrase, diff --git a/tests/test_connection_auth.py b/tests/test_connection_auth.py index e93b1da..49ae6b9 100644 --- a/tests/test_connection_auth.py +++ b/tests/test_connection_auth.py @@ -1090,6 +1090,56 @@ async def test_encrypted_client_key(self): passphrase='passphrase'): pass + @asynctest + async def test_encrypted_client_key_callable(self): + """Test public key auth with callable passphrase""" + + def _passphrase(filename): + self.assertEqual(filename, 'ckey_encrypted') + return 'passphrase' + + async with self.connect(username='ckey', client_keys='ckey_encrypted', + passphrase=_passphrase): + pass + + @asynctest + async def test_encrypted_client_key_awaitable(self): + """Test public key auth with awaitable passphrase""" + + async def _passphrase(filename): + self.assertEqual(filename, 'ckey_encrypted') + return 'passphrase' + + async with self.connect(username='ckey', client_keys='ckey_encrypted', + passphrase=_passphrase): + pass + + @asynctest + async def test_encrypted_client_key_list_callable(self): + """Test public key auth with callable passphrase""" + + def _passphrase(filename): + self.assertEqual(filename, 'ckey_encrypted') + return 'passphrase' + + async with self.connect(username='ckey', + client_keys=['ckey_encrypted'], + passphrase=_passphrase): + pass + + @asynctest + async def test_encrypted_client_key_list_awaitable(self): + """Test public key auth with awaitable passphrase""" + + async def _passphrase(filename): + self.assertEqual(filename, 'ckey_encrypted') + return 'passphrase' + + async with self.connect(username='ckey', + client_keys=['ckey_encrypted'], + passphrase=_passphrase): + pass + @asynctest async def test_encrypted_client_key_bad_passphrase(self): """Test wrong passphrase for encrypted client key""" diff --git a/tests/test_public_key.py b/tests/test_public_key.py index a776cf3..70dbf73 100644 --- a/tests/test_public_key.py +++ b/tests/test_public_key.py @@ -252,10 +252,6 @@ def validate_x509(self, cert, user_principal=None): def check_private(self, format_name, passphrase=None): """Check for a private key match""" - def _passphrase(filename): - self.assertEqual(filename, 'new') - return passphrase - newkey = asyncssh.read_private_key('new', passphrase) algorithm = newkey.get_algorithm() keydata = newkey.export_private_key() @@ -279,9 +275,6 @@ def _passphrase(filename): keypair = asyncssh.load_keypairs('new', passphrase)[0] self.assertEqual(keypair.public_data, pubdata) - keypair = asyncssh.load_keypairs('new', _passphrase)[0] - self.assertEqual(keypair.public_data, pubdata) - keypair = asyncssh.load_keypairs([newkey])[0] self.assertEqual(keypair.public_data, pubdata) @@ -297,15 +290,9 @@ def _passphrase(filename): keypair = asyncssh.load_keypairs(['new'], passphrase)[0] self.assertEqual(keypair.public_data, pubdata) - keypair = asyncssh.load_keypairs(['new'], _passphrase)[0] - self.assertEqual(keypair.public_data, pubdata) - keypair = asyncssh.load_keypairs([('new', None)], passphrase)[0] self.assertEqual(keypair.public_data, pubdata) - keypair = asyncssh.load_keypairs([('new', None)], _passphrase)[0] - self.assertEqual(keypair.public_data, pubdata) - keypair = asyncssh.load_keypairs(Path('new'), passphrase)[0] self.assertEqual(keypair.public_data, pubdata)