Skip to content

Commit

Permalink
Use raw OTel and actual event loggers in InstrumentedModel (#945)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmojaki authored Feb 21, 2025
1 parent dfc919c commit 4a472bd
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 322 deletions.
59 changes: 42 additions & 17 deletions pydantic_ai_slim/pydantic_ai/models/instrumented.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from collections.abc import AsyncIterator
from collections.abc import AsyncIterator, Iterator
from contextlib import asynccontextmanager, contextmanager
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Literal
from typing import Any, Callable, Literal

import logfire_api
from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider

from ..messages import (
ModelMessage,
Expand All @@ -22,7 +24,7 @@
)
from ..settings import ModelSettings
from ..usage import Usage
from . import ModelRequestParameters, StreamedResponse
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse
from .wrapper import WrapperModel

MODEL_SETTING_ATTRIBUTES: tuple[
Expand Down Expand Up @@ -51,10 +53,33 @@
class InstrumentedModel(WrapperModel):
"""Model which is instrumented with logfire."""

logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE
tracer: Tracer = field(repr=False)
event_logger: EventLogger = field(repr=False)

def __post_init__(self):
self.logfire_instance = self.logfire_instance.with_settings(custom_scope_suffix='pydantic_ai')
def __init__(
self,
wrapped: Model | KnownModelName,
tracer_provider: TracerProvider | None = None,
event_logger_provider: EventLoggerProvider | None = None,
):
super().__init__(wrapped)
tracer_provider = tracer_provider or get_tracer_provider()
event_logger_provider = event_logger_provider or get_event_logger_provider()
self.tracer = tracer_provider.get_tracer('pydantic-ai')
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')

@classmethod
def from_logfire(
cls,
wrapped: Model | KnownModelName,
logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
) -> InstrumentedModel:
if hasattr(logfire_instance.config, 'get_event_logger_provider'):
event_provider = logfire_instance.config.get_event_logger_provider()
else:
event_provider = None
tracer_provider = logfire_instance.config.get_tracer_provider()
return cls(wrapped, tracer_provider, event_provider)

async def request(
self,
Expand Down Expand Up @@ -90,7 +115,7 @@ def _instrument(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
):
) -> Iterator[Callable[[ModelResponse, Usage], None]]:
operation = 'chat'
model_name = self.model_name
span_name = f'{operation} {model_name}'
Expand All @@ -114,7 +139,7 @@ def _instrument(

emit_event = partial(self._emit_event, system)

with self.logfire_instance.span(span_name, **attributes) as span:
with self.tracer.start_as_current_span(span_name, attributes=attributes) as span:
if span.is_recording():
for message in messages:
if isinstance(message, ModelRequest):
Expand Down Expand Up @@ -157,27 +182,27 @@ def finish(response: ModelResponse, usage: Usage):
yield finish

def _emit_event(self, system: str, event_name: str, body: dict[str, Any]) -> None:
self.logfire_instance.info(event_name, **{'gen_ai.system': system}, **body)
self.event_logger.emit(Event(event_name, body=body, attributes={'gen_ai.system': system}))


def _request_part_body(part: ModelRequestPart) -> tuple[str, dict[str, Any]]:
if isinstance(part, SystemPromptPart):
return 'gen_ai.system.message', {'content': part.content}
return 'gen_ai.system.message', {'content': part.content, 'role': 'system'}
elif isinstance(part, UserPromptPart):
return 'gen_ai.user.message', {'content': part.content}
return 'gen_ai.user.message', {'content': part.content, 'role': 'user'}
elif isinstance(part, ToolReturnPart):
return 'gen_ai.tool.message', {'content': part.content, 'id': part.tool_call_id}
return 'gen_ai.tool.message', {'content': part.content, 'role': 'tool', 'id': part.tool_call_id}
elif isinstance(part, RetryPromptPart):
if part.tool_name is None:
return 'gen_ai.user.message', {'content': part.model_response()}
return 'gen_ai.user.message', {'content': part.model_response(), 'role': 'user'}
else:
return 'gen_ai.tool.message', {'content': part.model_response(), 'id': part.tool_call_id}
return 'gen_ai.tool.message', {'content': part.model_response(), 'role': 'tool', 'id': part.tool_call_id}
else:
return '', {}


def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
body: dict[str, Any] = {}
body: dict[str, Any] = {'role': 'assistant'}
result = [body]
for part in message.parts:
if isinstance(part, ToolCallPart):
Expand All @@ -193,7 +218,7 @@ def _response_bodies(message: ModelResponse) -> list[dict[str, Any]]:
)
elif isinstance(part, TextPart):
if body.get('content'):
body = {}
body = {'role': 'assistant'}
result.append(body)
body['content'] = part.content

Expand Down
7 changes: 5 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
from ..messages import ModelMessage, ModelResponse
from ..settings import ModelSettings
from ..usage import Usage
from . import Model, ModelRequestParameters, StreamedResponse
from . import KnownModelName, Model, ModelRequestParameters, StreamedResponse, infer_model


@dataclass
@dataclass(init=False)
class WrapperModel(Model):
"""Model which wraps another model."""

wrapped: Model

def __init__(self, wrapped: Model | KnownModelName):
self.wrapped = infer_model(wrapped)

async def request(self, *args: Any, **kwargs: Any) -> tuple[ModelResponse, Usage]:
return await self.wrapped.request(*args, **kwargs)

Expand Down
Loading

0 comments on commit 4a472bd

Please sign in to comment.