From d46ea2d6f60ea7123a78b52ed8ca9e9cf15ad9f5 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Mon, 3 Feb 2025 15:13:51 -0700 Subject: [PATCH] Make graph.run(...) return an instance of GraphRun --- docs/graph.md | 10 +- pydantic_ai_slim/pydantic_ai/agent.py | 3 +- pydantic_graph/README.md | 6 +- pydantic_graph/pydantic_graph/__init__.py | 3 +- pydantic_graph/pydantic_graph/graph.py | 138 +++++++++------------- tests/graph/test_graph.py | 20 ++-- tests/graph/test_history.py | 10 +- tests/graph/test_mermaid.py | 6 +- tests/graph/test_state.py | 6 +- tests/typed_graph.py | 6 +- 10 files changed, 90 insertions(+), 118 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index 8fea97a15..9db4cbffe 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -156,11 +156,11 @@ class Increment(BaseNode): # (2)! fives_graph = Graph(nodes=[DivisibleBy5, Increment]) # (3)! -result, history = fives_graph.run_sync(DivisibleBy5(4)) # (4)! -print(result) +graph_run = fives_graph.run_sync(DivisibleBy5(4)) # (4)! +print(graph_run.result) #> 5 # the full history is quite verbose (see below), so we'll just print the summary -print([item.data_snapshot() for item in history]) +print([item.data_snapshot() for item in graph_run.history]) #> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` @@ -464,8 +464,8 @@ async def main(): ) state = State(user) feedback_graph = Graph(nodes=(WriteEmail, Feedback)) - email, _ = await feedback_graph.run(WriteEmail(), state=state) - print(email) + graph_run = await feedback_graph.run(WriteEmail(), state=state) + print(graph_run.result) """ Email( subject='Welcome to our tech blog!', diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 32854ad8e..bfeb9850a 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -337,12 +337,13 @@ async def main(): ) # Actually run - end_result, _ = await graph.run( + graph_run = await graph.run( start_node, state=state, deps=graph_deps, infer_name=False, ) + end_result = graph_run.result # Build final run result # We don't do any advanced checking if the data is actually from a final result or not diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 15a4062e0..3e4ffb24d 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -50,10 +50,10 @@ class Increment(BaseNode): fives_graph = Graph(nodes=[DivisibleBy5, Increment]) -result, history = fives_graph.run_sync(DivisibleBy5(4)) -print(result) +graph_run = fives_graph.run_sync(DivisibleBy5(4)) +print(graph_run.result) #> 5 # the full history is quite verbose (see below), so we'll just print the summary -print([item.data_snapshot() for item in history]) +print([item.data_snapshot() for item in graph_run.history]) #> [DivisibleBy5(foo=4), Increment(foo=4), DivisibleBy5(foo=5), End(data=5)] ``` diff --git a/pydantic_graph/pydantic_graph/__init__.py b/pydantic_graph/pydantic_graph/__init__.py index 84f05a391..f5f2a01c0 100644 --- a/pydantic_graph/pydantic_graph/__init__.py +++ b/pydantic_graph/pydantic_graph/__init__.py @@ -1,12 +1,11 @@ from .exceptions import GraphRuntimeError, GraphSetupError -from .graph import Graph, GraphRun, GraphRunner +from .graph import Graph, GraphRun from .nodes import BaseNode, Edge, End, GraphRunContext from .state import EndStep, HistoryStep, NodeStep __all__ = ( 'Graph', 'GraphRun', - 'GraphRunner', 'BaseNode', 'End', 'GraphRunContext', diff --git a/pydantic_graph/pydantic_graph/graph.py b/pydantic_graph/pydantic_graph/graph.py index 7656e6cc3..9f6b37222 100644 --- a/pydantic_graph/pydantic_graph/graph.py +++ b/pydantic_graph/pydantic_graph/graph.py @@ -30,7 +30,7 @@ logfire._internal.stack_info.NON_USER_CODE_PREFIXES += (str(Path(__file__).parent.absolute()),) -__all__ = ('Graph', 'GraphRun', 'GraphRunner') +__all__ = ('Graph', 'GraphRun') _logfire = logfire_api.Logfire(otel_scope='pydantic-graph') @@ -133,7 +133,7 @@ def run( state: StateT = None, deps: DepsT = None, infer_name: bool = True, - ) -> GraphRunner[StateT, DepsT, T]: + ) -> GraphRun[StateT, DepsT, T]: """Run the graph from a starting node until it ends. Args: @@ -153,24 +153,24 @@ def run( async def main(): state = MyState(1) - _, history = await never_42_graph.run(Increment(), state=state) + graph_run = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=2) - print(len(history)) + print(len(graph_run.history)) #> 3 state = MyState(41) - _, history = await never_42_graph.run(Increment(), state=state) + graph_run = await never_42_graph.run(Increment(), state=state) print(state) #> MyState(number=43) - print(len(history)) + print(len(graph_run.history)) #> 5 ``` """ if infer_name and self.name is None: self._infer_name(inspect.currentframe()) - return GraphRunner[StateT, DepsT, T]( + return GraphRun[StateT, DepsT, T]( self, start_node, history=[], state=state, deps=deps, auto_instrument=self._auto_instrument ) @@ -181,7 +181,7 @@ def run_sync( state: StateT = None, deps: DepsT = None, infer_name: bool = True, - ) -> tuple[T, list[HistoryStep[StateT, T]]]: + ) -> GraphRun[StateT, DepsT, T]: """Run the graph synchronously. This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`. @@ -499,11 +499,10 @@ def _infer_name(self, function_frame: types.FrameType | None) -> None: return -class GraphRunner(Generic[StateT, DepsT, RunEndT]): - """An object that can be awaited to perform a graph run. +class GraphRun(Generic[StateT, DepsT, RunEndT]): + """A stateful run of a graph. - This object can also be used as a contextmanager to get a handle to a specific graph run, - allowing you to iterate over nodes, and possibly perform modifications to the nodes as they are run. + After being entered, can be used like an async generator to listen to / modify nodes as the run is executed. """ def __init__( @@ -517,84 +516,25 @@ def __init__( auto_instrument: bool, ): self.graph = graph - self.first_node = first_node self.history = history self.state = state self.deps = deps - - self._run: GraphRun[StateT, DepsT, RunEndT] | None = None - self._auto_instrument = auto_instrument - self._span: LogfireSpan | None = None - - @property - def run(self) -> GraphRun[StateT, DepsT, RunEndT]: - if self._run is None: - raise exceptions.GraphRuntimeError('GraphRunner has not been awaited yet.') - return self._run - - def __await__(self) -> Generator[Any, Any, tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]]: - """Run the graph until it ends, and return the final result.""" - - async def _run() -> tuple[RunEndT, list[HistoryStep[StateT, RunEndT]]]: - async with self as run: - self._run = run - async for _next_node in run: - pass - - return run.final_result, run.history - - return _run().__await__() - - async def __aenter__(self) -> GraphRun[StateT, DepsT, RunEndT]: - if self._run is not None: - raise exceptions.GraphRuntimeError('A GraphRunner can only start a GraphRun once.') - - if self._auto_instrument: - self._span = logfire_api.span('run graph {graph.name}', graph=self.graph) - self._span.__enter__() - - self._run = run = GraphRun(self.graph, self.first_node, history=self.history, state=self.state, deps=self.deps) - return run - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - if self._span is not None: - self._span.__exit__(exc_type, exc_val, exc_tb) - self._span = None # make it more obvious if you try to use it after exiting - -class GraphRun(Generic[StateT, DepsT, RunEndT]): - """A stateful run of a graph. - - Can be used like an async generator to listen to / modify nodes as the run is executed. - """ - - def __init__( - self, - graph: Graph[StateT, DepsT, RunEndT], - next_node: BaseNode[StateT, DepsT, RunEndT], - *, - history: list[HistoryStep[StateT, RunEndT]], - state: StateT, - deps: DepsT, - ): - self.graph = graph - self.next_node = next_node - self.history = history - self.state = state - self.deps = deps - - self._final_result: End[RunEndT] | None = None + self._next_node = first_node + self._started: bool = False + self._result: End[RunEndT] | None = None + self._span: LogfireSpan | None = None @property def is_ended(self): - return self._final_result is not None + return self._result is not None @property - def final_result(self) -> RunEndT: - if self._final_result is None: + def result(self) -> RunEndT: + if self._result is None: raise exceptions.GraphRuntimeError('GraphRun has not ended yet.') - return self._final_result.data + return self._result.data async def next( self: GraphRun[StateT, DepsT, T], node: BaseNode[StateT, DepsT, T] @@ -607,16 +547,48 @@ async def next( next_node = await self.graph.next(node, history, state=state, deps=deps, infer_name=False) if isinstance(next_node, End): - self._final_result = next_node + self._result = next_node else: - self.next_node = next_node + self._next_node = next_node return next_node + def __await__(self) -> Generator[Any, Any, typing_extensions.Self]: + """Run the graph until it ends, and return the final result.""" + + async def _run() -> typing_extensions.Self: + with self: + async for _next_node in self: + pass + + return self + + return _run().__await__() + + def __enter__(self) -> typing_extensions.Self: + if self._started: + raise exceptions.GraphRuntimeError('A GraphRun can only be started once.') + + if self._auto_instrument: + self._span = logfire_api.span('run graph {graph.name}', graph=self.graph) + self._span.__enter__() + + self._started = True + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._span is not None: + self._span.__exit__(exc_type, exc_val, exc_tb) + self._span = None # make it more obvious if you try to use it after exiting + def __aiter__(self) -> AsyncIterator[BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]]: return self async def __anext__(self) -> BaseNode[StateT, DepsT, RunEndT] | End[RunEndT]: """Use the last returned node as the input to `Graph.next`.""" - if self._final_result: + if self._result: raise StopAsyncIteration - return await self.next(self.next_node) + if not self._started: + raise exceptions.GraphRuntimeError( + 'You must enter the GraphRun as a contextmanager before you can iterate over it.' + ) + return await self.next(self._next_node) diff --git a/tests/graph/test_graph.py b/tests/graph/test_graph.py index ebd254a37..4668bc2c6 100644 --- a/tests/graph/test_graph.py +++ b/tests/graph/test_graph.py @@ -57,11 +57,11 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # assert my_graph.name is None assert my_graph._get_state_type() is type(None) assert my_graph._get_run_end_type() is int - result, history = await my_graph.run(Float2String(3.14)) + graph_run = await my_graph.run(Float2String(3.14)) # len('3.14') * 2 == 8 - assert result == 8 + assert graph_run.result == 8 assert my_graph.name == 'my_graph' - assert history == snapshot( + assert graph_run.history == snapshot( [ NodeStep( state=None, @@ -84,10 +84,10 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # EndStep(result=End(data=8), ts=IsNow(tz=timezone.utc)), ] ) - result, history = await my_graph.run(Float2String(3.14159)) + graph_run = await my_graph.run(Float2String(3.14159)) # len('3.14159') == 7, 21 * 2 == 42 - assert result == 42 - assert history == snapshot( + assert graph_run.result == 42 + assert graph_run.history == snapshot( [ NodeStep( state=None, @@ -122,7 +122,7 @@ async def run(self, ctx: GraphRunContext) -> Union[String2Length, End[int]]: # EndStep(result=End(data=42), ts=IsNow(tz=timezone.utc)), ] ) - assert [e.data_snapshot() for e in history] == snapshot( + assert [e.data_snapshot() for e in graph_run.history] == snapshot( [ Float2String(input_data=3.14159), String2Length(input_data='3.14159'), @@ -320,10 +320,10 @@ async def run(self, ctx: GraphRunContext[None, Deps]) -> End[int]: return End(123) g = Graph(nodes=(Foo, Bar)) - result, history = await g.run(Foo(), deps=Deps(1, 2)) + graph_run = await g.run(Foo(), deps=Deps(1, 2)) - assert result == 123 - assert history == snapshot( + assert graph_run.result == 123 + assert graph_run.history == snapshot( [ NodeStep(state=None, node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), NodeStep(state=None, node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), diff --git a/tests/graph/test_history.py b/tests/graph/test_history.py index 2508a5347..bcd8dca19 100644 --- a/tests/graph/test_history.py +++ b/tests/graph/test_history.py @@ -46,16 +46,16 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[int]: ], ) async def test_dump_load_history(graph: Graph[MyState, None, int]): - result, history = await graph.run(Foo(), state=MyState(1, '')) - assert result == snapshot(4) - assert history == snapshot( + graph_run = await graph.run(Foo(), state=MyState(1, '')) + assert graph_run.result == snapshot(4) + assert graph_run.history == snapshot( [ NodeStep(state=MyState(x=2, y=''), node=Foo(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), NodeStep(state=MyState(x=2, y='y'), node=Bar(), start_ts=IsNow(tz=timezone.utc), duration=IsFloat()), EndStep(result=End(4), ts=IsNow(tz=timezone.utc)), ] ) - history_json = graph.dump_history(history) + history_json = graph.dump_history(graph_run.history) assert json.loads(history_json) == snapshot( [ { @@ -76,7 +76,7 @@ async def test_dump_load_history(graph: Graph[MyState, None, int]): ] ) history_loaded = graph.load_history(history_json) - assert history == history_loaded + assert graph_run.history == history_loaded custom_history = [ { diff --git a/tests/graph/test_mermaid.py b/tests/graph/test_mermaid.py index 9f76d93cd..041fe6027 100644 --- a/tests/graph/test_mermaid.py +++ b/tests/graph/test_mermaid.py @@ -58,9 +58,9 @@ async def run(self, ctx: GraphRunContext) -> Annotated[End[None], Edge(label='eg async def test_run_graph(): - result, history = await graph1.run(Foo()) - assert result is None - assert history == snapshot( + graph_run = await graph1.run(Foo()) + assert graph_run.result is None + assert graph_run.history == snapshot( [ NodeStep( state=None, diff --git a/tests/graph/test_state.py b/tests/graph/test_state.py index fbb570cf0..8c59667ae 100644 --- a/tests/graph/test_state.py +++ b/tests/graph/test_state.py @@ -36,9 +36,9 @@ async def run(self, ctx: GraphRunContext[MyState]) -> End[str]: assert graph._get_state_type() is MyState assert graph._get_run_end_type() is str state = MyState(1, '') - result, history = await graph.run(Foo(), state=state) - assert result == snapshot('x=2 y=y') - assert history == snapshot( + graph_run = await graph.run(Foo(), state=state) + assert graph_run.result == snapshot('x=2 y=y') + assert graph_run.history == snapshot( [ NodeStep( state=MyState(x=2, y=''), diff --git a/tests/typed_graph.py b/tests/typed_graph.py index d0b6a02b7..deba4dd45 100644 --- a/tests/typed_graph.py +++ b/tests/typed_graph.py @@ -109,6 +109,6 @@ def run_g5() -> None: g5.run_sync(A()) # pyright: ignore[reportArgumentType] g5.run_sync(A(), state=MyState(x=1)) # pyright: ignore[reportArgumentType] g5.run_sync(A(), deps=MyDeps(y='y')) # pyright: ignore[reportArgumentType] - answer, history = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) - assert_type(answer, int) - assert_type(history, list[HistoryStep[MyState, int]]) + graph_run = g5.run_sync(A(), state=MyState(x=1), deps=MyDeps(y='y')) + assert_type(graph_run.result, int) + assert_type(graph_run.history, list[HistoryStep[MyState, int]])