Skip to content

Commit

Permalink
add UA to requests (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin authored Nov 18, 2024
1 parent 6b1a111 commit 3c3041f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 2 deletions.
13 changes: 12 additions & 1 deletion pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,15 @@ def cached_async_http_client(timeout: int = 600, connect: int = 5) -> httpx.Asyn
The default timeouts match those of OpenAI,
see <https://github.com/openai/openai-python/blob/v1.54.4/src/openai/_constants.py#L9>.
"""
return httpx.AsyncClient(timeout=httpx.Timeout(timeout=timeout, connect=connect))
return httpx.AsyncClient(
timeout=httpx.Timeout(timeout=timeout, connect=connect),
headers={'User-Agent': get_user_agent()},
)


@cache
def get_user_agent() -> str:
"""Get the user agent string for the HTTP client."""
from .. import __version__

return f'pydantic-ai/{__version__}'
2 changes: 2 additions & 0 deletions pydantic_ai/models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
StreamTextResponse,
cached_async_http_client,
check_allow_model_requests,
get_user_agent,
)

GeminiModelName = Literal['gemini-1.5-flash', 'gemini-1.5-flash-8b', 'gemini-1.5-pro', 'gemini-1.0-pro']
Expand Down Expand Up @@ -182,6 +183,7 @@ async def _make_request(self, messages: list[Message], streamed: bool) -> AsyncI
headers = {
'X-Goog-Api-Key': self.api_key,
'Content-Type': 'application/json',
'User-Agent': get_user_agent(),
}
url = self.url_template.format(
model=self.model_name, function='streamGenerateContent' if streamed else 'generateContent'
Expand Down
5 changes: 4 additions & 1 deletion tests/models/test_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,9 +344,12 @@ async def get_gemini_client(
def create_client(response_or_list: ResOrList) -> httpx.AsyncClient:
index = 0

def handler(_request: httpx.Request) -> httpx.Response:
def handler(request: httpx.Request) -> httpx.Response:
nonlocal index

ua = request.headers.get('User-Agent')
assert isinstance(ua, str) and ua.startswith('pydantic-ai')

if isinstance(response_or_list, Sequence):
response = response_or_list[index]
index += 1
Expand Down

0 comments on commit 3c3041f

Please sign in to comment.