diff --git a/examples/api_node_examples/api_client.py b/examples/api_node_examples/api_client.py index 9cdeca3..7d17266 100644 --- a/examples/api_node_examples/api_client.py +++ b/examples/api_node_examples/api_client.py @@ -18,6 +18,7 @@ def __init__( input_tick_channel: str, input_response_channel: str, output_channel: str, + node_name: str, redis_url: str, ): response_class = get_rest_response_class(AnyDataModel) @@ -28,6 +29,7 @@ def __init__( (input_response_channel, response_class), ], output_channel_types=[(output_channel, request_class)], + node_name=node_name, redis_url=redis_url, ) diff --git a/examples/chatbot_examples/nodes.py b/examples/chatbot_examples/nodes.py index 562d52e..c2d3bd0 100644 --- a/examples/chatbot_examples/nodes.py +++ b/examples/chatbot_examples/nodes.py @@ -7,10 +7,13 @@ @NodeFactory.register("gpt4_text_chatbot_node") class GPT4TextChatbotNode(Node[Text, Text]): - def __init__(self, input_channel: str, output_channel: str, redis_url: str): + def __init__( + self, input_channel: str, output_channel: str, node_name: str, redis_url: str + ): super().__init__( input_channel_types=[(input_channel, Text)], output_channel_types=[(output_channel, Text)], + node_name=node_name, redis_url=redis_url, ) self.input_channel = input_channel diff --git a/examples/node_manager_examples/example.toml b/examples/node_manager_examples/example.toml new file mode 100644 index 0000000..fbca631 --- /dev/null +++ b/examples/node_manager_examples/example.toml @@ -0,0 +1,12 @@ +redis_url = "redis://localhost:6379/0" # required + +[[nodes]] +node_name = "special_print" +node_class = "special_print" + +[nodes.node_args.print_channel_types] +"tick/secs/1" = "tick" + +[[nodes]] +node_name = "tick" +node_class = "tick" diff --git a/src/aact/cli/launch/launch.py b/src/aact/cli/launch/launch.py index 8c7faeb..b3f7723 100644 --- a/src/aact/cli/launch/launch.py +++ b/src/aact/cli/launch/launch.py @@ -1,28 +1,24 @@ import asyncio import logging -import os -import signal -import sys import time -from typing import Annotated, Any, Optional, TypeVar +from typing import Annotated, Optional, TypeVar from ..app import app from ..reader import get_dataflow_config, draw_dataflow_mermaid, NodeConfig, Config import typer from ...nodes import NodeFactory -from subprocess import Popen -if sys.version_info >= (3, 11): - import tomllib -else: - import tomlkit as tomllib +from ...utils import tomllib + from rq import Queue from rq.exceptions import InvalidJobOperation from rq.job import Job from rq.command import send_stop_job_command from redis import Redis +from ...manager import NodeManager + InputType = TypeVar("InputType") OutputType = TypeVar("OutputType") @@ -35,6 +31,7 @@ async def _run_node(node_config: NodeConfig, redis_url: str) -> None: async with NodeFactory.make( node_config.node_class, **node_config.node_args.model_dump(), + node_name=node_config.node_name, redis_url=redis_url, ) as node: logger.info(f"Starting eventloop {node_config.node_name}") @@ -131,43 +128,8 @@ def run_dataflow( finally: return - subprocesses: list[Popen[bytes]] = [] - - try: - # Nodes that run w/ subprocess - for node in config.nodes: - command = f"aact run-node --dataflow-toml {dataflow_toml} --node-name {node.node_name} --redis-url {config.redis_url}" - logger.info(f"executing {command}") - node_process = Popen( - [command], - shell=True, - preexec_fn=os.setsid, # Start the subprocess in a new process group - ) - subprocesses.append(node_process) - - def _cleanup_subprocesses( - signum: int | None = None, frame: Any | None = None - ) -> None: - for node_process in subprocesses: - try: - os.killpg(os.getpgid(node_process.pid), signal.SIGTERM) - logger.info(f"Terminating process group {node_process.pid}") - except ProcessLookupError: - logger.warning( - f"Process group {node_process.pid} has been terminated." - ) - - signal.signal(signal.SIGTERM, _cleanup_subprocesses) - signal.signal(signal.SIGINT, _cleanup_subprocesses) - - for node_process in subprocesses: - node_process.wait() - - except Exception as e: - logger.warning("Error in multiprocessing: ", e) - _cleanup_subprocesses() - finally: - _cleanup_subprocesses() + with NodeManager(dataflow_toml, False, config.redis_url) as node_manager: + node_manager.wait() @app.command(help="A nice debugging feature. Draw dataflows with Mermaid.") diff --git a/src/aact/manager/__init__.py b/src/aact/manager/__init__.py new file mode 100644 index 0000000..075664d --- /dev/null +++ b/src/aact/manager/__init__.py @@ -0,0 +1,3 @@ +from .manager import NodeManager + +__all__ = ["NodeManager"] diff --git a/src/aact/manager/manager.py b/src/aact/manager/manager.py new file mode 100644 index 0000000..da28a6d --- /dev/null +++ b/src/aact/manager/manager.py @@ -0,0 +1,148 @@ +import asyncio +import datetime +from logging import Logger +import os +import signal +from subprocess import Popen +import threading +from ..utils import tomllib +from typing import Any, Literal +from uuid import uuid4 + +from redis.asyncio import Redis +from redis import Redis as SyncRedis + +from ..cli.reader import Config + +from ..utils import Self + +logger = Logger("NodeManager") + +Health = Literal["Started", "Running", "No Response", "Stopped"] + + +def run_event_loop_in_thread(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_forever() + + +class NodeManager(object): + def __init__( + self, + dataflow_toml: str, + with_rq: bool = False, + redis_url: str = "redis://localhost:6379/0", + ): + self.id = f"manager-{str(uuid4())}" + self.dataflow_toml = dataflow_toml + self.with_rq = with_rq + self.subprocesses: dict[str, Popen[bytes]] = {} + self.pubsub = Redis.from_url(redis_url).pubsub() + self.shutdown_pubsub = SyncRedis.from_url(redis_url).pubsub() + self.background_tasks: list[asyncio.Task[None]] = [] + self.node_health: dict[str, Health] = {} + self.last_heartbeat: dict[str, float] = {} + self.loop: asyncio.AbstractEventLoop | None = None + self.shutdown_signal: bool = False + + def __enter__( + self, + ) -> Self: + config = Config.model_validate(tomllib.load(open(self.dataflow_toml, "rb"))) + + # Nodes that run w/ subprocess + for node in config.nodes: + try: + command = f"aact run-node --dataflow-toml {self.dataflow_toml} --node-name {node.node_name} --redis-url {config.redis_url}" + node_process = Popen( + [command], + shell=True, + preexec_fn=os.setsid, # Start the subprocess in a new process group + ) + logger.info( + f"Starting subprocess {node_process} for node {node.node_name}" + ) + assert ( + node.node_name not in self.subprocesses + ), f"Node {node.node_name} is duplicated." + self.subprocesses[node.node_name] = node_process + self.node_health[node.node_name] = "Started" + except Exception as e: + logger.error( + f"Error starting subprocess {node.node_name}: {e}. Stopping other nodes as well." + ) + for node_name, node_process in self.subprocesses.items(): + logger.info( + f"Terminating Node {node_name}. Process: {node_process}" + ) + try: + os.killpg(os.getpgid(node_process.pid), signal.SIGTERM) + except ProcessLookupError: + logger.info(f"Process group {node_process.pid} not found.") + self.subprocesses = {} + raise e + + thread = threading.Thread(target=run_event_loop_in_thread, daemon=True) + thread.start() + self.loop = asyncio.get_event_loop() + + self.background_tasks.append(self.loop.create_task(self.wait_for_heartbeat())) + self.background_tasks.append(self.loop.create_task(self.update_health_status())) + + return self + + async def wait_for_heartbeat( + self, + ) -> None: + for node_name in self.subprocesses.keys(): + await self.pubsub.subscribe(f"heartbeat:{node_name}") + + async for message in self.pubsub.listen(): + node_name = ":".join(message["channel"].decode("utf-8").split(":")[1:]) + self.last_heartbeat[node_name] = datetime.datetime.now().timestamp() + + async def update_health_status( + self, + ) -> None: + while True: + for node_name, last_heartbeat in self.last_heartbeat.items(): + if datetime.datetime.now().timestamp() - last_heartbeat > 10: + self.node_health[node_name] = "No Response" + else: + self.node_health[node_name] = "Running" + await asyncio.sleep(1) + + def wait( + self, + ) -> None: + for node_name in self.subprocesses.keys(): + self.shutdown_pubsub.subscribe(f"shutdown:{node_name}") + for message in self.shutdown_pubsub.listen(): + node_name = ":".join(message["channel"].decode("utf-8").split(":")[1:]) + if message["data"] == b"shutdown": + logger.info(f"Received shutdown signal for node {node_name}") + self.shutdown_signal = True + break + self.shutdown_pubsub.unsubscribe() + self.shutdown_pubsub.close() + + def __exit__( + self, + signum: int | None = None, + frame: Any | None = None, + traceback: Any | None = None, + ) -> None: + for _, node_process in self.subprocesses.items(): + try: + os.killpg(os.getpgid(node_process.pid), signal.SIGTERM) + logger.info(f"Terminating process group {node_process.pid}") + except ProcessLookupError: + logger.warning(f"Process group {node_process.pid} not found.") + for task in self.background_tasks: + task.cancel() + + if self.loop: + self.loop.run_until_complete(self.pubsub.unsubscribe()) + self.loop.run_until_complete(self.pubsub.close()) + self.loop.stop() diff --git a/src/aact/nodes/__init__.py b/src/aact/nodes/__init__.py index 60718c7..f049bd1 100644 --- a/src/aact/nodes/__init__.py +++ b/src/aact/nodes/__init__.py @@ -18,6 +18,7 @@ from .tts import TTSNode from .registry import NodeFactory from .api import RestAPINode +from .special_print import SpecialPrintNode __all__ = [ "Node", @@ -32,4 +33,5 @@ "PrintNode", "TTSNode", "RestAPINode", + "SpecialPrintNode", ] diff --git a/src/aact/nodes/api.py b/src/aact/nodes/api.py index 624f750..6cba819 100644 --- a/src/aact/nodes/api.py +++ b/src/aact/nodes/api.py @@ -77,6 +77,7 @@ def __init__( output_channel: str, input_type_str: str, output_type_str: str, + node_name: str, redis_url: str, ): if input_type_str not in DataModelFactory.registry: @@ -98,6 +99,7 @@ def __init__( super().__init__( input_channel_types=[(input_channel, request_class)], output_channel_types=[(output_channel, response_class)], + node_name=node_name, redis_url=redis_url, ) diff --git a/src/aact/nodes/base.py b/src/aact/nodes/base.py index c2dbcbf..fd06847 100644 --- a/src/aact/nodes/base.py +++ b/src/aact/nodes/base.py @@ -1,16 +1,13 @@ from asyncio import CancelledError +import asyncio import logging -import sys -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self +from ..utils import Self from typing import Any, AsyncIterator, Generic, Type, TypeVar from pydantic import BaseModel, ConfigDict, ValidationError from abc import abstractmethod -from ..messages import Message +from ..messages import Message, Tick from redis.asyncio import Redis from ..messages.base import DataModel @@ -145,6 +142,10 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None """ A dictionary that maps the output channel names to the corresponding output message types. """ + node_name: str + """ + The name of the node. When using NodeManger, the node name should be unique. + """ redis_url: str """ The URL of the Redis server. It should be in the format of `redis://:/`. @@ -158,12 +159,15 @@ def __init__( self, input_channel_types: list[tuple[str, Type[InputType]]], output_channel_types: list[tuple[str, Type[OutputType]]], + node_name: str, redis_url: str = "redis://localhost:6379/0", ): try: - super().__init__( + BaseModel.__init__( + self, input_channel_types=dict(input_channel_types), output_channel_types=dict(output_channel_types), + node_name=node_name, redis_url=redis_url, ) except ValidationError as _: @@ -174,7 +178,6 @@ def __init__( f"The required output channel types are: {self.model_fields['output_channel_types'].annotation}\n" f"The output channel types are: {output_channel_types}\n" ) - self.r: Redis = Redis.from_url(redis_url) """ @private @@ -187,6 +190,7 @@ def __init__( """ @private """ + self._background_tasks: list[asyncio.Task[None]] = [] async def __aenter__(self) -> Self: try: @@ -196,12 +200,27 @@ async def __aenter__(self) -> Self: f"Could not connect to Redis with the provided url. {self.redis_url}" ) await self.pubsub.subscribe(*self.input_channel_types.keys()) + self._background_tasks.append(asyncio.create_task(self._send_heartbeat())) return self async def __aexit__(self, _: Any, __: Any, ___: Any) -> None: + for task in self._background_tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass await self.pubsub.unsubscribe() await self.r.aclose() + async def _send_heartbeat(self) -> None: + while True: + await asyncio.sleep(1) + await self.r.publish( + f"heartbeat:{self.node_name}", + Message[Tick](data=Tick(tick=0)).model_dump_json(), + ) + async def _wait_for_input( self, ) -> AsyncIterator[tuple[str, Message[InputType]]]: diff --git a/src/aact/nodes/listener.py b/src/aact/nodes/listener.py index 53c1650..73adecb 100644 --- a/src/aact/nodes/listener.py +++ b/src/aact/nodes/listener.py @@ -1,11 +1,7 @@ import asyncio -import sys from typing import Any, AsyncIterator, Optional, TYPE_CHECKING -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self +from ..utils import Self from ..messages.base import DataModel, Message from .base import Node @@ -30,6 +26,7 @@ class ListenerNode(Node[Zero, Audio]): def __init__( self, output_channel: str, + name_name: str, redis_url: str, channels: int = 1, rate: int = 44100, @@ -44,6 +41,7 @@ def __init__( super().__init__( input_channel_types=[], output_channel_types=[(output_channel, Audio)], + node_name=name_name, redis_url=redis_url, ) self.output_channel = output_channel diff --git a/src/aact/nodes/performance.py b/src/aact/nodes/performance.py index 26946d2..7e54d10 100644 --- a/src/aact/nodes/performance.py +++ b/src/aact/nodes/performance.py @@ -10,7 +10,12 @@ @NodeFactory.register("performance") class PerformanceMeasureNode(Node[Tick | Image, Image]): def __init__( - self, input_channel: str, output_channel: str, message_size: int, redis_url: str + self, + input_channel: str, + output_channel: str, + message_size: int, + node_name: str, + redis_url: str, ): super().__init__( input_channel_types=[ @@ -18,6 +23,7 @@ def __init__( (output_channel, Image), ], output_channel_types=[(output_channel, Image)], + node_name=node_name, redis_url=redis_url, ) self.input_channel = input_channel diff --git a/src/aact/nodes/print.py b/src/aact/nodes/print.py index ee7a7c1..6a82d45 100644 --- a/src/aact/nodes/print.py +++ b/src/aact/nodes/print.py @@ -1,10 +1,6 @@ import asyncio -import sys -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self +from ..utils import Self from typing import Any, AsyncIterator from ..messages.commons import DataEntry @@ -23,6 +19,7 @@ class PrintNode(Node[DataModel, Zero]): def __init__( self, print_channel_types: dict[str, str], + node_name: str, redis_url: str, ): input_channel_types: list[tuple[str, type[DataModel]]] = [] @@ -34,6 +31,7 @@ def __init__( super().__init__( input_channel_types=input_channel_types, output_channel_types=[], + node_name=node_name, redis_url=redis_url, ) self.output: AsyncTextIndirectIOWrapper | None = None diff --git a/src/aact/nodes/random.py b/src/aact/nodes/random.py index 0f029a0..a56e279 100644 --- a/src/aact/nodes/random.py +++ b/src/aact/nodes/random.py @@ -13,11 +13,13 @@ def __init__( self, input_channel: str, output_channel: str, + node_name: str, redis_url: str = "redis://localhost:6379/0", ): super().__init__( input_channel_types=[(input_channel, Tick)], output_channel_types=[(output_channel, Float)], + node_name=node_name, redis_url=redis_url, ) self.input_channel = input_channel diff --git a/src/aact/nodes/record.py b/src/aact/nodes/record.py index a1a465f..79976a5 100644 --- a/src/aact/nodes/record.py +++ b/src/aact/nodes/record.py @@ -1,11 +1,7 @@ import asyncio from datetime import datetime -import sys -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self +from ..utils import Self from typing import Any, AsyncIterator from ..messages.commons import DataEntry @@ -26,6 +22,7 @@ def __init__( self, record_channel_types: dict[str, str], jsonl_file_path: str, + node_name: str, redis_url: str, add_datetime: bool = True, ): @@ -45,6 +42,7 @@ def __init__( super().__init__( input_channel_types=input_channel_types, output_channel_types=[], + node_name=node_name, redis_url=redis_url, ) self.jsonl_file_path = jsonl_file_path diff --git a/src/aact/nodes/speaker.py b/src/aact/nodes/speaker.py index e237ba3..80f9c34 100644 --- a/src/aact/nodes/speaker.py +++ b/src/aact/nodes/speaker.py @@ -1,11 +1,6 @@ -import sys from typing import Any, AsyncIterator, Optional, TYPE_CHECKING -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self - +from ..utils import Self from .base import Node from .registry import NodeFactory from ..messages import Audio, Zero, Message @@ -28,6 +23,7 @@ class SpeakerNode(Node[Audio, Zero]): def __init__( self, input_channel: str, + node_name: str, redis_url: str, channels: int = 1, rate: int = 44100, @@ -42,6 +38,7 @@ def __init__( super().__init__( input_channel_types=[(input_channel, Audio)], output_channel_types=[], + node_name=node_name, redis_url=redis_url, ) self.input_channel = input_channel diff --git a/src/aact/nodes/special_print.py b/src/aact/nodes/special_print.py new file mode 100644 index 0000000..49e5220 --- /dev/null +++ b/src/aact/nodes/special_print.py @@ -0,0 +1,17 @@ +from .print import PrintNode +from aact.nodes import NodeFactory + + +@NodeFactory.register("special_print") +class SpecialPrintNode(PrintNode): + async def write_to_screen(self) -> None: + count = 0 + while self.output: + if count > 10: + await self.r.publish(f"shutdown:{self.node_name}", "shutdown") + break + data_entry = await self.write_queue.get() + await self.output.write(data_entry.model_dump_json() + "\n") + await self.output.flush() + self.write_queue.task_done() + count += 1 diff --git a/src/aact/nodes/tick.py b/src/aact/nodes/tick.py index 7b10f94..594d962 100644 --- a/src/aact/nodes/tick.py +++ b/src/aact/nodes/tick.py @@ -1,10 +1,6 @@ import asyncio -import sys -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self +from ..utils import Self from typing import AsyncIterator from ..messages import Tick, Message, Zero @@ -15,7 +11,7 @@ @NodeFactory.register("tick") class TickNode(Node[Zero, Tick]): - def __init__(self, redis_url: str = "redis://localhost:6379/0"): + def __init__(self, node_name: str, redis_url: str = "redis://localhost:6379/0"): super().__init__( input_channel_types=[], output_channel_types=[ @@ -26,6 +22,7 @@ def __init__(self, redis_url: str = "redis://localhost:6379/0"): ("tick/millis/100", Tick), ("tick/secs/1", Tick), ], + node_name=node_name, redis_url=redis_url, ) diff --git a/src/aact/nodes/transcriber.py b/src/aact/nodes/transcriber.py index 63d1a52..ef667e7 100644 --- a/src/aact/nodes/transcriber.py +++ b/src/aact/nodes/transcriber.py @@ -1,10 +1,6 @@ import asyncio -import sys -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self +from ..utils import Self from typing import TYPE_CHECKING, Any, AsyncIterator from ..messages.base import Message @@ -34,6 +30,7 @@ def __init__( output_channel: str, rate: int, api_key: str, + node_name: str, redis_url: str, ) -> None: if not GOOLE_CLOUD_SPEECH_AVAILABLE: @@ -44,6 +41,7 @@ def __init__( super().__init__( input_channel_types=[(input_channel, Audio)], output_channel_types=[(output_channel, Text)], + node_name=node_name, redis_url=redis_url, ) self.input_channel = input_channel diff --git a/src/aact/nodes/tts.py b/src/aact/nodes/tts.py index 5b66afa..bce6289 100644 --- a/src/aact/nodes/tts.py +++ b/src/aact/nodes/tts.py @@ -1,10 +1,6 @@ import asyncio -import sys -if sys.version_info >= (3, 11): - from typing import Self -else: - from typing_extensions import Self +from ..utils import Self from typing import TYPE_CHECKING, Any, AsyncIterator from ..messages.base import Message @@ -35,6 +31,7 @@ def __init__( output_channel: str, api_key: str, rate: int, + node_name: str, redis_url: str, ) -> None: if not GOOGLE_CLOUD_TEXTTOSPEECH_AVAILABLE: @@ -45,6 +42,7 @@ def __init__( super().__init__( input_channel_types=[(input_channel, Text)], output_channel_types=[(output_channel, Audio)], + node_name=node_name, redis_url=redis_url, ) self.input_channel = input_channel diff --git a/src/aact/utils/__init__.py b/src/aact/utils/__init__.py index e69de29..597d1c7 100644 --- a/src/aact/utils/__init__.py +++ b/src/aact/utils/__init__.py @@ -0,0 +1,4 @@ +from .types import Self +from .tomllib import tomllib + +__all__ = ["Self", "tomllib"] diff --git a/src/aact/utils/tomllib.py b/src/aact/utils/tomllib.py new file mode 100644 index 0000000..05a4331 --- /dev/null +++ b/src/aact/utils/tomllib.py @@ -0,0 +1,7 @@ +import sys + + +if sys.version_info >= (3, 11): + pass +else: + pass diff --git a/src/aact/utils/types.py b/src/aact/utils/types.py new file mode 100644 index 0000000..05a4331 --- /dev/null +++ b/src/aact/utils/types.py @@ -0,0 +1,7 @@ +import sys + + +if sys.version_info >= (3, 11): + pass +else: + pass diff --git a/tests/nodes/main.py b/tests/nodes/main.py index 51bf021..72f2827 100644 --- a/tests/nodes/main.py +++ b/tests/nodes/main.py @@ -1,4 +1,4 @@ -from fastapi import FastAPI, Response, HTTPException, File, UploadFile +from fastapi import FastAPI, Response, File, UploadFile from fastapi.responses import PlainTextResponse, HTMLResponse, StreamingResponse from typing import Generator, Dict, Any from pydantic import BaseModel @@ -73,17 +73,6 @@ async def get_binary() -> Response: return Response(content=content, media_type="application/octet-stream") -# Status codes -@app.get("/error/404") -async def error_404() -> HTTPException: - raise HTTPException(status_code=404, detail="Item not found") - - -@app.get("/error/500") -async def error_500() -> HTTPException: - raise HTTPException(status_code=500, detail="Internal server error") - - class FileResponse(BaseModel): filename: str content_type: str