Skip to content

Commit

Permalink
Minor tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
dmontagu committed Feb 5, 2025
1 parent 8fc633a commit f504858
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions pydantic_graph/pydantic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,33 +522,45 @@ def __init__(
self.state = state
self.deps = deps

self._run: GraphRun[StateT, DepsT, RunEndT] | None = None

self._auto_instrument = auto_instrument
self._span: LogfireSpan | None = None

@property
def run(self) -> GraphRun[StateT, DepsT, RunEndT]:
if self._run is None:
raise exceptions.GraphRuntimeError('GraphRunner has not been awaited yet.')
return self._run

def __await__(self) -> Generator[Any, Any, tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]]:
"""Run the graph until it ends, and return the final result."""

async def _run() -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]:
async with self as run:
self._run = run
async for _next_node in run:
pass

assert run.final_result is not None
return run.final_result.data, run.history
return run.final_result, run.history

return _run().__await__()

async def __aenter__(self) -> GraphRun[StateT, DepsT, RunEndT]:
if self._run is not None:
raise exceptions.GraphRuntimeError('A GraphRunner can only start a GraphRun once.')

if self._auto_instrument:
self._span = logfire_api.span('run graph {graph.name}', graph=self.graph)
self._span.__enter__()

return GraphRun(self.graph, self.first_node, history=self.history, state=self.state, deps=self.deps)
self._run = run = GraphRun(self.graph, self.first_node, history=self.history, state=self.state, deps=self.deps)
return run

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._span is not None:
self._span.__exit__(exc_type, exc_val, exc_tb)
self._span = None
self._span = None # make it more obvious if you try to use it after exiting


class GraphRun(Generic[StateT, DepsT, RunEndT]):
Expand All @@ -562,17 +574,27 @@ def __init__(
graph: Graph[StateT, DepsT, RunEndT],
next_node: BaseNode[StateT, DepsT, RunEndT],
*,
history: list[HistoryStep[StateT, RunEndT]] | None = None,
state: StateT = None,
deps: DepsT = None,
history: list[HistoryStep[StateT, RunEndT]],
state: StateT,
deps: DepsT,
):
self.graph = graph
self.next_node = next_node
self.history = history
self.state = state
self.deps = deps

self.history = history or []
self.final_result: End[RunEndT] | None = None
self._final_result: End[RunEndT] | None = None

@property
def is_ended(self):
return self._final_result is not None

@property
def final_result(self) -> RunEndT:
if self._final_result is None:
raise exceptions.GraphRuntimeError('GraphRun has not ended yet.')
return self._final_result.data

async def next(
self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T]
Expand All @@ -585,7 +607,7 @@ async def next(
next_node = await self.graph.next(node, history, state=state, deps=deps, infer_name=False)

if isinstance(next_node, End):
self.final_result = next_node
self._final_result = next_node
else:
self.next_node = next_node
return next_node
Expand All @@ -595,6 +617,6 @@ def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunE

async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]:
"""Use the last returned node as the input to `Graph.next`."""
if self.final_result:
if self._final_result:
raise StopAsyncIteration
return await self.next(self.next_node)

0 comments on commit f504858

Please sign in to comment.