Skip to content

Commit

Permalink
Binary and Numpy data serialization for Chains. Fixes BT-10089 (baset…
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten authored Dec 4, 2024
1 parent d6fcf15 commit 017ecf3
Show file tree
Hide file tree
Showing 15 changed files with 725 additions and 393 deletions.
7 changes: 2 additions & 5 deletions truss-chains/examples/audio-transcription/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ class DeployedWhisper(chains.StubBase):
async def run_remote(
self, whisper_input: data_types.WhisperInput
) -> data_types.WhisperResult:
resp = await self._remote.predict_async(
json_payload={"whisper_input": whisper_input.model_dump()},
)
return data_types.WhisperResult.parse_obj(resp)
return await self.predict_async(whisper_input, data_types.WhisperResult)


class MacroChunkWorker(chains.ChainletBase):
Expand Down Expand Up @@ -93,7 +90,7 @@ async def run_remote(
t1 = time.time()
return data_types.SegmentList(
segments=segments,
chunk_info=macro_chunk.copy(update={"processing_duration": t1 - t0}),
chunk_info=macro_chunk.model_copy(update={"processing_duration": t1 - t0}),
)


Expand Down
92 changes: 92 additions & 0 deletions truss-chains/examples/numpy_and_binary/chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import numpy as np
import pydantic

import truss_chains as chains
from truss_chains import pydantic_numpy


class DataModel(pydantic.BaseModel):
msg: str
np_array: pydantic_numpy.NumpyArrayField


class SyncChainlet(chains.ChainletBase):
def run_remote(self, data: DataModel) -> DataModel:
print(data)
return data.model_copy(update={"msg": "From sync"})


class AsyncChainlet(chains.ChainletBase):
async def run_remote(self, data: DataModel) -> DataModel:
print(data)
return data.model_copy(update={"msg": "From async"})


class AsyncChainletNoInput(chains.ChainletBase):
async def run_remote(self) -> DataModel:
data = DataModel(msg="From async no input", np_array=np.full((2, 2), 3))
print(data)
return data


class AsyncChainletNoOutput(chains.ChainletBase):
async def run_remote(self, data: DataModel) -> None:
print(data)


class HostJSON(chains.ChainletBase):
"""Calls various chainlets in JSON mode."""

def __init__(
self,
sync_chainlet=chains.depends(SyncChainlet, use_binary=False),
async_chainlet=chains.depends(AsyncChainlet, use_binary=False),
async_chainlet_no_output=chains.depends(
AsyncChainletNoOutput, use_binary=False
),
async_chainlet_no_input=chains.depends(AsyncChainletNoInput, use_binary=False),
):
self._sync_chainlet = sync_chainlet
self._async_chainlet = async_chainlet
self._async_chainlet_no_output = async_chainlet_no_output
self._async_chainlet_no_input = async_chainlet_no_input

async def run_remote(self) -> tuple[DataModel, DataModel, DataModel]:
a = np.ones((3, 2, 1))
data = DataModel(msg="From Host", np_array=a)
sync_result = self._sync_chainlet.run_remote(data)
print(sync_result)
async_result = await self._async_chainlet.run_remote(data)
print(async_result)
await self._async_chainlet_no_output.run_remote(data)
async_no_input = await self._async_chainlet_no_input.run_remote()
print(async_no_input)
return sync_result, async_result, async_no_input


class HostBinary(chains.ChainletBase):
"""Calls various chainlets in binary mode."""

def __init__(
self,
sync_chainlet=chains.depends(SyncChainlet, use_binary=True),
async_chainlet=chains.depends(AsyncChainlet, use_binary=True),
async_chainlet_no_output=chains.depends(AsyncChainletNoOutput, use_binary=True),
async_chainlet_no_input=chains.depends(AsyncChainletNoInput, use_binary=True),
):
self._sync_chainlet = sync_chainlet
self._async_chainlet = async_chainlet
self._async_chainlet_no_output = async_chainlet_no_output
self._async_chainlet_no_input = async_chainlet_no_input

async def run_remote(self) -> tuple[DataModel, DataModel, DataModel]:
a = np.ones((3, 2, 1))
data = DataModel(msg="From Host", np_array=a)
sync_result = self._sync_chainlet.run_remote(data)
print(sync_result)
async_result = await self._async_chainlet.run_remote(data)
print(async_result)
await self._async_chainlet_no_output.run_remote(data)
async_no_input = await self._async_chainlet_no_input.run_remote()
print(async_no_input)
return sync_result, async_result, async_no_input
4 changes: 2 additions & 2 deletions truss-chains/examples/rag/rag_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ async def run_remote(self, new_bio: str, bios: list[str]) -> str:
f"{PERSON_MATCHING_PROMPT}\nPerson you're matching: {new_bio}\n"
f"People from database: {bios_info}"
)
resp = await self._remote.predict_async(
json_payload={
resp = await self.predict_async(
{
"messages": [{"role": "user", "content": prompt}],
"stream": False,
"max_new_tokens": 32,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,30 +123,31 @@ async def test_chain_local():

@pytest.mark.integration
def test_streaming_chain():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
service = remote.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
chain_name="stream",
only_generate_trusses=False,
use_local_chains_src=True,
),
)
assert service is not None
response = service.run_remote({})
assert response.status_code == 200
print(response.json())
result = response.json()
print(result)
assert result["header"]["msg"] == "Start."
assert result["chunks"][0]["words"] == ["G"]
assert result["chunks"][1]["words"] == ["G", "HH"]
assert result["chunks"][2]["words"] == ["G", "HH", "III"]
assert result["chunks"][3]["words"] == ["G", "HH", "III", "JJJJ"]
assert result["footer"]["duration_sec"] > 0
assert result["strings"] == "First second last."
with ensure_kill_all():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
service = remote.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
chain_name="integration-test-stream",
only_generate_trusses=False,
use_local_chains_src=True,
),
)
assert service is not None
response = service.run_remote({})
assert response.status_code == 200
print(response.json())
result = response.json()
print(result)
assert result["header"]["msg"] == "Start."
assert result["chunks"][0]["words"] == ["G"]
assert result["chunks"][1]["words"] == ["G", "HH"]
assert result["chunks"][2]["words"] == ["G", "HH", "III"]
assert result["chunks"][3]["words"] == ["G", "HH", "III", "JJJJ"]
assert result["footer"]["duration_sec"] > 0
assert result["strings"] == "First second last."


@pytest.mark.asyncio
Expand All @@ -164,3 +165,28 @@ async def test_streaming_chain_local():
assert result.chunks[3].words == ["G", "HH", "III", "JJJJ"]
assert result.footer.duration_sec > 0
assert result.strings == "First second last."


@pytest.mark.integration
@pytest.mark.parametrize("mode", ["json", "binary"])
def test_numpy_chain(mode):
if mode == "json":
target = "HostJSON"
else:
target = "HostBinary"
with ensure_kill_all():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "numpy_and_binary" / "chain.py"
with framework.import_target(chain_root, target) as entrypoint:
service = remote.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
chain_name=f"integration-test-numpy-{mode}",
only_generate_trusses=False,
use_local_chains_src=True,
),
)
assert service is not None
response = service.run_remote({})
assert response.status_code == 200
print(response.json())
35 changes: 15 additions & 20 deletions truss-chains/truss_chains/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,36 +251,32 @@ def _stub_endpoint_body_src(
E.g.:
```
json_result = await self._remote.predict_async(
SplitTextInput(inputs=inputs, extra_arg=extra_arg).model_dump())
return SplitTextOutput.model_validate(json_result).output
return await self.predict_async(
SplitTextInput(inputs=inputs, extra_arg=extra_arg), SplitTextOutput).root
```
"""
imports: set[str] = set()
args = [f"{arg.name}={arg.name}" for arg in endpoint.input_args]
if args:
inputs = (
f"{_get_input_model_name(chainlet_name)}({', '.join(args)}).model_dump()"
)
inputs = f"{_get_input_model_name(chainlet_name)}({', '.join(args)})"
else:
inputs = "{}"

parts = []
# Invoke remote.
if not endpoint.is_streaming:
output_model_name = _get_output_model_name(chainlet_name)
if endpoint.is_async:
remote_call = f"await self._remote.predict_async({inputs})"
parts = [
f"return (await self.predict_async({inputs}, {output_model_name})).root"
]
else:
remote_call = f"self._remote.predict_sync({inputs})"
parts = [f"return self.predict_sync({inputs}, {output_model_name}).root"]

parts = [f"json_result = {remote_call}"]
# Unpack response and parse as pydantic models if needed.
output_model_name = _get_output_model_name(chainlet_name)
parts.append(f"return {output_model_name}.model_validate(json_result).root")
else:
if endpoint.is_async:
parts.append(
f"async for data in await self._remote.predict_async_stream({inputs}):",
f"async for data in await self.predict_async_stream({inputs}):",
)
if endpoint.streaming_type.is_string:
parts.append(_indent("yield data.decode()"))
Expand Down Expand Up @@ -312,9 +308,8 @@ class SplitText(stub.StubBase):
async def run_remote(
self, inputs: shared_chainlet.SplitTextInput, extra_arg: int
) -> tuple[shared_chainlet.SplitTextOutput, int]:
json_result = await self._remote.predict_async(
SplitTextInput(inputs=inputs, extra_arg=extra_arg).model_dump())
return SplitTextOutput.model_validate(json_result).root
return await self.predict_async(
SplitTextInput(inputs=inputs, extra_arg=extra_arg), SplitTextOutput).root
```
"""
imports = {"from truss_chains import stub"}
Expand Down Expand Up @@ -428,7 +423,7 @@ def _gen_load_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _So

def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source:
"""Generates AST for the `predict` method of the truss model."""
imports: set[str] = {"from truss_chains import utils"}
imports: set[str] = {"from truss_chains import stub"}
parts: list[str] = []
def_str = "async def" if chainlet_descriptor.endpoint.is_async else "def"
input_model_name = _get_input_model_name(chainlet_descriptor.name)
Expand All @@ -440,16 +435,16 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) ->
output_type_name = streaming_src.src
else:
output_type_name = _get_output_model_name(chainlet_descriptor.name)

imports.add("import starlette.requests")
imports.add("from truss_chains import stub")
parts.append(
f"{def_str} predict(self, inputs: {input_model_name}, "
f"request: starlette.requests.Request) -> {output_type_name}:"
)
# Add error handling context manager:
parts.append(
_indent(
f"with stub.trace_parent(request), utils.exception_to_http_error("
f"with stub.trace_parent(request), stub.exception_to_http_error("
f'chainlet_name="{chainlet_descriptor.name}"):'
)
)
Expand All @@ -463,7 +458,7 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) ->
maybe_await = ""
run_remote = chainlet_descriptor.endpoint.name
# See docs of `pydantic_set_field_dict` for why this is needed.
args = "**utils.pydantic_set_field_dict(inputs)"
args = "**stub.pydantic_set_field_dict(inputs)"
parts.append(
_indent(f"result = {maybe_await}self._chainlet.{run_remote}({args})", 2)
)
Expand Down
19 changes: 16 additions & 3 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def __init__(

def get_spec(self) -> AssetSpec:
"""Returns parsed and validated assets."""
return self._spec.copy(deep=True)
return self._spec.model_copy(deep=True)


class ChainletOptions(SafeModelNonSerializable):
Expand Down Expand Up @@ -397,10 +397,23 @@ def get_asset_spec(self) -> AssetSpec:


class RPCOptions(SafeModel):
"""Options to customize RPCs to dependency chainlets."""
"""Options to customize RPCs to dependency chainlets.
Args:
retries: The number of times to retry the remote chainlet in case of failures
(e.g. due to transient network issues). For streaming, retries are only made
if the request fails before streaming any results back. Failures mid-stream
not retried.
timeout_sec: Timeout for the HTTP request to this chainlet.
use_binary: whether to send data data in binary format. This can give a parsing
speedup and message size reduction (~25%) for numpy arrays. Use
``NumpyArrayField`` as a field type on pydantic models for integration and set
this option to ``True``. For simple text data, there is no significant benefit.
"""

timeout_sec: int = DEFAULT_TIMEOUT_SEC
retries: int = 1
timeout_sec: int = DEFAULT_TIMEOUT_SEC
use_binary: bool = False


class ServiceDescriptor(SafeModel):
Expand Down
13 changes: 11 additions & 2 deletions truss-chains/truss_chains/public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def depends(
chainlet_cls: Type[framework.ChainletT],
retries: int = 1,
timeout_sec: int = definitions.DEFAULT_TIMEOUT_SEC,
use_binary: bool = False,
) -> framework.ChainletT:
"""Sets a "symbolic marker" to indicate to the framework that a chainlet is a
dependency of another chainlet. The return value of ``depends`` is intended to be
Expand All @@ -58,14 +59,22 @@ def depends(
Args:
chainlet_cls: The chainlet class of the dependency.
retries: The number of times to retry the remote chainlet in case of failures
(e.g. due to transient network issues).
(e.g. due to transient network issues). For streaming, retries are only made
if the request fails before streaming any results back. Failures mid-stream
not retried.
timeout_sec: Timeout for the HTTP request to this chainlet.
use_binary: whether to send data data in binary format. This can give a parsing
speedup and message size reduction (~25%) for numpy arrays. Use
``NumpyArrayField`` as a field type on pydantic models for integration and set
this option to ``True``. For simple text data, there is no significant benefit.
Returns:
A "symbolic marker" to be used as a default argument in a chainlet's
initializer.
"""
options = definitions.RPCOptions(retries=retries, timeout_sec=timeout_sec)
options = definitions.RPCOptions(
retries=retries, timeout_sec=timeout_sec, use_binary=use_binary
)
# The type error is silenced to because chains framework will at runtime inject
# a corresponding instance. Nonetheless, we want to use a type annotation here,
# to facilitate type inference, code-completion and type checking within the code
Expand Down
Loading

0 comments on commit 017ecf3

Please sign in to comment.