From b34acc820207bff49910f293ab77e75e1195c22c Mon Sep 17 00:00:00 2001 From: Marius Killinger <155577904+marius-baseten@users.noreply.github.com> Date: Thu, 12 Dec 2024 13:00:21 -0800 Subject: [PATCH] Update Chains Docs. (#1282) * Update Chains Docs. * Fix imblanced braces * Fix docstring --- docs/chains/doc_gen/API-reference.mdx | 151 ++++++--- docs/chains/doc_gen/generate_reference.py | 4 +- docs/chains/doc_gen/generated-reference.mdx | 108 ++++-- docs/chains/doc_gen/mdx_adapter.py | 2 +- docs/chains/doc_gen/reference.patch | 312 ++++++++++-------- truss-chains/truss_chains/definitions.py | 4 +- truss-chains/truss_chains/public_api.py | 2 +- .../truss_chains/remote_chainlet/stub.py | 56 ++-- 8 files changed, 398 insertions(+), 241 deletions(-) diff --git a/docs/chains/doc_gen/API-reference.mdx b/docs/chains/doc_gen/API-reference.mdx index fe3d42a1f..4e3b6166e 100644 --- a/docs/chains/doc_gen/API-reference.mdx +++ b/docs/chains/doc_gen/API-reference.mdx @@ -7,6 +7,7 @@ https://github.com/basetenlabs/truss/tree/main/docs/chains/doc_gen APIs for creating user-defined Chainlets. + ### *class* `truss_chains.ChainletBase` Base class for all chainlets. @@ -18,6 +19,7 @@ Refer to [the docs](https://docs.baseten.co/chains/getting-started) and this [example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py) for more guidance on how to create subclasses. + ### `truss_chains.depends` Sets a “symbolic marker” to indicate to the framework that a chainlet is a @@ -38,15 +40,18 @@ chainlet instance is provided. **Parameters:** -| Name | Type | Description | -|----------------|----------------------------------------------------------|--------------------------------------------------------------------------------------------------------------| -| `chainlet_cls` | *Type[[ChainletBase](#class-truss-chains-chainletbase)]* | The chainlet class of the dependency. | -| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). | -| `timeout_sec` | *int* | Timeout for the HTTP request to this chainlet. | +| Name | Type | Description | +|----------------|----------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `chainlet_cls` | *Type[[ChainletBase](#class-truss-chains-chainletbase)]* | The chainlet class of the dependency. | +| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). | +| `timeout_sec` | *int* | Timeout for the HTTP request to this chainlet. | +| `use_binary` | *bool* | Whether to send data in binary format. This can give a parsing speedup and message size reduction (~25%) for numpy arrays. Use `NumpyArrayField` as a field type on pydantic models for integration and set this option to `True`. For simple text data, there is no significant benefit. | + * **Returns:** A “symbolic marker” to be used as a default argument in a chainlet’s initializer. + ### `truss_chains.depends_context` Sets a “symbolic marker” for injecting a context object at runtime. @@ -65,6 +70,7 @@ context instance is provided. A “symbolic marker” to be used as a default argument in a chainlet’s initializer. + ### *class* `truss_chains.DeploymentContext` Bases: `pydantic.BaseModel` @@ -78,12 +84,12 @@ an access token for downloading model weights). **Parameters:** -| Name | Type | Description | -|-----------------------|---------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `data_dir` | *Path\|None* | The directory where the chainlet can store and access data, e.g. for downloading model weights. | -| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#class-truss-chains-servicedescriptor)]* | A mapping from chainlet names to service descriptors. This is used create RPCs sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. | -| `secrets` | *Mapping[str,str]* | A mapping from secret names to secret values. It contains only the secrets that are listed in `remote_config.assets.secret_keys` of the current chainlet. | -| `environment` | *[Environment](#class-truss-chains-definitions-environment)\|None* | The environment that the chainlet is deployed in. None if the chainlet is not associated with an environment. | +| Name | Type | Description | +|-----------------------|-------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `data_dir` | *Path\|None* | The directory where the chainlet can store and access data, e.g. for downloading model weights. | +| `chainlet_to_service` | *Mapping[str,[DeployedServiceDescriptor](#class-truss-chains-deployedservicedescriptor)]* | A mapping from chainlet names to service descriptors. This is used to create RPC sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. | +| `secrets` | *Mapping[str,str]* | A mapping from secret names to secret values. It contains only the secrets that are listed in `remote_config.assets.secret_keys` of the current chainlet. | +| `environment` | *[Environment](#class-truss-chains-definitions-environment)\|None* | The environment that the chainlet is deployed in. None if the chainlet is not associated with an environment. | #### get_baseten_api_key() @@ -99,7 +105,8 @@ an access token for downloading model weights). | `chainlet_name` | *str* | The name of the chainlet. | * **Return type:** - [*ServiceDescriptor*](#class-truss-chains-servicedescriptor) + [*DeployedServiceDescriptor*](#class-truss-chains-deployedservicedescriptor) + ### *class* `truss_chains.definitions.Environment` @@ -110,6 +117,7 @@ The environment the chainlet is deployed in. * **Parameters:** **name** (*str*) – The name of the environment. + ### *class* `truss_chains.ChainletOptions` Bases: `pydantic.BaseModel` @@ -122,6 +130,7 @@ Bases: `pydantic.BaseModel` | `enable_b10_tracing` | *bool* | enables baseten-internal trace data collection. This helps baseten engineers better analyze chain performance in case of issues. It is independent of a potentially user-configured tracing instrumentation. Turning this on, could add performance overhead. | | `env_variables` | *Mapping[str,str]* | static environment variables available to the deployed chainlet. | + ### *class* `truss_chains.RPCOptions` Bases: `pydantic.BaseModel` @@ -130,10 +139,12 @@ Options to customize RPCs to dependency chainlets. **Parameters:** -| Name | Type | Description | -|---------------|-------|-------------| -| `timeout_sec` | *int* | | -| `retries` | *int* | | +| Name | Type | Description | +|---------------|--------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `timeout_sec` | *int* | | +| `retries` | *int* | | +| `use_binary` | *bool* | Whether to send data in binary format. This can give a parsing speedup and message size reduction (~25%) for numpy arrays. Use `NumpyArrayField` as a field type on pydantic models for integration and set this option to `True`. For simple text data, there is no significant benefit. | + ### `truss_chains.mark_entrypoint` @@ -167,6 +178,7 @@ class MyChainlet(ChainletBase): These data structures specify for each chainlet how it gets deployed remotely, e.g. dependencies and compute resources. + ### *class* `truss_chains.RemoteConfig` Bases: `pydantic.BaseModel` @@ -199,6 +211,7 @@ class MyChainlet(chains.ChainletBase): | `name` | *str\|None* | | | `options` | *[ChainletOptions](#class-truss-chains-chainletoptions)* | | + ### *class* `truss_chains.DockerImage` Bases: `pydantic.BaseModel` @@ -223,6 +236,7 @@ modules and keep their requirement files right next their python source files. | `data_dir` | *AbsPath\|None* | Data from this directory is copied into the docker image and accessible to the remote chainlet at runtime. | | `external_package_dirs` | *list[AbsPath]\|None* | A list of directories containing additional python packages outside the chain’s workspace dir, e.g. a shared library. This code is copied into the docker image and importable at runtime. | + ### *class* `truss_chains.BasetenImage` Bases: `Enum` @@ -251,6 +265,7 @@ Configures the usage of a custom image hosted on dockerhub. | `python_executable_path` | *str\|None* | Absolute path to python executable (if default `python` is ambiguous). | | `docker_auth` | *DockerAuthSettings\|None* | See [corresponding truss config](https://docs.baseten.co/truss-reference/config#base-image-docker-auth). | + ### *class* `truss_chains.Compute` Specifies which compute resources a chainlet has in the *remote* deployment. @@ -284,6 +299,7 @@ two ways: - With a threadpool if it’s a synchronous function. This requires that the threads don’t have significant CPU load (due to the GIL). + ### *class* `truss_chains.Assets` Specifies which assets a chainlet can access in the remote deployment. @@ -316,6 +332,7 @@ for more details on caching. General framework and helper functions. + ### `truss_chains.push` Deploys a chain remotely (with all dependent chainlets). @@ -323,26 +340,24 @@ Deploys a chain remotely (with all dependent chainlets). **Parameters:** -| Name | Type | Description | -|-------------------------|----------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `entrypoint` | *Type[[ChainletBase](#class-truss-chains-chainletbase)]* | The chainlet class that serves as the entrypoint to the chain. | -| `chain_name` | *str* | The name of the chain. | -| `publish` | *bool* | Whether to publish the chain as a published deployment (it is a draft deployment otherwise) | -| `promote` | *bool* | Whether to promote the chain to be the production deployment (this implies publishing as well). | -| `only_generate_trusses` | *bool* | Used for debugging purposes. If set to True, only the the underlying truss models for the chainlets are generated in `/tmp/.chains_generated`. | -| `remote` | *str\|None* | name of a remote config in .trussrc. If not provided, it will be inquired. | -| `environment` | *str\|None* | The name of an environment to promote deployment into. | +| Name | Type | Description | +|-------------------------|----------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------| +| `entrypoint` | *Type[[ChainletBase](#class-truss-chains-chainletbase)]* | The chainlet class that serves as the entrypoint to the chain. | +| `chain_name` | *str* | The name of the chain. | +| `publish` | *bool* | Whether to publish the chain as a published deployment (it is a draft deployment otherwise) | +| `promote` | *bool* | Whether to promote the chain to be the production deployment (this implies publishing as well). | +| `only_generate_trusses` | *bool* | Used for debugging purposes. If set to True, only the the underlying truss models for the chainlets are generated in `/tmp/.chains_generated`. | +| `remote` | *str\|None* | name of a remote config in .trussrc. If not provided, it will be inquired. | +| `environment` | *str\|None* | The name of an environment to promote deployment into. | +| `progress_bar` | *Type[progress.Progress]\|None* | Optional `rich.progress.Progress` if output is desired. | * **Returns:** A chain service handle to the deployed chain. * **Return type:** [*ChainService*](#class-truss-chains-remote-chainservice) -### `truss_chains.deploy_remotely` -Deprecated, use [`push`](#truss-chains-push) instead. - -### *class* `truss_chains.remote.ChainService` +### *class* `truss_chains.deployment.deployment_client.ChainService` Handle for a deployed chain. @@ -386,6 +401,7 @@ URL to invoke the entrypoint. Link to status page on Baseten. + ### `truss_chains.make_abs_path_here` Helper to specify file paths relative to the *immediately calling* module. @@ -447,6 +463,7 @@ foo("./somewhere") * **Return type:** *AbsPath* + ### `truss_chains.run_local` Context manager local debug execution of a chain. @@ -457,11 +474,11 @@ corresponding fields of **Parameters:** -| Name | Type | Description | -|-----------------------|--------------------------------------------------------------------------|----------------------------------------------------------------| -| `secrets` | *Mapping[str,str]\|None* | A dict of secrets keys and values to provide to the chainlets. | -| `data_dir` | *Path\|str\|None* | Path to a directory with data files. | -| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#class-truss-chains-servicedescriptor)* | A dict of chainlet names to service descriptors. | +| Name | Type | Description | +|-----------------------|-------------------------------------------------------------------------------------------|----------------------------------------------------------------| +| `secrets` | *Mapping[str,str]\|None* | A dict of secrets keys and values to provide to the chainlets. | +| `data_dir` | *Path\|str\|None* | Path to a directory with data files. | +| `chainlet_to_service` | *Mapping[str,[DeployedServiceDescriptor](#class-truss-chains-deployedservicedescriptor)]* | A dict of chainlet names to service descriptors. | * **Return type:** *ContextManager*[None] @@ -481,8 +498,9 @@ if __name__ == "__main__": with chains.run_local( secrets={"some_token": os.environ["SOME_TOKEN"]}, chainlet_to_service={ - "SomeChainlet": chains.ServiceDescriptor( + "SomeChainlet": chains.DeployedServiceDescriptor( name="SomeChainlet", + display_name="SomeChainlet", predict_url="https://...", options=chains.RPCOptions(), ) @@ -498,9 +516,10 @@ Refer to the [local debugging guide](https://docs.baseten.co/chains/guide#test-a-chain-locally) for more details. -### *class* `truss_chains.ServiceDescriptor` -Bases: `pydantic.BaseModel` +### *class* `truss_chains.DeployedServiceDescriptor` + +Bases: `ServiceDescriptor` Bundles values to establish an RPC session to a dependency chainlet, specifically with `StubBase`. @@ -512,14 +531,21 @@ specifically with `StubBase`. | `name` | *str* | | | `predict_url` | *str* | | | `options` | *[RPCOptions](#class-truss-chains-rpcoptions)* | | +| `predict_url` | *str* | | + -## *class* `truss_chains.StubBase` +### *class* `truss_chains.StubBase` + +Bases: `BasetenSession`, `ABC` Base class for stubs that invoke remote chainlets. +Extends `BasetenSession` with methods for data serialization, de-serialization +and invoking other endpoints. + It is used internally for RPCs to dependency chainlets, but it can also be used -in user-code for wrapping a deployed truss model into the chains framework, e.g. -like that: +in user-code for wrapping a deployed truss model into the Chains framework. It +flexibly supports JSON and pydantic inputs and output. Example usage: ```python import pydantic @@ -532,10 +558,20 @@ class WhisperOutput(pydantic.BaseModel): class DeployedWhisper(chains.StubBase): + # Input JSON, output JSON. + async def run_remote(self, audio_b64: str) -> Any: + return await self.predict_async( + inputs={"audio": audio_b64}) + # resp == {"text": ..., "language": ...} + + # OR Input JSON, output pydantic model. async def run_remote(self, audio_b64: str) -> WhisperOutput: - resp = await self._remote.predict_async( - json_payload={"audio": audio_b64}) - return WhisperOutput(text=resp["text"], language=resp["language"]) + return await self.predict_async( + inputs={"audio": audio_b64}, output_model=WhisperOutput) + + # OR Input and output are pydantic models. + async def run_remote(self, data: WhisperInput) -> WhisperOutput: + return await self.predict_async(data, output_model=WhisperOutput) class MyChainlet(chains.ChainletBase): @@ -547,14 +583,19 @@ class MyChainlet(chains.ChainletBase): context, options=chains.RPCOptions(retries=3), ) + + async def run_remote(self, ...): + await self._whisper.run_remote(...) ``` + **Parameters:** -| Name | Type | Description | -|----------------------|--------------------------------------------------------------|-------------------------------------------| -| `service_descriptor` | *[ServiceDescriptor](#class-truss-chains-servicedescriptor)* | Contains the URL and other configuration. | -| `api_key` | *str* | A baseten API key to authorize requests. | +| Name | Type | Description | +|----------------------|-------------------------------------------------------------------------------|-------------------------------------------| +| `service_descriptor` | *[DeployedServiceDescriptor](#class-truss-chains-deployedservicedescriptor)]* | Contains the URL and other configuration. | +| `api_key` | *str* | A baseten API key to authorize requests. | + #### *classmethod* from_url(predict_url, context, options=None) @@ -568,6 +609,21 @@ Factory method, convenient to be used in chainlet’s `__init__`-method. | `context` | *[DeploymentContext](#class-truss-chains-deploymentcontext)* | Deployment context object, obtained in the chainlet’s `__init__`. | | `options` | *[RPCOptions](#class-truss-chains-rpcoptions)* | RPC options, e.g. retries. | +#### *async* predict_async(inputs: PydanticModel, output_model: Type[PydanticModel]) → PydanticModel + +#### *async* predict_async(inputs: JSON, output_model: Type[PydanticModel]) → PydanticModel + +#### *async* predict_async(inputs: JSON) → JSON + +#### *async* predict_async_stream(inputs: PydanticModel | JSON) -> AsyncIterator[bytes] + +#### predict_sync(inputs: PydanticModel, output_model: Type[PydanticModel]) → PydanticModel + +#### predict_sync(inputs: JSON, output_model: Type[PydanticModel]) → PydanticModel + +#### predict_sync(inputs: JSON) → JSON + + ### *class* `truss_chains.RemoteErrorDetail` Bases: `pydantic.BaseModel` @@ -580,7 +636,6 @@ error response. | Name | Type | Description | |-------------------------|--------------------|-------------| -| `remote_name` | *str* | | | `exception_cls_name` | *str* | | | `exception_module_name` | *str\|None* | | | `exception_message` | *str* | | diff --git a/docs/chains/doc_gen/generate_reference.py b/docs/chains/doc_gen/generate_reference.py index 9c92c5112..131afe1c2 100644 --- a/docs/chains/doc_gen/generate_reference.py +++ b/docs/chains/doc_gen/generate_reference.py @@ -30,7 +30,7 @@ NON_PUBLIC_SYMBOLS = [ # "truss_chains.definitions.AssetSpec", # "truss_chains.definitions.ComputeSpec", - "truss_chains.remote.ChainService", + "truss_chains.deployment.deployment_client.ChainService", "truss_chains.definitions.Environment", ] @@ -69,7 +69,7 @@ "General framework and helper functions.", [ "truss_chains.push", - "truss_chains.remote.ChainService", + "truss_chains.deployment.deployment_client.ChainService", "truss_chains.make_abs_path_here", "truss_chains.run_local", "truss_chains.DeployedServiceDescriptor", diff --git a/docs/chains/doc_gen/generated-reference.mdx b/docs/chains/doc_gen/generated-reference.mdx index d0a3bb48b..bbc226cda 100644 --- a/docs/chains/doc_gen/generated-reference.mdx +++ b/docs/chains/doc_gen/generated-reference.mdx @@ -7,6 +7,7 @@ https://github.com/basetenlabs/truss/tree/main/docs/chains/doc_gen APIs for creating user-defined Chainlets. + ### *class* `truss_chains.ChainletBase` Base class for all chainlets. @@ -41,8 +42,9 @@ chainlet instance is provided. | Name | Type | Description | |------|------|-------------| | `chainlet_cls` | *Type[ChainletT]* | The chainlet class of the dependency. | -| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). | +| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). For streaming, retries are only made if the request fails before streaming any results back. Failures mid-stream not retried. | | `timeout_sec` | *int* | Timeout for the HTTP request to this chainlet. | +| `use_binary` | *bool* | Whether to send data in binary format. This can give a parsing speedup and message size reduction (~25%) for numpy arrays. Use `NumpyArrayField` as a field type on pydantic models for integration and set this option to `True`. For simple text data, there is no significant benefit. | * **Returns:** A “symbolic marker” to be used as a default argument in a chainlet’s @@ -69,6 +71,7 @@ context instance is provided. * **Return type:** [*DeploymentContext*](#truss_chains.DeploymentContext) + ### *class* `truss_chains.DeploymentContext` Bases: `pydantic.BaseModel` @@ -85,12 +88,11 @@ an access token for downloading model weights). | Name | Type | Description | |------|------|-------------| | `data_dir` | *Path\|None* | The directory where the chainlet can store and access data, e.g. for downloading model weights. | -| `user_config` | ** | User-defined configuration for the chainlet. | -| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#truss_chains.ServiceDescriptor* | A mapping from chainlet names to service descriptors. This is used create RPCs sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. | +| `chainlet_to_service` | *Mapping[str,[DeployedServiceDescriptor](#truss_chains.DeployedServiceDescriptor* | A mapping from chainlet names to service descriptors. This is used to create RPC sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. | | `secrets` | *MappingNoIter[str,str]* | A mapping from secret names to secret values. It contains only the secrets that are listed in `remote_config.assets.secret_keys` of the current chainlet. | | `environment` | *[Environment](#truss_chains.definitions.Environment* | The environment that the chainlet is deployed in. None if the chainlet is not associated with an environment. | -#### chainlet_to_service *: Mapping[str, [ServiceDescriptor](#truss_chains.ServiceDescriptor)]* +#### chainlet_to_service *: Mapping[str, [DeployedServiceDescriptor](#truss_chains.DeployedServiceDescriptor)]* #### data_dir *: Path | None* @@ -106,10 +108,11 @@ an access token for downloading model weights). * **Parameters:** **chainlet_name** (*str*) * **Return type:** - [*ServiceDescriptor*](#truss_chains.ServiceDescriptor) + [*DeployedServiceDescriptor*](#truss_chains.DeployedServiceDescriptor) #### secrets *: MappingNoIter[str, str]* + ### *class* `truss_chains.definitions.Environment` Bases: `pydantic.BaseModel` @@ -120,6 +123,7 @@ The environment the chainlet is deployed in. **name** (*str*) – The name of the environment. #### name *: str* + ### *class* `truss_chains.ChainletOptions` Bases: `pydantic.BaseModel` @@ -136,24 +140,28 @@ Bases: `pydantic.BaseModel` #### env_variables *: Mapping[str, str]* + ### *class* `truss_chains.RPCOptions` Bases: `pydantic.BaseModel` Options to customize RPCs to dependency chainlets. + **Parameters:** | Name | Type | Description | |------|------|-------------| -| `timeout_sec` | *int* | | -| `retries` | *int* | | - +| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). For streaming, retries are only made if the request fails before streaming any results back. Failures mid-stream not retried. | +| `timeout_sec` | *int* | Timeout for the HTTP request to this chainlet. | +| `use_binary` | *bool* | Whether to send data in binary format. This can give a parsing speedup and message size reduction (~25%) for numpy arrays. Use `NumpyArrayField` as a field type on pydantic models for integration and set this option to `True`. For simple text data, there is no significant benefit. | #### retries *: int* #### timeout_sec *: int* +#### use_binary *: bool* + ### `truss_chains.mark_entrypoint` Decorator to mark a chainlet as the entrypoint of a chain. @@ -181,6 +189,7 @@ class MyChainlet(ChainletBase): These data structures specify for each chainlet how it gets deployed remotely, e.g. dependencies and compute resources. + ### *class* `truss_chains.RemoteConfig` Bases: `pydantic.BaseModel` @@ -234,6 +243,7 @@ class MyChainlet(chains.ChainletBase): #### options *: [ChainletOptions](#truss_chains.ChainletOptions)* + ### *class* `truss_chains.DockerImage` Bases: `pydantic.BaseModel` @@ -266,10 +276,13 @@ modules and keep their requirement files right next their python source files. #### external_package_dirs *: list[AbsPath] | None* +#### *classmethod* migrate_fields(values) + #### pip_requirements *: list[str]* #### pip_requirements_file *: AbsPath | None* + ### *class* `truss_chains.BasetenImage` Bases: `Enum` @@ -283,6 +296,7 @@ uses GPUs, drivers will be included in the image. #### PY39 *= 'py39'* + ### *class* `truss_chains.CustomImage` Bases: `pydantic.BaseModel` @@ -304,6 +318,7 @@ Configures the usage of a custom image hosted on dockerhub. #### python_executable_path *: str | None* + ### *class* `truss_chains.Compute` Specifies which compute resources a chainlet has in the *remote* deployment. @@ -342,6 +357,7 @@ two ways: * **Return type:** *ComputeSpec* + ### *class* `truss_chains.Assets` Specifies which assets a chainlet can access in the remote deployment. @@ -355,7 +371,7 @@ from truss.base import truss_config mistral_cache = truss_config.ModelRepo( repo_id="mistralai/Mistral-7B-Instruct-v0.2", allow_patterns=["*.json", "*.safetensors", ".model"] - ) +) chains.Assets(cached=[mistral_cache], ...) ``` @@ -402,15 +418,17 @@ Deploys a chain remotely (with all dependent chainlets). | `publish` | *bool* | Whether to publish the chain as a published deployment (it is a draft deployment otherwise) | | `promote` | *bool* | Whether to promote the chain to be the production deployment (this implies publishing as well). | | `only_generate_trusses` | *bool* | Used for debugging purposes. If set to True, only the the underlying truss models for the chainlets are generated in `/tmp/.chains_generated`. | -| `remote` | *str\|None* | name of a remote config in .trussrc. If not provided, it will be inquired. | +| `remote` | *str* | name of a remote config in .trussrc. If not provided, it will be inquired. | | `environment` | *str\|None* | The name of an environment to promote deployment into. | +| `progress_bar` | *Type[progress.Progress]\|None* | Optional rich.progress.Progress if output is desired. | * **Returns:** A chain service handle to the deployed chain. * **Return type:** *BasetenChainService* -### *class* `truss_chains.remote.ChainService` + +### *class* `truss_chains.deployment.deployment_client.ChainService` Bases: `ABC` @@ -535,7 +553,7 @@ corresponding fields of `DeploymentContext`. |------|------|-------------| | `secrets` | *Mapping[str,str]\|None* | A dict of secrets keys and values to provide to the chainlets. | | `data_dir` | *Path\|str\|None* | Path to a directory with data files. | -| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#truss_chains.ServiceDescriptor* | A dict of chainlet names to service descriptors. | +| `chainlet_to_service` | *Mapping[str,[DeployedServiceDescriptor](#truss_chains.DeployedServiceDescriptor* | A dict of chainlet names to service descriptors. | * **Return type:** *ContextManager*[None] @@ -555,8 +573,9 @@ if __name__ == "__main__": with chains.run_local( secrets={"some_token": os.environ["SOME_TOKEN"]}, chainlet_to_service={ - "SomeChainlet": chains.ServiceDescriptor( + "SomeChainlet": chains.DeployedServiceDescriptor( name="SomeChainlet", + display_name="SomeChainlet", predict_url="https://...", options=chains.RPCOptions(), ) @@ -571,52 +590,61 @@ if __name__ == "__main__": Refer to the [local debugging guide](https://docs.baseten.co/chains/guide#test-a-chain-locally) for more details. -### *class* `truss_chains.ServiceDescriptor` -Bases: `pydantic.BaseModel` +### *class* `truss_chains.DeployedServiceDescriptor` -Bundles values to establish an RPC session to a dependency chainlet, -specifically with `StubBase`. +Bases: `ServiceDescriptor` **Parameters:** | Name | Type | Description | |------|------|-------------| | `name` | *str* | | -| `predict_url` | *str* | | +| `display_name` | *str* | | | `options` | *[RPCOptions](#truss_chains.RPCOptions* | | +| `predict_url` | *str* | | -#### name *: str* - -#### options *: [RPCOptions](#truss_chains.RPCOptions)* - #### predict_url *: str* + ### *class* `truss_chains.StubBase` -Bases: `ABC` +Bases: `BasetenSession`, `ABC` Base class for stubs that invoke remote chainlets. +Extends `BasetenSession` with methods for data serialization, de-serialization +and invoking other endpoints. + It is used internally for RPCs to dependency chainlets, but it can also be used -in user-code for wrapping a deployed truss model into the chains framework, e.g. -like that: +in user-code for wrapping a deployed truss model into the Chains framework. It +flexibly supports JSON and pydantic inputs and output. Example usage: ```default import pydantic import truss_chains as chains + class WhisperOutput(pydantic.BaseModel): ... class DeployedWhisper(chains.StubBase): + # Input JSON, output JSON. + async def run_remote(self, audio_b64: str) -> Any: + return await self.predict_async( + inputs={"audio": audio_b64}) + # resp == {"text": ..., "language": ...} + # OR Input JSON, output pydantic model. async def run_remote(self, audio_b64: str) -> WhisperOutput: - resp = await self._remote.predict_async( - json_payload={"audio": audio_b64}) - return WhisperOutput(text=resp["text"], language=resp["language"]) + return await self.predict_async( + inputs={"audio": audio_b64}, output_model=WhisperOutput) + + # OR Input and output are pydantic models. + async def run_remote(self, data: WhisperInput) -> WhisperOutput: + return await self.predict_async(data, output_model=WhisperOutput) class MyChainlet(chains.ChainletBase): @@ -628,6 +656,9 @@ class MyChainlet(chains.ChainletBase): context, options=chains.RPCOptions(retries=3), ) + + async def run_remote(self, ...): + await self._whisper.run_remote(...) ``` @@ -635,7 +666,7 @@ class MyChainlet(chains.ChainletBase): | Name | Type | Description | |------|------|-------------| -| `service_descriptor` | *[ServiceDescriptor](#truss_chains.ServiceDescriptor* | Contains the URL and other configuration. | +| `service_descriptor` | *[DeployedServiceDescriptor](#truss_chains.DeployedServiceDescriptor* | Contains the URL and other configuration. | | `api_key` | *str* | A baseten API key to authorize requests. | @@ -653,6 +684,22 @@ Factory method, convenient to be used in chainlet’s `__init__`-method. | `options` | *[RPCOptions](#truss_chains.RPCOptions* | RPC options, e.g. retries. | +#### *async* predict_async(inputs: InputT, output_model: Type[OutputModelT]) → OutputModelT + +#### *async* predict_async(inputs: InputT, output_model: None = None) → Any + +#### *async* predict_async_stream(inputs) + +* **Parameters:** + **inputs** (*InputT*) +* **Return type:** + *AsyncIterator*[bytes] + +#### predict_sync(inputs: InputT, output_model: Type[OutputModelT]) → OutputModelT + +#### predict_sync(inputs: InputT, output_model: None = None) → Any + + ### *class* `truss_chains.RemoteErrorDetail` Bases: `pydantic.BaseModel` @@ -665,7 +712,6 @@ error response. | Name | Type | Description | |------|------|-------------| -| `remote_name` | *str* | | | `exception_cls_name` | *str* | | | `exception_module_name` | *str\|None* | | | `exception_message` | *str* | | @@ -686,6 +732,4 @@ with stack traces. * **Return type:** str -#### remote_name *: str* - #### user_stack_trace *: list[StackFrame]* diff --git a/docs/chains/doc_gen/mdx_adapter.py b/docs/chains/doc_gen/mdx_adapter.py index 5661b5fb6..16a4ba9a4 100644 --- a/docs/chains/doc_gen/mdx_adapter.py +++ b/docs/chains/doc_gen/mdx_adapter.py @@ -75,7 +75,7 @@ def _line_replacements(line: str) -> str: first_brace = line.find("(") if first_brace > 0: line = line[:first_brace] - return f"### *class* `{line}`" + return f"\n### *class* `{line}`" elif line.startswith("### "): line = line.replace("### ", "").strip() if not any(sym in line for sym in NON_PUBLIC_SYMBOLS): diff --git a/docs/chains/doc_gen/reference.patch b/docs/chains/doc_gen/reference.patch index fd8d8fa29..fa9ba9dd5 100644 --- a/docs/chains/doc_gen/reference.patch +++ b/docs/chains/doc_gen/reference.patch @@ -1,6 +1,13 @@ ---- docs/chains/doc_gen/generated-reference.mdx 2024-11-14 15:10:37.862189314 -0800 -+++ docs/chains/doc_gen/API-reference.mdx 2024-11-18 12:04:23.725353699 -0800 -@@ -24,31 +24,28 @@ +--- docs/chains/doc_gen/generated-reference.mdx 2024-12-12 12:51:17.671757358 -0800 ++++ docs/chains/doc_gen/API-reference.mdx 2024-12-12 12:56:32.358153491 -0800 +@@ -19,38 +19,38 @@ + [example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py) + for more guidance on how to create subclasses. + ++ + ### `truss_chains.depends` + + Sets a “symbolic marker” to indicate to the framework that a chainlet is a dependency of another chainlet. The return value of `depends` is intended to be used as a default argument in a chainlet’s `__init__`-method. When deploying a chain remotely, a corresponding stub to the remote is injected in @@ -24,23 +31,26 @@ -| Name | Type | Description | -|------|------|-------------| -| `chainlet_cls` | *Type[ChainletT]* | The chainlet class of the dependency. | --| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). | +-| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). For streaming, retries are only made if the request fails before streaming any results back. Failures mid-stream not retried. | -| `timeout_sec` | *int* | Timeout for the HTTP request to this chainlet. | -- -+| Name | Type | Description | -+|----------------|----------------------------------------------------------|--------------------------------------------------------------------------------------------------------------| -+| `chainlet_cls` | *Type[[ChainletBase](#class-truss-chains-chainletbase)]* | The chainlet class of the dependency. | -+| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). | -+| `timeout_sec` | *int* | Timeout for the HTTP request to this chainlet. | +-| `use_binary` | *bool* | Whether to send data in binary format. This can give a parsing speedup and message size reduction (~25%) for numpy arrays. Use `NumpyArrayField` as a field type on pydantic models for integration and set this option to `True`. For simple text data, there is no significant benefit. | ++| Name | Type | Description | ++|----------------|----------------------------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ++| `chainlet_cls` | *Type[[ChainletBase](#class-truss-chains-chainletbase)]* | The chainlet class of the dependency. | ++| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). | ++| `timeout_sec` | *int* | Timeout for the HTTP request to this chainlet. | ++| `use_binary` | *bool* | Whether to send data in binary format. This can give a parsing speedup and message size reduction (~25%) for numpy arrays. Use `NumpyArrayField` as a field type on pydantic models for integration and set this option to `True`. For simple text data, there is no significant benefit. | + * **Returns:** A “symbolic marker” to be used as a default argument in a chainlet’s initializer. -* **Return type:** - *ChainletT* ++ ### `truss_chains.depends_context` -@@ -58,16 +55,15 @@ +@@ -60,16 +60,15 @@ [example chainlet](https://github.com/basetenlabs/truss/blob/main/truss-chains/truss_chains/example_chainlet.py) for more guidance on the `__init__`-signature of chainlets. @@ -57,42 +67,41 @@ -* **Return type:** - [*DeploymentContext*](#truss_chains.DeploymentContext) - ### *class* `truss_chains.DeploymentContext` -@@ -82,19 +78,12 @@ + ### *class* `truss_chains.DeploymentContext` +@@ -85,18 +84,12 @@ **Parameters:** -| Name | Type | Description | -|------|------|-------------| -| `data_dir` | *Path\|None* | The directory where the chainlet can store and access data, e.g. for downloading model weights. | --| `user_config` | ** | User-defined configuration for the chainlet. | --| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#truss_chains.ServiceDescriptor* | A mapping from chainlet names to service descriptors. This is used create RPCs sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. | +-| `chainlet_to_service` | *Mapping[str,[DeployedServiceDescriptor](#truss_chains.DeployedServiceDescriptor* | A mapping from chainlet names to service descriptors. This is used to create RPC sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. | -| `secrets` | *MappingNoIter[str,str]* | A mapping from secret names to secret values. It contains only the secrets that are listed in `remote_config.assets.secret_keys` of the current chainlet. | -| `environment` | *[Environment](#truss_chains.definitions.Environment* | The environment that the chainlet is deployed in. None if the chainlet is not associated with an environment. | - --#### chainlet_to_service *: Mapping[str, [ServiceDescriptor](#truss_chains.ServiceDescriptor)]* +-#### chainlet_to_service *: Mapping[str, [DeployedServiceDescriptor](#truss_chains.DeployedServiceDescriptor)]* - -#### data_dir *: Path | None* - -#### environment *: [Environment](#truss_chains.definitions.Environment) | None* -+| Name | Type | Description | -+|-----------------------|---------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -+| `data_dir` | *Path\|None* | The directory where the chainlet can store and access data, e.g. for downloading model weights. | -+| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#class-truss-chains-servicedescriptor)]* | A mapping from chainlet names to service descriptors. This is used create RPCs sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. | -+| `secrets` | *Mapping[str,str]* | A mapping from secret names to secret values. It contains only the secrets that are listed in `remote_config.assets.secret_keys` of the current chainlet. | -+| `environment` | *[Environment](#class-truss-chains-definitions-environment)\|None* | The environment that the chainlet is deployed in. None if the chainlet is not associated with an environment. | ++| Name | Type | Description | ++|-----------------------|-------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ++| `data_dir` | *Path\|None* | The directory where the chainlet can store and access data, e.g. for downloading model weights. | ++| `chainlet_to_service` | *Mapping[str,[DeployedServiceDescriptor](#class-truss-chains-deployedservicedescriptor)]* | A mapping from chainlet names to service descriptors. This is used to create RPC sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. | ++| `secrets` | *Mapping[str,str]* | A mapping from secret names to secret values. It contains only the secrets that are listed in `remote_config.assets.secret_keys` of the current chainlet. | ++| `environment` | *[Environment](#class-truss-chains-definitions-environment)\|None* | The environment that the chainlet is deployed in. None if the chainlet is not associated with an environment. | #### get_baseten_api_key() -@@ -103,12 +92,14 @@ +@@ -105,12 +98,14 @@ #### get_service_descriptor(chainlet_name) -* **Parameters:** - **chainlet_name** (*str*) -* **Return type:** -- [*ServiceDescriptor*](#truss_chains.ServiceDescriptor) +- [*DeployedServiceDescriptor*](#truss_chains.DeployedServiceDescriptor) +**Parameters:** -#### secrets *: MappingNoIter[str, str]* @@ -101,19 +110,19 @@ +| `chainlet_name` | *str* | The name of the chainlet. | + +* **Return type:** -+ [*ServiceDescriptor*](#class-truss-chains-servicedescriptor) ++ [*DeployedServiceDescriptor*](#class-truss-chains-deployedservicedescriptor) - ### *class* `truss_chains.definitions.Environment` -@@ -118,7 +109,6 @@ + ### *class* `truss_chains.definitions.Environment` +@@ -121,7 +116,6 @@ * **Parameters:** **name** (*str*) – The name of the environment. -#### name *: str* - ### *class* `truss_chains.ChainletOptions` -@@ -127,14 +117,10 @@ + ### *class* `truss_chains.ChainletOptions` +@@ -131,14 +125,10 @@ **Parameters:** @@ -130,29 +139,35 @@ +| `enable_b10_tracing` | *bool* | enables baseten-internal trace data collection. This helps baseten engineers better analyze chain performance in case of issues. It is independent of a potentially user-configured tracing instrumentation. Turning this on, could add performance overhead. | +| `env_variables` | *Mapping[str,str]* | static environment variables available to the deployed chainlet. | + ### *class* `truss_chains.RPCOptions` +@@ -147,20 +137,14 @@ -@@ -144,15 +130,10 @@ + Options to customize RPCs to dependency chainlets. +- **Parameters:** -| Name | Type | Description | -|------|------|-------------| --| `timeout_sec` | *int* | | --| `retries` | *int* | | -- +-| `retries` | *int* | The number of times to retry the remote chainlet in case of failures (e.g. due to transient network issues). For streaming, retries are only made if the request fails before streaming any results back. Failures mid-stream not retried. | +-| `timeout_sec` | *int* | Timeout for the HTTP request to this chainlet. | +-| `use_binary` | *bool* | Whether to send data in binary format. This can give a parsing speedup and message size reduction (~25%) for numpy arrays. Use `NumpyArrayField` as a field type on pydantic models for integration and set this option to `True`. For simple text data, there is no significant benefit. | - -#### retries *: int* - -#### timeout_sec *: int* -+| Name | Type | Description | -+|---------------|-------|-------------| -+| `timeout_sec` | *int* | | -+| `retries` | *int* | | ++| Name | Type | Description | ++|---------------|--------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| ++| `timeout_sec` | *int* | | ++| `retries` | *int* | | ++| `use_binary` | *bool* | Whether to send data in binary format. This can give a parsing speedup and message size reduction (~25%) for numpy arrays. Use `NumpyArrayField` as a field type on pydantic models for integration and set this option to `True`. For simple text data, there is no significant benefit. | + +-#### use_binary *: bool* ### `truss_chains.mark_entrypoint` -@@ -164,18 +145,23 @@ +@@ -172,18 +156,23 @@ Example usage: @@ -180,7 +195,7 @@ # Remote Configuration -@@ -189,7 +175,7 @@ +@@ -198,7 +187,7 @@ This is specified as a class variable for each chainlet class, e.g.: @@ -189,7 +204,7 @@ import truss_chains as chains -@@ -205,34 +191,13 @@ +@@ -214,34 +203,13 @@ **Parameters:** @@ -229,9 +244,9 @@ +| `name` | *str\|None* | | +| `options` | *[ChainletOptions](#class-truss-chains-chainletoptions)* | | - ### *class* `truss_chains.DockerImage` -@@ -240,35 +205,23 @@ + ### *class* `truss_chains.DockerImage` +@@ -250,37 +218,23 @@ Configures the docker image in which a remoted chainlet is deployed. @@ -264,6 +279,8 @@ - -#### external_package_dirs *: list[AbsPath] | None* - +-#### *classmethod* migrate_fields(values) +- -#### pip_requirements *: list[str]* - -#### pip_requirements_file *: AbsPath | None* @@ -276,26 +293,26 @@ +| `data_dir` | *AbsPath\|None* | Data from this directory is copied into the docker image and accessible to the remote chainlet at runtime. | +| `external_package_dirs` | *list[AbsPath]\|None* | A list of directories containing additional python packages outside the chain’s workspace dir, e.g. a shared library. This code is copied into the docker image and importable at runtime. | - ### *class* `truss_chains.BasetenImage` -@@ -277,11 +230,12 @@ + ### *class* `truss_chains.BasetenImage` +@@ -290,11 +244,11 @@ Default images, curated by baseten, for different python versions. If a Chainlet uses GPUs, drivers will be included in the image. -#### PY310 *= 'py310'* - -#### PY311 *= 'py311'* +- +-#### PY39 *= 'py39'* +| Enum Member | Value | +|-------------|---------| +| `PY310` | *py310* | +| `PY311 ` | *py311* | +| `PY39` | *py39* | --#### PY39 *= 'py39'* ### *class* `truss_chains.CustomImage` - -@@ -291,42 +245,35 @@ +@@ -305,43 +259,36 @@ **Parameters:** @@ -317,6 +334,7 @@ +| `python_executable_path` | *str\|None* | Absolute path to python executable (if default `python` is ambiguous). | +| `docker_auth` | *DockerAuthSettings\|None* | See [corresponding truss config](https://docs.baseten.co/truss-reference/config#base-image-docker-auth). | + ### *class* `truss_chains.Compute` Specifies which compute resources a chainlet has in the *remote* deployment. @@ -354,7 +372,7 @@ It is important to understand the difference between predict_concurrency and the concurrency target (used for autoscaling, i.e. adding or removing replicas). Furthermore, the `predict_concurrency` of a single instance is implemented in -@@ -337,52 +284,33 @@ +@@ -352,11 +299,6 @@ - With a threadpool if it’s a synchronous function. This requires that the threads don’t have significant CPU load (due to the GIL). @@ -363,9 +381,10 @@ -* **Return type:** - *ComputeSpec* - + ### *class* `truss_chains.Assets` - Specifies which assets a chainlet can access in the remote deployment. +@@ -364,7 +306,7 @@ For example, model weight caching can be used like this: @@ -374,14 +393,7 @@ import truss_chains as chains from truss.base import truss_config - mistral_cache = truss_config.ModelRepo( - repo_id="mistralai/Mistral-7B-Instruct-v0.2", - allow_patterns=["*.json", "*.safetensors", ".model"] -- ) -+) - chains.Assets(cached=[mistral_cache], ...) - ``` - +@@ -378,32 +320,19 @@ See [truss caching guide](https://docs.baseten.co/deploy/guides/model-cache#enabling-caching-for-a-model) for more details on caching. @@ -414,7 +426,13 @@ # Core -@@ -395,24 +323,26 @@ + General framework and helper functions. + ++ + ### `truss_chains.push` + + Deploys a chain remotely (with all dependent chainlets). +@@ -411,56 +340,38 @@ **Parameters:** @@ -425,17 +443,19 @@ -| `publish` | *bool* | Whether to publish the chain as a published deployment (it is a draft deployment otherwise) | -| `promote` | *bool* | Whether to promote the chain to be the production deployment (this implies publishing as well). | -| `only_generate_trusses` | *bool* | Used for debugging purposes. If set to True, only the the underlying truss models for the chainlets are generated in `/tmp/.chains_generated`. | --| `remote` | *str\|None* | name of a remote config in .trussrc. If not provided, it will be inquired. | +-| `remote` | *str* | name of a remote config in .trussrc. If not provided, it will be inquired. | -| `environment` | *str\|None* | The name of an environment to promote deployment into. | -+| Name | Type | Description | -+|-------------------------|----------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -+| `entrypoint` | *Type[[ChainletBase](#class-truss-chains-chainletbase)]* | The chainlet class that serves as the entrypoint to the chain. | -+| `chain_name` | *str* | The name of the chain. | -+| `publish` | *bool* | Whether to publish the chain as a published deployment (it is a draft deployment otherwise) | -+| `promote` | *bool* | Whether to promote the chain to be the production deployment (this implies publishing as well). | -+| `only_generate_trusses` | *bool* | Used for debugging purposes. If set to True, only the the underlying truss models for the chainlets are generated in `/tmp/.chains_generated`. | -+| `remote` | *str\|None* | name of a remote config in .trussrc. If not provided, it will be inquired. | -+| `environment` | *str\|None* | The name of an environment to promote deployment into. | +-| `progress_bar` | *Type[progress.Progress]\|None* | Optional rich.progress.Progress if output is desired. | ++| Name | Type | Description | ++|-------------------------|----------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------------| ++| `entrypoint` | *Type[[ChainletBase](#class-truss-chains-chainletbase)]* | The chainlet class that serves as the entrypoint to the chain. | ++| `chain_name` | *str* | The name of the chain. | ++| `publish` | *bool* | Whether to publish the chain as a published deployment (it is a draft deployment otherwise) | ++| `promote` | *bool* | Whether to promote the chain to be the production deployment (this implies publishing as well). | ++| `only_generate_trusses` | *bool* | Used for debugging purposes. If set to True, only the the underlying truss models for the chainlets are generated in `/tmp/.chains_generated`. | ++| `remote` | *str\|None* | name of a remote config in .trussrc. If not provided, it will be inquired. | ++| `environment` | *str\|None* | The name of an environment to promote deployment into. | ++| `progress_bar` | *Type[progress.Progress]\|None* | Optional `rich.progress.Progress` if output is desired. | * **Returns:** A chain service handle to the deployed chain. @@ -443,17 +463,14 @@ - *BasetenChainService* + [*ChainService*](#class-truss-chains-remote-chainservice) --### *class* `truss_chains.remote.ChainService` -+### `truss_chains.deploy_remotely` -+ -+Deprecated, use [`push`](#truss-chains-push) instead. --Bases: `ABC` -+### *class* `truss_chains.remote.ChainService` + ### *class* `truss_chains.deployment.deployment_client.ChainService` +-Bases: `ABC` +- Handle for a deployed chain. -@@ -420,29 +350,13 @@ + A `ChainService` is created and returned when using `push`. It bundles the individual services for each chainlet in the chain, and provides utilities to query their status, invoke the entrypoint etc. @@ -486,7 +503,7 @@ * **Return type:** list[*DeployedChainlet*] -@@ -452,18 +366,23 @@ +@@ -470,21 +381,27 @@ Invokes the entrypoint with JSON data. @@ -514,7 +531,11 @@ Link to status page on Baseten. -@@ -485,12 +404,12 @@ ++ + ### `truss_chains.make_abs_path_here` + + Helper to specify file paths relative to the *immediately calling* module. +@@ -503,12 +420,12 @@ You can now in `root/sub_package/chainlet.py` point to the requirements file like this: @@ -529,7 +550,7 @@ This helper uses the directory of the immediately calling module as an absolute reference point for resolving the file location. Therefore, you MUST NOT wrap the instantiation of `make_abs_path_here` into a -@@ -498,7 +417,7 @@ +@@ -516,7 +433,7 @@ Ok: @@ -538,7 +559,7 @@ def foo(path: AbsPath): abs_path = path.abs_path -@@ -508,7 +427,7 @@ +@@ -526,7 +443,7 @@ Not Ok: @@ -547,7 +568,7 @@ def foo(path: str): dangerous_value = make_abs_path_here(path).abs_path -@@ -516,8 +435,15 @@ +@@ -534,33 +451,41 @@ foo("./somewhere") ``` @@ -565,7 +586,9 @@ * **Return type:** *AbsPath* -@@ -526,23 +452,23 @@ ++ + ### `truss_chains.run_local` + Context manager local debug execution of a chain. The arguments only need to be provided if the chainlets explicitly access any the @@ -580,12 +603,12 @@ -|------|------|-------------| -| `secrets` | *Mapping[str,str]\|None* | A dict of secrets keys and values to provide to the chainlets. | -| `data_dir` | *Path\|str\|None* | Path to a directory with data files. | --| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#truss_chains.ServiceDescriptor* | A dict of chainlet names to service descriptors. | -+| Name | Type | Description | -+|-----------------------|--------------------------------------------------------------------------|----------------------------------------------------------------| -+| `secrets` | *Mapping[str,str]\|None* | A dict of secrets keys and values to provide to the chainlets. | -+| `data_dir` | *Path\|str\|None* | Path to a directory with data files. | -+| `chainlet_to_service` | *Mapping[str,[ServiceDescriptor](#class-truss-chains-servicedescriptor)* | A dict of chainlet names to service descriptors. | +-| `chainlet_to_service` | *Mapping[str,[DeployedServiceDescriptor](#truss_chains.DeployedServiceDescriptor* | A dict of chainlet names to service descriptors. | ++| Name | Type | Description | ++|-----------------------|-------------------------------------------------------------------------------------------|----------------------------------------------------------------| ++| `secrets` | *Mapping[str,str]\|None* | A dict of secrets keys and values to provide to the chainlets. | ++| `data_dir` | *Path\|str\|None* | Path to a directory with data files. | ++| `chainlet_to_service` | *Mapping[str,[DeployedServiceDescriptor](#class-truss-chains-deployedservicedescriptor)]* | A dict of chainlet names to service descriptors. | * **Return type:** *ContextManager*[None] @@ -597,7 +620,7 @@ import os import truss_chains as chains -@@ -568,7 +494,8 @@ +@@ -587,7 +512,8 @@ print(result) ``` @@ -606,73 +629,80 @@ +[local debugging guide](https://docs.baseten.co/chains/guide#test-a-chain-locally) for more details. - ### *class* `truss_chains.ServiceDescriptor` -@@ -580,22 +507,13 @@ - **Parameters:** +@@ -595,17 +521,17 @@ + + Bases: `ServiceDescriptor` +-**Parameters:** +- -| Name | Type | Description | -|------|------|-------------| -| `name` | *str* | | --| `predict_url` | *str* | | +-| `display_name` | *str* | | -| `options` | *[RPCOptions](#truss_chains.RPCOptions* | | -- -- --#### name *: str* -- --#### options *: [RPCOptions](#truss_chains.RPCOptions)* -- +-| `predict_url` | *str* | | ++Bundles values to establish an RPC session to a dependency chainlet, ++specifically with `StubBase`. + ++**Parameters:** + -#### predict_url *: str* +| Name | Type | Description | +|---------------|------------------------------------------------|-------------| +| `name` | *str* | | +| `predict_url` | *str* | | +| `options` | *[RPCOptions](#class-truss-chains-rpcoptions)* | | ++| `predict_url` | *str* | | --### *class* `truss_chains.StubBase` -- --Bases: `ABC` -+## *class* `truss_chains.StubBase` - - Base class for stubs that invoke remote chainlets. -@@ -603,17 +521,18 @@ - in user-code for wrapping a deployed truss model into the chains framework, e.g. - like that: + ### *class* `truss_chains.StubBase` +@@ -621,7 +547,7 @@ + in user-code for wrapping a deployed truss model into the Chains framework. It + flexibly supports JSON and pydantic inputs and output. Example usage: -```default +```python import pydantic import truss_chains as chains -+ - class WhisperOutput(pydantic.BaseModel): - ... +@@ -631,19 +557,20 @@ class DeployedWhisper(chains.StubBase): - ++ + # Input JSON, output JSON. +- async def run_remote(self, audio_b64: str) -> Any: ++ async def run_remote(self, audio_b64: str) -> Any: + return await self.predict_async( + inputs={"audio": audio_b64}) + # resp == {"text": ..., "language": ...} + + # OR Input JSON, output pydantic model. - async def run_remote(self, audio_b64: str) -> WhisperOutput: + async def run_remote(self, audio_b64: str) -> WhisperOutput: - resp = await self._remote.predict_async( - json_payload={"audio": audio_b64}) - return WhisperOutput(text=resp["text"], language=resp["language"]) -@@ -630,28 +549,24 @@ - ) - ``` + return await self.predict_async( + inputs={"audio": audio_b64}, output_model=WhisperOutput) + + # OR Input and output are pydantic models. +- async def run_remote(self, data: WhisperInput) -> WhisperOutput: ++ async def run_remote(self, data: WhisperInput) -> WhisperOutput: + return await self.predict_async(data, output_model=WhisperOutput) + + +@@ -664,40 +591,37 @@ -- **Parameters:** -| Name | Type | Description | -|------|------|-------------| --| `service_descriptor` | *[ServiceDescriptor](#truss_chains.ServiceDescriptor* | Contains the URL and other configuration. | +-| `service_descriptor` | *[DeployedServiceDescriptor](#truss_chains.DeployedServiceDescriptor* | Contains the URL and other configuration. | -| `api_key` | *str* | A baseten API key to authorize requests. | -- -+| Name | Type | Description | -+|----------------------|--------------------------------------------------------------|-------------------------------------------| -+| `service_descriptor` | *[ServiceDescriptor](#class-truss-chains-servicedescriptor)* | Contains the URL and other configuration. | -+| `api_key` | *str* | A baseten API key to authorize requests. | ++| Name | Type | Description | ++|----------------------|-------------------------------------------------------------------------------|-------------------------------------------| ++| `service_descriptor` | *[DeployedServiceDescriptor](#class-truss-chains-deployedservicedescriptor)]* | Contains the URL and other configuration. | ++| `api_key` | *str* | A baseten API key to authorize requests. | + #### *classmethod* from_url(predict_url, context, options=None) @@ -686,22 +716,43 @@ -| `predict_url` | *str* | URL to predict endpoint of another chain / truss model. | -| `context` | *[DeploymentContext](#truss_chains.DeploymentContext* | Deployment context object, obtained in the chainlet’s `__init__`. | -| `options` | *[RPCOptions](#truss_chains.RPCOptions* | RPC options, e.g. retries. | -- +| Name | Type | Description | +|---------------|--------------------------------------------------------------|-------------------------------------------------------------------| +| `predict_url` | *str* | URL to predict endpoint of another chain / truss model. | +| `context` | *[DeploymentContext](#class-truss-chains-deploymentcontext)* | Deployment context object, obtained in the chainlet’s `__init__`. | +| `options` | *[RPCOptions](#class-truss-chains-rpcoptions)* | RPC options, e.g. retries. | - ### *class* `truss_chains.RemoteErrorDetail` ++#### *async* predict_async(inputs: PydanticModel, output_model: Type[PydanticModel]) → PydanticModel + +-#### *async* predict_async(inputs: InputT, output_model: Type[OutputModelT]) → OutputModelT ++#### *async* predict_async(inputs: JSON, output_model: Type[PydanticModel]) → PydanticModel -@@ -663,20 +578,13 @@ +-#### *async* predict_async(inputs: InputT, output_model: None = None) → Any ++#### *async* predict_async(inputs: JSON) → JSON + +-#### *async* predict_async_stream(inputs) ++#### *async* predict_async_stream(inputs: PydanticModel | JSON) -> AsyncIterator[bytes] + +-* **Parameters:** +- **inputs** (*InputT*) +-* **Return type:** +- *AsyncIterator*[bytes] ++#### predict_sync(inputs: PydanticModel, output_model: Type[PydanticModel]) → PydanticModel + +-#### predict_sync(inputs: InputT, output_model: Type[OutputModelT]) → OutputModelT ++#### predict_sync(inputs: JSON, output_model: Type[PydanticModel]) → PydanticModel + +-#### predict_sync(inputs: InputT, output_model: None = None) → Any ++#### predict_sync(inputs: JSON) → JSON + + + ### *class* `truss_chains.RemoteErrorDetail` +@@ -710,19 +634,12 @@ **Parameters:** -| Name | Type | Description | -|------|------|-------------| --| `remote_name` | *str* | | -| `exception_cls_name` | *str* | | -| `exception_module_name` | *str\|None* | | -| `exception_message` | *str* | | @@ -715,7 +766,6 @@ -#### exception_module_name *: str | None* +| Name | Type | Description | +|-------------------------|--------------------|-------------| -+| `remote_name` | *str* | | +| `exception_cls_name` | *str* | | +| `exception_module_name` | *str\|None* | | +| `exception_message` | *str* | | @@ -723,11 +773,9 @@ #### format() -@@ -685,7 +593,3 @@ +@@ -731,5 +648,3 @@ * **Return type:** str - --#### remote_name *: str* -- -#### user_stack_trace *: list[StackFrame]* diff --git a/truss-chains/truss_chains/definitions.py b/truss-chains/truss_chains/definitions.py index 7de3c5188..c9e051161 100644 --- a/truss-chains/truss_chains/definitions.py +++ b/truss-chains/truss_chains/definitions.py @@ -405,7 +405,7 @@ class RPCOptions(SafeModel): if the request fails before streaming any results back. Failures mid-stream not retried. timeout_sec: Timeout for the HTTP request to this chainlet. - use_binary: whether to send data data in binary format. This can give a parsing + use_binary: Whether to send data in binary format. This can give a parsing speedup and message size reduction (~25%) for numpy arrays. Use ``NumpyArrayField`` as a field type on pydantic models for integration and set this option to ``True``. For simple text data, there is no significant benefit. @@ -451,7 +451,7 @@ class DeploymentContext(SafeModelNonSerializable): data_dir: The directory where the chainlet can store and access data, e.g. for downloading model weights. chainlet_to_service: A mapping from chainlet names to service descriptors. - This is used create RPCs sessions to dependency chainlets. It contains only + This is used to create RPC sessions to dependency chainlets. It contains only the chainlet services that are dependencies of the current chainlet. secrets: A mapping from secret names to secret values. It contains only the secrets that are listed in ``remote_config.assets.secret_keys`` of the diff --git a/truss-chains/truss_chains/public_api.py b/truss-chains/truss_chains/public_api.py index 47d0db950..d0d64c163 100644 --- a/truss-chains/truss_chains/public_api.py +++ b/truss-chains/truss_chains/public_api.py @@ -63,7 +63,7 @@ def depends( if the request fails before streaming any results back. Failures mid-stream not retried. timeout_sec: Timeout for the HTTP request to this chainlet. - use_binary: whether to send data data in binary format. This can give a parsing + use_binary: Whether to send data in binary format. This can give a parsing speedup and message size reduction (~25%) for numpy arrays. Use ``NumpyArrayField`` as a field type on pydantic models for integration and set this option to ``True``. For simple text data, there is no significant benefit. diff --git a/truss-chains/truss_chains/remote_chainlet/stub.py b/truss-chains/truss_chains/remote_chainlet/stub.py index 177c87ada..81bd940a0 100644 --- a/truss-chains/truss_chains/remote_chainlet/stub.py +++ b/truss-chains/truss_chains/remote_chainlet/stub.py @@ -36,8 +36,8 @@ _RetryPolicyT = TypeVar("_RetryPolicyT", tenacity.AsyncRetrying, tenacity.Retrying) -_InputT = TypeVar("_InputT", pydantic.BaseModel, Any) # Any signifies "JSON". -_OutputT = TypeVar("_OutputT", bound=pydantic.BaseModel) +InputT = TypeVar("InputT", pydantic.BaseModel, Any) # Any signifies "JSON". +OutputModelT = TypeVar("OutputModelT", bound=pydantic.BaseModel) _trace_parent_context: contextvars.ContextVar[str] = contextvars.ContextVar( @@ -177,22 +177,32 @@ class StubBase(BasetenSession, abc.ABC): and invoking other endpoints. It is used internally for RPCs to dependency chainlets, but it can also be used - in user-code for wrapping a deployed truss model into the chains framework, e.g. - like that:: + in user-code for wrapping a deployed truss model into the Chains framework. It + flexibly supports JSON and pydantic inputs and output. Example usage:: import pydantic import truss_chains as chains + class WhisperOutput(pydantic.BaseModel): ... class DeployedWhisper(chains.StubBase): + # Input JSON, output JSON. + async def run_remote(self, audio_b64: str) -> Any: + return await self.predict_async( + inputs={"audio": audio_b64}) + # resp == {"text": ..., "language": ...} + # OR Input JSON, output pydantic model. async def run_remote(self, audio_b64: str) -> WhisperOutput: - resp = await self.predict_async( - json_payload={"audio": audio_b64}) - return WhisperOutput(text=resp["text"], language=resp["language"]) + return await self.predict_async( + inputs={"audio": audio_b64}, output_model=WhisperOutput) + + # OR Input and output are pydantic models. + async def run_remote(self, data: WhisperInput) -> WhisperOutput: + return await self.predict_async(data, output_model=WhisperOutput) class MyChainlet(chains.ChainletBase): @@ -205,6 +215,8 @@ def __init__(self, ..., context=chains.depends_context()): options=chains.RPCOptions(retries=3), ) + async def run_remote(self, ...): + await self._whisper.run_remote(...) """ @final @@ -246,7 +258,7 @@ def from_url( ) def _make_request_params( - self, inputs: _InputT, for_httpx: bool = False + self, inputs: InputT, for_httpx: bool = False ) -> Mapping[str, Any]: kwargs: Dict[str, Any] = {} headers = { @@ -273,8 +285,8 @@ def _make_request_params( return kwargs def _response_to_pydantic( - self, response: bytes, output_model: Type[_OutputT] - ) -> _OutputT: + self, response: bytes, output_model: Type[OutputModelT] + ) -> OutputModelT: if self._service_descriptor.options.use_binary: data_dict = serialization.truss_msgpack_deserialize(response) return output_model.model_validate(data_dict) @@ -287,15 +299,15 @@ def _response_to_json(self, response: bytes) -> Any: @overload def predict_sync( - self, inputs: _InputT, output_model: Type[_OutputT] - ) -> _OutputT: ... + self, inputs: InputT, output_model: Type[OutputModelT] + ) -> OutputModelT: ... @overload # Returns JSON - def predict_sync(self, inputs: _InputT, output_model: None = None) -> Any: ... + def predict_sync(self, inputs: InputT, output_model: None = None) -> Any: ... def predict_sync( - self, inputs: _InputT, output_model: Optional[Type[_OutputT]] = None - ) -> Union[_OutputT, Any]: + self, inputs: InputT, output_model: Optional[Type[OutputModelT]] = None + ) -> Union[OutputModelT, Any]: retry = self._make_retry_policy(tenacity.Retrying) params = self._make_request_params(inputs, for_httpx=True) @@ -313,17 +325,15 @@ def _rpc() -> bytes: @overload async def predict_async( - self, inputs: _InputT, output_model: Type[_OutputT] - ) -> _OutputT: ... + self, inputs: InputT, output_model: Type[OutputModelT] + ) -> OutputModelT: ... @overload # Returns JSON. - async def predict_async( - self, inputs: _InputT, output_model: None = None - ) -> Any: ... + async def predict_async(self, inputs: InputT, output_model: None = None) -> Any: ... async def predict_async( - self, inputs: _InputT, output_model: Optional[Type[_OutputT]] = None - ) -> Union[_OutputT, Any]: + self, inputs: InputT, output_model: Optional[Type[OutputModelT]] = None + ) -> Union[OutputModelT, Any]: retry = self._make_retry_policy(tenacity.AsyncRetrying) params = self._make_request_params(inputs) @@ -341,7 +351,7 @@ async def _rpc() -> bytes: return self._response_to_pydantic(response_bytes, output_model) return self._response_to_json(response_bytes) - async def predict_async_stream(self, inputs: _InputT) -> AsyncIterator[bytes]: + async def predict_async_stream(self, inputs: InputT) -> AsyncIterator[bytes]: retry = self._make_retry_policy(tenacity.AsyncRetrying) params = self._make_request_params(inputs)