diff --git a/truss-chains/examples/audio-transcription/transcribe.py b/truss-chains/examples/audio-transcription/transcribe.py index 4e4016b08..f9e06c29f 100644 --- a/truss-chains/examples/audio-transcription/transcribe.py +++ b/truss-chains/examples/audio-transcription/transcribe.py @@ -33,10 +33,7 @@ class DeployedWhisper(chains.StubBase): async def run_remote( self, whisper_input: data_types.WhisperInput ) -> data_types.WhisperResult: - resp = await self._remote.predict_async( - json_payload={"whisper_input": whisper_input.model_dump()}, - ) - return data_types.WhisperResult.parse_obj(resp) + return await self.predict_async(whisper_input, data_types.WhisperResult) class MacroChunkWorker(chains.ChainletBase): @@ -93,7 +90,7 @@ async def run_remote( t1 = time.time() return data_types.SegmentList( segments=segments, - chunk_info=macro_chunk.copy(update={"processing_duration": t1 - t0}), + chunk_info=macro_chunk.model_copy(update={"processing_duration": t1 - t0}), ) diff --git a/truss-chains/examples/numpy_and_binary/chain.py b/truss-chains/examples/numpy_and_binary/chain.py new file mode 100644 index 000000000..ebd8a7c92 --- /dev/null +++ b/truss-chains/examples/numpy_and_binary/chain.py @@ -0,0 +1,92 @@ +import numpy as np +import pydantic + +import truss_chains as chains +from truss_chains import pydantic_numpy + + +class DataModel(pydantic.BaseModel): + msg: str + np_array: pydantic_numpy.NumpyArrayField + + +class SyncChainlet(chains.ChainletBase): + def run_remote(self, data: DataModel) -> DataModel: + print(data) + return data.model_copy(update={"msg": "From sync"}) + + +class AsyncChainlet(chains.ChainletBase): + async def run_remote(self, data: DataModel) -> DataModel: + print(data) + return data.model_copy(update={"msg": "From async"}) + + +class AsyncChainletNoInput(chains.ChainletBase): + async def run_remote(self) -> DataModel: + data = DataModel(msg="From async no input", np_array=np.full((2, 2), 3)) + print(data) + return data + + +class AsyncChainletNoOutput(chains.ChainletBase): + async def run_remote(self, data: DataModel) -> None: + print(data) + + +class HostJSON(chains.ChainletBase): + """Calls various chainlets in JSON mode.""" + + def __init__( + self, + sync_chainlet=chains.depends(SyncChainlet, use_binary=False), + async_chainlet=chains.depends(AsyncChainlet, use_binary=False), + async_chainlet_no_output=chains.depends( + AsyncChainletNoOutput, use_binary=False + ), + async_chainlet_no_input=chains.depends(AsyncChainletNoInput, use_binary=False), + ): + self._sync_chainlet = sync_chainlet + self._async_chainlet = async_chainlet + self._async_chainlet_no_output = async_chainlet_no_output + self._async_chainlet_no_input = async_chainlet_no_input + + async def run_remote(self) -> tuple[DataModel, DataModel, DataModel]: + a = np.ones((3, 2, 1)) + data = DataModel(msg="From Host", np_array=a) + sync_result = self._sync_chainlet.run_remote(data) + print(sync_result) + async_result = await self._async_chainlet.run_remote(data) + print(async_result) + await self._async_chainlet_no_output.run_remote(data) + async_no_input = await self._async_chainlet_no_input.run_remote() + print(async_no_input) + return sync_result, async_result, async_no_input + + +class HostBinary(chains.ChainletBase): + """Calls various chainlets in binary mode.""" + + def __init__( + self, + sync_chainlet=chains.depends(SyncChainlet, use_binary=True), + async_chainlet=chains.depends(AsyncChainlet, use_binary=True), + async_chainlet_no_output=chains.depends(AsyncChainletNoOutput, use_binary=True), + async_chainlet_no_input=chains.depends(AsyncChainletNoInput, use_binary=True), + ): + self._sync_chainlet = sync_chainlet + self._async_chainlet = async_chainlet + self._async_chainlet_no_output = async_chainlet_no_output + self._async_chainlet_no_input = async_chainlet_no_input + + async def run_remote(self) -> tuple[DataModel, DataModel, DataModel]: + a = np.ones((3, 2, 1)) + data = DataModel(msg="From Host", np_array=a) + sync_result = self._sync_chainlet.run_remote(data) + print(sync_result) + async_result = await self._async_chainlet.run_remote(data) + print(async_result) + await self._async_chainlet_no_output.run_remote(data) + async_no_input = await self._async_chainlet_no_input.run_remote() + print(async_no_input) + return sync_result, async_result, async_no_input diff --git a/truss-chains/examples/rag/rag_chain.py b/truss-chains/examples/rag/rag_chain.py index a83ae65d0..1e32171fd 100644 --- a/truss-chains/examples/rag/rag_chain.py +++ b/truss-chains/examples/rag/rag_chain.py @@ -115,8 +115,8 @@ async def run_remote(self, new_bio: str, bios: list[str]) -> str: f"{PERSON_MATCHING_PROMPT}\nPerson you're matching: {new_bio}\n" f"People from database: {bios_info}" ) - resp = await self._remote.predict_async( - json_payload={ + resp = await self.predict_async( + { "messages": [{"role": "user", "content": prompt}], "stream": False, "max_new_tokens": 32, diff --git a/truss-chains/tests/chains_e2e_test.py b/truss-chains/tests/test_e2e.py similarity index 68% rename from truss-chains/tests/chains_e2e_test.py rename to truss-chains/tests/test_e2e.py index 29d7ca894..6a89f35bc 100644 --- a/truss-chains/tests/chains_e2e_test.py +++ b/truss-chains/tests/test_e2e.py @@ -123,30 +123,31 @@ async def test_chain_local(): @pytest.mark.integration def test_streaming_chain(): - examples_root = Path(__file__).parent.parent.resolve() / "examples" - chain_root = examples_root / "streaming" / "streaming_chain.py" - with framework.import_target(chain_root, "Consumer") as entrypoint: - service = remote.push( - entrypoint, - options=definitions.PushOptionsLocalDocker( - chain_name="stream", - only_generate_trusses=False, - use_local_chains_src=True, - ), - ) - assert service is not None - response = service.run_remote({}) - assert response.status_code == 200 - print(response.json()) - result = response.json() - print(result) - assert result["header"]["msg"] == "Start." - assert result["chunks"][0]["words"] == ["G"] - assert result["chunks"][1]["words"] == ["G", "HH"] - assert result["chunks"][2]["words"] == ["G", "HH", "III"] - assert result["chunks"][3]["words"] == ["G", "HH", "III", "JJJJ"] - assert result["footer"]["duration_sec"] > 0 - assert result["strings"] == "First second last." + with ensure_kill_all(): + examples_root = Path(__file__).parent.parent.resolve() / "examples" + chain_root = examples_root / "streaming" / "streaming_chain.py" + with framework.import_target(chain_root, "Consumer") as entrypoint: + service = remote.push( + entrypoint, + options=definitions.PushOptionsLocalDocker( + chain_name="integration-test-stream", + only_generate_trusses=False, + use_local_chains_src=True, + ), + ) + assert service is not None + response = service.run_remote({}) + assert response.status_code == 200 + print(response.json()) + result = response.json() + print(result) + assert result["header"]["msg"] == "Start." + assert result["chunks"][0]["words"] == ["G"] + assert result["chunks"][1]["words"] == ["G", "HH"] + assert result["chunks"][2]["words"] == ["G", "HH", "III"] + assert result["chunks"][3]["words"] == ["G", "HH", "III", "JJJJ"] + assert result["footer"]["duration_sec"] > 0 + assert result["strings"] == "First second last." @pytest.mark.asyncio @@ -164,3 +165,28 @@ async def test_streaming_chain_local(): assert result.chunks[3].words == ["G", "HH", "III", "JJJJ"] assert result.footer.duration_sec > 0 assert result.strings == "First second last." + + +@pytest.mark.integration +@pytest.mark.parametrize("mode", ["json", "binary"]) +def test_numpy_chain(mode): + if mode == "json": + target = "HostJSON" + else: + target = "HostBinary" + with ensure_kill_all(): + examples_root = Path(__file__).parent.parent.resolve() / "examples" + chain_root = examples_root / "numpy_and_binary" / "chain.py" + with framework.import_target(chain_root, target) as entrypoint: + service = remote.push( + entrypoint, + options=definitions.PushOptionsLocalDocker( + chain_name=f"integration-test-numpy-{mode}", + only_generate_trusses=False, + use_local_chains_src=True, + ), + ) + assert service is not None + response = service.run_remote({}) + assert response.status_code == 200 + print(response.json()) diff --git a/truss-chains/truss_chains/code_gen.py b/truss-chains/truss_chains/code_gen.py index 6ec2e98ca..1392fb307 100644 --- a/truss-chains/truss_chains/code_gen.py +++ b/truss-chains/truss_chains/code_gen.py @@ -251,36 +251,32 @@ def _stub_endpoint_body_src( E.g.: ``` - json_result = await self._remote.predict_async( - SplitTextInput(inputs=inputs, extra_arg=extra_arg).model_dump()) - return SplitTextOutput.model_validate(json_result).output + return await self.predict_async( + SplitTextInput(inputs=inputs, extra_arg=extra_arg), SplitTextOutput).root ``` """ imports: set[str] = set() args = [f"{arg.name}={arg.name}" for arg in endpoint.input_args] if args: - inputs = ( - f"{_get_input_model_name(chainlet_name)}({', '.join(args)}).model_dump()" - ) + inputs = f"{_get_input_model_name(chainlet_name)}({', '.join(args)})" else: inputs = "{}" parts = [] # Invoke remote. if not endpoint.is_streaming: + output_model_name = _get_output_model_name(chainlet_name) if endpoint.is_async: - remote_call = f"await self._remote.predict_async({inputs})" + parts = [ + f"return (await self.predict_async({inputs}, {output_model_name})).root" + ] else: - remote_call = f"self._remote.predict_sync({inputs})" + parts = [f"return self.predict_sync({inputs}, {output_model_name}).root"] - parts = [f"json_result = {remote_call}"] - # Unpack response and parse as pydantic models if needed. - output_model_name = _get_output_model_name(chainlet_name) - parts.append(f"return {output_model_name}.model_validate(json_result).root") else: if endpoint.is_async: parts.append( - f"async for data in await self._remote.predict_async_stream({inputs}):", + f"async for data in await self.predict_async_stream({inputs}):", ) if endpoint.streaming_type.is_string: parts.append(_indent("yield data.decode()")) @@ -312,9 +308,8 @@ class SplitText(stub.StubBase): async def run_remote( self, inputs: shared_chainlet.SplitTextInput, extra_arg: int ) -> tuple[shared_chainlet.SplitTextOutput, int]: - json_result = await self._remote.predict_async( - SplitTextInput(inputs=inputs, extra_arg=extra_arg).model_dump()) - return SplitTextOutput.model_validate(json_result).root + return await self.predict_async( + SplitTextInput(inputs=inputs, extra_arg=extra_arg), SplitTextOutput).root ``` """ imports = {"from truss_chains import stub"} @@ -428,7 +423,7 @@ def _gen_load_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _So def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source: """Generates AST for the `predict` method of the truss model.""" - imports: set[str] = {"from truss_chains import utils"} + imports: set[str] = {"from truss_chains import stub"} parts: list[str] = [] def_str = "async def" if chainlet_descriptor.endpoint.is_async else "def" input_model_name = _get_input_model_name(chainlet_descriptor.name) @@ -440,8 +435,8 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> output_type_name = streaming_src.src else: output_type_name = _get_output_model_name(chainlet_descriptor.name) + imports.add("import starlette.requests") - imports.add("from truss_chains import stub") parts.append( f"{def_str} predict(self, inputs: {input_model_name}, " f"request: starlette.requests.Request) -> {output_type_name}:" @@ -449,7 +444,7 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> # Add error handling context manager: parts.append( _indent( - f"with stub.trace_parent(request), utils.exception_to_http_error(" + f"with stub.trace_parent(request), stub.exception_to_http_error(" f'chainlet_name="{chainlet_descriptor.name}"):' ) ) @@ -463,7 +458,7 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> maybe_await = "" run_remote = chainlet_descriptor.endpoint.name # See docs of `pydantic_set_field_dict` for why this is needed. - args = "**utils.pydantic_set_field_dict(inputs)" + args = "**stub.pydantic_set_field_dict(inputs)" parts.append( _indent(f"result = {maybe_await}self._chainlet.{run_remote}({args})", 2) ) diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index 0510f9f4c..47c853819 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -344,7 +344,7 @@ def __init__( def get_spec(self) -> AssetSpec: """Returns parsed and validated assets.""" - return self._spec.copy(deep=True) + return self._spec.model_copy(deep=True) class ChainletOptions(SafeModelNonSerializable): @@ -397,10 +397,23 @@ def get_asset_spec(self) -> AssetSpec: class RPCOptions(SafeModel): - """Options to customize RPCs to dependency chainlets.""" + """Options to customize RPCs to dependency chainlets. + + Args: + retries: The number of times to retry the remote chainlet in case of failures + (e.g. due to transient network issues). For streaming, retries are only made + if the request fails before streaming any results back. Failures mid-stream + not retried. + timeout_sec: Timeout for the HTTP request to this chainlet. + use_binary: whether to send data data in binary format. This can give a parsing + speedup and message size reduction (~25%) for numpy arrays. Use + ``NumpyArrayField`` as a field type on pydantic models for integration and set + this option to ``True``. For simple text data, there is no significant benefit. + """ - timeout_sec: int = DEFAULT_TIMEOUT_SEC retries: int = 1 + timeout_sec: int = DEFAULT_TIMEOUT_SEC + use_binary: bool = False class ServiceDescriptor(SafeModel): diff --git a/truss-chains/truss_chains/public_api.py b/truss-chains/truss_chains/public_api.py index ec95df886..f07e57cdf 100644 --- a/truss-chains/truss_chains/public_api.py +++ b/truss-chains/truss_chains/public_api.py @@ -38,6 +38,7 @@ def depends( chainlet_cls: Type[framework.ChainletT], retries: int = 1, timeout_sec: int = definitions.DEFAULT_TIMEOUT_SEC, + use_binary: bool = False, ) -> framework.ChainletT: """Sets a "symbolic marker" to indicate to the framework that a chainlet is a dependency of another chainlet. The return value of ``depends`` is intended to be @@ -58,14 +59,22 @@ def depends( Args: chainlet_cls: The chainlet class of the dependency. retries: The number of times to retry the remote chainlet in case of failures - (e.g. due to transient network issues). + (e.g. due to transient network issues). For streaming, retries are only made + if the request fails before streaming any results back. Failures mid-stream + not retried. timeout_sec: Timeout for the HTTP request to this chainlet. + use_binary: whether to send data data in binary format. This can give a parsing + speedup and message size reduction (~25%) for numpy arrays. Use + ``NumpyArrayField`` as a field type on pydantic models for integration and set + this option to ``True``. For simple text data, there is no significant benefit. Returns: A "symbolic marker" to be used as a default argument in a chainlet's initializer. """ - options = definitions.RPCOptions(retries=retries, timeout_sec=timeout_sec) + options = definitions.RPCOptions( + retries=retries, timeout_sec=timeout_sec, use_binary=use_binary + ) # The type error is silenced to because chains framework will at runtime inject # a corresponding instance. Nonetheless, we want to use a type annotation here, # to facilitate type inference, code-completion and type checking within the code diff --git a/truss-chains/truss_chains/pydantic_numpy.py b/truss-chains/truss_chains/pydantic_numpy.py new file mode 100644 index 000000000..ad112eca6 --- /dev/null +++ b/truss-chains/truss_chains/pydantic_numpy.py @@ -0,0 +1,131 @@ +import base64 +from typing import TYPE_CHECKING, Any, ClassVar + +import pydantic +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import core_schema + +if TYPE_CHECKING: + from numpy.typing import NDArray + + +class NumpyArrayField: + """Wrapper class to support numpy arrays as fields on pydantic models and provide + JSON or binary serialization implementations. + + The JSON serialization exposes (data, shape, dtype), and the data is base-64 + encoded which leads to ~33% overhead. A more compact serialization can be achieved + using ``msgpack_numpy`` (integrated in chains, if RPC-option ``use_binary`` is + enabled). + + Usage example: + + ``` + import numpy as np + + class MyModel(pydantic.BaseModel): + my_array: NumpyArrayField + + m = MyModel(my_array=np.arange(4).reshape((2, 2))) + m.my_array.array += 10 # Work with the numpy array. + print(m) + # my_array=NumpyArrayField( + # shape=(2, 2), + # dtype=int64, + # data=[[10 11] [12 13]]) + m_json = m.model_dump_json() # Serialize. + print(m_json) + # {"my_array":{"data_b64":"CgAAAAAAAAALAAAAAAAAAAwAAAAAAAAADQAAAAAAAAA=","shape":[2,2],"dtype":"int64"}} + m2 = MyModel.model_validate_json(m_json) # De-serialize. + ``` + """ + + data_key: ClassVar[str] = "data_b64" + shape_key: ClassVar[str] = "shape" + dtype_key: ClassVar[str] = "dtype" + array: "NDArray" + + def __init__(self, array: "NDArray"): + self.array = array + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(shape={self.array.shape}, " + f"dtype={self.array.dtype}, data={self.array})" + ) + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: pydantic.GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.no_info_after_validator_function( + cls.validate_numpy_array, + core_schema.any_schema(), + serialization=core_schema.plain_serializer_function_ser_schema( + cls.serialize_numpy_array, info_arg=True + ), + ) + + @classmethod + def validate_numpy_array(cls, value: Any): + import numpy as np + + keys = {cls.data_key, cls.shape_key, cls.dtype_key} + if isinstance(value, dict) and keys.issubset(value): + try: + data = base64.b64decode(value[cls.data_key]) + array = np.frombuffer(data, dtype=value[cls.dtype_key]).reshape( + value[cls.shape_key] + ) + return cls(array) + except (ValueError, TypeError) as e: + raise TypeError( + "numpy_array_validation" + f"Invalid data, shape, or dtype for NumPy array: {str(e)}", + ) + if isinstance(value, np.ndarray): + return cls(value) + if isinstance(value, cls): + return value + + raise TypeError( + "numpy_array_validation\n" + f"Expected a NumPy array or a dictionary with keys {keys}.\n" + f"Got:\n{value}" + ) + + @classmethod + def serialize_numpy_array( + cls, obj: "NumpyArrayField", info: core_schema.SerializationInfo + ): + if info.mode == "json": + return { + cls.data_key: base64.b64encode(obj.array.tobytes()).decode("utf-8"), + cls.shape_key: obj.array.shape, + cls.dtype_key: str(obj.array.dtype), + } + return obj.array + + @classmethod + def __get_pydantic_json_schema__( + cls, + _core_schema: core_schema.CoreSchema, + handler: pydantic.GetJsonSchemaHandler, + ) -> JsonSchemaValue: + json_schema = handler(_core_schema) + json_schema.update( + { + "type": "object", + "properties": { + "data": {"type": "string", "format": "byte"}, + "shape": { + "type": "array", + "items": {"type": "integer"}, + "minItems": 1, + }, + "dtype": {"type": "string"}, + }, + "required": ["data", "shape", "dtype"], + } + ) + return json_schema diff --git a/truss-chains/truss_chains/streaming.py b/truss-chains/truss_chains/streaming.py index 9d9a1cae8..7b20539ba 100644 --- a/truss-chains/truss_chains/streaming.py +++ b/truss-chains/truss_chains/streaming.py @@ -4,14 +4,12 @@ import struct import sys from collections.abc import AsyncIterator -from typing import Generic, Optional, Protocol, Type, TypeVar, Union, overload +from typing import Generic, Optional, Protocol, Type, TypeVar, overload import pydantic _TAG_SIZE = 5 # uint8 + uint32. -_JSONType = Union[ - str, int, float, bool, None, list["_JSONType"], dict[str, "_JSONType"] -] + _T = TypeVar("_T") if sys.version_info < (3, 10): diff --git a/truss-chains/truss_chains/stub.py b/truss-chains/truss_chains/stub.py index 5de4f66de..dd4c905cd 100644 --- a/truss-chains/truss_chains/stub.py +++ b/truss-chains/truss_chains/stub.py @@ -1,33 +1,222 @@ import abc import asyncio +import builtins import contextlib import contextvars +import json import logging -import ssl +import sys +import textwrap import threading import time +import traceback from typing import ( Any, AsyncIterator, ClassVar, + Dict, Iterator, Mapping, + NoReturn, Optional, Type, TypeVar, + Union, final, + overload, ) import aiohttp +import fastapi import httpx +import pydantic import starlette.requests import tenacity +from truss.templates.shared import serialization from truss_chains import definitions, utils DEFAULT_MAX_CONNECTIONS = 1000 DEFAULT_MAX_KEEPALIVE_CONNECTIONS = 400 + +_RetryPolicyT = TypeVar("_RetryPolicyT", tenacity.AsyncRetrying, tenacity.Retrying) +_InputT = TypeVar("_InputT", pydantic.BaseModel, Any) # Any signifies "JSON". +_OutputT = TypeVar("_OutputT", bound=pydantic.BaseModel) + + +# Error Propagation Utils. ############################################################# + + +def _handle_exception(exception: Exception, chainlet_name: str) -> NoReturn: + """Raises `fastapi.HTTPException` with `RemoteErrorDetail` as detail.""" + if hasattr(exception, "__module__"): + exception_module_name = exception.__module__ + else: + exception_module_name = None + + error_stack = traceback.extract_tb(exception.__traceback__) + # Exclude the error handling functions from the stack trace. + exclude_frames = { + exception_to_http_error.__name__, + _response_raise_errors.__name__, + _async_response_raise_errors.__name__, + } + final_tb = [frame for frame in error_stack if frame.name not in exclude_frames] + stack = list( + [definitions.StackFrame.from_frame_summary(frame) for frame in final_tb] + ) + error = definitions.RemoteErrorDetail( + remote_name=chainlet_name, + exception_cls_name=exception.__class__.__name__, + exception_module_name=exception_module_name, + exception_message=str(exception), + user_stack_trace=stack, + ) + raise fastapi.HTTPException( + status_code=500, detail=error.model_dump() + ) from exception + + +@contextlib.contextmanager +def exception_to_http_error(chainlet_name: str) -> Iterator[None]: + # TODO: move chainlet name from here to caller side. + try: + yield + except Exception as e: + _handle_exception(e, chainlet_name) + + +def _resolve_exception_class( + error: definitions.RemoteErrorDetail, +) -> Type[Exception]: + """Tries to find the exception class in builtins or imported libs, + falls back to `definitions.GenericRemoteError` if not found.""" + exception_cls = None + if error.exception_module_name is None: + exception_cls = getattr(builtins, error.exception_cls_name, None) + else: + if mod := sys.modules.get(error.exception_module_name): + exception_cls = getattr(mod, error.exception_cls_name, None) + + if exception_cls is None: + logging.warning( + f"Could not resolve exception with name `{error.exception_cls_name}` " + f"and module `{error.exception_module_name}` - fall back to " + f"`{definitions.GenericRemoteException.__name__}`." + ) + exception_cls = definitions.GenericRemoteException + + if issubclass(exception_cls, pydantic.ValidationError): + # Cannot re-raise naively. + # https://github.com/pydantic/pydantic/issues/6734. + exception_cls = definitions.GenericRemoteException + + return exception_cls + + +def _handle_response_error(response_json: dict, remote_name: str): + try: + error_json = response_json["error"] + except KeyError as e: + logging.error(f"response_json: {response_json}") + raise ValueError( + "Could not get `error` field from JSON from error response" + ) from e + try: + error = definitions.RemoteErrorDetail.model_validate(error_json) + except pydantic.ValidationError as e: + if isinstance(error_json, str): + msg = f"Remote error occurred in `{remote_name}`: '{error_json}'" + raise definitions.GenericRemoteException(msg) from None + raise ValueError( + "Could not parse error. Error details are expected to be either a " + "plain string (old truss models) or a serialized " + f"`definitions.RemoteErrorDetail.__name__`, got:\n{repr(error_json)}" + ) from e + exception_cls = _resolve_exception_class(error) + msg = ( + f"(showing remote errors, root message at the bottom)\n" + f"--> Preceding Remote Cause:\n" + f"{textwrap.indent(error.format(), ' ')}" + ) + raise exception_cls(msg) + + +def _response_raise_errors(response: httpx.Response, remote_name: str) -> None: + """In case of error, raise it. + + If the response error contains `RemoteErrorDetail`, it tries to re-raise + the same exception that was raised remotely and falls back to + `GenericRemoteException` if the exception class could not be resolved. + + Exception messages are chained to trace back to the root cause, i.e. the first + Chainlet that raised an exception. E.g. the message might look like this: + + ``` + RemoteChainletError in "Chain" + Traceback (most recent call last): + File "/app/model/Chainlet.py", line 112, in predict + result = await self._chainlet.run( + File "/app/model/Chainlet.py", line 79, in run + value += self._text_to_num.run(part) + File "/packages/remote_stubs.py", line 21, in run + json_result = self.predict_sync(json_args) + File "/packages/truss_chains/stub.py", line 37, in predict_sync + return utils.handle_response( + ValueError: (showing remote errors, root message at the bottom) + --> Preceding Remote Cause: + RemoteChainletError in "TextToNum" + Traceback (most recent call last): + File "/app/model/Chainlet.py", line 113, in predict + result = self._chainlet.run(data=payload["data"]) + File "/app/model/Chainlet.py", line 54, in run + generated_text = self._replicator.run(data) + File "/packages/remote_stubs.py", line 7, in run + json_result = self.predict_sync(json_args) + File "/packages/truss_chains/stub.py", line 37, in predict_sync + return utils.handle_response( + ValueError: (showing remote errors, root message at the bottom) + --> Preceding Remote Cause: + RemoteChainletError in "TextReplicator" + Traceback (most recent call last): + File "/app/model/Chainlet.py", line 112, in predict + result = self._chainlet.run(data=payload["data"]) + File "/app/model/Chainlet.py", line 36, in run + raise ValueError(f"This input is too long: {len(data)}.") + ValueError: This input is too long: 100. + + ``` + """ + if response.is_error: + try: + response_json = response.json() + except Exception as e: + raise ValueError( + "Could not get JSON from error response. Status: " + f"`{response.status_code}`." + ) from e + _handle_response_error(response_json=response_json, remote_name=remote_name) + + +async def _async_response_raise_errors( + response: aiohttp.ClientResponse, remote_name: str +) -> None: + """Async version of `async_response_raise_errors`.""" + if response.status >= 400: + try: + response_json = await response.json() + except Exception as e: + raise ValueError( + "Could not get JSON from error response. Status: " + f"`{response.status}`." + ) from e + _handle_response_error(response_json=response_json, remote_name=remote_name) + + +######################################################################################## + + _trace_parent_context: contextvars.ContextVar[str] = contextvars.ContextVar( "trace_parent" ) @@ -44,8 +233,27 @@ def trace_parent(request: starlette.requests.Request) -> Iterator[None]: _trace_parent_context.reset(token) +def pydantic_set_field_dict(obj: pydantic.BaseModel) -> dict[str, pydantic.BaseModel]: + """Like `BaseModel.model_dump(exclude_unset=True), but only top-level. + + This is used to get kwargs for invoking a function, while dropping fields for which + there is no value explicitly set in the pydantic model. A field is considered unset + if the key was not present in the incoming JSON request (from which the model was + parsed/initialized) and the pydantic model has a default value, such as `None`. + + By dropping these unset fields, the default values from the function definition + will be used instead. This behavior ensures correct handling of arguments where + the function has a default, such as in the case of `run_remote`. If the model has + an optional field defaulting to `None`, this approach differentiates between + the user explicitly passing a value of `None` and the field being unset in the + request. + + """ + return {name: getattr(obj, name) for name in obj.model_fields_set} + + class BasetenSession: - """Helper to invoke predict method on Baseten deployments.""" + """Provides configured HTTP clients, retries rate limit warning etc.""" _client_cycle_time_sec: ClassVar[int] = 3600 * 1 # 1 hour. _client_limits: ClassVar[httpx.Limits] = httpx.Limits( @@ -97,7 +305,19 @@ def _client_cycle_needed(self, cached_client: Optional[tuple[Any, int]]) -> bool or (int(time.time()) - cached_client[1]) > self._client_cycle_time_sec ) - def _client_sync(self) -> httpx.Client: + def _log_retry(self, retry_state: tenacity.RetryCallState) -> None: + logging.info(f"Retrying `{self.name}`, attempt {retry_state.attempt_number}") + + def _make_retry_policy(self, retrying: Type[_RetryPolicyT]) -> _RetryPolicyT: + return retrying( + stop=tenacity.stop_after_attempt(self._service_descriptor.options.retries), + retry=tenacity.retry_if_exception_type(Exception), + reraise=True, + before_sleep=self._log_retry, + ) + + @contextlib.contextmanager + def _client_sync(self) -> Iterator[httpx.Client]: # Check `_client_cycle_needed` before and after locking to avoid # needing a lock each time the client is accessed. if self._client_cycle_needed(self._cached_sync_client): @@ -112,9 +332,14 @@ def _client_sync(self) -> httpx.Client: int(time.time()), ) assert self._cached_sync_client is not None - return self._cached_sync_client[0] + client = self._cached_sync_client[0] + + with self._sync_num_requests as num_requests: + self._maybe_warn_for_overload(num_requests) + yield client - async def _client_async(self) -> aiohttp.ClientSession: + @contextlib.asynccontextmanager + async def _client_async(self) -> AsyncIterator[aiohttp.ClientSession]: # Check `_client_cycle_needed` before and after locking to avoid # needing a lock each time the client is accessed. if self._client_cycle_needed(self._cached_async_client): @@ -134,103 +359,19 @@ async def _client_async(self) -> aiohttp.ClientSession: int(time.time()), ) assert self._cached_async_client is not None - return self._cached_async_client[0] + client = self._cached_async_client[0] - def predict_sync(self, json_payload): - headers = { - definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() - } - retrying = tenacity.Retrying( - stop=tenacity.stop_after_attempt(self._service_descriptor.options.retries), - retry=tenacity.retry_if_exception_type(Exception), - reraise=True, - ) - for attempt in retrying: - with attempt: - if (num := attempt.retry_state.attempt_number) > 1: - logging.info(f"Retrying `{self.name}`, " f"attempt {num}") - try: - with self._sync_num_requests as num_requests: - self._maybe_warn_for_overload(num_requests) - response = self._client_sync().post( - self._service_descriptor.predict_url, - json=json_payload, - headers=headers, - ) - utils.response_raise_errors(response, self.name) - return response.json() - - # As a special case we invalidate the client in case of certificate - # errors. This has happened in the past and is a defensive measure. - except ssl.SSLError: - self._cached_sync_client = None - raise - - async def predict_async(self, json_payload): - headers = { - definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() - } - retrying = tenacity.AsyncRetrying( - stop=tenacity.stop_after_attempt(self._service_descriptor.options.retries), - retry=tenacity.retry_if_exception_type(Exception), - reraise=True, - ) - async for attempt in retrying: - with attempt: - if (num := attempt.retry_state.attempt_number) > 1: - logging.info(f"Retrying `{self.name}`, " f"attempt {num}") - try: - client = await self._client_async() - async with self._async_num_requests as num_requests: - self._maybe_warn_for_overload(num_requests) - async with client.post( - self._service_descriptor.predict_url, - json=json_payload, - headers=headers, - ) as response: - await utils.async_response_raise_errors(response, self.name) - return await response.json() - # As a special case we invalidate the client in case of certificate - # errors. This has happened in the past and is a defensive measure. - except ssl.SSLError: - self._cached_async_client = None - raise - - async def predict_async_stream(self, json_payload) -> AsyncIterator[bytes]: # type: ignore[return] # Handled by retries. - headers = { - definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() - } - retrying = tenacity.AsyncRetrying( - stop=tenacity.stop_after_attempt(self._service_descriptor.options.retries), - retry=tenacity.retry_if_exception_type(Exception), - reraise=True, - ) - async for attempt in retrying: - with attempt: - if (num := attempt.retry_state.attempt_number) > 1: - logging.info(f"Retrying `{self.name}`, " f"attempt {num}") - try: - client = await self._client_async() - async with self._async_num_requests as num_requests: - self._maybe_warn_for_overload(num_requests) - response = await client.post( - self._service_descriptor.predict_url, - json=json_payload, - headers=headers, - ) - await utils.async_response_raise_errors(response, self.name) - return response.content.iter_any() - - # As a special case we invalidate the client in case of certificate - # errors. This has happened in the past and is a defensive measure. - except ssl.SSLError: - self._cached_async_client = None - raise - - -class StubBase(abc.ABC): + async with self._async_num_requests as num_requests: + self._maybe_warn_for_overload(num_requests) + yield client + + +class StubBase(BasetenSession, abc.ABC): """Base class for stubs that invoke remote chainlets. + Extends ``BasetenSession`` with methods for data serialization, de-serialization + and invoking other endpoints. + It is used internally for RPCs to dependency chainlets, but it can also be used in user-code for wrapping a deployed truss model into the chains framework, e.g. like that:: @@ -245,7 +386,7 @@ class WhisperOutput(pydantic.BaseModel): class DeployedWhisper(chains.StubBase): async def run_remote(self, audio_b64: str) -> WhisperOutput: - resp = await self._remote.predict_async( + resp = await self.predict_async( json_payload={"audio": audio_b64}) return WhisperOutput(text=resp["text"], language=resp["language"]) @@ -262,8 +403,6 @@ def __init__(self, ..., context=chains.depends_context()): """ - _remote: BasetenSession - @final def __init__( self, @@ -275,7 +414,7 @@ def __init__( service_descriptor: Contains the URL and other configuration. api_key: A baseten API key to authorize requests. """ - self._remote = BasetenSession(service_descriptor, api_key) + super().__init__(service_descriptor, api_key) @classmethod def from_url( @@ -302,6 +441,117 @@ def from_url( api_key=context.get_baseten_api_key(), ) + def _make_request_params( + self, inputs: _InputT, for_httpx: bool = False + ) -> Mapping[str, Any]: + kwargs: Dict[str, Any] = {} + headers = { + definitions.OTEL_TRACE_PARENT_HEADER_KEY: _trace_parent_context.get() + } + if isinstance(inputs, pydantic.BaseModel): + if self._service_descriptor.options.use_binary: + data_dict = inputs.model_dump(mode="python") + kwargs["data"] = serialization.truss_msgpack_serialize(data_dict) + headers["Content-Type"] = "application/octet-stream" + else: + data_key = "content" if for_httpx else "data" + kwargs[data_key] = inputs.model_dump_json() + headers["Content-Type"] = "application/json" + else: # inputs is JSON dict. + if self._service_descriptor.options.use_binary: + kwargs["data"] = serialization.truss_msgpack_serialize(inputs) + headers["Content-Type"] = "application/octet-stream" + else: + kwargs["json"] = inputs + headers["Content-Type"] = "application/json" + + kwargs["headers"] = headers + return kwargs + + def _response_to_pydantic( + self, response: bytes, output_model: Type[_OutputT] + ) -> _OutputT: + if self._service_descriptor.options.use_binary: + data_dict = serialization.truss_msgpack_deserialize(response) + return output_model.model_validate(data_dict) + return output_model.model_validate_json(response) + + def _response_to_json(self, response: bytes) -> Any: + if self._service_descriptor.options.use_binary: + return serialization.truss_msgpack_deserialize(response) + return json.loads(response) + + @overload + def predict_sync( + self, inputs: _InputT, output_model: Type[_OutputT] + ) -> _OutputT: ... + + @overload # Returns JSON + def predict_sync(self, inputs: _InputT, output_model: None = None) -> Any: ... + + def predict_sync( + self, inputs: _InputT, output_model: Optional[Type[_OutputT]] = None + ) -> Union[_OutputT, Any]: + retry = self._make_retry_policy(tenacity.Retrying) + params = self._make_request_params(inputs, for_httpx=True) + + def _rpc() -> bytes: + client: httpx.Client + with self._client_sync() as client: + response = client.post(self._service_descriptor.predict_url, **params) + _response_raise_errors(response, self.name) + return response.content + + response_bytes = retry(_rpc) + if output_model: + return self._response_to_pydantic(response_bytes, output_model) + return self._response_to_json(response_bytes) + + @overload + async def predict_async( + self, inputs: _InputT, output_model: Type[_OutputT] + ) -> _OutputT: ... + + @overload # Returns JSON. + async def predict_async( + self, inputs: _InputT, output_model: None = None + ) -> Any: ... + + async def predict_async( + self, inputs: _InputT, output_model: Optional[Type[_OutputT]] = None + ) -> Union[_OutputT, Any]: + retry = self._make_retry_policy(tenacity.AsyncRetrying) + params = self._make_request_params(inputs) + + async def _rpc() -> bytes: + client: aiohttp.ClientSession + async with self._client_async() as client: + async with client.post( + self._service_descriptor.predict_url, **params + ) as response: + await _async_response_raise_errors(response, self.name) + return await response.read() + + response_bytes: bytes = await retry(_rpc) + if output_model: + return self._response_to_pydantic(response_bytes, output_model) + return self._response_to_json(response_bytes) + + async def predict_async_stream(self, inputs: _InputT) -> AsyncIterator[bytes]: + retry = self._make_retry_policy(tenacity.AsyncRetrying) + params = self._make_request_params(inputs) + + async def _rpc() -> AsyncIterator[bytes]: + client: aiohttp.ClientSession + async with self._client_async() as client: + response = await client.post( + self._service_descriptor.predict_url, **params + ) + await _async_response_raise_errors(response, self.name) + return response.content.iter_any() + + return await retry(_rpc) + StubT = TypeVar("StubT", bound=StubBase) diff --git a/truss-chains/truss_chains/utils.py b/truss-chains/truss_chains/utils.py index 28a485451..7ddb773a4 100644 --- a/truss-chains/truss_chains/utils.py +++ b/truss-chains/truss_chains/utils.py @@ -1,5 +1,4 @@ import asyncio -import builtins import contextlib import enum import inspect @@ -8,26 +7,17 @@ import os import random import socket -import sys -import textwrap import threading -import traceback from typing import ( Any, Dict, Iterable, Iterator, Mapping, - NoReturn, - Type, TypeVar, Union, ) -import aiohttp -import fastapi -import httpx -import pydantic from truss.templates.shared import dynamic_config_resolver from truss_chains import definitions @@ -185,171 +175,6 @@ def populate_chainlet_service_predict_urls( return chainlet_to_deployed_service -# Error Propagation Utils. ############################################################# -# TODO: move request related code into `stub.py`. - - -def _handle_exception(exception: Exception, chainlet_name: str) -> NoReturn: - """Raises `fastapi.HTTPException` with `RemoteErrorDetail` as detail.""" - if hasattr(exception, "__module__"): - exception_module_name = exception.__module__ - else: - exception_module_name = None - - error_stack = traceback.extract_tb(exception.__traceback__) - # Exclude the error handling functions from the stack trace. - exclude_frames = { - exception_to_http_error.__name__, - response_raise_errors.__name__, - async_response_raise_errors.__name__, - } - final_tb = [frame for frame in error_stack if frame.name not in exclude_frames] - stack = list( - [definitions.StackFrame.from_frame_summary(frame) for frame in final_tb] - ) - error = definitions.RemoteErrorDetail( - remote_name=chainlet_name, - exception_cls_name=exception.__class__.__name__, - exception_module_name=exception_module_name, - exception_message=str(exception), - user_stack_trace=stack, - ) - raise fastapi.HTTPException( - status_code=500, detail=error.model_dump() - ) from exception - - -@contextlib.contextmanager -def exception_to_http_error(chainlet_name: str) -> Iterator[None]: - # TODO: move chainlet name from here to caller side. - try: - yield - except Exception as e: - _handle_exception(e, chainlet_name) - - -def _resolve_exception_class( - error: definitions.RemoteErrorDetail, -) -> Type[Exception]: - """Tries to find the exception class in builtins or imported libs, - falls back to `definitions.GenericRemoteError` if not found.""" - exception_cls = None - if error.exception_module_name is None: - exception_cls = getattr(builtins, error.exception_cls_name, None) - else: - if mod := sys.modules.get(error.exception_module_name): - exception_cls = getattr(mod, error.exception_cls_name, None) - - if exception_cls is None: - logging.warning( - f"Could not resolve exception with name `{error.exception_cls_name}` " - f"and module `{error.exception_module_name}` - fall back to " - f"`{definitions.GenericRemoteException.__name__}`." - ) - exception_cls = definitions.GenericRemoteException - - return exception_cls - - -def _handle_response_error(response_json: dict, remote_name: str): - try: - error_json = response_json["error"] - except KeyError as e: - logging.error(f"response_json: {response_json}") - raise ValueError( - "Could not get `error` field from JSON from error response" - ) from e - try: - error = definitions.RemoteErrorDetail.model_validate(error_json) - except pydantic.ValidationError as e: - if isinstance(error_json, str): - msg = f"Remote error occurred in `{remote_name}`: '{error_json}'" - raise definitions.GenericRemoteException(msg) from None - raise ValueError( - "Could not parse error. Error details are expected to be either a " - "plain string (old truss models) or a serialized " - f"`definitions.RemoteErrorDetail.__name__`, got:\n{repr(error_json)}" - ) from e - exception_cls = _resolve_exception_class(error) - msg = ( - f"(showing remote errors, root message at the bottom)\n" - f"--> Preceding Remote Cause:\n" - f"{textwrap.indent(error.format(), ' ')}" - ) - raise exception_cls(msg) - - -def response_raise_errors(response: httpx.Response, remote_name: str) -> None: - """In case of error, raise it. - - If the response error contains `RemoteErrorDetail`, it tries to re-raise - the same exception that was raised remotely and falls back to - `GenericRemoteException` if the exception class could not be resolved. - - Exception messages are chained to trace back to the root cause, i.e. the first - Chainlet that raised an exception. E.g. the message might look like this: - - ``` - RemoteChainletError in "Chain" - Traceback (most recent call last): - File "/app/model/Chainlet.py", line 112, in predict - result = await self._chainlet.run( - File "/app/model/Chainlet.py", line 79, in run - value += self._text_to_num.run(part) - File "/packages/remote_stubs.py", line 21, in run - json_result = self._remote.predict_sync(json_args) - File "/packages/truss_chains/stub.py", line 37, in predict_sync - return utils.handle_response( - ValueError: (showing remote errors, root message at the bottom) - --> Preceding Remote Cause: - RemoteChainletError in "TextToNum" - Traceback (most recent call last): - File "/app/model/Chainlet.py", line 113, in predict - result = self._chainlet.run(data=payload["data"]) - File "/app/model/Chainlet.py", line 54, in run - generated_text = self._replicator.run(data) - File "/packages/remote_stubs.py", line 7, in run - json_result = self._remote.predict_sync(json_args) - File "/packages/truss_chains/stub.py", line 37, in predict_sync - return utils.handle_response( - ValueError: (showing remote errors, root message at the bottom) - --> Preceding Remote Cause: - RemoteChainletError in "TextReplicator" - Traceback (most recent call last): - File "/app/model/Chainlet.py", line 112, in predict - result = self._chainlet.run(data=payload["data"]) - File "/app/model/Chainlet.py", line 36, in run - raise ValueError(f"This input is too long: {len(data)}.") - ValueError: This input is too long: 100. - - ``` - """ - if response.is_error: - try: - response_json = response.json() - except Exception as e: - raise ValueError( - "Could not get JSON from error response. Status: " - f"`{response.status_code}`." - ) from e - _handle_response_error(response_json=response_json, remote_name=remote_name) - - -async def async_response_raise_errors( - response: aiohttp.ClientResponse, remote_name: str -) -> None: - """Async version of `async_response_raise_errors`.""" - if response.status >= 400: - try: - response_json = await response.json() - except Exception as e: - raise ValueError( - "Could not get JSON from error response. Status: " - f"`{response.status}`." - ) from e - _handle_response_error(response_json=response_json, remote_name=remote_name) - - ######################################################################################## @@ -410,25 +235,6 @@ def issubclass_safe(x: Any, cls: type) -> bool: return isinstance(x, type) and issubclass(x, cls) -def pydantic_set_field_dict(obj: pydantic.BaseModel) -> dict[str, pydantic.BaseModel]: - """Like `BaseModel.model_dump(exclude_unset=True), but only top-level. - - This is used to get kwargs for invoking a function, while dropping fields for which - there is no value explicitly set in the pydantic model. A field is considered unset - if the key was not present in the incoming JSON request (from which the model was - parsed/initialized) and the pydantic model has a default value, such as `None`. - - By dropping these unset fields, the default values from the function definition - will be used instead. This behavior ensures correct handling of arguments where - the function has a default, such as in the case of `run_remote`. If the model has - an optional field defaulting to `None`, this approach differentiates between - the user explicitly passing a value of `None` and the field being unset in the - request. - - """ - return {name: getattr(obj, name) for name in obj.__fields_set__} - - class AsyncSafeCounter: def __init__(self, initial: int = 0) -> None: self._counter = initial diff --git a/truss/templates/server/model_wrapper.py b/truss/templates/server/model_wrapper.py index ab28713d2..6feea1eca 100644 --- a/truss/templates/server/model_wrapper.py +++ b/truss/templates/server/model_wrapper.py @@ -36,7 +36,6 @@ from common.retry import retry from common.schema import TrussSchema from opentelemetry import trace -from pydantic import BaseModel from shared import dynamic_config_resolver, serialization from shared.lazy_data_resolver import LazyDataResolver from shared.secrets_resolver import SecretsResolver @@ -64,6 +63,7 @@ Generator[bytes, None, None], AsyncGenerator[bytes, None], "starlette.responses.Response", + pydantic.BaseModel, ] @@ -735,15 +735,7 @@ async def __call__( span_post, "postprocess" ), tracing.detach_context(): postprocess_result = await self.postprocess(predict_result, request) - - final_result: OutputType - if isinstance(postprocess_result, BaseModel): - # If we return a pydantic object, convert it back to a dict - with tracing.section_as_event(span_post, "dump-pydantic"): - final_result = postprocess_result.dict() - else: - final_result = postprocess_result - return final_result + return postprocess_result async def _gather_generator( diff --git a/truss/templates/server/requirements.txt b/truss/templates/server/requirements.txt index e5f0dd23d..924b0d8ff 100644 --- a/truss/templates/server/requirements.txt +++ b/truss/templates/server/requirements.txt @@ -7,7 +7,7 @@ fastapi==0.114.1 joblib==1.2.0 loguru==0.7.2 msgpack-numpy==0.4.8 -msgpack==1.0.2 +msgpack==1.1.0 # Numpy/msgpack versions are finniky (1.0.2 breaks), double check when changing. numpy>=1.23.5 opentelemetry-api>=1.25.0 opentelemetry-sdk>=1.25.0 diff --git a/truss/templates/server/truss_server.py b/truss/templates/server/truss_server.py index 37ab4c223..8e30a884b 100644 --- a/truss/templates/server/truss_server.py +++ b/truss/templates/server/truss_server.py @@ -16,10 +16,11 @@ from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.responses import ORJSONResponse, StreamingResponse from fastapi.routing import APIRoute as FastAPIRoute -from model_wrapper import InputType, ModelWrapper +from model_wrapper import InputType, ModelWrapper, OutputType from opentelemetry import propagate as otel_propagate from opentelemetry import trace from opentelemetry.sdk import trace as sdk_trace +from pydantic import BaseModel from shared import serialization from shared.logging import setup_logging from shared.secrets_resolver import SecretsResolver @@ -31,6 +32,8 @@ else: from typing_extensions import AsyncGenerator, Generator +PYDANTIC_MAJOR_VERSION = int(pydantic.VERSION.split(".")[0]) + # [IMPORTANT] A lot of things depend on this currently, change with extreme care. TIMEOUT_GRACEFUL_SHUTDOWN = 120 INFERENCE_SERVER_FAILED_FILE = Path("~/inference_server_crashed.txt").expanduser() @@ -118,14 +121,13 @@ async def _parse_body( ) from e else: if truss_schema: - if truss_schema: - try: - with tracing.section_as_event(span, "parse-pydantic"): - inputs = truss_schema.input_type.parse_raw(body_raw) - except pydantic.ValidationError as e: - raise errors.InputParsingError( - errors.format_pydantic_validation_error(e) - ) from e + try: + with tracing.section_as_event(span, "parse-pydantic"): + inputs = truss_schema.input_type.parse_raw(body_raw) + except pydantic.ValidationError as e: + raise errors.InputParsingError( + errors.format_pydantic_validation_error(e) + ) from e else: try: with tracing.section_as_event(span, "json-deserialize"): @@ -166,7 +168,7 @@ async def predict( ) # Calls ModelWrapper which runs: preprocess, predict, postprocess. with tracing.section_as_event(span, "model-call"): - result: Union[Dict, Generator] = await model(inputs, request) + result: OutputType = await model(inputs, request) # In the case that the model returns a Generator object, return a # StreamingResponse instead. @@ -177,22 +179,42 @@ async def predict( if result.status_code >= HTTPStatus.MULTIPLE_CHOICES.value: errors.add_error_headers_to_user_response(result) return result + return self._serialize_result(result, self.is_binary(request), span) - response_headers = {} - if self.is_binary(request): + def _serialize_result( + self, result: OutputType, is_binary: bool, span: trace.Span + ) -> Response: + response_headers = {} + if is_binary: + if isinstance(result, BaseModel): + with tracing.section_as_event(span, "binary-dump"): + if PYDANTIC_MAJOR_VERSION > 1: + result = result.model_dump(mode="python") + else: + result = result.dict() + # If the result is not already serialize and not a pydantic model, it must + # be something that can be serialized with `truss_msgpack_serialize` (some + # dict / nested structure). + if not isinstance(result, bytes): with tracing.section_as_event(span, "binary-serialize"): - response_headers["Content-Type"] = "application/octet-stream" - return Response( - content=serialization.truss_msgpack_serialize(result), - headers=response_headers, - ) - else: - with tracing.section_as_event(span, "json-serialize"): - response_headers["Content-Type"] = "application/json" - return Response( - content=json.dumps(result, cls=serialization.DeepNumpyEncoder), - headers=response_headers, - ) + result = serialization.truss_msgpack_serialize(result) + + response_headers["Content-Type"] = "application/octet-stream" + return Response(content=result, headers=response_headers) + else: + with tracing.section_as_event(span, "json-serialize"): + if isinstance(result, BaseModel): + # Note: chains has a pydantic integration for numpy arrays + # `NumpyArrayField`. `result.dict()`, passes through the array + # object which cannot be JSON serialized. + # In pydantic v2 `result.model_dump(mode="json")` could be used. + # For backwards compatibility we dump directly the JSON string. + content = result.json() + else: + content = json.dumps(result, cls=serialization.DeepNumpyEncoder) + + response_headers["Content-Type"] = "application/json" + return Response(content=content, headers=response_headers) async def schema(self, model_name: str) -> Dict: model: ModelWrapper = self._safe_lookup_model(model_name) diff --git a/truss/tests/test_testing_utilities_for_other_tests.py b/truss/tests/test_testing_utilities_for_other_tests.py index 1e3041e90..7b045d11a 100644 --- a/truss/tests/test_testing_utilities_for_other_tests.py +++ b/truss/tests/test_testing_utilities_for_other_tests.py @@ -52,10 +52,11 @@ def _show_container_logs_if_raised(): print("An exception was raised, showing logs of all containers.") containers = get_containers({TRUSS: True}) new_containers = [c for c in containers if c.id not in initial_ids] - parts = [] + parts = ["\n"] for container in new_containers: parts.append(f"Logs for container {container.name} ({container.id}):") parts.append(_human_readable_json_logs(container.logs())) + parts.append("\n") logging.warning("\n".join(parts))