From f504858cb9eb05a97f208103eeff5f3e45eb076e Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 3 Feb 2025 14:57:46 -0700 Subject: [PATCH] Minor tweak --- pydantic_graph/pydantic_graph/graph.py | 44 +++++++++++++++++++------- 1 file changed, 33 insertions(+), 11 deletions(-) diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index a72cde051..7656e6cc3 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -522,33 +522,45 @@ def __init__( self.state = state self.deps = deps + self._run: GraphRun[StateT, DepsT, RunEndT] | None = None + self._auto_instrument = auto_instrument self._span: LogfireSpan | None = None + @property + def run(self) -> GraphRun[StateT, DepsT, RunEndT]: + if self._run is None: + raise exceptions.GraphRuntimeError('GraphRunner has not been awaited yet.') + return self._run + def __await__(self) -> Generator[Any, Any, tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]]: """Run the graph until it ends, and return the final result.""" async def _run() -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: async with self as run: + self._run = run async for _next_node in run: pass - assert run.final_result is not None - return run.final_result.data, run.history + return run.final_result, run.history return _run().__await__() async def __aenter__(self) -> GraphRun[StateT, DepsT, RunEndT]: + if self._run is not None: + raise exceptions.GraphRuntimeError('A GraphRunner can only start a GraphRun once.') + if self._auto_instrument: self._span = logfire_api.span('run graph {graph.name}', graph=self.graph) self._span.__enter__() - return GraphRun(self.graph, self.first_node, history=self.history, state=self.state, deps=self.deps) + self._run = run = GraphRun(self.graph, self.first_node, history=self.history, state=self.state, deps=self.deps) + return run async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if self._span is not None: self._span.__exit__(exc_type, exc_val, exc_tb) - self._span = None + self._span = None # make it more obvious if you try to use it after exiting class GraphRun(Generic[StateT, DepsT, RunEndT]): @@ -562,17 +574,27 @@ def __init__( graph: Graph[StateT, DepsT, RunEndT], next_node: BaseNode[StateT, DepsT, RunEndT], *, - history: list[HistoryStep[StateT, RunEndT]] | None = None, - state: StateT = None, - deps: DepsT = None, + history: list[HistoryStep[StateT, RunEndT]], + state: StateT, + deps: DepsT, ): self.graph = graph self.next_node = next_node + self.history = history self.state = state self.deps = deps - self.history = history or [] - self.final_result: End[RunEndT] | None = None + self._final_result: End[RunEndT] | None = None + + @property + def is_ended(self): + return self._final_result is not None + + @property + def final_result(self) -> RunEndT: + if self._final_result is None: + raise exceptions.GraphRuntimeError('GraphRun has not ended yet.') + return self._final_result.data async def next( self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T] @@ -585,7 +607,7 @@ async def next( next_node = await self.graph.next(node, history, state=state, deps=deps, infer_name=False) if isinstance(next_node, End): - self.final_result = next_node + self._final_result = next_node else: self.next_node = next_node return next_node @@ -595,6 +617,6 @@ def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunE async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: """Use the last returned node as the input to `Graph.next`.""" - if self.final_result: + if self._final_result: raise StopAsyncIteration return await self.next(self.next_node)