-
Notifications
You must be signed in to change notification settings - Fork 521
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1f7a174
commit 2350965
Showing
28 changed files
with
1,322 additions
and
384 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from dataclasses import dataclass | ||
|
||
from pydantic_ai import Agent | ||
from devtools import debug | ||
|
||
system_prompt = """\ | ||
Given the following PostgreSQL table of records, your job is to write a SQL query that suits the user's request. | ||
CREATE TABLE records AS ( | ||
start_timestamp timestamp with time zone, | ||
created_at timestamp with time zone, | ||
trace_id text, | ||
span_id text, | ||
parent_span_id text, | ||
kind span_kind, | ||
end_timestamp timestamp with time zone, | ||
level smallint, | ||
span_name text, | ||
message text, | ||
attributes_json_schema text, | ||
attributes jsonb, | ||
tags text[], | ||
otel_links jsonb, | ||
otel_events jsonb, | ||
is_exception boolean, | ||
otel_status_code status_code, | ||
otel_status_message text, | ||
otel_scope_name text, | ||
otel_scope_version text, | ||
otel_scope_attributes jsonb, | ||
service_namespace text, | ||
service_name text, | ||
service_version text, | ||
service_instance_id text, | ||
process_pid integer | ||
); | ||
today's date = 2024-10-09 | ||
Example | ||
request: show me records where foobar is false | ||
response: SELECT * FROM records WHERE attributes->>'foobar' = false' | ||
Example | ||
request: show me records from yesterday | ||
response: SELECT * FROM records WHERE start_timestamp::date > CURRENT_TIMESTAMP - INTERVAL '1 day' | ||
Example | ||
request: show me error records with the tag "foobar" | ||
response: SELECT * FROM records WHERE level = 'error' and 'foobar' = ANY(tags) | ||
""" | ||
|
||
|
||
@dataclass | ||
class Response: | ||
sql_query: str | ||
|
||
|
||
agent = Agent('gemini-1.5-flash', result_type=Response, system_prompt=system_prompt, deps=None) | ||
|
||
|
||
if __name__ == '__main__': | ||
with debug.timer('SQL Generation'): | ||
result = agent.run_sync('show me logs from yesterday, with level "error"') | ||
debug(result.response.sql_query) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .agent import Agent | ||
from .call import CallContext, Retry | ||
from .shared import AgentError, CallContext, Retry, UnexpectedModelBehaviour, UserError | ||
|
||
__all__ = 'Agent', 'CallContext', 'Retry' | ||
__all__ = 'Agent', 'AgentError', 'CallContext', 'Retry', 'UnexpectedModelBehaviour', 'UserError' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from __future__ import annotations as _annotations | ||
|
||
import inspect | ||
from collections.abc import Awaitable | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Generic, Union, cast | ||
|
||
from . import _utils, messages | ||
from .result import ResultData | ||
from .shared import AgentDeps, CallContext, Retry | ||
|
||
# A function that always takes `ResultData` and returns `ResultData`, | ||
# but may or maybe not take `CallInfo` as a first argument, and may or may not be async. | ||
# Usage `ResultValidator[AgentDeps, ResultData]` | ||
ResultValidatorFunc = Union[ | ||
Callable[[CallContext[AgentDeps], ResultData], ResultData], | ||
Callable[[CallContext[AgentDeps], ResultData], Awaitable[ResultData]], | ||
Callable[[ResultData], ResultData], | ||
Callable[[ResultData], Awaitable[ResultData]], | ||
] | ||
|
||
|
||
@dataclass | ||
class ResultValidator(Generic[AgentDeps, ResultData]): | ||
function: ResultValidatorFunc[AgentDeps, ResultData] | ||
_takes_ctx: bool = False | ||
_is_async: bool = False | ||
|
||
def __post_init__(self): | ||
self._takes_ctx = len(inspect.signature(self.function).parameters) > 1 | ||
self._is_async = inspect.iscoroutinefunction(self.function) | ||
|
||
async def validate( | ||
self, result: ResultData, deps: AgentDeps, retry: int, tool_call: messages.ToolCall | ||
) -> ResultData: | ||
"""Validate a result but calling the function. | ||
Args: | ||
result: The result data after Pydantic validation the message content. | ||
deps: The agent dependencies. | ||
retry: The current retry number. | ||
tool_call: The original tool call message. | ||
Returns: | ||
Result of either the validated result data (ok) or a retry message (Err). | ||
""" | ||
if self._takes_ctx: | ||
args = CallContext(deps, retry), result | ||
else: | ||
args = (result,) | ||
|
||
try: | ||
if self._is_async: | ||
function = cast(Callable[[Any], Awaitable[ResultData]], self.function) | ||
result_data = await function(*args) | ||
else: | ||
function = cast(Callable[[Any], ResultData], self.function) | ||
result_data = await _utils.run_in_executor(function, *args) | ||
except Retry as r: | ||
m = messages.ToolRetry( | ||
tool_name=tool_call.tool_name, | ||
content=r.message, | ||
tool_id=tool_call.tool_id, | ||
) | ||
raise ToolRetryError(m) from r | ||
else: | ||
return result_data | ||
|
||
|
||
class ToolRetryError(Exception): | ||
"""Internal exception used to indicate a signal a `ToolRetry` message should be returned to the LLM""" | ||
|
||
def __init__(self, tool_retry: messages.ToolRetry): | ||
self.tool_retry = tool_retry | ||
super().__init__() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.