From 24b60d131238a707895ba82722bae6752a60a846 Mon Sep 17 00:00:00 2001 From: Ron Frederick Date: Tue, 2 Jul 2024 06:03:05 -0700 Subject: [PATCH] Fix flow control for AsyncFileWriter and StreamWriter This commit is a follow up to f2020ed, adding proper back pressure when output is redirected to an AsyncFileWriter or StreamWriter and data is arriving on the SSH channel faster than these writers can consume it. Once the queue of outstanding data begins to grow, reading from the SSH channel will be paused to allow the queue to drain somewhat before continuing, limiting the amount of memory needed. --- asyncssh/process.py | 32 +++++++++++++++--- tests/test_process.py | 79 +++++++++++++++++++++++++++---------------- 2 files changed, 77 insertions(+), 34 deletions(-) diff --git a/asyncssh/process.py b/asyncssh/process.py index 811b4f3..d1ce484 100644 --- a/asyncssh/process.py +++ b/asyncssh/process.py @@ -65,6 +65,10 @@ MaybeAwait[None]] +_QUEUE_LOW_WATER = 8 +_QUEUE_HIGH_WATER = 16 + + class _AsyncFileProtocol(Protocol[AnyStr]): """Protocol for an async file""" @@ -304,12 +308,14 @@ class _AsyncFileWriter(_UnicodeWriter[AnyStr]): def __init__(self, process: 'SSHProcess[AnyStr]', file: _AsyncFileProtocol[bytes], needs_close: bool, - encoding: Optional[str], errors: str): + datatype: Optional[int], encoding: Optional[str], errors: str): super().__init__(encoding, errors, hasattr(file, 'encoding')) self._process: 'SSHProcess[AnyStr]' = process self._file = file self._needs_close = needs_close + self._datatype = datatype + self._paused = False self._queue: asyncio.Queue[Optional[AnyStr]] = asyncio.Queue() self._write_task: Optional[asyncio.Task[None]] = \ process.channel.get_connection().create_task(self._writer()) @@ -327,6 +333,10 @@ async def _writer(self) -> None: await self._file.write(self.encode(data)) self._queue.task_done() + if self._paused and self._queue.qsize() < _QUEUE_LOW_WATER: + self._process.resume_feeding(self._datatype) + self._paused = False + if self._needs_close: await self._file.close() @@ -335,6 +345,10 @@ def write(self, data: AnyStr) -> None: self._queue.put_nowait(data) + if not self._paused and self._queue.qsize() >= _QUEUE_HIGH_WATER: + self._paused = True + self._process.pause_feeding(self._datatype) + def write_eof(self) -> None: """Close output file when end of file is received""" @@ -573,12 +587,14 @@ class _StreamWriter(_UnicodeWriter[AnyStr]): def __init__(self, process: 'SSHProcess[AnyStr]', writer: asyncio.StreamWriter, recv_eof: bool, - encoding: Optional[str], errors: str): + datatype: Optional[int], encoding: Optional[str], errors: str): super().__init__(encoding, errors) self._process: 'SSHProcess[AnyStr]' = process self._writer = writer self._recv_eof = recv_eof + self._datatype = datatype + self._paused = False self._queue: asyncio.Queue[Optional[AnyStr]] = asyncio.Queue() self._write_task: Optional[asyncio.Task[None]] = \ process.channel.get_connection().create_task(self._feed()) @@ -597,6 +613,10 @@ async def _feed(self) -> None: await self._writer.drain() self._queue.task_done() + if self._paused and self._queue.qsize() < _QUEUE_LOW_WATER: + self._process.resume_feeding(self._datatype) + self._paused = False + if self._recv_eof: self._writer.write_eof() @@ -605,6 +625,10 @@ def write(self, data: AnyStr) -> None: self._queue.put_nowait(data) + if not self._paused and self._queue.qsize() >= _QUEUE_HIGH_WATER: + self._paused = True + self._process.pause_feeding(self._datatype) + def write_eof(self) -> None: """Write EOF to the stream""" @@ -953,7 +977,7 @@ def pipe_factory() -> _PipeWriter: writer_process.set_reader(reader, send_eof, writer_datatype) writer = _ProcessWriter[AnyStr](writer_process, writer_datatype) elif isinstance(target, asyncio.StreamWriter): - writer = _StreamWriter(self, target, recv_eof, + writer = _StreamWriter(self, target, recv_eof, datatype, self._encoding, self._errors) else: file: _File @@ -978,7 +1002,7 @@ def pipe_factory() -> _PipeWriter: inspect.isgeneratorfunction(file.write)): writer = _AsyncFileWriter( self, cast(_AsyncFileProtocol, file), needs_close, - self._encoding, self._errors) + datatype, self._encoding, self._errors) elif _is_regular_file(cast(IO[bytes], file)): writer = _FileWriter(cast(IO[bytes], file), needs_close, self._encoding, self._errors) diff --git a/tests/test_process.py b/tests/test_process.py index addc8bd..d2c24eb 100644 --- a/tests/test_process.py +++ b/tests/test_process.py @@ -1310,6 +1310,20 @@ async def test_pause_async_file_reader(self): self.assertEqual(result.stdout, data) + @asynctest + async def test_pause_async_file_writer(self): + """Test pausing and resuming writing to an aiofile""" + + data = 4*1024*1024*'*' + + async with aiofiles.open('stdout', 'w') as file: + async with self.connect() as conn: + await conn.run('delay', input=data, stdout=file, + stderr=asyncssh.DEVNULL) + + with open('stdout', 'r') as file: + self.assertEqual(file.read(), data) + @unittest.skipIf(sys.platform == 'win32', 'skip pipe tests on Windows') class _TestProcessPipes(_TestProcess): @@ -1538,50 +1552,55 @@ async def test_stdout_socketpair(self): self.assertEqual(result.stderr, data) @asynctest - async def test_pause_socketpair_reader(self): - """Test pausing and resuming reading from a socketpair""" + async def test_pause_socketpair_pipes(self): + """Test pausing and resuming reading from and writing to pipes""" - data = 4*1024*1024*'*' + data = 4*1024*1024*b'*' sock1, sock2 = socket.socketpair() + sock3, sock4 = socket.socketpair() - _, writer = await asyncio.open_unix_connection(sock=sock1) - writer.write(data.encode()) - writer.close() + _, writer1 = await asyncio.open_unix_connection(sock=sock1) + writer1.write(data) + writer1.close() - async with self.connect() as conn: - result = await conn.run('delay', stdin=sock2, - stderr=asyncssh.DEVNULL) + reader2, writer2 = await asyncio.open_unix_connection(sock=sock4) - self.assertEqual(result.stdout, data) - - @asynctest - async def test_pause_socketpair_writer(self): - """Test pausing and resuming writing to a socketpair""" + async with self.connect() as conn: + process = await conn.create_process('delay', encoding=None, + stdin=sock2, stdout=sock3, + stderr=asyncssh.DEVNULL) - data = 4*1024*1024*'*' + self.assertEqual((await reader2.read()), data) + await process.wait() - rsock1, wsock1 = socket.socketpair() - rsock2, wsock2 = socket.socketpair() + writer2.close() - reader1, writer1 = await asyncio.open_unix_connection(sock=rsock1) - reader2, writer2 = await asyncio.open_unix_connection(sock=rsock2) + @asynctest + async def test_pause_socketpair_streams(self): + """Test pausing and resuming reading from and writing to streams""" - async with self.connect() as conn: - process = await conn.create_process(input=data) + data = 4*1024*1024*b'*' - await asyncio.sleep(1) + sock1, sock2 = socket.socketpair() + sock3, sock4 = socket.socketpair() - await process.redirect_stdout(wsock1) - await process.redirect_stderr(wsock2) + _, writer1 = await asyncio.open_unix_connection(sock=sock1) + writer1.write(data) + writer1.close() - stdout_data, stderr_data = \ - await asyncio.gather(reader1.read(), reader2.read()) + reader2, writer2 = await asyncio.open_unix_connection(sock=sock2) + _, writer3 = await asyncio.open_unix_connection(sock=sock3) + reader4, writer4 = await asyncio.open_unix_connection(sock=sock4) - writer1.close() - writer2.close() + async with self.connect() as conn: + process = await conn.create_process('delay', encoding=None, + stdin=reader2, stdout=writer3, + stderr=asyncssh.DEVNULL) + self.assertEqual((await reader4.read()), data) await process.wait() - self.assertEqual(stdout_data.decode(), data) - self.assertEqual(stderr_data.decode(), data) + writer2.close() + writer3.close() + writer4.close()