From b21e758d30b51205c8140116f9335abb69e4bbcc Mon Sep 17 00:00:00 2001 From: Ron Frederick Date: Mon, 15 Jul 2024 14:30:33 -0700 Subject: [PATCH] Fix logging and typing issues in SFTP high-level copy functions This commit fixes logging and typing issues with SFTP get, put, copy, mget, mput, and mcopy functions. AsyncSSH should now properly handle sequences which mix bytes, str, and PurePath entries and also fixes type annotations for these functions to indicate that they accept either a single path or a list of paths. Thanks go to GitHub user eyalgolan1337 for reporting these issues! --- asyncssh/logging.py | 34 ++++++++++++++++++++++------------ asyncssh/sftp.py | 22 ++++++++++------------ 2 files changed, 32 insertions(+), 24 deletions(-) diff --git a/asyncssh/logging.py b/asyncssh/logging.py index 828d85cd..d28eb3bf 100644 --- a/asyncssh/logging.py +++ b/asyncssh/logging.py @@ -61,14 +61,28 @@ def get_child(self, child: str = '', context: str = '') -> 'SSHLogger': def log(self, level: int, msg: object, *args, **kwargs) -> None: """Log a message to the underlying logger""" - def _text(arg: _LogArg) -> str: + def _item_text(item: _LogArg) -> str: + """Convert a list item to text""" + + if isinstance(item, bytes): + result = item.decode('utf-8', errors='replace') + + if not result.isprintable(): + result = repr(result)[1:-1] + elif not isinstance(item, str): + result = str(item) + else: + result = item + + return result + + def _text(arg: _LogArg) -> _LogArg: """Convert a log argument to text""" + result: _LogArg + if isinstance(arg, list): - if arg and isinstance(arg[0], bytes): - result = b','.join(arg).decode('utf-8', errors='replace') - else: - result = ','.join(arg) + result = ','.join(_item_text(item) for item in arg) elif isinstance(arg, tuple): host, port = arg @@ -76,14 +90,10 @@ def _text(arg: _LogArg) -> str: result = '%s, port %d' % (host, port) if port else host else: result = 'port %d' % port if port else 'dynamic port' + elif isinstance(arg, bytes): + result = _item_text(arg) else: - result = cast(str, arg) - - if isinstance(result, bytes): - result = result.decode('ascii', errors='backslashreplace') - - if not result.isprintable(): - result = repr(result)[1:-1] + result = arg return result diff --git a/asyncssh/sftp.py b/asyncssh/sftp.py index 31570965..d3498fc4 100644 --- a/asyncssh/sftp.py +++ b/asyncssh/sftp.py @@ -3679,8 +3679,7 @@ async def _copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, raise async def _begin_copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, - srcpaths: Sequence[_SFTPPath], - dstpath: Optional[_SFTPPath], + srcpaths: _SFTPPaths, dstpath: Optional[_SFTPPath], copy_type: str, expand_glob: bool, preserve: bool, recurse: bool, follow_symlinks: bool, block_size: int, max_requests: int, @@ -3688,15 +3687,14 @@ async def _begin_copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, error_handler: SFTPErrorHandler) -> None: """Begin a new file upload, download, or copy""" - if isinstance(srcpaths, tuple): + if isinstance(srcpaths, (bytes, str, PurePath)): + srcpaths = [srcpaths] + elif not isinstance(srcpaths, list): srcpaths = list(srcpaths) self.logger.info('Starting SFTP %s of %s to %s', copy_type, srcpaths, dstpath) - if isinstance(srcpaths, (bytes, str, PurePath)): - srcpaths = [srcpaths] - srcnames: List[SFTPName] = [] if expand_glob: @@ -3741,7 +3739,7 @@ async def _begin_copy(self, srcfs: _SFTPFSProtocol, dstfs: _SFTPFSProtocol, preserve, recurse, follow_symlinks, block_size, max_requests, progress_handler, error_handler) - async def get(self, remotepaths: Sequence[_SFTPPath], + async def get(self, remotepaths: _SFTPPaths, localpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, @@ -3846,7 +3844,7 @@ async def get(self, remotepaths: Sequence[_SFTPPath], block_size, max_requests, progress_handler, error_handler) - async def put(self, localpaths: Sequence[_SFTPPath], + async def put(self, localpaths: _SFTPPaths, remotepath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, @@ -3951,7 +3949,7 @@ async def put(self, localpaths: Sequence[_SFTPPath], block_size, max_requests, progress_handler, error_handler) - async def copy(self, srcpaths: Sequence[_SFTPPath], + async def copy(self, srcpaths: _SFTPPaths, dstpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, @@ -4056,7 +4054,7 @@ async def copy(self, srcpaths: Sequence[_SFTPPath], block_size, max_requests, progress_handler, error_handler) - async def mget(self, remotepaths: Sequence[_SFTPPath], + async def mget(self, remotepaths: _SFTPPaths, localpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, @@ -4080,7 +4078,7 @@ async def mget(self, remotepaths: Sequence[_SFTPPath], block_size, max_requests, progress_handler, error_handler) - async def mput(self, localpaths: Sequence[_SFTPPath], + async def mput(self, localpaths: _SFTPPaths, remotepath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False, @@ -4104,7 +4102,7 @@ async def mput(self, localpaths: Sequence[_SFTPPath], block_size, max_requests, progress_handler, error_handler) - async def mcopy(self, srcpaths: Sequence[_SFTPPath], + async def mcopy(self, srcpaths: _SFTPPaths, dstpath: Optional[_SFTPPath] = None, *, preserve: bool = False, recurse: bool = False, follow_symlinks: bool = False,