Skip to content

Commit

Permalink
handle older logfire, make providers optional
Browse files Browse the repository at this point in the history
  • Loading branch information
alexmojaki committed Feb 20, 2025
1 parent e4b9716 commit 8890086
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions pydantic_ai_slim/pydantic_ai/models/instrumented.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from typing import Any, Callable, Literal

import logfire_api
from opentelemetry._events import Event, EventLogger, EventLoggerProvider
from opentelemetry.trace import Tracer, TracerProvider
from opentelemetry._events import Event, EventLogger, EventLoggerProvider, get_event_logger_provider
from opentelemetry.trace import Tracer, TracerProvider, get_tracer_provider

from ..messages import (
ModelMessage,
Expand Down Expand Up @@ -59,10 +59,12 @@ class InstrumentedModel(WrapperModel):
def __init__(
self,
wrapped: Model | KnownModelName,
tracer_provider: TracerProvider,
event_logger_provider: EventLoggerProvider,
tracer_provider: TracerProvider | None = None,
event_logger_provider: EventLoggerProvider | None = None,
):
super().__init__(wrapped)
tracer_provider = tracer_provider or get_tracer_provider()
event_logger_provider = event_logger_provider or get_event_logger_provider()
self.tracer = tracer_provider.get_tracer('pydantic-ai')
self.event_logger = event_logger_provider.get_event_logger('pydantic-ai')

Expand All @@ -72,11 +74,12 @@ def from_logfire(
wrapped: Model | KnownModelName,
logfire_instance: logfire_api.Logfire = logfire_api.DEFAULT_LOGFIRE_INSTANCE,
) -> InstrumentedModel:
return cls(
wrapped,
logfire_instance.config.get_tracer_provider(),
logfire_instance.config.get_event_logger_provider(),
)
if hasattr(logfire_instance.config, 'get_event_logger_provider'):
event_provider = logfire_instance.config.get_event_logger_provider()
else:
event_provider = None
tracer_provider = logfire_instance.config.get_tracer_provider()
return cls(wrapped, tracer_provider, event_provider)

async def request(
self,
Expand Down

0 comments on commit 8890086

Please sign in to comment.