Skip to content

Commit

Permalink
state persistence
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Feb 19, 2025
1 parent 8fcf8c9 commit b4e2bda
Show file tree
Hide file tree
Showing 16 changed files with 417 additions and 249 deletions.
7 changes: 3 additions & 4 deletions docs/graph.md
Original file line number Diff line number Diff line change
Expand Up @@ -576,26 +576,25 @@ In this example, an AI asks the user a question, the user provides an answer, th

_(This example is complete, it can be run "as is" with Python 3.10+)_


```python {title="ai_q_and_a_run.py" noqa="I001" py="3.10"}
from rich.prompt import Prompt

from pydantic_graph import End, HistoryStep
from pydantic_graph import End, Snapshot

from ai_q_and_a_graph import Ask, question_graph, QuestionState, Answer


async def main():
state = QuestionState() # (1)!
node = Ask() # (2)!
history: list[HistoryStep[QuestionState]] = [] # (3)!
history: list[Snapshot[QuestionState]] = [] # (3)!
while True:
node = await question_graph.next(node, history, state=state) # (4)!
if isinstance(node, Answer):
node.answer = Prompt.ask(node.question) # (5)!
elif isinstance(node, End): # (6)!
print(f'Correct answer! {node.data}')
#> Correct answer! Well done, 1 + 1 = 2
# > Correct answer! Well done, 1 + 1 = 2
print([e.data_snapshot() for e in history])
"""
[
Expand Down
4 changes: 2 additions & 2 deletions examples/pydantic_ai_examples/question_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import logfire
from devtools import debug
from pydantic_graph import BaseNode, Edge, End, Graph, GraphRunContext, HistoryStep
from pydantic_graph import BaseNode, Edge, End, Graph, GraphRunContext, Snapshot

from pydantic_ai import Agent
from pydantic_ai.format_as_xml import format_as_xml
Expand Down Expand Up @@ -116,7 +116,7 @@ async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask:
async def run_as_continuous():
state = QuestionState()
node = Ask()
history: list[HistoryStep[QuestionState, None]] = []
history: list[Snapshot[QuestionState, None]] = []
with logfire.span('run questions graph'):
while True:
node = await question_graph.next(node, history, state=state)
Expand Down
4 changes: 2 additions & 2 deletions pydantic_ai_slim/pydantic_ai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import logfire_api
from typing_extensions import TypeVar, deprecated

from pydantic_graph import Graph, GraphRunContext, HistoryStep
from pydantic_graph import Graph, GraphRunContext, Snapshot
from pydantic_graph.nodes import End

from . import (
Expand Down Expand Up @@ -583,7 +583,7 @@ async def main():

# Actually run
node = start_node
history: list[HistoryStep[_agent_graph.GraphAgentState, RunResultDataT]] = []
history: list[Snapshot[_agent_graph.GraphAgentState, RunResultDataT]] = []
while True:
if isinstance(node, _agent_graph.StreamModelRequestNode):
node = cast(
Expand Down
8 changes: 4 additions & 4 deletions pydantic_graph/pydantic_graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from .exceptions import GraphRuntimeError, GraphSetupError
from .graph import Graph
from .nodes import BaseNode, Edge, End, GraphRunContext
from .state import EndStep, HistoryStep, NodeStep
from .state import EndSnapshot, NodeSnapshot, Snapshot

__all__ = (
'Graph',
'BaseNode',
'End',
'GraphRunContext',
'Edge',
'EndStep',
'HistoryStep',
'NodeStep',
'EndSnapshot',
'Snapshot',
'NodeSnapshot',
'GraphSetupError',
'GraphRuntimeError',
)
5 changes: 0 additions & 5 deletions pydantic_graph/pydantic_graph/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import sys
import types
from datetime import datetime, timezone
from typing import Annotated, Any, TypeVar, Union, get_args, get_origin

import typing_extensions
Expand Down Expand Up @@ -80,10 +79,6 @@ def get_parent_namespace(frame: types.FrameType | None) -> dict[str, Any] | None
return back.f_locals


def now_utc() -> datetime:
return datetime.now(tz=timezone.utc)


class Unset:
"""A singleton to represent an unset value.
Expand Down
124 changes: 58 additions & 66 deletions pydantic_graph/pydantic_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@
from contextlib import ExitStack
from dataclasses import dataclass, field
from functools import cached_property
from pathlib import Path
from time import perf_counter
from typing import TYPE_CHECKING, Annotated, Any, Callable, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar

import logfire_api
import pydantic
import typing_extensions

from . import _utils, exceptions, mermaid
from .nodes import BaseNode, DepsT, End, GraphRunContext, NodeDef, RunEndT
from .state import EndStep, HistoryStep, NodeStep, StateT, deep_copy_state, nodes_schema_var
from .state import StatePersistence, StateT, node_type_adapter
from .state.memory import LatestMemoryStatePersistence

# while waiting for https://github.com/pydantic/logfire/issues/745
try:
Expand Down Expand Up @@ -84,7 +83,6 @@ async def run(self, ctx: GraphRunContext) -> Increment | End[int]:

name: str | None
node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]]
snapshot_state: Callable[[StateT], StateT]
_state_type: type[StateT] | _utils.Unset = field(repr=False)
_run_end_type: type[RunEndT] | _utils.Unset = field(repr=False)
_auto_instrument: bool = field(repr=False)
Expand All @@ -96,7 +94,6 @@ def __init__(
name: str | None = None,
state_type: type[StateT] | _utils.Unset = _utils.UNSET,
run_end_type: type[RunEndT] | _utils.Unset = _utils.UNSET,
snapshot_state: Callable[[StateT], StateT] = deep_copy_state,
auto_instrument: bool = True,
):
"""Create a graph from a sequence of nodes.
Expand All @@ -108,16 +105,12 @@ def __init__(
on the first call to a graph method.
state_type: The type of the state for the graph, this can generally be inferred from `nodes`.
run_end_type: The type of the result of running the graph, this can generally be inferred from `nodes`.
snapshot_state: A function to snapshot the state of the graph, this is used in
[`NodeStep`][pydantic_graph.state.NodeStep] and [`EndStep`][pydantic_graph.state.EndStep] to record
the state before each step.
auto_instrument: Whether to create a span for the graph run and the execution of each node's run method.
"""
self.name = name
self._state_type = state_type
self._run_end_type = run_end_type
self._auto_instrument = auto_instrument
self.snapshot_state = snapshot_state

parent_namespace = _utils.get_parent_namespace(inspect.currentframe())
self.node_defs: dict[str, NodeDef[StateT, DepsT, RunEndT]] = {}
Expand All @@ -132,15 +125,18 @@ async def run(
*,
state: StateT = None,
deps: DepsT = None,
state_persistence: StatePersistence[StateT, T] | None = None,
infer_name: bool = True,
) -> tuple[T, list[HistoryStep[StateT, T]]]:
) -> T:
"""Run the graph from a starting node until it ends.
Args:
start_node: the first node to run, since the graph definition doesn't define the entry point in the graph,
you need to provide the starting node.
state: The initial state of the graph.
deps: The dependencies of the graph.
state_persistence: State persistence interface, defaults to
[`LatestMemoryStatePersistence`][pydantic_graph.state.memory.LatestMemoryStatePersistence] if `None`.
infer_name: Whether to infer the graph name from the calling frame.
Returns:
Expand Down Expand Up @@ -170,11 +166,16 @@ async def main():
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())

history: list[HistoryStep[StateT, T]] = []
if state_persistence is None:
state_persistence = LatestMemoryStatePersistence()

# have to snapshot state before iterating over nodes, as we'll expect a snapshot in the
# state_persistence soon
await state_persistence.snapshot_node(state, start_node, self._node_type_adapter)

with ExitStack() as stack:
run_span: logfire_api.LogfireSpan | None = None
if self._auto_instrument:
run_span = stack.enter_context(
stack.enter_context(
_logfire.span(
'{graph_name} run {start=}',
graph_name=self.name or 'graph',
Expand All @@ -184,12 +185,9 @@ async def main():

next_node = start_node
while True:
next_node = await self.next(next_node, history, state=state, deps=deps, infer_name=False)
next_node = await self.next(next_node, state_persistence, state=state, deps=deps, infer_name=False)
if isinstance(next_node, End):
history.append(EndStep(result=next_node))
if run_span is not None:
run_span.set_attribute('history', history)
return next_node.data, history
return next_node.data
elif not isinstance(next_node, BaseNode):
if TYPE_CHECKING:
typing_extensions.assert_never(next_node)
Expand All @@ -204,8 +202,9 @@ def run_sync(
*,
state: StateT = None,
deps: DepsT = None,
state_persistence: StatePersistence[StateT, T] | None = None,
infer_name: bool = True,
) -> tuple[T, list[HistoryStep[StateT, T]]]:
) -> T:
"""Run the graph synchronously.
This is a convenience method that wraps [`self.run`][pydantic_graph.Graph.run] with `loop.run_until_complete(...)`.
Expand All @@ -216,6 +215,8 @@ def run_sync(
you need to provide the starting node.
state: The initial state of the graph.
deps: The dependencies of the graph.
state_persistence: State persistence interface, defaults to
[`LatestMemoryStatePersistence`][pydantic_graph.state.memory.LatestMemoryStatePersistence] if `None`.
infer_name: Whether to infer the graph name from the calling frame.
Returns:
Expand All @@ -224,13 +225,13 @@ def run_sync(
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())
return asyncio.get_event_loop().run_until_complete(
self.run(start_node, state=state, deps=deps, infer_name=False)
self.run(start_node, state=state, deps=deps, state_persistence=state_persistence, infer_name=False)
)

async def next(
self: Graph[StateT, DepsT, T],
node: BaseNode[StateT, DepsT, T],
history: list[HistoryStep[StateT, T]],
state_persistence: StatePersistence[StateT, T],
*,
state: StateT = None,
deps: DepsT = None,
Expand All @@ -240,7 +241,7 @@ async def next(
Args:
node: The node to run.
history: The history of the graph run so far. NOTE: this will be mutated to add the new step.
state_persistence: State persistence interface.
state: The current state of the graph.
deps: The dependencies of the graph.
infer_name: Whether to infer the graph name from the calling frame.
Expand All @@ -258,52 +259,33 @@ async def next(
if self._auto_instrument:
stack.enter_context(_logfire.span('run node {node_id}', node_id=node_id, node=node))
ctx = GraphRunContext(state, deps)
start_ts = _utils.now_utc()
start = perf_counter()
next_node = await node.run(ctx)
duration = perf_counter() - start
async with state_persistence.record_run():
next_or_end = await node.run(ctx)

history.append(
NodeStep(state=state, node=node, start_ts=start_ts, duration=duration, snapshot_state=self.snapshot_state)
)
return next_node

def dump_history(
self: Graph[StateT, DepsT, T], history: list[HistoryStep[StateT, T]], *, indent: int | None = None
) -> bytes:
"""Dump the history of a graph run as JSON.
Args:
history: The history of the graph run.
indent: The number of spaces to indent the JSON.
Returns:
The JSON representation of the history.
"""
return self.history_type_adapter.dump_json(history, indent=indent)

def load_history(self, json_bytes: str | bytes | bytearray) -> list[HistoryStep[StateT, RunEndT]]:
"""Load the history of a graph run from JSON.
Args:
json_bytes: The JSON representation of the history.
if isinstance(next_or_end, BaseNode):
await state_persistence.snapshot_node(state, next_or_end, self._node_type_adapter)
else:
await state_persistence.snapshot_end(state, next_or_end, self._end_data_type_adapter)
return next_or_end

Returns:
The history of the graph run.
"""
return self.history_type_adapter.validate_json(json_bytes)
async def next_from_persistence(
self: Graph[StateT, DepsT, T],
state_persistence: StatePersistence[StateT, T],
*,
deps: DepsT = None,
infer_name: bool = True,
) -> BaseNode[StateT, DepsT, Any] | End[T]:
if infer_name and self.name is None:
self._infer_name(inspect.currentframe())

@cached_property
def history_type_adapter(self) -> pydantic.TypeAdapter[list[HistoryStep[StateT, RunEndT]]]:
nodes = [node_def.node for node_def in self.node_defs.values()]
state_t = self._get_state_type()
end_t = self._get_run_end_type()
token = nodes_schema_var.set(nodes)
try:
ta = pydantic.TypeAdapter(list[Annotated[HistoryStep[state_t, end_t], pydantic.Discriminator('kind')]])
finally:
nodes_schema_var.reset(token)
return ta
snapshot = await state_persistence.restore_node_snapshot()
return await self.next(
snapshot.node,
state_persistence,
state=snapshot.state,
deps=deps,
infer_name=False,
)

def mermaid_code(
self,
Expand Down Expand Up @@ -428,6 +410,16 @@ def mermaid_save(
kwargs['title'] = self.name
mermaid.save_image(path, self, **kwargs)

@cached_property
def _node_type_adapter(self) -> pydantic.TypeAdapter[BaseNode[StateT, Any, RunEndT]]:
nodes = [node_def.node for node_def in self.node_defs.values()]
return node_type_adapter(nodes, self._get_state_type(), self._get_run_end_type())

@cached_property
def _end_data_type_adapter(self) -> pydantic.TypeAdapter[RunEndT]:
end_t = self._get_run_end_type()
return pydantic.TypeAdapter(end_t)

def _get_state_type(self) -> type[StateT]:
if _utils.is_set(self._state_type):
return self._state_type
Expand Down
Loading

0 comments on commit b4e2bda

Please sign in to comment.