Skip to content

Commit

Permalink
Add support for passphrase to be a callable
Browse files Browse the repository at this point in the history
This commit allows the passphrase argument in load_private_keys and
connection options to be a callable which accepts the filename of the
key being decrypted and returns the passphrase to use. Thanks go to
GitHub user goblin for suggesting this enhancement.
  • Loading branch information
ronf committed Mar 23, 2024
1 parent f4df7f4 commit 56e533b
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 5 deletions.
26 changes: 21 additions & 5 deletions asyncssh/public_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -3485,7 +3485,8 @@ def load_keypairs(
:param keylist:
The list of private keys and certificates to load.
:param passphrase: (optional)
The passphrase to use to decrypt private keys.
The passphrase to use to decrypt the keys, or a `callable` which
takes a filename and returns the passphrase to decrypt it.
:param certlist: (optional)
A list of certificates to attempt to pair with the provided
list of private keys.
Expand Down Expand Up @@ -3515,9 +3516,19 @@ def load_keypairs(

if isinstance(keylist, (PurePath, str)):
try:
priv_keys = read_private_key_list(keylist, passphrase,
if callable(passphrase):
resolved_passphrase = passphrase(str(keylist))
else:
resolved_passphrase = passphrase

priv_keys = read_private_key_list(keylist, resolved_passphrase,
unsafe_skip_rsa_key_validation)
keys_to_load = [keylist] if len(priv_keys) <= 1 else priv_keys

if len(priv_keys) <= 1:
keys_to_load = [keylist]
passphrase = resolved_passphrase
else:
keys_to_load = priv_keys
except KeyImportError:
keys_to_load = [keylist]
elif isinstance(keylist, (tuple, bytes, SSHKey, SSHKeyPair)):
Expand All @@ -3543,15 +3554,20 @@ def load_keypairs(
if isinstance(key_to_load, (PurePath, str)):
key_prefix = str(key_to_load)

if callable(passphrase):
resolved_passphrase = passphrase(key_prefix)
else:
resolved_passphrase = passphrase

if allow_certs:
key, certs_to_load = read_private_key_and_certs(
key_to_load, passphrase,
key_to_load, resolved_passphrase,
unsafe_skip_rsa_key_validation)

if not certs_to_load:
certs_to_load = key_prefix + '-cert.pub'
else:
key = read_private_key(key_to_load, passphrase,
key = read_private_key(key_to_load, resolved_passphrase,
unsafe_skip_rsa_key_validation)

pubkey_to_load = key_prefix + '.pub'
Expand Down
13 changes: 13 additions & 0 deletions tests/test_public_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,10 @@ def validate_x509(self, cert, user_principal=None):
def check_private(self, format_name, passphrase=None):
"""Check for a private key match"""

def _passphrase(filename):
self.assertEqual(filename, 'new')
return passphrase

newkey = asyncssh.read_private_key('new', passphrase)
algorithm = newkey.get_algorithm()
keydata = newkey.export_private_key()
Expand All @@ -275,6 +279,9 @@ def check_private(self, format_name, passphrase=None):
keypair = asyncssh.load_keypairs('new', passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

keypair = asyncssh.load_keypairs('new', _passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

keypair = asyncssh.load_keypairs([newkey])[0]
self.assertEqual(keypair.public_data, pubdata)

Expand All @@ -290,9 +297,15 @@ def check_private(self, format_name, passphrase=None):
keypair = asyncssh.load_keypairs(['new'], passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

keypair = asyncssh.load_keypairs(['new'], _passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

keypair = asyncssh.load_keypairs([('new', None)], passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

keypair = asyncssh.load_keypairs([('new', None)], _passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

keypair = asyncssh.load_keypairs(Path('new'), passphrase)[0]
self.assertEqual(keypair.public_data, pubdata)

Expand Down

0 comments on commit 56e533b

Please sign in to comment.