Skip to content

Commit

Permalink
Add GraphRun object
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Feb 1, 2025
1 parent c4e9180 commit 04fc74c
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 12 deletions.
76 changes: 65 additions & 11 deletions pydantic_graph/pydantic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import inspect
import types
from collections.abc import Sequence
from collections.abc import AsyncGenerator, Sequence
from contextlib import ExitStack
from dataclasses import dataclass, field
from functools import cached_property
Expand Down Expand Up @@ -170,7 +170,7 @@ async def main():
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())

history: list[HistoryStep[StateT, T]] = []
graph_run = GraphRun[StateT, DepsT, T](self, state=state, deps=deps)
with ExitStack() as stack:
run_span: logfire_api.LogfireSpan | None = None
if self._auto_instrument:
Expand All @@ -184,19 +184,12 @@ async def main():

next_node = start_node
while True:
next_node = await self.next(next_node, history, state=state, deps=deps, infer_name=False)
next_node = await graph_run.next(next_node)
if isinstance(next_node, End):
history.append(EndStep(result=next_node))
history = graph_run.history
if run_span is not None:
run_span.set_attribute('history', history)
return next_node.data, history
elif not isinstance(next_node, BaseNode):
if TYPE_CHECKING:
typing_extensions.assert_never(next_node)
else:
raise exceptions.GraphRuntimeError(
f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.'
)

def run_sync(
self: Graph[StateT, DepsT, T],
Expand Down Expand Up @@ -510,3 +503,64 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None:
if item is self:
self.name = name
return


class GraphRun(Generic[StateT, DepsT, RunEndT]):
def __init__(
self,
graph: Graph[StateT, DepsT, RunEndT],
*,
state: StateT = None,
deps: DepsT = None,
):
self.graph = graph
self.state = state
self.deps = deps

self.history: list[HistoryStep[StateT, RunEndT]] = []
self.final_result: End[RunEndT] | None = None

self._agen: (
AsyncGenerator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT], BaseNode[StateT, DepsT, RunEndT]] | None
) = None

async def next(
self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T]
) -> BaseNode[StateT, DepsT, Any] | End[T]:
agen = await self._get_primed_agen()
return await agen.asend(node)

async def _get_primed_agen(
self: GraphRun[StateT, DepsT, T],
) -> AsyncGenerator[BaseNode[StateT, DepsT, T] | End[T], BaseNode[StateT, DepsT, T]]:
graph = self.graph
state = self.state
deps = self.deps
history = self.history

if self._agen is None:

async def _agen() -> AsyncGenerator[BaseNode[StateT, DepsT, T] | End[T], BaseNode[StateT, DepsT, T]]:
next_node = yield # pyright: ignore[reportReturnType] # we prime the generator immediately below
while True:
next_node = await graph.next(next_node, history, state=state, deps=deps, infer_name=False)
if isinstance(next_node, End):
history.append(EndStep(result=next_node))
self.final_result = next_node
yield next_node
return
elif isinstance(next_node, BaseNode):
next_node = yield next_node # Give user a chance to modify the next node
else:
if TYPE_CHECKING:
typing_extensions.assert_never(next_node)
else:
raise exceptions.GraphRuntimeError(
f'Invalid node return type: `{type(next_node).__name__}`. Expected `BaseNode` or `End`.'
)

agen = _agen()
await agen.__anext__() # prime the generator

self._agen = agen
return self._agen
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,4 @@ skip = '.git*,*.svg,*.lock,*.css'
check-hidden = true
# Ignore "formatting" like **L**anguage
ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b'
# ignore-words-list = ''
ignore-words-list = 'asend'

0 comments on commit 04fc74c

Please sign in to comment.