Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use raw OTel and actual event loggers in InstrumentedModel #945

Merged
merged 12 commits into from
Feb 21, 2025
59 changes: 42 additions & 17 deletions pydantic_ai_slim/pydantic_ai/models/instrumented.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Iterator
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Literal
from typing import Any, Callable, Literal

import logfire_api
from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider

from ..messages import (
ModelMessage,
Expand All @@ -22,7 +24,7 @@
)
from ..settings import ModelSettings
from ..usage import Usage
from . import ModelRequestParameters, StreamedResponse
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
from .wrapper import WrapperModel

MODEL_SETTING_ATTRIBUTES: tuple[
Expand Down Expand Up @@ -51,10 +53,33 @@
class InstrumentedModel(WrapperModel):
"""Model which is instrumented with logfire."""

logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE
tracer: Tracer = field(repr=False)
event_logger: EventLogger = field(repr=False)

def __post_init__(self):
self.logfire_instance = self.logfire_instance.with_settings(custom_scope_suffix='pydantic_ai')
def __init__(
self,
wrapped: Model | KnownModelName,
tracer_provider: TracerProvider | None = None,
event_logger_provider: EventLoggerProvider | None = None,
):
super().__init__(wrapped)
tracer_provider = tracer_provider or get_tracer_provider()
event_logger_provider = event_logger_provider or get_event_logger_provider()
self.tracer = tracer_provider.get_tracer('pydantic-ai')
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')

@classmethod
def from_logfire(
cls,
wrapped: Model | KnownModelName,
logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
) -> InstrumentedModel:
if hasattr(logfire_instance.config, 'get_event_logger_provider'):
event_provider = logfire_instance.config.get_event_logger_provider()
else:
event_provider = None
tracer_provider = logfire_instance.config.get_tracer_provider()
return cls(wrapped, tracer_provider, event_provider)

async def request(
self,
Expand Down Expand Up @@ -90,7 +115,7 @@ def _instrument(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
):
) -> Iterator[Callable[[ModelResponse, Usage], None]]:
operation = 'chat'
model_name = self.model_name
span_name = f'{operation} {model_name}'
Expand All @@ -114,7 +139,7 @@ def _instrument(

emit_event = partial(self._emit_event, system)

with self.logfire_instance.span(span_name, **attributes) as span:
with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
if span.is_recording():
for message in messages:
if isinstance(message, ModelRequest):
Expand Down Expand Up @@ -157,27 +182,27 @@ def finish(response: ModelResponse, usage: Usage):
yield finish

def _emit_event(self, system: str, event_name: str, body: dict[str, Any]) -> None:
self.logfire_instance.info(event_name, **{'gen_ai.system': system}, **body)
self.event_logger.emit(Event(event_name, body=body, attributes={'gen_ai.system': system}))


def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
if isinstance(part, SystemPromptPart):
return 'gen_ai.system.message', {'content': part.content}
return 'gen_ai.system.message', {'content': part.content, 'role': 'system'}
elif isinstance(part, UserPromptPart):
return 'gen_ai.user.message', {'content': part.content}
return 'gen_ai.user.message', {'content': part.content, 'role': 'user'}
elif isinstance(part, ToolReturnPart):
return 'gen_ai.tool.message', {'content': part.content, 'id': part.tool_call_id}
return 'gen_ai.tool.message', {'content': part.content, 'role': 'tool', 'id': part.tool_call_id}
elif isinstance(part, RetryPromptPart):
if part.tool_name is None:
return 'gen_ai.user.message', {'content': part.model_response()}
return 'gen_ai.user.message', {'content': part.model_response(), 'role': 'user'}
else:
return 'gen_ai.tool.message', {'content': part.model_response(), 'id': part.tool_call_id}
return 'gen_ai.tool.message', {'content': part.model_response(), 'role': 'tool', 'id': part.tool_call_id}
else:
return '', {}


def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
body: dict[str, Any] = {}
body: dict[str, Any] = {'role': 'assistant'}
result = [body]
for part in message.parts:
if isinstance(part, ToolCallPart):
Expand All @@ -193,7 +218,7 @@ def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
)
elif isinstance(part, TextPart):
if body.get('content'):
body = {}
body = {'role': 'assistant'}
result.append(body)
body['content'] = part.content

Expand Down
7 changes: 5 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
from ..messages import ModelMessage, ModelResponse
from ..settings import ModelSettings
from ..usage import Usage
from . import Model, ModelRequestParameters, StreamedResponse
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model


@dataclass
@dataclass(init=False)
class WrapperModel(Model):
"""Model which wraps another model."""

wrapped: Model

def __init__(self, wrapped: Model | KnownModelName):
self.wrapped = infer_model(wrapped)

async def request(self, *args: Any, **kwargs: Any) -> tuple[ModelResponse, Usage]:
return await self.wrapped.request(*args, **kwargs)

Expand Down
Loading