diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9f93f131..826b6bb5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,51 +1,53 @@ -# Contributing to TT-STUDIO +# **Contributing to TT-STUDIO** -Thank you for your interest in this project! We want to make contributing as easy and transparent as possible. +Thank you for your interest in this project! We aim to make the contribution process as easy and transparent as possible. -If you're interested in making a contribution, please familiarize yourself with our technical [contribution standards](#contribution-standards) outlined in this guide. +If youβd like to contribute or have suggestions , please familiarize yourself with technical [contribution standards](#contribution-standards) outlined in this guide. -Next, request the appropriate write permissions by [opening an issue](https://github.com/tenstorrent/tt-studio/issues/new/choose) for GitHub permissions. +--- + +## **Contribution Requirements** + +- **Issue Tracking:** -All contributions require: + - File a feature request or bug report in the Issues section to notify maintainers. -- An issue: - - Please file a feature request or bug report under the Issues section to help get the attention of a maintainer. -- A pull request (PR). -- Your PR must be approved by the appropriate reviewers. +- **Pull Requests (PRs):** + - All changes must be submitted via a PR. + - PRs require approval from the appropriate reviewers before merging. + +--- -## Contribution Standards +## **Contribution Standards** -### Code Reviews +### **Code Reviews** We actively welcome your pull requests! To ensure quality contributions, any code change must meet the following criteria: - A PR must be opened and approved by: - A maintaining team member. - Any codeowners whose modules are relevant to the PR. -- Run pre-commit hooks. -- Pass all acceptance criteria mandated in the original issue. -- Pass the automated GitHub Actions workflow tests. -- Pass any testing requirements specified by the relevant codeowners. +- Run **pre-commit hooks**. +- Pass all **acceptance criteria** mandated in the original issue. +- Pass the **automated GitHub Actions workflow tests**. +- Pass any **testing requirements** specified by the relevant codeowners. -### **Git Branching Strategy Overview** +--- -### **1. Main Branches:** +## **Git Branching Strategy Overview** -- **`main`**: Holds production-ready code. +### **1. Main Branches** - - **Rules:** +- **`main`** β Holds production-ready tagged code. + - **Rules:** - No force pushes. + - Requires **rebase and merge** or **squash and merge** from a release cut branch. - - Requires rebase and merge. - -- **`dev`**: The central branch where all feature branches are merged and validated before preparing for release cut branch. - +- **`dev`** β The central branch where all feature branches are merged and validated before preparing a release branch. - **Rules:** - - No force pushes. - - - Requires squash merge. + - Requires **squash merge** from a feature branch. --- @@ -53,54 +55,75 @@ We actively welcome your pull requests! To ensure quality contributions, any cod #### **Development Process** -- Developers create feature branches from `main` to work on new features or bug fixes. -- Once a feature is completed and reviewed, it is **squash merged** into `dev` to maintain a clean history. +- **Feature Branches:** + + - Created from `dev`. + - **Naming convention:** `dev-name/feature` or `dev/github-issue-number`. + - Example: `dev-john/new-feature` or `dev-john/1234`. + +- **Merging to `dev`:** + - Once completed and reviewed, feature branches are **squash merged** into `dev` to maintain a clean history. --- +### **3. Release Process** + #### **Release Preparation** -- When `dev` is stable and ready for release, a **release cut branch** (e.g., `release-v1.xxx`) is created from `dev`. -- Developers **cherry-pick** their completed and validated features from `dev` into the release branch. -- The release branch is tested before deployment. +- **Creating the Release Branch:** + + - When `dev` is stable and ready for release, a **release cut branch** is created from `main`. + - **Naming convention:** `rc-vx.x.x` (e.g., `rc-v1.2.0`). + +- **Feature Inclusion:** + + - Developers **cherry-pick** validated features from `dev` into the release branch. + - Test and resolve any **merge conflicts** as needed. + +- **Testing & Fixes:** + - The release branch undergoes **testing before deployment**. + - **Bug fixes and PR comments** follow the standard development flow and can be cherry-picked into the same release branch. --- #### **Final Deployment** -- Once the release branch is validated, it is merged into `main` for production deployment. -- Merging to `main` requires **at least two approvals** to ensure code quality and stability. -- After merging, the release is tagged following semantic versioning (e.g., `v1.0.0`). +- After validation, the **release branch is merged into `main`** for production. +- **At least two approvals** are required for merging to `main` to ensure quality. +- Merging can be done via **rebase and merge** or **squash and merge** if multiple commits were cherry-picked. +- The release is **tagged following semantic versioning** (e.g., `v1.0.0`). --- -#### Git Tagging - -- Tags are created in main to mark production releases. +### **4. Git Tagging** -- [Semantic versioning (e.g., v1.0.0) is used to track different versions.](#versioning-standards) +- **Tags are created in `main`** to mark production releases. +- **Semantic versioning** (e.g., `v1.0.0`) is used to track different versions. --- -### Versioning Standards +## **Versioning Standards** -To ensure consistency in versioning, we follow the principles of **semantic versioning**: +To ensure consistency, we follow **semantic versioning** principles: -- **MAJOR**: Increment for incompatible or breaking changes to backend or frontend APIs or functionality: +- **MAJOR**: Increment for **breaking changes** to backend or frontend APIs or functionality. - Removing or significantly altering existing features. - - Changing the current networking design. - - Altering backend API flows. - - Changing frontend API calls and/or redoing entire components. + - Changing the networking design. + - Modifying backend API flows. + - Redesigning frontend API calls or components. -- **MINOR**: Increment when adding new features or capabilities in a backward-compatible manner: +- **MINOR**: Increment for **new features** or capabilities that are **backward-compatible**. - - For example, supporting new models like YOLOv4 or adding new functionalities. + - Example: Supporting new models like YOLOv4 or adding additional functionalities. + - If the current version is `1.2.3` and a new **minor** release is introduced, it becomes **`1.3.0`**. + - If additional patches are needed after `1.3.0`, the version will increment to **`1.3.1`**, **`1.3.2`**, and so on. -- **PATCH**: Increment for bug fixes and minor improvements that are backward-compatible. +- **PATCH**: Increment for **bug fixes and minor improvements** that are **backward-compatible**. + - If patches are applied to `1.2.3`, the next versions would be `1.2.4`, `1.2.5`, etc. --- -### FlowChart for our Git Branching Strategy +## **Git Branching Strategy Flowchart** diff --git a/Git-management.png b/Git-management.png index ebace6ff..5ce6266c 100644 Binary files a/Git-management.png and b/Git-management.png differ diff --git a/HowToRun_vLLM_Models.md b/HowToRun_vLLM_Models.md index fa30371d..8b19c6a5 100644 --- a/HowToRun_vLLM_Models.md +++ b/HowToRun_vLLM_Models.md @@ -1,6 +1,10 @@ -# Running Llama3.1-70B and Mock vLLM Models in TT-Studio +# Running Llama and Mock vLLM Models in TT-Studio -This guide provides step-by-step instructions on setting up and deploying vLLM Llama3.1-70B and vLLM Mock models using TT-Studio. +This guide walks you through setting up vLLM Llama models and vLLM Mock models via the TT-Inference-Server, and then deploying them via TT-Studio. + +## Supported Models + +For the complete and up-to-date list of models supported by TT-Studio via TT-Inference-Server, please refer to [TT-Inference-Server GitHub README](https://github.com/tenstorrent/tt-inference-server/blob/main/README.md). --- @@ -8,9 +12,8 @@ This guide provides step-by-step instructions on setting up and deploying vLLM L 1. **Docker**: Make sure Docker is installed on your system. Follow the [Docker installation guide](https://docs.docker.com/engine/install/). -2. **Hugging Face Token**: Both models require authentication to Hugging Face repositories. To obtain a token, go to [Hugging Face Account](https://huggingface.co/settings/tokens) and generate a token. Additionally; make sure to accept the terms and conditions on Hugging Face for the Llama3.1 models by visiting [Hugging Face Meta-Llama Page](https://huggingface.co/meta-llama/Llama-3.1-70B-Instruct). +2. **Hugging Face Token**: Both models require authentication to Hugging Face repositories. To obtain a token, go to [Hugging Face Account](https://huggingface.co/settings/tokens) and generate a token. Additionally; make sure to accept the terms and conditions on Hugging Face for the the desired model(s). -3. **Model Access Weight**: To access specific models like Llama3.1, you may need to register with Meta to obtain download links for model weights. Visit [Llama Downloads](https://www.llama.com/llama-downloads/) for more information. --- ## Instructions Overview @@ -19,15 +22,14 @@ This guide provides step-by-step instructions on setting up and deploying vLLM L 1. [Clone repositories](#1-clone-required-repositories) 2. [Pull the mock model Docker image](#2-pull-the-desired-model-docker-images-using-docker-github-registry) 3. [Set up the Hugging Face (HF) token](#3-set-up-environment-variables-and-hugging-face-token) -4. [Run the mock vLLM model via the GUI](#7-deploy-and-run-the-model) +4. [Deploy and run inference for the model via the GUI](#5-deploy-and-run-the-model) -### **For vLLM Llama3.1-70B Model:** -1. [Clone repositories](#1-clone-required-repositories) -2. [Pull the model Docker image](#2-pull-the-desired-model-docker-images-using-docker-github-registry) -3. [Set up the Hugging Face (HF) token in the TT-Studio `.env` file](#3-set-up-environment-variables-and-hugging-face-token) -4. [Run the model setup script](#4-run-the-setup-script-vllm-llama31-70b-only) -5. [Update the vLLM Environment Variable in Environment File](#6-add-the-vllm-environment-variable-in-environment-file--copy-the-file-over-to-tt-studio-persistent-volume) -6. [Deploy and run inference for the Llama3.1-70B model via the GUI](#7-deploy-and-run-the-model) + +### **For vLLM Llama Model(s):** +1. [Clone repositories](#1-clone-required-repositories) +2. [Pull the model Docker image](#2-pull-the-desired-model-docker-images-using-docker-github-registry) +3. [Run the model setup script](#4-run-the-setup-script) +4. [Deploy and run inference for the model via the GUI](#6-deploy-and-run-the-model) --- @@ -55,128 +57,106 @@ git clone https://github.com/tenstorrent/tt-inference-server 1. **Navigate to the Docker Images:** - Visit [TT-Inference-Server GitHub Packages](https://github.com/orgs/tenstorrent/packages?repo_name=tt-inference-server). -2. **Pull the Docker Image:** +2. **Pull the Desired Model Docker Image:** ```bash - docker pull ghcr.io/tenstorrent/tt-inference-server: + docker pull ghcr.io/tenstorrent/tt-inference-server/:: ``` -3. **Authenticate Your Terminal (Optional):** +3. **Authenticate Your Terminal (Optional - If Pull Command Fails)):** ```bash echo YOUR_PAT | docker login ghcr.io -u YOUR_USERNAME --password-stdin ``` - + --- + ## 3. Set Up Environment Variables and Hugging Face Token -## 3. Set Up Environment Variables and Hugging Face Token - -Add the Hugging Face Token within the `.env` file in the `tt-studio/app/` directory. - -```bash -HF_TOKEN=hf_******** -``` + Add the Hugging Face Token within the `.env` file in the `tt-studio/app/` directory. + ```bash + HF_TOKEN=hf_******** + ``` --- -## 4. Run the Setup Script (vLLM Llama3.1-70B only) +## 4. Run the Setup Script -Follow these step-by-step instructions for a smooth automated process of model weights setup. +Follow these step-by-step instructions to smoothly automate the process of setting up model weights. -1. **Navigate to the `vllm-tt-metal-llama3-70b/` folder** within the `tt-inference-server`. This folder contains the necessary files and scripts for model setup. +1. **Create the `tt_studio_persistent_volume` folder** + - Either create this folder manually inside `tt-studio/`, or run `./startup.sh` from within `tt-studio` to have it created automatically. -2. **Run the automated setup script** as outlined in the [official documentation](https://github.com/tenstorrent/tt-inference-server/tree/main/vllm-tt-metal-llama3-70b#5-automated-setup-environment-variables-and-weights-files:~:text=70b/docs/development-,5.%20Automated%20Setup%3A%20environment%20variables%20and%20weights%20files,-The%20script%20vllm). This script handles key steps such as configuring environment variables, downloading weight files, repacking weights, and creating directories. +2. **Ensure folder permissions** + - Verify that you (the user) have permission to edit the newly created folder. If not, adjust ownership or permissions using commands like `chmod` or `chown`. -**Note** During the setup process, you will see the following prompt: +3. **Navigate to `tt-inference-server`** + - Consult the [README](https://github.com/tenstorrent/tt-inference-server?tab=readme-ov-file#model-implementations) to see which model servers are supported by TT-Studio. - ``` - Enter your PERSISTENT_VOLUME_ROOT [default: tt-inference-server/tt_inference_server_persistent_volume]: - ``` +4. **Run the automated setup script** - **Do not accept the default path.** Instead, set the persistent volume path to `tt-studio/tt_studio_persistent_volume`. This ensures the configuration matches TT-Studioβs directory structure. Using the default path may result in incorrect configuration. + - **Execute the script** + Navigate to `tt-inference-server`, run: + ```bash + ./setup.sh **Model** + ``` + + - **Choose how to provide the model** + You will see: + ``` + How do you want to provide a model? + 1) Download from π€ Hugging Face (default) + 2) Download from Meta + 3) Local folder + Enter your choice: + ``` + For first-time users, we recommend **option 1** (Hugging Face). -By following these instructions, you will have a properly configured model infrastructure, ready for inference and further development. + - **Next Set `PERSISTENT_VOLUME_ROOT`** + The script will prompt you for a `PERSISTENT_VOLUME_ROOT` path. A default path will be suggested, but **do not accept the default**. Instead, specify the **absolute path** to your `tt-studio/tt_studio_persistent_volume` directory to maintain the correct structure. + Using the default path can lead to incorrect configurations. + - **Validate token and set environment variables** + The script will: + 1. Validate your Hugging Face token (`HF_TOKEN`). + 2. Prompt you for an `HF_HOME` location (default is often `~/.cache/huggingface`). + 3. Ask for a JWT secret, which should match the one in `tt-studio/app/.env` (commonly `test-secret-456`). +By following these steps, your tt-inference-server model infrastructure will be correctly configured and ready for inference via the TT-Studio GUI. --- ## 5. Folder Structure for Model Weights -Verify that the weights are correctly stored in the following structure: - -```bash -/path/to/tt-studio/tt_studio_persistent_volume/ -βββ volume_id_tt-metal-llama-3.1-70b-instructv0.0.1/ - βββ layers_0-4.pth - βββ layers_5-9.pth - βββ params.json - βββ tokenizer.model -``` - -**What to Look For:** -- Ensure all expected weight files (e.g., `layers_0-4.pth`, `params.json`, `tokenizer.model`) are present. -- If any files are missing, re-run the `setup.sh` script to complete the download. - -This folder structure allows TT Studio to automatically recognize and access models without further configuration adjustments. For each model, verify that the weights are correctly copied to this directory to ensure proper access by TT Studio. - - -## 6. Copy the Environment File and Point to it in TT-Studio - -### Step 1: Copy the Environment File -During the model weights download process, an `.env` file will be automatically created. The path to the `.env` file might resemble the following example: - -``` -/path/to/tt-inference-server/vllm-tt-metal-llama3-70b/.env -``` - -To ensure the model can be deployed via the TT-Studio GUI, this `.env` file must be copied to the model's persistent storage location. For example: - -```bash -/path/to/tt_studio_persistent_volume/volume_id_tt-metal-llama-3.1-70b-instructv0.0.1/copied_env -``` - -The following command can be used as a reference (*replace paths as necessary*): - -```bash -sudo cp /$USR/tt-inference-server/vllm-tt-metal-llama3-70b/.env /$USR/tt_studio/tt_studio_persistent_volume/volume_id_tt-metal-llama-3.1-70b-instructv0.0.1/.env -``` - -### Step 2: Point to the Copied Environment File -The `VLLM_LLAMA31_ENV_FILE` variable within the TT-Studio `$USR/tt-studio/app/.env` file must point to *this* copied `.env` file. This should be a **relative path**, for example it can be set as follows: - -``` -VLLM_LLAMA31_ENV_FILE="/tt_studio_persistent_volume/volume_id_tt-metal-llama-3.1-70b-instructv0.0.1/.env" -``` ---- +When using the setup script it creates (or updates) specific directories and files within your `tt_studio_persistent_volume` folder. Hereβs what to look for: -### Step 2: Update the TT-Studio Environment File -After copying the `.env` file, update the `VLLM_LLAMA31_ENV_FILE` variable in the `tt-studio/app/.env` file to point to the **copied file path**. This ensures TT-Studio uses the correct environment configuration for the model. +1. **Model Weights Directories** + Verify that the weights are correctly stored in a directory similar to: + ```bash + /path/to/tt-studio/tt_studio_persistent_volume/ + βββ model_envs + β βββ Llama-3.1-70B-Instruct.env + βββ volume_id_tt-metal-llama-3.1-70b-instructv0.0.1/ + βββ layers_0-4.pth + βββ layers_5-9.pth + βββ params.json + βββ tokenizer.model -```bash -VLLM_LLAMA31_ENV_FILE="/path/to/tt_studio_persistent_volume/volume_id_tt-metal-llama-3.1-70b-instructv0.0.1/copied_env" -``` + ``` + - Ensure all expected weight files (e.g., `layers_0-4.pth`, `params.json`, `tokenizer.model`) are present. + - If any files are missing, re-run the `setup.sh` script to complete the download. ---- -Here is an example of a complete `.env` file configuration for reference: +2. **`model_envs` Folder** + Within your `tt_studio_persistent_volume`, you will also find a `model_envs` folder (e.g., `model_envs/Llama-3.1-8B-Instruct.env`). + - Each `.env` file contains the values you input during the setup script run (e.g., `HF_TOKEN`, `HF_HOME`, `JWT_SECRET`). + - Verify that these environment variables match what you entered; if you need to adjust them, re-run the setup process. -```bash -TT_STUDIO_ROOT=/Users/**username**/tt-studio -HOST_PERSISTENT_STORAGE_VOLUME=${TT_STUDIO_ROOT}/tt_studio_persistent_volume -INTERNAL_PERSISTENT_STORAGE_VOLUME=/tt_studio_persistent_volume -BACKEND_API_HOSTNAME="tt-studio-backend-api" -VLLM_LLAMA31_ENV_FILE="/path/to/tt_studio_persistent_volume/volume_id_tt-metal-llama-3.1-70b-instructv0.0.1/**copied_env -# SECURITY WARNING: keep these secret in production! -JWT_SECRET=test-secret-456 -DJANGO_SECRET_KEY=django-insecure-default -HF_TOKEN=hf_**** -``` +This folder and file structure allows TT-Studio to automatically recognize and access models without any additional configuration steps. --- -## 7. Deploy and Run the Model +## 6. Deploy and Run the Model 1. **Start TT-Studio:** Run TT-Studio using the startup command. 2. **Access Model Weights:** In the TT-Studio interface, navigate to the model weights section. -3. **Select Custom Weights:** Use the custom weights option to select the weights for Llama3.1-70B. +3. **Select Weights:** Select the model weights. 4. **Run the Model:** Start the model and wait for it to initialize. --- @@ -242,6 +222,30 @@ curl -s --no-buffer -X POST "http://localhost:7000/v1/chat/completions" -H "Cont If successful, you will receive a response from the model. +#### iv. Sample Command for Changing Ownership (chown) + +If you need to adjust permissions for the `tt_studio_persistent_volume` folder, first determine your user and group IDs by running: (*replace paths as necessary*) + +```bash +id +``` + +You will see an output similar to: + +``` +uid=1001(youruser) gid=1001(yourgroup) groups=... +``` + +Use these numeric IDs to set the correct ownership. For example: + +```bash +sudo chown -R 1001:1001 /home/youruser/tt-studio/tt_studio_persistent_volume/ +``` + +Replace `1001:1001` with your actual UID:GID and `/home/youruser/tt-studio/tt_studio_persistent_volume/` with the path to your persistent volume folder. + + + ## You're All Set π With the setup complete, youβre ready to run inference on the vLLM models (or any other supported model(s)) within TT-Studio. Refer to the documentation and setup instructions in the repositories for further guidance. \ No newline at end of file diff --git a/README.md b/README.md index e4a2cf4d..ae1ec056 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,11 @@ TT-Studio enables rapid deployment of TT Inference servers locally and is optimi 1. [Prerequisites](#prerequisites) 2. [Overview](#overview) -3. [Quick Start](#quick-start) - - [For General Users](#for-general-users) +3. [Quick Start](#quick-start) + - [For General Users](#for-general-users) + - Clone the Repository + - Set Up the Model Weights. + - Run the App via `startup.sh` - [For Developers](#for-developers) 4. [Using `startup.sh`](#using-startupsh) - [Basic Usage](#basic-usage) @@ -16,7 +19,7 @@ TT-Studio enables rapid deployment of TT Inference servers locally and is optimi 5. [Documentation](#documentation) - [Frontend Documentation](#frontend-documentation) - [Backend API Documentation](#backend-api-documentation) - - [Running Llama3.1-70B in TT-Studio](#running-llama31-70b-in-tt-studio) + - [Running vLLM Models in TT-Studio]) --- @@ -36,8 +39,11 @@ To set up TT-Studio: git clone https://github.com/tenstorrent/tt-studio.git cd tt-studio ``` +2. **Choose and Set Up the Model**: -2. **Run the Startup Script**: + Select your desired model and configure its corresponding weights by following the instructions in [HowToRun_vLLM_Models.md](./HowToRun_vLLM_Models.md). + +3. **Run the Startup Script**: Run the `startup.sh` script: @@ -47,16 +53,16 @@ To set up TT-Studio: #### See this [section](#command-line-options) for more information on command-line arguments available within the startup script. -3. **Access the Application**: +4. **Access the Application**: The app will be available at [http://localhost:3000](http://localhost:3000). -4. **Cleanup**: +5. **Cleanup**: - To stop and remove Docker services, run: ```bash ./startup.sh --cleanup ``` -5. Running on a Remote Machine +6. Running on a Remote Machine To forward traffic between your local machine and a remote server, enabling you to access the frontend application in your local browser, follow these steps: @@ -70,28 +76,60 @@ To set up TT-Studio: > β οΈ **Note**: To use Tenstorrent hardware, during the run of `startup.sh` script, select "yes" when prompted to mount hardware. This will automatically configure the necessary settings, eliminating manual edits to docker compose.yml. --- -### For Developers +## Running in Development Mode + +Developers can control and run the app directly via `docker compose`, keeping this running in a terminal allows for hot reload of the frontend app. + +1. **Start the Application**: -Developers can control and run the app directly via `docker compose`, keeping this running in a terminal allows for hot reload of the frontend app. For any backend changes its advisable to re restart the services. + Navigate to the project directory and start the application: -1. **Run in Development Mode**: + ```bash + cd tt-studio/app + docker compose up --build + ``` - ```bash - cd tt-studio/app - docker compose up --build - ``` + Alternatively, run the backend and frontend servers interactively: -2. **Stop the Services**: + ```bash + docker compose up + ``` + + To force a rebuild of Docker images: + + ```bash + docker compose up --build + ``` - ```bash - docker compose down - ``` +2. **Hot Reload & Debugging**: + + #### Frontend + - The frontend supports hot reloading when running inside the `docker compose` environment. + - Ensure that the required lines (**71-73**) in `docker-compose.yml` are uncommented. + + #### Backend + - Local files in `./api` are mounted to `/api` within the container for development. + - Code changes trigger an automatic rebuild and redeployment of the Django server. + - To manually start the Django development server: + + ```bash + ./manage.py runserver 0.0.0.0:8000 + ``` + +3. **Stopping the Services**: + + To shut down the application and remove running containers: + + ```bash + docker compose down + ``` -3. **Using the Mock vLLM Model**: - - For local testing, you can use the `Mock vLLM` model, which spits out random set of characters back . Instructions to run it are [here](HowToRun_vLLM_Models.md) +4. **Using the Mock vLLM Model**: + - For local testing, you can use the `Mock vLLM` model, which generates a random set of characters as output. + - Instructions to run it are available [here](./HowToRun_vLLM_Models.md). -4. **Running on a Machine with Tenstorrent Hardware**: +5. **Running on a Machine with Tenstorrent Hardware**: To run TT-Studio on a device with Tenstorrent hardware, you need to uncomment specific lines in the `app/docker-compose.yml` file. Follow these steps: @@ -159,7 +197,7 @@ If a Tenstorrent device (`/dev/tenstorrent`) is detected, the script will prompt - **Backend API Documentation**: [app/api/README.md](app/api/README.md) Information on the backend API, powered by Django Rest Framework, including available endpoints and integration details. -- **Running vLLM Llama3.1-70B and vLLM Mock Model(s) in TT-Studio**: [HowToRun_vLLM_Models.md](HowToRun_vLLM_Models.md) +- **Running vLLM Model(s) and Mock vLLM Model in TT-Studio**: [HowToRun_vLLM_Models.md](HowToRun_vLLM_Models.md) Step-by-step instructions on how to configure and run the vLLM model(s) using TT-Studio. - **Contribution Guide**: [CONTRIBUTING.md](CONTRIBUTING.md) diff --git a/app/.env.default b/app/.env.default index db81d8fe..e450fabf 100644 --- a/app/.env.default +++ b/app/.env.default @@ -6,3 +6,4 @@ VLLM_LLAMA31_ENV_FILE="" # SECURITY WARNING: keep these secret in production! JWT_SECRET=test-secret-456 DJANGO_SECRET_KEY=django-insecure-default +TAVILY_API_KEY= \ No newline at end of file diff --git a/app/api/agent_control/Agent_flow.png b/app/api/agent_control/Agent_flow.png new file mode 100644 index 00000000..6f931533 Binary files /dev/null and b/app/api/agent_control/Agent_flow.png differ diff --git a/app/api/agent_control/Dockerfile b/app/api/agent_control/Dockerfile new file mode 100644 index 00000000..4120b347 --- /dev/null +++ b/app/api/agent_control/Dockerfile @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# SPDX-FileCopyrightText: Β© 2025 Tenstorrent AI ULC + +# install ubutntu base image +FROM ubuntu:20.04 +ENV TZ=America/Los_Angeles +ARG DEBIAN_FRONTEND=noninteractive + +# Update the package repository and install some default tools +RUN apt-get update && apt-get install -y \ + vim \ + nano \ + software-properties-common \ + git \ + htop \ + screen \ + tmux \ + unzip \ + zip \ + curl \ + wget + +# add deadsnakes for newer python versions +RUN add-apt-repository ppa:deadsnakes/ppa -y && apt-get update +# Install the specific version of Python +RUN apt-get install -y python3.11 python3.11-venv python3.11-dev + +# Set Python3.11 as the default Python3 +RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.11 1 + +# Ensure pip is installed and upgrade it +RUN python3 -m ensurepip --upgrade && \ + python3 -m pip install --upgrade pip setuptools wheel + +# Verify the Python version +RUN python3 --version + +# Set the working directory +WORKDIR /app + +# install python dependencies +COPY ./requirements.txt /app +RUN pip install -r requirements.txt --no-cache-dir + +# Copy files (optional) +COPY . /app + +# Command to run when the container starts +CMD ["/bin/bash"] diff --git a/app/api/agent_control/README.md b/app/api/agent_control/README.md new file mode 100644 index 00000000..4c23408c --- /dev/null +++ b/app/api/agent_control/README.md @@ -0,0 +1,13 @@ +# AI Agent + +TT-Studio currently supports a search agent. To use the search agent, pull the following image from Github Container Registry (GCHR). + +```bash +docker pull ghcr.io/tenstorrent/tt-studio/agent_image:v1 +``` + +You will also need to create and add your [Tavily API key](https://tavily.com/) to your `.env` file + +How the agent works is depicted in the visual below. + + diff --git a/app/api/agent_control/agent.py b/app/api/agent_control/agent.py new file mode 100644 index 00000000..fb632a21 --- /dev/null +++ b/app/api/agent_control/agent.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# SPDX-FileCopyrightText: Β© 2025 Tenstorrent AI ULC + +from custom_llm import CustomLLM +from utils import poll_requests, setup_executer +from code_tool import CodeInterpreterFunctionTool +from langchain.memory import ConversationBufferMemory +from langchain_community.tools.tavily_search import TavilySearchResults +import os +import jwt +import json +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from fastapi.responses import StreamingResponse + + + +app = FastAPI() +json_payload = json.loads('{"team_id": "tenstorrent", "token_id":"debug-test"}') +jwt_secret = os.getenv("JWT_SECRET") +encoded_jwt = jwt.encode(json_payload, jwt_secret, algorithm="HS256") + +class RequestPayload(BaseModel): + message: str + thread_id: str + + +llm_container_name = os.getenv("LLM_CONTAINER_NAME") +llm = CustomLLM(server_url=f"http://{llm_container_name}:7000/v1/chat/completions", encoded_jwt=encoded_jwt, + streaming=True) +memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) +os.environ["TAVILY_API_KEY"] = os.getenv("TAVILY_API_KEY") + +search = TavilySearchResults( + max_results=2, + include_answer=True, + include_raw_content=True) +tools = [search] +agent_executer = setup_executer(llm, memory, tools) + +@app.post("/poll_requests") +async def handle_requests(payload: RequestPayload): + config = {"configurable": {"thread_id": payload.thread_id}} + try: + # use await to prevent handle_requests from blocking, allow other tasks to execute + return StreamingResponse(poll_requests(agent_executer, config, tools, memory, payload.message), media_type="text/plain") + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +# health check +@app.get("/") +def read_root(): + return {"message": "Server is running"} diff --git a/app/api/agent_control/code_tool.py b/app/api/agent_control/code_tool.py new file mode 100644 index 00000000..51fa1a89 --- /dev/null +++ b/app/api/agent_control/code_tool.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# SPDX-FileCopyrightText: Β© 2025 Tenstorrent AI ULC + +from pydantic import BaseModel, Field +from typing import List, Sequence, Tuple, Any +from langchain_core.messages import BaseMessage +from langchain.agents.output_parsers.tools import ToolAgentAction +from e2b_code_interpreter import Sandbox +from langchain_core.tools import Tool +from langchain_core.messages import ToolMessage + + +import os +import json + +def format_to_tool_messages( + intermediate_steps: Sequence[Tuple[ToolAgentAction, dict]], +) -> List[BaseMessage]: + messages = [] + for agent_action, observation in intermediate_steps: + if agent_action.tool == CodeInterpreterFunctionTool.tool_name: + new_messages = CodeInterpreterFunctionTool.format_to_tool_message( + agent_action, + observation, + ) + messages.extend([new for new in new_messages if new not in messages]) + else: + # Handle other tools + print("Not handling tool: ", agent_action.tool) + + return messages + +class LangchainCodeInterpreterToolInput(BaseModel): + code: str = Field(description="Python code to execute.") + + +class CodeInterpreterFunctionTool: + """ + This class calls arbitrary code against a Python Jupyter notebook. + It requires an E2B_API_KEY to create a sandbox. + """ + + tool_name: str = "code_interpreter" + + def __init__(self): + # Instantiate the E2B sandbox - this is a long lived object + # that's pinging E2B cloud to keep the sandbox alive. + if "E2B_API_KEY" not in os.environ: + raise Exception( + "Code Interpreter tool called while E2B_API_KEY environment variable is not set. Please get your E2B api key here https://e2b.dev/docs and set the E2B_API_KEY environment variable." + ) + self.code_interpreter = Sandbox(timeout=1800) + + def call(self, parameters: dict, **kwargs: Any): + code = parameters.get("code", "") + if code.startswith("```"): + code = code[3:] + if code.endswith("[DONE]"): + # TODO: check if this needs to be parsed + pass + if code.endswith("```"): + code = code[:-3] + elif code.endswith("```\n"): + code = code[:-4] + print(f"***Code Interpreting...\n{code}\n====") + execution = self.code_interpreter.run_code(code) + return { + "results": execution.results, + "stdout": execution.logs.stdout, + "stderr": execution.logs.stderr, + "error": execution.error, + } + + def close(self): + self.code_interpreter.kill() + + # langchain does not return a dict as a parameter, only a code string + def langchain_call(self, code: str): + return self.call({"code": code}) + + def to_langchain_tool(self) -> Tool: + tool = Tool( + name=self.tool_name, + description="Execute python code in a Jupyter notebook cell and returns any rich data (eg charts), stdout, stderr, and error.", + func=self.langchain_call, + ) + tool.args_schema = LangchainCodeInterpreterToolInput + return tool + + @staticmethod + def format_to_tool_message( + agent_action: ToolAgentAction, + observation: dict, + ) -> List[BaseMessage]: + """ + Format the output of the CodeInterpreter tool to be returned as a ToolMessage. + """ + new_messages = list(agent_action.message_log) + + # TODO: Add info about the results for the LLM + content = json.dumps( + {k: v for k, v in observation.items() if k not in ("results")}, indent=2 + ) + print(observation, agent_action, content) + new_messages.append( + ToolMessage(content=content, tool_call_id=agent_action.tool_call_id) + ) + + return new_messages \ No newline at end of file diff --git a/app/api/agent_control/custom_llm.py b/app/api/agent_control/custom_llm.py new file mode 100644 index 00000000..b0e5970f --- /dev/null +++ b/app/api/agent_control/custom_llm.py @@ -0,0 +1,194 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# SPDX-FileCopyrightText: Β© 2025 Tenstorrent AI ULC + +from pydantic.v1 import BaseModel +from typing import ( + List, + Sequence, + Any, + Optional, + Iterator, + Union, + Dict, + Type, + Callable, + Literal +) + +from langchain_core.language_models import BaseChatModel, LanguageModelInput +from langchain_core.messages import AIMessageChunk, BaseMessage +from langchain_core.callbacks.manager import CallbackManagerForLLMRun +from langchain_core.outputs import ChatGenerationChunk, ChatResult +from langchain_core.tools import BaseTool +from langchain_core.runnables import Runnable +from langchain_core.utils.function_calling import convert_to_openai_tool +from langchain.callbacks.streaming_stdout_final_only import FinalStreamingStdOutCallbackHandler +import requests +import json +import os + + +class CustomLLM(BaseChatModel): + server_url: str + encoded_jwt: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Override the _generate method to implement the chat model logic. + + This can be a call to an API, a call to a local model, or any other + implementation that generates a response to the input prompt. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + If generation stops due to a stop token, the stop token itself + SHOULD BE INCLUDED as part of the output. This is not enforced + across models right now, but it's a good practice to follow since + it makes it much easier to parse the output of the model + downstream and understand why generation stopped. + run_manager: A run manager with callbacks for the LLM. + """ + pass + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = FinalStreamingStdOutCallbackHandler(), + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream the output of the model. + + This method should be implemented if the model can generate output + in a streaming fashion. If the model does not support streaming, + do not implement it. In that case streaming requests will be automatically + handled by the _generate method. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + If generation stops due to a stop token, the stop token itself + SHOULD BE INCLUDED as part of the output. This is not enforced + across models right now, but it's a good practice to follow since + it makes it much easier to parse the output of the model + downstream and understand why generation stopped. + run_manager: A run manager with callbacks for the LLM. + """ + last_message = messages[-1] # take most recent message as input to chat + filled_template = str(last_message.content) + + # code to strucuture template into format llama 3.1 70b chat/completions endpoint exepcts + end_of_template_substring = "Begin!" + position = filled_template.find(end_of_template_substring) + template = "" + user_content = "" + if position != -1: + template = filled_template[:position + len(end_of_template_substring)] + user_content = filled_template[position + len(end_of_template_substring):] + content_position = user_content.find("New input:") + if content_position != -1: + user_content = user_content[content_position:] + # message format for llama 3.1 70b chat endpoint + message_payload = [{"role": "system", "content": template}, + {"role": "user", "content": user_content}] + + + headers = {"Authorization": f"Bearer {self.encoded_jwt}"} + hf_model_path = os.getenv("HF_MODEL_PATH") + json_data = { + "model": hf_model_path, + "messages": message_payload, + "temperature": 1, + "top_k": 20, + "top_p": 0.9, + "max_tokens": 512, + "stream": True, + "stop": ["<|eot_id|>"], + } + with requests.post( + self.server_url, json=json_data, headers=headers, stream=True, timeout=None + ) as response: + for chunk in response.iter_content(chunk_size=None, decode_unicode=True): + new_chunk = chunk[len("data: "):] + new_chunk = new_chunk.strip() + if new_chunk == "[DONE]": + # Yield [DONE] to signal that streaming is complete + new_chunk = ChatGenerationChunk(message=AIMessageChunk(content=new_chunk)) + yield new_chunk + else: + new_chunk = json.loads(new_chunk) + # print(new_chunk) + new_chunk = new_chunk["choices"][0] + # below format is used for v1/completions endpoint + # new_chunk = ChatGenerationChunk(message=AIMessageChunk(content=new_chunk["text"])) + # below format is used for v1/chat/completions endpoint + new_chunk = ChatGenerationChunk(message=AIMessageChunk(content=new_chunk["delta"]["content"])) + # if run_manager: + # This is optional in newer versions of LangChain + # The on_llm_new_token will be called automatically + # run_manager.on_llm_new_token(new_chunk.text, chunk=new_chunk) + # print(f"RUN MANAGER: {run_manager.last_tokens_stripped}") + yield new_chunk + + + def bind_tools( + self, + tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], + *, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "any", "none"], bool] + ] = None, + **kwargs: Any, + ) -> Runnable[LanguageModelInput, BaseMessage]: + """Bind tool-like objects to this chat model. + + Args: + tools: A list of tool definitions to bind to this chat model. + Supports any tool definition handled by + :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`. + tool_choice: Which tool to require the model to call. + Must be the name of the single provided function, + "auto" to automatically determine which function to call + with the option to not call any function, "any" to enforce that some + function is called, or a dict of the form: + {"type": "function", "function": {"name": <>}}. + **kwargs: Any additional parameters to pass to the + :class:`~langchain.runnable.Runnable` constructor. + """ + formatted_tools = [convert_to_openai_tool(tool) for tool in tools] + if tool_choice is not None and tool_choice: + if tool_choice == "any": + tool_choice = "required" + if isinstance(tool_choice, str) and ( + tool_choice not in ("auto", "none", "required") + ): + tool_choice = {"type": "function", "function": {"name": tool_choice}} + if isinstance(tool_choice, bool): + if len(tools) > 1: + raise ValueError( + "tool_choice can only be True when there is one tool. Received " + f"{len(tools)} tools." + ) + tool_name = formatted_tools[0]["function"]["name"] + tool_choice = { + "type": "function", + "function": {"name": tool_name}, + } + + kwargs["tool_choice"] = tool_choice + return super().bind(tools=formatted_tools, **kwargs) + + @property + def _llm_type(self) -> str: + """Get the type of language model used by this chat model. Used for logging purposes only.""" + return "custom" \ No newline at end of file diff --git a/app/api/agent_control/requirements.txt b/app/api/agent_control/requirements.txt new file mode 100644 index 00000000..e1ebef60 --- /dev/null +++ b/app/api/agent_control/requirements.txt @@ -0,0 +1,22 @@ +e2b==1.0.5 +e2b-code-interpreter==1.0.0 +langchain==0.3.11 +langchain-community==0.3.11 +langchain-core==0.3.24 +langchain-experimental==0.3.3 +langchain-text-splitters==0.3.2 +langgraph==0.2.56 +langgraph-checkpoint==2.0.8 +langgraph-sdk==0.1.43 +langsmith==0.1.147 +pydantic==2.10.3 +pydantic-settings==2.6.1 +pydantic_core==2.27.1 +regex==2024.11.6 +requests==2.32.3 +tavily-python==0.5.0 +typing-inspect==0.9.0 +typing_extensions==4.12.2 +uvicorn +fastapi +pyjwt \ No newline at end of file diff --git a/app/api/agent_control/utils.py b/app/api/agent_control/utils.py new file mode 100644 index 00000000..e584a201 --- /dev/null +++ b/app/api/agent_control/utils.py @@ -0,0 +1,80 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# SPDX-FileCopyrightText: Β© 2025 Tenstorrent AI ULC + +from langchain.agents import AgentExecutor, create_react_agent +from langchain import hub + +async def poll_requests(agent_executor, config, tools, memory, message): + complete_output = "" # Initialize an empty string to accumulate output + chat_history = memory.buffer_as_messages + final_answer = False + mainstring = "Final Answer: " + possible_substrings = await gen_substrings(mainstring) + first_final_response = False + async for event in agent_executor.astream_events( + {"input": message, "chat_history": chat_history}, version="v2", config=config +): + kind = event["event"] + if kind == "on_chain_start": + if ( + event["name"] == "Agent" + ): # Was assigned when creating the agent with `.with_config({"run_name": "Agent"})` + print( + f"Starting agent: {event['name']} with input: {event['data'].get('input')}" + ) + # pass + elif kind == "on_chain_end": + if ( + event["name"] == "Agent" + ): # Was assigned when creating the agent with `.with_config({"run_name": "Agent"})` + print() + print("--") + print( + f"Done agent: {event['name']} with output: {event['data'].get('output')['output']}" + ) + if kind == "on_chat_model_stream": + content = event["data"]["chunk"].content + complete_output += content + if "Final Answer: " in complete_output: + final_answer = True + complete_output = "" + first_final_response =True + if final_answer and first_final_response: + for substring in possible_substrings: + if substring in content: + content = content.replace(substring, "", 1) + first_final_response = False + + if content and final_answer: + yield content + print(content, end="|", flush=True) + # if final_ans_recieved and content.strip().endswith("[DONE]"): + # break + elif kind == "on_tool_start": + print("--") + print( + f"Starting tool: {event['name']} with inputs: {event['data'].get('input')}" + ) + elif kind == "on_tool_end": + print(f"Done tool: {event['name']}") + print(f"Tool output was: {event['data'].get('output')}") + print("--") + + +async def gen_substrings(string_to_check): + return [string_to_check[i:j] for i in range(len(string_to_check)) for j in range(len(string_to_check))] + +def setup_executer(llm, memory, tools): + prompt = hub.pull("hwchase17/react-chat") + agent = create_react_agent(llm, tools, prompt) + agent_executor = AgentExecutor( + agent=agent, + tools=tools, + max_iterations=100, + memory=memory, + return_intermediate_steps=True, + handle_parsing_errors=True + ) + + return agent_executor diff --git a/app/api/docker_control/docker_utils.py b/app/api/docker_control/docker_utils.py index 163f4ce5..9fbf2307 100644 --- a/app/api/docker_control/docker_utils.py +++ b/app/api/docker_control/docker_utils.py @@ -66,7 +66,30 @@ def run_container(impl, weights_id): except docker.errors.ContainerError as e: return {"status": "error", "message": str(e)} - +def run_agent_container(container_name, port_bindings, impl): + # runs agent container after associated llm container runs + run_kwargs = copy.deepcopy(impl.docker_config) + host_agent_port = get_host_agent_port() + llm_host_port = list(port_bindings.values())[0] # port that llm is using for naming convention (for easier removal later) + run_kwargs = { + 'name': f'ai_agent_container_p{llm_host_port}', # Container name + 'network': 'tt_studio_network', # Docker network + 'ports': {'8080/tcp': host_agent_port}, # Mapping container port 8080 to host port (host port dependent on LLM port) + 'environment': { + 'TAVILY_API_KEY': os.getenv('TAVILY_API_KEY'), # found in env file + 'LLM_CONTAINER_NAME': container_name, + 'JWT_SECRET': run_kwargs["environment"]['JWT_SECRET'], + 'HF_MODEL_PATH': run_kwargs["environment"]["HF_MODEL_PATH"] + }, # Set the environment variables + 'detach': True, # Run the container in detached mode +} + container = client.containers.run( + 'agent_image:v1', + f"uvicorn agent:app --reload --host 0.0.0.0 --port {host_agent_port}", + auto_remove=True, + **run_kwargs +) + def stop_container(container_id): """Stop a specific docker container""" try: @@ -121,6 +144,28 @@ def get_host_port(impl): logger.warning("Could not find an unused port in block: 8001-8100") return None +def get_host_agent_port(): + # used fixed block of ports starting at 8101 for agents + agent_containers = get_agent_containers() + port_mappings = get_port_mappings(agent_containers) + used_host_agent_ports = get_used_host_ports(port_mappings) + logger.info(f"used_host_agent_ports={used_host_agent_ports}") + BASE_AGENT_PORT = 8201 + for port in range(BASE_AGENT_PORT, BASE_AGENT_PORT+100): + if str(port) not in used_host_agent_ports: + return port + logger.warning("Could not find an unused port in block: 8201-8300") + +def get_agent_containers(): + """ + get all containers used by an ai agent + """ + running_containers = client.containers.list() + agent_containers = [] + for container in running_containers: + if "ai_agent_container" in container.name: + agent_containers.append(container) + return agent_containers def get_managed_containers(): """get containers configured in model_config.py for LLM-studio management""" diff --git a/app/api/docker_control/views.py b/app/api/docker_control/views.py index 05e40f7f..55fb9681 100644 --- a/app/api/docker_control/views.py +++ b/app/api/docker_control/views.py @@ -8,9 +8,11 @@ from rest_framework.views import APIView from rest_framework.response import Response +import docker from .forms import DockerForm from .docker_utils import ( run_container, + run_agent_container, stop_container, get_container_status, perform_reset, @@ -30,6 +32,18 @@ def post(self, request, *args, **kwargs): container_id = request.data.get("container_id") logger.info(f"Received request to stop container with ID: {container_id}") + # Find agent container + client = docker.from_env() + container_name = client.containers.get(container_id).name + last_underscore_index = container_name.rfind('_') + llm_host_port = container_name[last_underscore_index + 1:] + + agent_container_name = f"ai_agent_container_{llm_host_port}" + all_containers = client.containers.list(all=True) + for container in all_containers: + if container.name == agent_container_name: # if the agent corresponding agent container is found + stop_container(container.id) # remove the agent container + # Stop the container stop_response = stop_container(container_id) logger.info(f"Stop response: {stop_response}") @@ -100,6 +114,7 @@ def post(self, request, *args, **kwargs): weights_id = request.data.get("weights_id") impl = model_implmentations[impl_id] response = run_container(impl, weights_id) + run_agent_container(response["container_name"], response["port_bindings"], impl) # run agent container that maps to appropriate LLM container return Response(response, status=status.HTTP_201_CREATED) else: return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) diff --git a/app/api/model_control/model_utils.py b/app/api/model_control/model_utils.py index 28a390dd..496a22fc 100644 --- a/app/api/model_control/model_utils.py +++ b/app/api/model_control/model_utils.py @@ -8,6 +8,7 @@ import requests import jwt +import json from django.core.cache import caches @@ -48,6 +49,49 @@ def health_check(url, json_data, timeout=5): logger.error(f"Health check failed: {str(e)}") return False, str(e) +def stream_response_from_agent_api(url, json_data): + try: + new_json_data = {} + new_json_data["thread_id"] = json_data["thread_id"] + new_json_data["message"] = json_data["messages"][-1]["content"] + headers = {"Content-Type": "application/json"} + + logger.info(f"stream_response_from_agent_api headers:={headers}") + logger.info(f"stream_response_from_agent_api json_data:={new_json_data}") + logger.info(f"using agent thread id: {new_json_data["thread_id"]}") + logger.info(f"POST URL: {url}") + logger.info(f"POST Headers: {headers}") + logger.info(f"POST Data: {json.dumps(new_json_data, indent=2)}") + + + + with requests.post( + url, json=new_json_data, headers=headers, stream=True, timeout=None + ) as response: + logger.info(f"stream_response_from_external_api response:={response}") + response.raise_for_status() + logger.info(f"response.headers:={response.headers}") + logger.info(f"response.encoding:={response.encoding}") + # only allow HTTP 1.1 chunked encoding + # assert response.headers.get("transfer-encoding") == "chunked" + + # Stream chunks + for chunk in response.iter_content(chunk_size=None, decode_unicode=True): + json_chunk = {} + logger.info(f"stream_response_from_external_api chunk:={chunk}") + if chunk == "[DONE]": + yield "data: " + chunk + "\n" + else: + json_chunk["choices"] = [{"index": 0, "delta": {"content": chunk}}] + json_chunk = json.dumps(json_chunk) + string = "data: " + json_chunk + logger.info(f"streaming json object: {string}") + yield "data: " + json_chunk + "\n" + logger.info("stream_response_from_external done") + + except requests.RequestException as e: + logger.error(f"RequestException: {str(e)}") + yield f"error: {str(e)}" def stream_response_from_external_api(url, json_data): logger.info(f"stream_response_from_external_api to: url={url}") diff --git a/app/api/model_control/urls.py b/app/api/model_control/urls.py index b7c4fb18..2e41d2bd 100644 --- a/app/api/model_control/urls.py +++ b/app/api/model_control/urls.py @@ -8,6 +8,7 @@ urlpatterns = [ path("inference/", views.InferenceView.as_view()), + path("agent/", views.AgentView.as_view()), path("deployed/", views.DeployedModelsView.as_view()), path("model_weights/", views.ModelWeightsView.as_view()), path("object-detection/", views.ObjectDetectionInferenceView.as_view()), diff --git a/app/api/model_control/views.py b/app/api/model_control/views.py index 7123819e..00c27fcb 100644 --- a/app/api/model_control/views.py +++ b/app/api/model_control/views.py @@ -18,6 +18,7 @@ encoded_jwt, get_deploy_cache, stream_response_from_external_api, + stream_response_from_agent_api, health_check, ) from shared_config.model_config import model_implmentations @@ -43,6 +44,30 @@ def post(self, request, *args, **kwargs): return StreamingHttpResponse(response_stream, content_type="text/plain") else: return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + +class AgentView(APIView): + def post(self, request, *agrs, **kwargs): + logger.info(f"URL '/agent/' accessed via POST method by {request.META['REMOTE_ADDR']}") + data = request.data + logger.info(f"AgentView data:={data}") + serializer = InferenceSerializer(data=data) + if serializer.is_valid(): + deploy_id = data.pop("deploy_id") + logger.info(f"Deploy ID: {deploy_id}") + deploy = get_deploy_cache()[deploy_id] + colon_idx = deploy["internal_url"].rfind(":") + underscore_idx = deploy["internal_url"].rfind("_") + llm_host_port = deploy["internal_url"][underscore_idx + 2: colon_idx] # add 2 to remove the p + # agent port on host is 200 + the llm host port + internal_url = f"http://ai_agent_container_p{llm_host_port}:{int(llm_host_port) + 200}/poll_requests" + logger.info(f"internal_url:= {internal_url}") + logger.info(f"using vllm model:= {deploy["model_impl"].model_name}") + data["model"] = deploy["model_impl"].hf_model_id + logger.info(f"Using internal url: {internal_url}") + response_stream = stream_response_from_agent_api(internal_url, data) + return StreamingHttpResponse(response_stream, content_type="text/plain") + else: + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) class ModelHealthView(APIView): def get(self, request, *args, **kwargs): diff --git a/app/docker-compose.yml b/app/docker-compose.yml index e2ffb8ce..cd340188 100644 --- a/app/docker-compose.yml +++ b/app/docker-compose.yml @@ -39,6 +39,7 @@ services: - INTERNAL_PERSISTENT_STORAGE_VOLUME - BACKEND_API_HOSTNAME - JWT_SECRET + - TAVILY_API_KEY volumes: # mounting docker unix socket allows for backend container to run docker cmds - /var/run/docker.sock:/var/run/docker.sock diff --git a/app/frontend/package-lock.json b/app/frontend/package-lock.json index 722d6af7..1b4ddc36 100644 --- a/app/frontend/package-lock.json +++ b/app/frontend/package-lock.json @@ -40,6 +40,7 @@ "html-react-parser": "^5.1.18", "lucide-react": "^0.460.0", "mini-svg-data-uri": "^1.4.4", + "pdfjs-dist": "^4.10.38", "re-resizable": "^6.10.3", "react": "^18.3.1", "react-code-blocks": "^0.1.6", @@ -997,6 +998,188 @@ "@llm-ui/react": "0.13.3" } }, + "node_modules/@napi-rs/canvas": { + "version": "0.1.67", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.67.tgz", + "integrity": "sha512-VA4Khm/5Kg2bQGx3jXotTC4MloOG8b1Ung80exafUK0k5u6yJmIz3Q2iXeeWZs5weV+LQOEB+CPKsYwEYaGAjw==", + "license": "MIT", + "optional": true, + "engines": { + "node": ">= 10" + }, + "optionalDependencies": { + "@napi-rs/canvas-android-arm64": "0.1.67", + "@napi-rs/canvas-darwin-arm64": "0.1.67", + "@napi-rs/canvas-darwin-x64": "0.1.67", + "@napi-rs/canvas-linux-arm-gnueabihf": "0.1.67", + "@napi-rs/canvas-linux-arm64-gnu": "0.1.67", + "@napi-rs/canvas-linux-arm64-musl": "0.1.67", + "@napi-rs/canvas-linux-riscv64-gnu": "0.1.67", + "@napi-rs/canvas-linux-x64-gnu": "0.1.67", + "@napi-rs/canvas-linux-x64-musl": "0.1.67", + "@napi-rs/canvas-win32-x64-msvc": "0.1.67" + } + }, + "node_modules/@napi-rs/canvas-android-arm64": { + "version": "0.1.67", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-android-arm64/-/canvas-android-arm64-0.1.67.tgz", + "integrity": "sha512-W+3DFG5h0WU8Vqqb3W5fNmm5/TPH5ECZRinQDK4CAKFSUkc4iZcDwrmyFG9sB4KdHazf1mFVHCpEeVMO6Mk6Zg==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "android" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-darwin-arm64": { + "version": "0.1.67", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-arm64/-/canvas-darwin-arm64-0.1.67.tgz", + "integrity": "sha512-xzrv7QboI47yhIHR5P5u/9KGswokuOKLiKSukr1Ku03RRJxP6lGuVtrAZAgdRg7F9FsuF2REf2yK53YVb6pMlA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-darwin-x64": { + "version": "0.1.67", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-darwin-x64/-/canvas-darwin-x64-0.1.67.tgz", + "integrity": "sha512-SNk9lYBr84N0gW8MZ2IrjygFtbFBILr3SEqMdHzHHuph20SQmssFvJGPZwSSCMEyKAvyqhogbmlew0te5Z4w9Q==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "darwin" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-arm-gnueabihf": { + "version": "0.1.67", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm-gnueabihf/-/canvas-linux-arm-gnueabihf-0.1.67.tgz", + "integrity": "sha512-qmBlSvUpl567bzH8tNXi82u5FrL4d0qINqd6K9O7GWGGGFmKMJdrgi2/SW3wwCTxqHBasIDdVWc4KSJfwyaoDQ==", + "cpu": [ + "arm" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-arm64-gnu": { + "version": "0.1.67", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-gnu/-/canvas-linux-arm64-gnu-0.1.67.tgz", + "integrity": "sha512-k3nAPQefkMeFuJ65Rqdnx92KX1JXQhEKjjWeKsCJB+7sIBgQUWtHo9c3etfVLv5pkWJJDFi/Zc2soNkH3E8dRA==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-arm64-musl": { + "version": "0.1.67", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-arm64-musl/-/canvas-linux-arm64-musl-0.1.67.tgz", + "integrity": "sha512-lZwHWR1cCP408l86n3Qbs3X1oFeAYMjJIQvQl1VMZh6wo5PfI+jaZSKBUOd8x44TnVllX9yhLY9unNRztk/sUQ==", + "cpu": [ + "arm64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-riscv64-gnu": { + "version": "0.1.67", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-riscv64-gnu/-/canvas-linux-riscv64-gnu-0.1.67.tgz", + "integrity": "sha512-PdBC9p6bLHA1W3OdA0vTHj701SB/kioGQ1uCFBRMs5KBCaMLb/H4aNi8uaIUIEvBWnxeAjoNcLU7//q0FxEosw==", + "cpu": [ + "riscv64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-x64-gnu": { + "version": "0.1.67", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-gnu/-/canvas-linux-x64-gnu-0.1.67.tgz", + "integrity": "sha512-kJJX6eWzjipL/LdKOWCJctc88e5yzuXri8+s0V/lN06OwuLGW62TWS3lvi8qlUrGMOfRGabSWWlB4omhASSB8w==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-linux-x64-musl": { + "version": "0.1.67", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-linux-x64-musl/-/canvas-linux-x64-musl-0.1.67.tgz", + "integrity": "sha512-jLKiPWGeN6ZzhnaLG7ex7eexsiHJ1mdtPK1qKvETIcu45dApMXyUIHvdL6XWB5gFFtj5ScHzLUxv1vkfPZsoxA==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "linux" + ], + "engines": { + "node": ">= 10" + } + }, + "node_modules/@napi-rs/canvas-win32-x64-msvc": { + "version": "0.1.67", + "resolved": "https://registry.npmjs.org/@napi-rs/canvas-win32-x64-msvc/-/canvas-win32-x64-msvc-0.1.67.tgz", + "integrity": "sha512-K/JmkOFbc4iRZYUqJhj0jwqfHA/wNQEmTiGNsgZ6d59yF/IBNp5T0D5eg3B8ghjI8GxDYCiSJ6DNX8mC3Oh2EQ==", + "cpu": [ + "x64" + ], + "license": "MIT", + "optional": true, + "os": [ + "win32" + ], + "engines": { + "node": ">= 10" + } + }, "node_modules/@nodelib/fs.scandir": { "version": "2.1.5", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", @@ -8256,6 +8439,18 @@ "url": "https://github.com/sponsors/isaacs" } }, + "node_modules/pdfjs-dist": { + "version": "4.10.38", + "resolved": "https://registry.npmjs.org/pdfjs-dist/-/pdfjs-dist-4.10.38.tgz", + "integrity": "sha512-/Y3fcFrXEAsMjJXeL9J8+ZG9U01LbuWaYypvDW2ycW1jL269L3js3DVBjDJ0Up9Np1uqDXsDrRihHANhZOlwdQ==", + "license": "Apache-2.0", + "engines": { + "node": ">=20" + }, + "optionalDependencies": { + "@napi-rs/canvas": "^0.1.65" + } + }, "node_modules/picocolors": { "version": "1.1.1", "resolved": "https://registry.npmjs.org/picocolors/-/picocolors-1.1.1.tgz", diff --git a/app/frontend/package.json b/app/frontend/package.json index 8b4c2a36..cef1c51c 100644 --- a/app/frontend/package.json +++ b/app/frontend/package.json @@ -44,6 +44,7 @@ "html-react-parser": "^5.1.18", "lucide-react": "^0.460.0", "mini-svg-data-uri": "^1.4.4", + "pdfjs-dist": "^4.10.38", "re-resizable": "^6.10.3", "react": "^18.3.1", "react-code-blocks": "^0.1.6", diff --git a/app/frontend/src/components/chatui/ChatComponent.tsx b/app/frontend/src/components/chatui/ChatComponent.tsx index 373e2b5d..93bb4330 100644 --- a/app/frontend/src/components/chatui/ChatComponent.tsx +++ b/app/frontend/src/components/chatui/ChatComponent.tsx @@ -23,6 +23,7 @@ import type { import { runInference } from "./runInference"; import { v4 as uuidv4 } from "uuid"; import { usePersistentState } from "./usePersistentState"; +import { threadId } from "worker_threads"; export default function ChatComponent() { const [files, setFiles] = useState([]); @@ -50,6 +51,7 @@ export default function ChatComponent() { >(null); const [isListening, setIsListening] = useState(false); const [isHistoryPanelOpen, setIsHistoryPanelOpen] = useState(true); + const [isAgentSelected, setIsAgentSelected] = useState(false); useEffect(() => { if (location.state) { @@ -69,6 +71,24 @@ export default function ChatComponent() { loadModels(); }, [location.state]); + useEffect(() => { + const currentThread = chatThreads[currentThreadIndex]; + if (currentThread && currentThread.length > 0) { + const messagesWithRag = currentThread + .filter((msg) => msg.sender === "user" && msg.ragDatasource) + .reverse(); + + if (messagesWithRag.length > 0) { + const mostRecentRag = messagesWithRag[0].ragDatasource; + setRagDatasource(mostRecentRag); + } else { + setRagDatasource(undefined); + } + } else { + setRagDatasource(undefined); + } + }, [currentThreadIndex, chatThreads]); + useEffect(() => { const handleResize = () => { if (window.innerWidth < 768) { @@ -111,11 +131,13 @@ export default function ChatComponent() { : msg ); } else { + // Store ragDatasource in the user message const userMessage: ChatMessage = { id: uuidv4(), sender: "user", text: textInput, files: files, + ragDatasource: ragDatasource, // Store the RAG datasource with the message }; updatedChatHistory = [ ...(chatThreads[currentThreadIndex] || []), @@ -164,7 +186,9 @@ export default function ChatComponent() { return newThreads; }); }, - setIsStreaming + setIsStreaming, + isAgentSelected, + currentThreadIndex ); setTextInput(""); @@ -202,6 +226,9 @@ export default function ChatComponent() { ); if (!userMessage) return; + // Get the RAG datasource from the user message if available + const messageRagDatasource = userMessage.ragDatasource || ragDatasource; + setReRenderingMessageId(messageId); setIsStreaming(true); @@ -213,7 +240,7 @@ export default function ChatComponent() { await runInference( inferenceRequest, - ragDatasource, + messageRagDatasource, chatThreads[currentThreadIndex] || [], (newHistory) => { setChatThreads((prevThreads) => { @@ -241,7 +268,9 @@ export default function ChatComponent() { return newThreads; }); }, - setIsStreaming + setIsStreaming, + isAgentSelected, + currentThreadIndex ); setReRenderingMessageId(null); @@ -338,7 +367,10 @@ export default function ChatComponent() { ({ + id: model.containerID || "", + name: model.modelName || "", + }))} setModelID={setModelID} setModelName={setModelName} ragDataSources={ragDataSources} @@ -346,6 +378,8 @@ export default function ChatComponent() { setRagDatasource={setRagDatasource} isHistoryPanelOpen={isHistoryPanelOpen} setIsHistoryPanelOpen={setIsHistoryPanelOpen} + isAgentSelected={isAgentSelected} // Pass the state down + setIsAgentSelected={setIsAgentSelected} // Pass the setter down /> void; onContinue: (messageId: string) => void; reRenderingMessageId: string | null; + ragDatasource: + | { + id: string; + name: string; + metadata?: { + created_at?: string; + embedding_func_name?: string; + last_uploaded_document?: string; + }; + } + | undefined; } -const isImageFile = ( - file: FileData -): file is FileData & { type: "image_url" } => file.type === "image_url"; +const RagPill: React.FC<{ + ragDatasource: { + id: string; + name: string; + metadata?: { + created_at?: string; + embedding_func_name?: string; + last_uploaded_document?: string; + }; + }; +}> = ({ ragDatasource }) => ( + + + {ragDatasource.name} + {ragDatasource.metadata?.last_uploaded_document && ( + + Β· {ragDatasource.metadata.last_uploaded_document} + + )} + +); + +interface FileViewerDialogProps { + file: { url: string; name: string; isImage: boolean } | null; + onClose: () => void; +} + +const FileViewerDialog: React.FC = ({ + file, + onClose, +}) => { + if (!file) return null; + + return ( + + + + + + + {file.name} + + + + + + + + + {file.isImage ? ( + + ) : ( + + + Preview not available + + Download File + + + )} + + + + ); +}; const ChatHistory: React.FC = ({ chatHistory = [], @@ -34,14 +117,19 @@ const ChatHistory: React.FC = ({ onReRender, onContinue, reRenderingMessageId, + ragDatasource, }) => { + console.log("ChatHistory component rendered", ragDatasource); const viewportRef = useRef(null); const [isScrollButtonVisible, setIsScrollButtonVisible] = useState(false); const lastMessageRef = useRef(null); - const [enlargedImage, setEnlargedImage] = useState(null); - const [minimizedImages, setMinimizedImages] = useState>( - new Set() - ); + + const [minimizedFiles, setMinimizedFiles] = useState>(new Set()); + const [selectedFile, setSelectedFile] = useState<{ + url: string; + name: string; + isImage: boolean; + } | null>(null); const scrollToBottom = useCallback(() => { if (viewportRef.current) { @@ -82,18 +170,31 @@ const ChatHistory: React.FC = ({ } }, [isStreaming, scrollToBottom]); - const toggleMinimizeImage = useCallback((imageUrl: string) => { - setMinimizedImages((prev) => { + const toggleMinimizeFile = useCallback((fileId: string) => { + setMinimizedFiles((prev) => { const newSet = new Set(prev); - if (newSet.has(imageUrl)) { - newSet.delete(imageUrl); + if (newSet.has(fileId)) { + newSet.delete(fileId); } else { - newSet.add(imageUrl); + newSet.add(fileId); } return newSet; }); }, []); + const handleFileClick = useCallback((fileUrl: string, fileName: string) => { + const imageExtensions = ["jpg", "jpeg", "png", "gif", "webp", "svg"]; + const extension = fileName.split(".").pop()?.toLowerCase() || ""; + const isImage = + imageExtensions.includes(extension) || fileUrl.startsWith("data:image/"); + + setSelectedFile({ + url: fileUrl, + name: fileName, + isImage, + }); + }, []); + return ( {chatHistory.length === 0 ? ( @@ -186,71 +287,6 @@ const ChatHistory: React.FC = ({ )} - {message.text && - message.files && - message.files.length > 0 && ( - - )} - {message.files && message.files.length > 0 && ( - - - Attached Images: - - - {message.files.map( - (file, index) => - isImageFile(file) && ( - - {!minimizedImages.has( - file.image_url?.url || "" - ) ? ( - - setEnlargedImage( - file.image_url?.url || - "/placeholder.svg" - ) - } - /> - ) : ( - - )} - - - {file.name} - - - toggleMinimizeImage( - file.image_url?.url || "" - ) - } - className="ml-2 text-gray-300 hover:text-white flex-shrink-0" - > - {minimizedImages.has( - file.image_url?.url || "" - ) ? ( - - ) : ( - - )} - - - - ) - )} - - - )} )} @@ -262,6 +298,18 @@ const ChatHistory: React.FC = ({ isReRendering={reRenderingMessageId === message.id} isStreaming={isStreaming} inferenceStats={message.inferenceStats} + messageContent={message.text} + /> + )} + {message.ragDatasource && ( + + )} + {message.files && message.files.length > 0 && ( + )} @@ -287,26 +335,11 @@ const ChatHistory: React.FC = ({ )} - setEnlargedImage(null)} - > - - - - - - - - - - - - + + setSelectedFile(null)} + /> ); }; diff --git a/app/frontend/src/components/chatui/FileDisplay.tsx b/app/frontend/src/components/chatui/FileDisplay.tsx new file mode 100644 index 00000000..da61df61 --- /dev/null +++ b/app/frontend/src/components/chatui/FileDisplay.tsx @@ -0,0 +1,244 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Β© 2025 Tenstorrent AI ULC + +import type React from "react"; +import type { FileData } from "./types"; +import { + FileText, + Image, + File, + Maximize2, + Minimize2, + FileCode2, + FileIcon as FilePdf, + FileJson, + FileType2, + FileSpreadsheet, + FileArchive, + FileVideo, + FileAudio, + FileImage, +} from "lucide-react"; +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from "../ui/tooltip"; + +interface FileDisplayProps { + files: FileData[]; + minimizedFiles: Set; + toggleMinimizeFile: (fileId: string) => void; + onFileClick: (fileUrl: string, fileName: string) => void; +} + +const isImageFile = ( + file: FileData +): file is FileData & { type: "image_url" } => file.type === "image_url"; + +const getFileExtension = (filename: string): string => { + const parts = filename.split("."); + return parts.length > 1 ? parts[parts.length - 1].toLowerCase() : ""; +}; + +const getFileIcon = (file: FileData) => { + if (isImageFile(file)) + return ; + + const extension = getFileExtension(file.name); + switch (extension) { + // Documents + case "pdf": + return ; + case "doc": + case "docx": + return ; + case "xls": + case "xlsx": + case "csv": + return ; + + // Code files + case "js": + case "jsx": + case "ts": + case "tsx": + return ; + case "py": + return ; + case "java": + return ; + case "cpp": + case "c": + case "h": + return ; + case "rs": + return ; + + // Data files + case "json": + return ; + case "xml": + return ; + case "yaml": + case "yml": + return ; + + // Archives + case "zip": + case "rar": + case "7z": + case "tar": + case "gz": + return ; + + // Media + case "mp4": + case "mov": + case "avi": + return ; + case "mp3": + case "wav": + case "ogg": + return ; + + // Text + case "txt": + case "md": + case "log": + return ; + + default: + return ; + } +}; + +const FileDisplay: React.FC = ({ + files, + minimizedFiles, + toggleMinimizeFile, + onFileClick, +}) => { + if (!files || files.length === 0) return null; + + console.log("Files:", files); + const imageFiles = files.filter(isImageFile); + const otherFiles = files.filter((file) => !isImageFile(file)); + const allFiles = [...imageFiles, ...otherFiles]; + + return ( + + + + {allFiles.map((file, index) => { + const fileId = isImageFile(file) + ? file.image_url?.url || file.id || index.toString() + : file.url || file.id || index.toString(); + const isMinimized = isImageFile(file) + ? minimizedFiles.has(fileId) + : true; // Files always start minimized + + if (isImageFile(file)) { + return ( + + {!isMinimized ? ( + + + onFileClick(file.image_url?.url || "", file.name) + } + /> + + + toggleMinimizeFile(fileId)} + className="absolute top-2 right-2 p-1 rounded-full bg-black/40 hover:bg-black/60 text-white opacity-0 group-hover:opacity-100 transition-opacity duration-200" + > + + + + + Minimize image + + + + ) : ( + + toggleMinimizeFile(fileId)} + > + + + + + + + + + Maximize image + + + + + {file.name} + + + )} + + ); + } + + return ( + + onFileClick(file.url || "", file.name)} + > + {getFileIcon(file)} + + + + {file.size && ( + + {formatFileSize(file.size)} + + )} + + + + {file.size && ( + + {formatFileSize(file.size)} + + )} + + + + + {file.name} + + + ); + })} + + + + ); +}; + +// Utility function to format file size +const formatFileSize = (bytes: number): string => { + if (bytes < 1024) return bytes + " B"; + else if (bytes < 1048576) return (bytes / 1024).toFixed(1) + " KB"; + else return (bytes / 1048576).toFixed(1) + " MB"; +}; + +export default FileDisplay; diff --git a/app/frontend/src/components/chatui/Header.tsx b/app/frontend/src/components/chatui/Header.tsx index dcd01e8c..f4bc413d 100644 --- a/app/frontend/src/components/chatui/Header.tsx +++ b/app/frontend/src/components/chatui/Header.tsx @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Β© 2024 Tenstorrent AI ULC import React from "react"; +import { useState } from "react"; import { Breadcrumb, BreadcrumbEllipsis, @@ -36,20 +37,25 @@ interface HeaderProps { modelName: string | null; modelsDeployed: { id: string; name: string }[]; setModelID: (id: string) => void; - setModelName: (name: string) => void; + setModelName: (name: string | null) => void; ragDataSources: RagDataSource[]; ragDatasource: RagDataSource | undefined; setRagDatasource: (datasource: RagDataSource | undefined) => void; isHistoryPanelOpen: boolean; setIsHistoryPanelOpen: (isOpen: boolean) => void; -} + isAgentSelected: boolean; + setIsAgentSelected: (value: boolean) => void; +} interface RagDataSource { id: string; name: string; - metadata: Record; + metadata?: { + created_at?: string; + embedding_func_name?: string; + last_uploaded_document?: string; + }; } - const ModelSelector = React.forwardRef< HTMLButtonElement, { @@ -101,6 +107,21 @@ const ForwardedSelect = React.forwardRef< ForwardedSelect.displayName = "ForwardedSelect"; +const ForwardedAISelect = React.forwardRef< + HTMLButtonElement, + React.ComponentPropsWithoutRef +>((props, ref) => ( + + + + + {props.children} + +)); + export default function Header({ modelName, modelsDeployed, @@ -111,7 +132,21 @@ export default function Header({ setRagDatasource, isHistoryPanelOpen, setIsHistoryPanelOpen, + isAgentSelected, + setIsAgentSelected }: HeaderProps) { + const [selectedAIAgent, setSelectedAIAgent] = useState(null); + + // Handle the AI agent selection change + const handleAgentSelection = (value: string) => { + if (value === "remove") { + setSelectedAIAgent(""); // Clear the selected agent + setIsAgentSelected(false); // Set to false if agent is removed + } else { + setSelectedAIAgent(value); // Set the selected agent + setIsAgentSelected(true); // Set to true if an agent is selected + } + }; return ( @@ -181,6 +216,7 @@ export default function Header({ + @@ -192,7 +228,7 @@ export default function Header({ setRagDatasource(undefined); } else { const dataSource = ragDataSources.find( - (rds) => rds.name === v, + (rds) => rds.name === v ); if (dataSource) { setRagDatasource(dataSource); @@ -233,7 +269,61 @@ export default function Header({ - + + + + + + + { + // if (v === "remove") { + // handleAgentSelection(""); // Clear selection if "remove" is chosen + // } else { + // handleAgentSelection(v); // Set selected AI Agent + // } + // }} + > + + + Search Agent + + + {/* Add more AI agents as SelectItems here if needed */} + {/* + Another AI Agent + */} + + {selectedAIAgent && ( + + + + Remove AI Agent + + + )} + + + + + {selectedAIAgent ? "Change or remove AI agent" : "Select AI Agent"} + + + + + + ); } diff --git a/app/frontend/src/components/chatui/InputArea.tsx b/app/frontend/src/components/chatui/InputArea.tsx index f5f759ee..9bf81c70 100644 --- a/app/frontend/src/components/chatui/InputArea.tsx +++ b/app/frontend/src/components/chatui/InputArea.tsx @@ -6,10 +6,20 @@ import { Button } from "../ui/button"; import { Paperclip, Send, X, File } from "lucide-react"; import { VoiceInput } from "./VoiceInput"; import { FileUpload } from "../ui/file-upload"; -import { isImageFile, validateFile, encodeFile } from "./fileUtils"; +import { isImageFile, validateFile, encodeFile, isTextFile } from "./fileUtils"; import { cn } from "../../lib/utils"; -import type { FileData, InputAreaProps } from "./types"; +import type { InputAreaProps } from "./types"; import { customToast } from "../CustomToaster"; +import { + AlertDialog, + AlertDialogAction, + AlertDialogCancel, + AlertDialogContent, + AlertDialogDescription, + AlertDialogFooter, + AlertDialogHeader, + AlertDialogTitle, +} from "../ui/alert-dialog"; import { Tooltip, TooltipContent, @@ -32,6 +42,9 @@ export default function InputArea({ const [isDragging, setIsDragging] = useState(false); const [showProgressBar, setShowProgressBar] = useState(false); const [showErrorIndicator, setShowErrorIndicator] = useState(false); + const [showReplaceDialog, setShowReplaceDialog] = useState(false); + const [pendingImageFile, setPendingImageFile] = useState(null); + const [isFocused, setIsFocused] = useState(false); useEffect(() => { if (textareaRef.current && !isStreaming) { @@ -43,7 +56,7 @@ export default function InputArea({ if (textareaRef.current) { adjustTextareaHeight(); } - }, [textareaRef]); + }, []); const adjustTextareaHeight = () => { if (textareaRef.current) { @@ -54,6 +67,7 @@ export default function InputArea({ const handleTextAreaInput = (e: React.ChangeEvent) => { setTextInput(e.target.value); + adjustTextareaHeight(); }; const handleKeyPress = (e: React.KeyboardEvent) => { @@ -67,56 +81,129 @@ export default function InputArea({ setTextInput((prevText) => prevText + (prevText ? " " : "") + transcript); }; + const processFile = useCallback(async (file: File) => { + try { + setShowProgressBar(true); + + const validation = validateFile(file); + if (!validation.valid) { + throw new Error(validation.error); + } + + const base64 = await encodeFile(file, true); + if (isImageFile(file)) { + return { + type: "image_url" as const, + image_url: { url: `data:${file.type};base64,${base64}` }, + name: file.name, + }; + } + + return { + type: "text" as const, + text: base64, + name: file.name, + }; + } catch (error) { + console.error("File processing error:", error); + throw error; + } + }, []); + const handleFileUpload = useCallback( async (uploadedFiles: File[]) => { try { - setIsDragging(true); + setIsDragging(false); setShowProgressBar(true); - const encodedFiles = await Promise.all( - uploadedFiles.map(async (file) => { - const validation = validateFile(file); - if (!validation.valid) { - throw new Error(validation.error); - } + const imageFiles = uploadedFiles.filter(isImageFile); + const textFiles = uploadedFiles.filter(isTextFile); + + // Handle image files + if (imageFiles.length > 0) { + const existingImages = files.filter((f) => f.type === "image_url"); + if (existingImages.length > 0) { + setPendingImageFile(imageFiles[0]); + setShowReplaceDialog(true); - const base64 = await encodeFile(file, true); - if (isImageFile(file)) { - return { - type: "image_url" as const, - image_url: { url: `data:${file.type};base64,${base64}` }, - name: file.name, - }; + // Process text files if any + if (textFiles.length > 0) { + const encodedTextFiles = await Promise.all( + textFiles.map(processFile) + ); + setFiles((prevFiles) => [...prevFiles, ...encodedTextFiles]); + customToast.success( + `Successfully uploaded ${textFiles.length} text file(s)!` + ); } - console.log("Encoded file:", base64); - return { - type: "text" as const, - text: base64, - name: file.name, - }; - }) - ); - setFiles( - (prevFiles: FileData[]) => - [...prevFiles, ...encodedFiles] as FileData[] - ); - customToast.success( - `Successfully uploaded ${uploadedFiles.length} file(s)!` - ); + return; + } + // No existing image, process the first image file + const encodedImage = await processFile(imageFiles[0]); + const encodedTextFiles = await Promise.all( + textFiles.map(processFile) + ); + + setFiles((prevFiles) => [ + ...prevFiles, + encodedImage, + ...encodedTextFiles, + ]); + customToast.success( + `Successfully uploaded ${imageFiles.length > 1 ? "1 image (extras ignored)" : "1 image"}${ + textFiles.length > 0 + ? ` and ${textFiles.length} text file(s)` + : "" + }!` + ); + } else if (textFiles.length > 0) { + // Only text files + const encodedFiles = await Promise.all(textFiles.map(processFile)); + setFiles((prevFiles) => [...prevFiles, ...encodedFiles]); + customToast.success( + `Successfully uploaded ${textFiles.length} text file(s)!` + ); + } } catch (error) { console.error("File upload error:", error); - customToast.error("Failed to upload file(s). Please try again."); + customToast.error( + error instanceof Error + ? error.message + : "Failed to upload file(s). Please try again." + ); setShowErrorIndicator(true); setTimeout(() => setShowErrorIndicator(false), 3000); } finally { - setIsDragging(false); + setShowProgressBar(false); setIsFileUploadOpen(false); - setTimeout(() => setShowProgressBar(false), 1000); } }, - [setFiles] + [files, processFile, setFiles] ); + const handleReplaceConfirm = async () => { + if (pendingImageFile) { + try { + const encodedImage = await processFile(pendingImageFile); + setFiles((prevFiles) => [ + ...prevFiles.filter((f) => f.type !== "image_url"), + encodedImage, + ]); + customToast.success("Image replaced successfully!"); + } catch (error) { + console.error("Error replacing image:", error); + customToast.error("Failed to replace image. Please try again."); + } + setPendingImageFile(null); + } + setShowReplaceDialog(false); + }; + + const handleReplaceCancel = () => { + setPendingImageFile(null); + setShowReplaceDialog(false); + }; + const removeFile = (index: number) => { setFiles((prevFiles) => prevFiles.filter((_, i) => i !== index)); customToast.success("File removed successfully!"); @@ -139,150 +226,197 @@ export default function InputArea({ if (files) handleFileUpload(Array.from(files)); }; - return ( - - {isDragging && ( - - - - Drop files to upload - Release to add files - - - )} + useEffect(() => { + adjustTextareaHeight(); + window.addEventListener("resize", adjustTextareaHeight); + return () => window.removeEventListener("resize", adjustTextareaHeight); + }, [textInput]); - - {/* File preview section */} - {files.length > 0 && ( - <> - - {files.map((file, index) => ( - - - {file.type === "image_url" ? ( - - ) : ( - - )} - - - {file.name} - - removeFile(index)} - aria-label="Remove file" - > - - + return ( + <> + + + + Replace Existing Image? + + You can only have one image at a time. Do you want to replace the + existing image with the new one? + {pendingImageFile && ( + + New image: {pendingImageFile.name} - ))} + )} + + + + + Cancel + + + Replace + + + + + + + {isDragging && ( + + + + + Drop files to upload + + + Limited to one image, multiple text files allowed + - - > + )} - {/* Main text input area */} - - - {/* Control buttons */} - - - - - - setIsFileUploadOpen((prev) => !prev)} - aria-label="Attach files" + + {/* File preview section */} + {files.length > 0 && ( + <> + + {files.map((file, index) => ( + - - - - - Attach files or drag and drop - - - - - - - - {" "} - {/* Wrap VoiceInput in a div for tooltip positioning */} - + + {file.type === "image_url" ? ( + + ) : ( + + )} + + + {file.name} + + removeFile(index)} + aria-label="Remove file" + > + + - - - Voice input - - - + ))} + + + > + )} + + {/* Main text input area */} + setIsFocused(true)} + onBlur={() => setIsFocused(false)} + /> + + {/* Control buttons */} + + + + + + setIsFileUploadOpen((prev) => !prev)} + aria-label="Attach files" + > + + + + + Attach files (1 image max) + + + + + + + + + + + + Voice input + + + + + handleInference(textInput, files)} + disabled={ + isStreaming || (!textInput.trim() && files.length === 0) + } + className="bg-[#7C68FA] hover:bg-[#7C68FA]/80 text-white px-4 py-2 rounded-lg flex items-center gap-2 transition-colors duration-300" + aria-label="Send message" + > + Generate + + - handleInference(textInput, files)} - disabled={isStreaming || (!textInput.trim() && files.length === 0)} - className="bg-[#7C68FA] hover:bg-[#7C68FA]/80 text-white px-4 py-2 rounded-lg flex items-center gap-2 transition-colors duration-300" - aria-label="Send message" - > - Generate - - + + {/* Preserved streaming indicator */} + {isStreaming && ( + + + + )} - {/* Preserved streaming indicator */} - {isStreaming && ( - - - + {showProgressBar && ( + + )} + {showErrorIndicator && ( + + )} + {isFileUploadOpen && ( + setIsFileUploadOpen(false)} + /> )} - - {showProgressBar && ( - - )} - {showErrorIndicator && ( - - )} - {isFileUploadOpen && ( - setIsFileUploadOpen(false)} - /> - )} - + > ); } diff --git a/app/frontend/src/components/chatui/MessageActions.tsx b/app/frontend/src/components/chatui/MessageActions.tsx index ef8191fe..14767043 100644 --- a/app/frontend/src/components/chatui/MessageActions.tsx +++ b/app/frontend/src/components/chatui/MessageActions.tsx @@ -1,11 +1,13 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Β© 2024 Tenstorrent AI ULC -import React from "react"; + +import type React from "react"; +import { useState, useEffect } from "react"; import { Button } from "../ui/button"; import { Clipboard, ThumbsUp, ThumbsDown } from "lucide-react"; import CustomToaster, { customToast } from "../CustomToaster"; import InferenceStats from "./InferenceStats"; -import { InferenceStats as InferenceStatsType } from "./types"; +import type { InferenceStats as InferenceStatsType } from "./types"; interface MessageActionsProps { messageId: string; @@ -14,6 +16,7 @@ interface MessageActionsProps { isReRendering: boolean; isStreaming: boolean; inferenceStats?: InferenceStatsType; + messageContent?: string; } const MessageActions: React.FC = ({ @@ -23,10 +26,29 @@ const MessageActions: React.FC = ({ isReRendering, isStreaming, inferenceStats, + messageContent, }) => { - const handleCopy = () => { - // Implement copy logic here - customToast.success("Message copied to clipboard"); + const [completeMessage, setCompleteMessage] = useState( + messageContent || "" + ); + + // Update the complete message when streaming finishes + useEffect(() => { + if (!isStreaming && messageContent) { + setCompleteMessage(messageContent); + } + }, [isStreaming, messageContent]); + + const handleCopy = async () => { + try { + if (completeMessage) { + await navigator.clipboard.writeText(completeMessage); + customToast.success("Message copied to clipboard"); + } + } catch (err) { + console.error("Failed to copy text: ", err); + customToast.error("Failed to copy message"); + } }; const handleThumbsUp = () => { @@ -49,6 +71,7 @@ const MessageActions: React.FC = ({ size="icon" onClick={handleCopy} className="h-8 w-8 p-0" + disabled={isStreaming} > Copy message diff --git a/app/frontend/src/components/chatui/fileUtils.tsx b/app/frontend/src/components/chatui/fileUtils.tsx index affeb1b9..748de1ab 100644 --- a/app/frontend/src/components/chatui/fileUtils.tsx +++ b/app/frontend/src/components/chatui/fileUtils.tsx @@ -1,97 +1,218 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Β© 2025 Tenstorrent AI ULC -import { customToast } from "../CustomToaster" +import { customToast } from "../CustomToaster"; -export const encodeFile = (file: File, base64Encoded = true): Promise => { +// File extensions mapping for code files +const codeFileExtensions = new Set([ + // Web + ".js", + ".jsx", + ".ts", + ".tsx", + ".html", + ".css", + ".scss", + ".sass", + // Backend + ".py", + ".java", + ".cpp", + ".c", + ".hpp", + ".h", + ".cs", + ".php", + ".rb", + // Other languages + ".go", + ".rs", + ".swift", + ".kt", + ".scala", + ".lua", + ".r", + ".pl", + ".sh", + // Config/Data + ".json", + ".yaml", + ".yml", + ".toml", + ".xml", + // Documentation + ".md", + ".mdx", + ".txt", + ".csv", +]); + +const supportedMimeTypes = { + images: ["image/png", "image/jpeg", "image/webp"], + textFiles: ["text/plain", "text/markdown", "text/x-markdown", "text/mdx"], + codeFiles: [ + "text/javascript", + "application/javascript", + "text/typescript", + "text/x-python", + "text/x-java", + "text/x-c++src", + "text/x-typescript", + ], +}; + +const getFileExtension = (filename: string): string => { + return filename + .slice(((filename.lastIndexOf(".") - 1) >>> 0) + 2) + .toLowerCase(); +}; + +export const encodeFile = ( + file: File, + base64Encoded = true +): Promise => { return new Promise((resolve, reject) => { - console.log("Starting file encoding process...") - console.log(`File name: ${file.name}, Size: ${file.size} bytes, Type: ${file.type}`) - console.log(`Encoding mode: ${base64Encoded ? "Base64" : "Raw binary"}`) + console.log("Starting file encoding process..."); + console.log( + `File name: ${file.name}, Size: ${file.size} bytes, Type: ${file.type}` + ); + console.log(`Encoding mode: ${base64Encoded ? "Base64" : "Raw binary"}`); if (!file) { - console.error("No file provided") - reject(new Error("No file provided")) - return + console.error("No file provided"); + reject(new Error("No file provided")); + return; } - const reader = new FileReader() + const reader = new FileReader(); reader.onload = () => { - console.log("File read successfully") - const result = reader.result as string + console.log("File read successfully"); + const result = reader.result as string; + const extension = getFileExtension(file.name); + // Handle text and code files + if ( + supportedMimeTypes.textFiles.includes(file.type) || + supportedMimeTypes.codeFiles.includes(file.type) || + codeFileExtensions.has(`.${extension}`) + ) { + console.log(`Text/code content length: ${result.length} chars`); + resolve(result); + customToast.success(`File ${file.name} processed successfully! π`); + return; + } + + // Original image handling if (base64Encoded) { // Extract only the base64 data without the data URI prefix - const base64Data = result.split(",")[1] - console.log(`Base64 encoded data (first 50 chars): ${base64Data.substring(0, 50)}...`) - resolve(base64Data) - customToast.success(`File name: ${file.name}, uploaded sucessfully!π`) + const base64Data = result.split(",")[1]; + console.log( + `Base64 encoded data (first 50 chars): ${base64Data.substring(0, 50)}...` + ); + resolve(base64Data); + customToast.success(`File name: ${file.name}, uploaded sucessfully!π`); } else { // Return raw binary data - console.log(`Raw binary data length: ${result.length} bytes`) - resolve(result) + console.log(`Raw binary data length: ${result.length} bytes`); + resolve(result); } - } + }; reader.onerror = (error) => { - customToast.error(`Error uploading file: ${file.name} only supports PNG, JPEG, and WebP images.`) - console.error("File reading error:", error) - reject(new Error("Failed to read file")) - } + customToast.error( + `Error uploading file: ${file.name} only supports PNG, JPEG, and WebP images.` + ); + console.error("File reading error:", error); + reject(new Error("Failed to read file")); + }; reader.onabort = () => { - console.warn("File reading aborted") - reject(new Error("File reading aborted")) - } + console.warn("File reading aborted"); + reject(new Error("File reading aborted")); + }; try { - if (base64Encoded) { - console.log("Reading file as Data URL...") - reader.readAsDataURL(file) + const extension = getFileExtension(file.name); + if ( + supportedMimeTypes.textFiles.includes(file.type) || + supportedMimeTypes.codeFiles.includes(file.type) || + codeFileExtensions.has(`.${extension}`) + ) { + console.log("Reading file as text..."); + reader.readAsText(file); + } else if (base64Encoded) { + console.log("Reading file as Data URL..."); + reader.readAsDataURL(file); } else { - console.log("Reading file as ArrayBuffer...") - reader.readAsArrayBuffer(file) + console.log("Reading file as ArrayBuffer..."); + reader.readAsArrayBuffer(file); } } catch (error) { - console.error("Error during file reading:", error) - reject(new Error(`Failed to read file: ${error instanceof Error ? error.message : "Unknown error"}`)) + console.error("Error during file reading:", error); + reject( + new Error( + `Failed to read file: ${error instanceof Error ? error.message : "Unknown error"}` + ) + ); } - }) -} + }); +}; export const isImageFile = (file: File): boolean => { - const supportedMimeTypes = ["image/png", "image/jpeg", "image/webp"] - const result = supportedMimeTypes.includes(file.type) - console.log(`File type check: ${file.type} - Is supported image: ${result}`) - return result -} + const result = supportedMimeTypes.images.includes(file.type); + console.log(`File type check: ${file.type} - Is supported image: ${result}`); + return result; +}; -export const validateFile = (file: File, maxSizeMB = 10): { valid: boolean; error?: string } => { - console.log(`Validating file: ${file.name}`) +export const isTextFile = (file: File): boolean => { + const extension = getFileExtension(file.name); + const result = + supportedMimeTypes.textFiles.includes(file.type) || + supportedMimeTypes.codeFiles.includes(file.type) || + codeFileExtensions.has(`.${extension}`); + console.log( + `File type check: ${file.type} - Is supported text/code file: ${result}` + ); + return result; +}; + +export const validateFile = ( + file: File, + maxSizeMB = 10 +): { valid: boolean; error?: string } => { + console.log(`Validating file: ${file.name}`); if (!file) { - console.error("No file provided for validation") - return { valid: false, error: "No file provided" } + console.error("No file provided for validation"); + return { valid: false, error: "No file provided" }; } - const maxSizeBytes = maxSizeMB * 1024 * 1024 + const maxSizeBytes = maxSizeMB * 1024 * 1024; if (file.size > maxSizeBytes) { - console.warn(`File size (${file.size} bytes) exceeds limit of ${maxSizeBytes} bytes`) + console.warn( + `File size (${file.size} bytes) exceeds limit of ${maxSizeBytes} bytes` + ); return { valid: false, error: `File size exceeds ${maxSizeMB}MB limit (${(file.size / 1024 / 1024).toFixed(2)}MB)`, - } + }; } - if (!isImageFile(file)) { - console.warn(`Unsupported file type: ${file.type}`) + const extension = getFileExtension(file.name); + const isSupported = + isImageFile(file) || + supportedMimeTypes.textFiles.includes(file.type) || + supportedMimeTypes.codeFiles.includes(file.type) || + codeFileExtensions.has(`.${extension}`); + + if (!isSupported) { + console.warn(`Unsupported file type: ${file.type}`); return { valid: false, - error: `Unsupported file type. Only PNG, JPEG,and WebP images are allowed.`, - } + error: `Unsupported file type. Only PNG, JPEG, WebP images, and text/code files are allowed.`, + }; } - console.log("File validation passed") - return { valid: true } -} - + console.log("File validation passed"); + return { valid: true }; +}; diff --git a/app/frontend/src/components/chatui/processUploadedFiles.tsx b/app/frontend/src/components/chatui/processUploadedFiles.tsx new file mode 100644 index 00000000..9a37f5e8 --- /dev/null +++ b/app/frontend/src/components/chatui/processUploadedFiles.tsx @@ -0,0 +1,42 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Β© 2025 Tenstorrent AI ULC + +import type { FileData } from "./types"; + +export function processUploadedFiles(files: FileData[]): FileData { + if (!files || files.length === 0) { + return {} as FileData; + } + + if (files.length === 1 || files[0].type !== "text") { + return files[0]; + } + const textFiles = files.filter((file) => file.type === "text" && file.text); + + if (textFiles.length === 0) { + return files[0]; + } + + const combinedText = textFiles.map((file) => file.text).join("\n\n"); + console.log(`Processed ${textFiles.length} text files into single content`); + + return { + type: "text", + text: combinedText, + name: `combined_${textFiles.length}_files.txt`, + } as FileData; +} + +// Usage example: +/* +const files = request.files +console.log("Uploaded files:", files) +const file = processUploadedFiles(files) + +// Now you can use 'file' as before +if (file.type === "text" && file.text) { + // Handle text content +} else if (file.image_url?.url || file) { + // Handle image +} +*/ diff --git a/app/frontend/src/components/chatui/runInference.ts b/app/frontend/src/components/chatui/runInference.ts index 125eb4d4..663aaf4a 100644 --- a/app/frontend/src/components/chatui/runInference.ts +++ b/app/frontend/src/components/chatui/runInference.ts @@ -11,13 +11,16 @@ import { getRagContext } from "./getRagContext"; import { generatePrompt } from "./templateRenderer"; import { v4 as uuidv4 } from "uuid"; import type React from "react"; +import { processUploadedFiles } from "./processUploadedFiles"; export const runInference = async ( request: InferenceRequest, ragDatasource: RagDataSource | undefined, chatHistory: ChatMessage[], setChatHistory: React.Dispatch>, - setIsStreaming: React.Dispatch> + setIsStreaming: React.Dispatch>, + isAgentSelected: boolean, + threadId: number ) => { try { setIsStreaming(true); @@ -34,24 +37,56 @@ export const runInference = async ( let messages; if (request.files && request.files.length > 0) { - console.log( - "Files detected, using image_url message structure", - request.files[0].image_url?.url - ); - messages = [ - { - role: "user", - content: [ - { type: "text", text: request.text || "What's in this image?" }, - { - type: "image_url", - image_url: { - url: request.files[0].image_url?.url || request.files[0], + const file = processUploadedFiles(request.files); + console.log("Processed file:", file); + + if (file.type === "text" && file.text) { + // Handle text file by treating its content as RAG context + console.log("Text file detected, processing as RAG context"); + const textContent = file.text; + console.log("Text content:", textContent); + + // Create a RAG context from the text file content + const fileRagContext = { + documents: [textContent], + }; + + // Merge with existing RAG context if any + if (ragContext) { + ragContext.documents = [ + ...ragContext.documents, + ...fileRagContext.documents, + ]; + } else { + ragContext = fileRagContext; + } + + // Process with RAG context + console.log("Processing with combined RAG context:", ragContext); + messages = generatePrompt( + chatHistory.map((msg) => ({ sender: msg.sender, text: msg.text })), + ragContext + ); + } else if (file.image_url?.url || file) { + console.log( + "Image file detected, using image_url message structure", + file.image_url?.url + ); + messages = [ + { + role: "user", + content: [ + { type: "text", text: request.text || "What's in this image?" }, + { + type: "image_url", + image_url: { + url: file.image_url?.url || file, + }, }, - }, - ], - }, - ]; + ], + }, + ]; + } } else if ( request.text && request.text.includes("https://") && @@ -95,8 +130,12 @@ export const runInference = async ( } console.log("Generated messages:", messages); + console.log("Thread ID: ", threadId); - const API_URL = import.meta.env.VITE_API_URL || "/models-api/inference/"; + const API_URL = isAgentSelected + ? import.meta.env.VITE_SPECIAL_API_URL || "/models-api/agent/" + : import.meta.env.VITE_API_URL || "/models-api/inference/"; + const AUTH_TOKEN = import.meta.env.VITE_AUTH_TOKEN || ""; const headers: Record = { @@ -106,16 +145,35 @@ export const runInference = async ( headers["Authorization"] = `Bearer ${AUTH_TOKEN}`; } - const requestBody = { - deploy_id: request.deploy_id, - messages: messages, - max_tokens: 512, - stream: true, - stream_options: { - include_usage: true, - }, - }; + let requestBody; + let threadIdStr = threadId.toString(); + if (!isAgentSelected) { + requestBody = { + deploy_id: request.deploy_id, + // model: "meta-llama/Llama-3.1-70B-Instruct", + messages: messages, + max_tokens: 512, + stream: true, + stream_options: { + include_usage: true, + }, + }; + } + else { + requestBody = { + deploy_id: request.deploy_id, + // model: "meta-llama/Llama-3.1-70B-Instruct", + messages: messages, + max_tokens: 512, + stream: true, + stream_options: { + include_usage: true, + }, + thread_id: threadIdStr, // Add thread_id to the request body + }; + } + console.log( "Sending request to model:", JSON.stringify(requestBody, null, 2) @@ -148,7 +206,6 @@ export const runInference = async ( let inferenceStats: InferenceStats | undefined; if (reader) { - // eslint-disable-next-line no-constant-condition while (true) { const { done, value } = await reader.read(); @@ -177,7 +234,8 @@ export const runInference = async ( try { const jsonData = JSON.parse(trimmedLine.slice(5)); - // Handle statistics separately after [DONE] + if (!isAgentSelected) { + // // Handle statistics separately after [DONE] if (jsonData.ttft && jsonData.tpot) { inferenceStats = { user_ttft_s: jsonData.ttft, @@ -187,10 +245,9 @@ export const runInference = async ( context_length: jsonData.context_length, }; console.log("Final Inference Stats received:", inferenceStats); - continue; // Skip processing this chunk as part of the generated text + continue; } - - // Handle the generated text + } const content = jsonData.choices[0]?.delta?.content || ""; if (content) { accumulatedText += content; @@ -215,7 +272,6 @@ export const runInference = async ( console.log("Inference stream ended."); setIsStreaming(false); - // Update chat history with inference stats after streaming is fully completed if (inferenceStats) { console.log( "Updating chat history with inference stats:", diff --git a/app/frontend/src/components/chatui/types.ts b/app/frontend/src/components/chatui/types.ts index df73343b..f7c4fa89 100644 --- a/app/frontend/src/components/chatui/types.ts +++ b/app/frontend/src/components/chatui/types.ts @@ -1,35 +1,71 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Β© 2024 Tenstorrent AI ULC +// File and Media Types +export interface ImageUrl { + url: string; + detail?: string; +} -export interface InputAreaProps { - textInput: string - setTextInput: React.Dispatch> - handleInference: (input: string, files: FileData[]) => void - isStreaming: boolean - isListening: boolean - setIsListening: (isListening: boolean) => void - files: FileData[] - setFiles: React.Dispatch> +export interface FileData { + id?: string; + name: string; + type: "text" | "image_url" | "document" | "audio" | "video"; + size?: number; + created_at?: string; + blob?: Blob; + url?: string; + mime_type?: string; + duration?: number; + thumbnail_url?: string; + + // Type-specific fields + text?: string; + image_url?: ImageUrl; + document_url?: string; + audio_url?: string; + video_url?: string; } + +// Chat and Message Types export interface ChatMessage { id: string; sender: "user" | "assistant"; text: string; files?: FileData[]; inferenceStats?: InferenceStats; -} -export interface FileData { - url: any; - type: "text" | "image_url"; - text?: string; - image_url?: { - url: string; - }; + ragDatasource?: RagDataSource; +} + +export type MessageContent = + | { + type: "text"; + text: string; + } + | { + type: "image_url"; + image_url: { + url: string; + }; + }; + +export type InferenceMessage = { + role: "user" | "assistant"; + content: MessageContent[]; +}; + +// RAG Types +export interface RagDataSource { + id: string; name: string; - blob?: Blob; + metadata?: { + created_at?: string; + embedding_func_name?: string; + last_uploaded_document?: string; + }; } +// Inference Types export interface InferenceRequest { deploy_id: string; text: string; @@ -37,40 +73,32 @@ export interface InferenceRequest { files?: FileData[]; } -export interface FileData { - type: "text" | "image_url"; - text?: string; - image_url?: { - url: string; - }; - name: string; - blob?: Blob; -} - -export interface RagDataSource { - id: string; - name: string; - metadata: Record; -} - -export interface ChatMessage { - id: string; - sender: "user" | "assistant"; - text: string; - inferenceStats?: InferenceStats; -} - -export interface Model { - id: string; - name: string; -} - export interface InferenceStats { - user_ttft_s: number; // Time to First Token in seconds - user_tpot: number; // Time Per Output Token in seconds - tokens_decoded: number; // Number of tokens decoded - tokens_prefilled: number; // Number of tokens prefilled - context_length: number; // Context length + user_ttft_s?: number; + user_tpot?: number; + tokens_decoded?: number; + tokens_prefilled?: number; + context_length?: number; + startTime?: string; + endTime?: string; + totalDuration?: number; + tokensPerSecond?: number; + promptTokens?: number; + completionTokens?: number; + totalTokens?: number; + total_time_ms?: number; +} + +// Component Props Types +export interface InputAreaProps { + textInput: string; + setTextInput: React.Dispatch>; + handleInference: (input: string, files: FileData[]) => void; + isStreaming: boolean; + isListening: boolean; + setIsListening: (isListening: boolean) => void; + files: FileData[]; + setFiles: React.Dispatch>; } export interface InferenceStatsProps { @@ -82,7 +110,26 @@ export interface StreamingMessageProps { isStreamFinished: boolean; } -// Voice input types +export interface HistoryPanelProps { + chatHistory: ChatMessage[][]; + onSelectThread: (index: number) => void; + onDeleteThread: (index: number) => void; + onCreateNewThread: () => void; +} + +export interface FileDisplayProps { + files: FileData[]; + minimizedFiles: Set; + toggleMinimizeFile: (fileId: string) => void; + onFileClick: (fileUrl: string, fileName: string) => void; +} + +export interface FileViewerDialogProps { + file: { url: string; name: string; isImage: boolean } | null; + onClose: () => void; +} + +// Voice Input Types export interface SpeechRecognitionAlternative { transcript: string; confidence: number; @@ -126,13 +173,19 @@ export interface VoiceInputProps { setIsListening: (isListening: boolean) => void; } -export interface HistoryPanelProps { - chatHistory: ChatMessage[][]; - onSelectThread: (index: number) => void; - onDeleteThread: (index: number) => void; - onCreateNewThread: () => void; +// Model Types +export interface Model { + id?: string; + containerID?: string; + name?: string; + modelName?: string; + modelSize?: string; + baseModel?: string; + task?: string; + status?: string; } +// Global Type Declarations declare global { interface Window { SpeechRecognition: new () => SpeechRecognition;
Preview not available
- Attached Images: -
Minimize image
Maximize image
+ {formatFileSize(file.size)} +
{selectedAIAgent ? "Change or remove AI agent" : "Select AI Agent"}
Attach files or drag and drop
Voice input
Attach files (1 image max)