Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement streaming for assistants #215

Merged
merged 22 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ classifiers = [
requires-python = ">=3.9"
dependencies = [
"aiofiles",
"anyio",
"emoji",
"fastapi",
"httpx",
"httpx_sse",
"importlib_metadata>=4.6; python_version<'3.10'",
"packaging",
"panel>=1.3.6,<1.4",
Expand All @@ -38,6 +38,8 @@ dependencies = [
"questionary",
"rich",
"sqlalchemy>=2",
"sse-starlette",
"starlette",
"tomlkit",
"typer",
"uvicorn",
Expand Down
29 changes: 26 additions & 3 deletions ragna/_compat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import builtins
import sys
from typing import Callable, Iterable, Iterator, Mapping, TypeVar

__all__ = ["itertools_pairwise", "importlib_metadata_package_distributions"]
from typing import (
AsyncIterator,
Awaitable,
Callable,
Iterable,
Iterator,
Mapping,
TypeVar,
)

__all__ = ["itertools_pairwise", "importlib_metadata_package_distributions", "anext"]

T = TypeVar("T")

Expand Down Expand Up @@ -38,3 +47,17 @@ def _importlib_metadata_package_distributions() -> (


importlib_metadata_package_distributions = _importlib_metadata_package_distributions()


def _anext() -> Callable[[AsyncIterator[T]], Awaitable[T]]:
if sys.version_info[:2] >= (3, 10):
anext = builtins.anext
else:

async def anext(ait: AsyncIterator[T]) -> T:
return await ait.__anext__()

return anext


anext = _anext()
31 changes: 21 additions & 10 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import cast
import json
from typing import AsyncIterator, cast

import httpx_sse

from ragna.core import RagnaException, Source

Expand Down Expand Up @@ -30,9 +33,11 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> str:
# # See https://docs.anthropic.com/claude/reference/complete_post
response = await self._client.post(
) -> AsyncIterator[str]:
# See https://docs.anthropic.com/claude/reference/streaming
async with httpx_sse.aconnect_sse(
self._client,
"POST",
"https://api.anthropic.com/v1/complete",
headers={
"accept": "application/json",
Expand All @@ -45,13 +50,19 @@ async def _call_api(
"prompt": self._instructize_prompt(prompt, sources),
"max_tokens_to_sample": max_new_tokens,
"temperature": 0.0,
"stream": True,
},
)
if response.is_error:
raise RagnaException(
status_code=response.status_code, response=response.json()
)
return cast(str, response.json()["completion"])
) as event_source:
async for sse in event_source.aiter_sse():
data = json.loads(sse.data)
if data["type"] != "completion":
continue
elif "error" in data:
raise RagnaException(data["error"].pop("message"), **data["error"])
elif data["stop_reason"] is not None:
break

yield cast(str, data["completion"])


class ClaudeInstant(AnthropicApiAssistant):
Expand Down
14 changes: 9 additions & 5 deletions ragna/assistants/_api.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import abc
import os
from typing import AsyncIterator

import httpx

import ragna
from ragna.core import Assistant, EnvVarRequirement, Requirement, Source
Expand All @@ -13,8 +16,6 @@ def requirements(cls) -> list[Requirement]:
return [EnvVarRequirement(cls._API_KEY_ENV_VAR)]

def __init__(self) -> None:
import httpx

self._client = httpx.AsyncClient(
headers={"User-Agent": f"{ragna.__version__}/{self}"},
timeout=60,
Expand All @@ -23,11 +24,14 @@ def __init__(self) -> None:

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
) -> str:
return await self._call_api(prompt, sources, max_new_tokens=max_new_tokens)
) -> AsyncIterator[str]:
async for chunk in self._call_api( # type: ignore[attr-defined, misc]
prompt, sources, max_new_tokens=max_new_tokens
):
yield chunk

@abc.abstractmethod
async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> str:
) -> AsyncIterator[str]:
...
7 changes: 4 additions & 3 deletions ragna/assistants/_demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import re
import sys
import textwrap
from typing import Iterator

from ragna.core import Assistant, Source

Expand All @@ -26,11 +27,11 @@ def display_name(cls) -> str:
def max_input_size(self) -> int:
return sys.maxsize

def answer(self, prompt: str, sources: list[Source]) -> str:
def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
if re.search("markdown", prompt, re.IGNORECASE):
return self._markdown_answer()
yield self._markdown_answer()
else:
return self._default_answer(prompt, sources)
yield self._default_answer(prompt, sources)

def _markdown_answer(self) -> str:
return textwrap.dedent(
Expand Down
6 changes: 3 additions & 3 deletions ragna/assistants/_mosaicml.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import cast
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source

Expand Down Expand Up @@ -29,7 +29,7 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> str:
) -> AsyncIterator[str]:
instruction = self._instructize_prompt(prompt, sources)
# https://docs.mosaicml.com/en/latest/inference.html#text-completion-requests
response = await self._client.post(
Expand All @@ -47,7 +47,7 @@ async def _call_api(
raise RagnaException(
status_code=response.status_code, response=response.json()
)
return cast(str, response.json()["outputs"][0]).replace(instruction, "").strip()
yield cast(str, response.json()["outputs"][0]).replace(instruction, "").strip()


class Mpt7bInstruct(MosaicmlApiAssistant):
Expand Down
33 changes: 19 additions & 14 deletions ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import cast
import json
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source
import httpx_sse

from ragna.core import Source

from ._api import ApiAssistant

Expand Down Expand Up @@ -29,10 +32,12 @@ def _make_system_content(self, sources: list[Source]) -> str:

async def _call_api(
self, prompt: str, sources: list[Source], *, max_new_tokens: int
) -> str:
) -> AsyncIterator[str]:
# See https://platform.openai.com/docs/api-reference/chat/create
# and https://platform.openai.com/docs/api-reference/chat/object
response = await self._client.post(
# and https://platform.openai.com/docs/api-reference/chat/streaming
async with httpx_sse.aconnect_sse(
self._client,
"POST",
"https://api.openai.com/v1/chat/completions",
headers={
"Content-Type": "application/json",
Expand All @@ -52,13 +57,16 @@ async def _call_api(
"model": self._MODEL,
"temperature": 0.0,
"max_tokens": max_new_tokens,
"stream": True,
},
)
if response.is_error:
raise RagnaException(
status_code=response.status_code, response=response.json()
)
return cast(str, response.json()["choices"][0]["message"]["content"])
) as event_source:
async for sse in event_source.aiter_sse():
data = json.loads(sse.data)
choice = data["choices"][0]
if choice["finish_reason"] is not None:
break

yield cast(str, choice["delta"]["content"])


class Gpt35Turbo16k(OpenaiApiAssistant):
Expand All @@ -73,9 +81,6 @@ class Gpt35Turbo16k(OpenaiApiAssistant):
_CONTEXT_SIZE = 16_384


Gpt35Turbo16k.__doc__ = "OOPS"


class Gpt4(OpenaiApiAssistant):
"""[OpenAI GPT-4](https://platform.openai.com/docs/models/gpt-4)

Expand Down
61 changes: 54 additions & 7 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import enum
import functools
import inspect
from typing import Type
from typing import AsyncIterable, AsyncIterator, Iterator, Optional, Type, Union

import pydantic
import pydantic.utils
Expand Down Expand Up @@ -138,11 +138,10 @@ class MessageRole(enum.Enum):
ASSISTANT = "assistant"


class Message(pydantic.BaseModel):
class Message:
"""Data class for messages.

Attributes:
content: The content of the message.
role: The message producer.
sources: The sources used to produce the message.

Expand All @@ -152,13 +151,61 @@ class Message(pydantic.BaseModel):
- [ragna.core.Chat.answer][]
"""

content: str
role: MessageRole
sources: list[Source] = pydantic.Field(default_factory=list)
def __init__(
self,
content: Union[str, AsyncIterable[str]],
*,
role: MessageRole = MessageRole.SYSTEM,
sources: Optional[list[Source]] = None,
) -> None:
if isinstance(content, str):
self._content: str = content
else:
self._content_stream: AsyncIterable[str] = content

self.role = role
self.sources = sources or []

async def __aiter__(self) -> AsyncIterator[str]:
if hasattr(self, "_content"):
yield self._content
return

chunks = []
async for chunk in self._content_stream:
chunks.append(chunk)
yield chunk

self._content = "".join(chunks)

async def read(self) -> str:
if not hasattr(self, "_content"):
# Since self.__aiter__ is already setting the self._content attribute, we
# only need to exhaust the content stream here.
async for _ in self:
pass
return self._content

@property
def content(self) -> str:
if not hasattr(self, "_content"):
raise RuntimeError(
"Message content cannot be accessed without having iterated over it, "
"e.g. `async for chunk in message`, or reading the content, e.g. "
"`await message.read()`, first."
)
return self._content

def __str__(self) -> str:
return self.content

def __repr__(self) -> str:
return (
f"{type(self).__name__}("
f"content={self.content}, role={self.role}, sources={self.sources}"
f")"
)


class Assistant(Component, abc.ABC):
"""Abstract base class for assistants used in [ragna.core.Chat][]"""
Expand All @@ -171,7 +218,7 @@ def max_input_size(self) -> int:
...

@abc.abstractmethod
def answer(self, prompt: str, sources: list[Source]) -> str:
def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
"""Answer a prompt given some sources.

Args:
Expand Down
Loading
Loading