Skip to content

Commit

Permalink
add node manager to make nodes peer-stoppable
Browse files Browse the repository at this point in the history
  • Loading branch information
ProKil committed Dec 2, 2024
1 parent 1f053f3 commit b9eff3b
Show file tree
Hide file tree
Showing 23 changed files with 274 additions and 105 deletions.
2 changes: 2 additions & 0 deletions examples/api_node_examples/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def __init__(
input_tick_channel: str,
input_response_channel: str,
output_channel: str,
node_name: str,
redis_url: str,
):
response_class = get_rest_response_class(AnyDataModel)
Expand All @@ -28,6 +29,7 @@ def __init__(
(input_response_channel, response_class),
],
output_channel_types=[(output_channel, request_class)],
node_name=node_name,
redis_url=redis_url,
)

Expand Down
5 changes: 4 additions & 1 deletion examples/chatbot_examples/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

@NodeFactory.register("gpt4_text_chatbot_node")
class GPT4TextChatbotNode(Node[Text, Text]):
def __init__(self, input_channel: str, output_channel: str, redis_url: str):
def __init__(
self, input_channel: str, output_channel: str, node_name: str, redis_url: str
):
super().__init__(
input_channel_types=[(input_channel, Text)],
output_channel_types=[(output_channel, Text)],
node_name=node_name,
redis_url=redis_url,
)
self.input_channel = input_channel
Expand Down
12 changes: 12 additions & 0 deletions examples/node_manager_examples/example.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
redis_url = "redis://localhost:6379/0" # required

[[nodes]]
node_name = "special_print"
node_class = "special_print"

[nodes.node_args.print_channel_types]
"tick/secs/1" = "tick"

[[nodes]]
node_name = "tick"
node_class = "tick"
54 changes: 8 additions & 46 deletions src/aact/cli/launch/launch.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,24 @@
import asyncio
import logging
import os
import signal
import sys
import time
from typing import Annotated, Any, Optional, TypeVar
from typing import Annotated, Optional, TypeVar
from ..app import app
from ..reader import get_dataflow_config, draw_dataflow_mermaid, NodeConfig, Config
import typer

from ...nodes import NodeFactory

from subprocess import Popen

if sys.version_info >= (3, 11):
import tomllib
else:
import tomlkit as tomllib
from ...utils import tomllib

from rq import Queue
from rq.exceptions import InvalidJobOperation
from rq.job import Job
from rq.command import send_stop_job_command
from redis import Redis

from ...manager import NodeManager

InputType = TypeVar("InputType")
OutputType = TypeVar("OutputType")

Expand All @@ -35,6 +31,7 @@ async def _run_node(node_config: NodeConfig, redis_url: str) -> None:
async with NodeFactory.make(
node_config.node_class,
**node_config.node_args.model_dump(),
node_name=node_config.node_name,
redis_url=redis_url,
) as node:
logger.info(f"Starting eventloop {node_config.node_name}")
Expand Down Expand Up @@ -131,43 +128,8 @@ def run_dataflow(
finally:
return

subprocesses: list[Popen[bytes]] = []

try:
# Nodes that run w/ subprocess
for node in config.nodes:
command = f"aact run-node --dataflow-toml {dataflow_toml} --node-name {node.node_name} --redis-url {config.redis_url}"
logger.info(f"executing {command}")
node_process = Popen(
[command],
shell=True,
preexec_fn=os.setsid, # Start the subprocess in a new process group
)
subprocesses.append(node_process)

def _cleanup_subprocesses(
signum: int | None = None, frame: Any | None = None
) -> None:
for node_process in subprocesses:
try:
os.killpg(os.getpgid(node_process.pid), signal.SIGTERM)
logger.info(f"Terminating process group {node_process.pid}")
except ProcessLookupError:
logger.warning(
f"Process group {node_process.pid} has been terminated."
)

signal.signal(signal.SIGTERM, _cleanup_subprocesses)
signal.signal(signal.SIGINT, _cleanup_subprocesses)

for node_process in subprocesses:
node_process.wait()

except Exception as e:
logger.warning("Error in multiprocessing: ", e)
_cleanup_subprocesses()
finally:
_cleanup_subprocesses()
with NodeManager(dataflow_toml, False, config.redis_url) as node_manager:
node_manager.wait()


@app.command(help="A nice debugging feature. Draw dataflows with Mermaid.")
Expand Down
3 changes: 3 additions & 0 deletions src/aact/manager/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .manager import NodeManager

__all__ = ["NodeManager"]
148 changes: 148 additions & 0 deletions src/aact/manager/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
import asyncio
import datetime
from logging import Logger
import os
import signal
from subprocess import Popen
import threading
from ..utils import tomllib
from typing import Any, Literal
from uuid import uuid4

from redis.asyncio import Redis
from redis import Redis as SyncRedis

from ..cli.reader import Config

from ..utils import Self

logger = Logger("NodeManager")

Health = Literal["Started", "Running", "No Response", "Stopped"]


def run_event_loop_in_thread():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_forever()


class NodeManager(object):
def __init__(
self,
dataflow_toml: str,
with_rq: bool = False,
redis_url: str = "redis://localhost:6379/0",
):
self.id = f"manager-{str(uuid4())}"
self.dataflow_toml = dataflow_toml
self.with_rq = with_rq
self.subprocesses: dict[str, Popen[bytes]] = {}
self.pubsub = Redis.from_url(redis_url).pubsub()
self.shutdown_pubsub = SyncRedis.from_url(redis_url).pubsub()
self.background_tasks: list[asyncio.Task[None]] = []
self.node_health: dict[str, Health] = {}
self.last_heartbeat: dict[str, float] = {}
self.loop: asyncio.AbstractEventLoop | None = None
self.shutdown_signal: bool = False

def __enter__(
self,
) -> Self:
config = Config.model_validate(tomllib.load(open(self.dataflow_toml, "rb")))

# Nodes that run w/ subprocess
for node in config.nodes:
try:
command = f"aact run-node --dataflow-toml {self.dataflow_toml} --node-name {node.node_name} --redis-url {config.redis_url}"
node_process = Popen(
[command],
shell=True,
preexec_fn=os.setsid, # Start the subprocess in a new process group
)
logger.info(
f"Starting subprocess {node_process} for node {node.node_name}"
)
assert (
node.node_name not in self.subprocesses
), f"Node {node.node_name} is duplicated."
self.subprocesses[node.node_name] = node_process
self.node_health[node.node_name] = "Started"
except Exception as e:
logger.error(
f"Error starting subprocess {node.node_name}: {e}. Stopping other nodes as well."
)
for node_name, node_process in self.subprocesses.items():
logger.info(
f"Terminating Node {node_name}. Process: {node_process}"
)
try:
os.killpg(os.getpgid(node_process.pid), signal.SIGTERM)
except ProcessLookupError:
logger.info(f"Process group {node_process.pid} not found.")
self.subprocesses = {}
raise e

thread = threading.Thread(target=run_event_loop_in_thread, daemon=True)
thread.start()
self.loop = asyncio.get_event_loop()

self.background_tasks.append(self.loop.create_task(self.wait_for_heartbeat()))
self.background_tasks.append(self.loop.create_task(self.update_health_status()))

return self

async def wait_for_heartbeat(
self,
) -> None:
for node_name in self.subprocesses.keys():
await self.pubsub.subscribe(f"heartbeat:{node_name}")

async for message in self.pubsub.listen():
node_name = ":".join(message["channel"].decode("utf-8").split(":")[1:])
self.last_heartbeat[node_name] = datetime.datetime.now().timestamp()

async def update_health_status(
self,
) -> None:
while True:
for node_name, last_heartbeat in self.last_heartbeat.items():
if datetime.datetime.now().timestamp() - last_heartbeat > 10:
self.node_health[node_name] = "No Response"
else:
self.node_health[node_name] = "Running"
await asyncio.sleep(1)

def wait(
self,
) -> None:
for node_name in self.subprocesses.keys():
self.shutdown_pubsub.subscribe(f"shutdown:{node_name}")
for message in self.shutdown_pubsub.listen():
node_name = ":".join(message["channel"].decode("utf-8").split(":")[1:])
if message["data"] == b"shutdown":
logger.info(f"Received shutdown signal for node {node_name}")
self.shutdown_signal = True
break
self.shutdown_pubsub.unsubscribe()
self.shutdown_pubsub.close()

def __exit__(
self,
signum: int | None = None,
frame: Any | None = None,
traceback: Any | None = None,
) -> None:
for _, node_process in self.subprocesses.items():
try:
os.killpg(os.getpgid(node_process.pid), signal.SIGTERM)
logger.info(f"Terminating process group {node_process.pid}")
except ProcessLookupError:
logger.warning(f"Process group {node_process.pid} not found.")
for task in self.background_tasks:
task.cancel()

if self.loop:
self.loop.run_until_complete(self.pubsub.unsubscribe())
self.loop.run_until_complete(self.pubsub.close())
self.loop.stop()
2 changes: 2 additions & 0 deletions src/aact/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .tts import TTSNode
from .registry import NodeFactory
from .api import RestAPINode
from .special_print import SpecialPrintNode

__all__ = [
"Node",
Expand All @@ -32,4 +33,5 @@
"PrintNode",
"TTSNode",
"RestAPINode",
"SpecialPrintNode",
]
2 changes: 2 additions & 0 deletions src/aact/nodes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(
output_channel: str,
input_type_str: str,
output_type_str: str,
node_name: str,
redis_url: str,
):
if input_type_str not in DataModelFactory.registry:
Expand All @@ -98,6 +99,7 @@ def __init__(
super().__init__(
input_channel_types=[(input_channel, request_class)],
output_channel_types=[(output_channel, response_class)],
node_name=node_name,
redis_url=redis_url,
)

Expand Down
Loading

0 comments on commit b9eff3b

Please sign in to comment.