Skip to content

Commit

Permalink
Merge pull request #95 from Cloud-Code-AI/46-kokoro-tts-doesnt-work-o…
Browse files Browse the repository at this point in the history
…n-webgpu

46 kokoro tts doesnt work on webgpu
  • Loading branch information
sauravpanda authored Feb 8, 2025
2 parents 12f52ec + 18d3709 commit 562c262
Show file tree
Hide file tree
Showing 21 changed files with 403 additions and 223 deletions.
2 changes: 1 addition & 1 deletion examples/tts-demo/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"preview": "vite preview"
},
"dependencies": {
"@browserai/browserai": "^1.0.18",
"@browserai/browserai": "^1.0.24",
"@emotion/styled": "^11.14.0",
"react": "^18.3.1",
"react-dom": "^18.3.1"
Expand Down
6 changes: 3 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@browserai/browserai",
"version": "1.0.23",
"version": "1.0.24",
"private": false,
"description": "A library for running AI models directly in the browser",
"main": "dist/index.js",
Expand Down Expand Up @@ -36,9 +36,9 @@
},
"homepage": "https://github.com/sauravpanda/BrowserAI#readme",
"dependencies": {
"@huggingface/jinja": "^0.3.2",
"@huggingface/jinja": "^0.3.3",
"@mlc-ai/web-llm": "^0.2.78",
"onnxruntime-web": "1.21.0-dev.20250114-228dd16893",
"onnxruntime-web": "1.21.0-dev.20250206-d981b153d3",
"phonemizer": "^1.2.1"
},
"publishConfig": {
Expand Down
15 changes: 13 additions & 2 deletions src/config/models/transformers-models.json
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,23 @@
"engine": "transformers",
"modelName": "kokoro-tts",
"modelType": "text-to-speech",
"repo": "onnx-community/Kokoro-82M-ONNX",
"repo": "onnx-community/Kokoro-82M-v1.0-ONNX",
"pipeline": "text-to-speech",
"defaultQuantization": "q4",
"defaultQuantization": "fp32",
"quantizations": ["q4", "q8", "fp32", "fp16", "q4f16"],
"defaultParams": {
"language": "en"
}
},
"janus-1.3b": {
"engine": "transformers",
"modelName": "Janus-1.3B-ONNX",
"modelType": "multimodal",
"repo": "onnx-community/Janus-1.3B-ONNX",
"pipeline": "multimodal",
"defaultQuantization": "q4f16",
"quantizations": ["q4", "q8", "fp32", "fp16", "q4f16"],
"defaultParams": {
}
}
}
1 change: 1 addition & 0 deletions src/config/models/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export type ModelType =
| 'sentiment-analysis'
| 'feature-extraction'
| 'automatic-speech-recognition'
| 'multimodal'
| 'text-to-speech';

export interface MLCConfig extends BaseModelConfig {
Expand Down
19 changes: 19 additions & 0 deletions src/core/llm/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -179,4 +179,23 @@ export class BrowserAI {
throw error;
}
}

async generateImage(text: string, options: Record<string, unknown> = {}): Promise<string> {
if (!this.modelIdentifier) {
throw new Error('No model loaded. Please call loadModel first.');
}

if (this.currentModel?.modelType !== 'multimodal') {
throw new Error('Current model does not support multimodal inputs.');
}

if (this.engine instanceof TransformersEngineWrapper) {
const response = await this.engine.generateImage({
text: text as string,
}, options);
return response;
}

throw new Error('Current engine does not support multimodal generation');
}
}
15 changes: 10 additions & 5 deletions src/engines/mlc-engine-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,14 @@ export class MLCEngineWrapper {
throw new Error('MLC Engine not initialized.');
}

let messages = Array.isArray(input) ? input : [];
// Initialize messages array regardless of input type
let messages: Record<string, any>[] = [];

// If input is a string, construct messages array
if (typeof input === 'string') {
// If input is an array, use it directly
if (Array.isArray(input)) {
messages = input;
} else if (typeof input === 'string') {
// If input is a string, construct messages array
if (options.system_prompt) {
messages.push({ role: 'system', content: options.system_prompt });
}
Expand All @@ -65,9 +69,10 @@ export class MLCEngineWrapper {
options.presence_penalty = options.presence_penalty || 0.5;
if (options.stream) {
options.stream_options = { include_usage: true };
return this.mlcEngine.chat.completions.create({ messages, ...options });
return this.mlcEngine.chat.completions.create({ messages: messages as any, ...options });
}
const result = await this.mlcEngine.chat.completions.create({ messages, ...options });
console.log(messages);
const result = await this.mlcEngine.chat.completions.create({ messages: messages as any, ...options });
return result.choices[0].message.content;
}

Expand Down
51 changes: 50 additions & 1 deletion src/engines/transformer-engine-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import {
} from '../libs/transformers/transformers';
import { ModelConfig } from '../config/models/types';
import { TTSEngine } from './tts-engine';
import { AutoProcessor, MultiModalityCausalLM } from '../libs/transformers/transformers';

export class TransformersEngineWrapper {
private transformersPipeline:
Expand All @@ -24,6 +25,8 @@ export class TransformersEngineWrapper {
| null = null;
private modelType: string | null = null;
private ttsEngine: TTSEngine | null = null;
private imageProcessor: any | null = null;
private multimodalModel: any | null = null;

constructor() {
this.transformersPipeline = null;
Expand All @@ -46,7 +49,7 @@ export class TransformersEngineWrapper {

// Configure pipeline options with proper worker settings
const pipelineOptions = {
progress_callback: options.onProgress,
progress_callback: options.onProgress || (() => {}),
...options
};

Expand All @@ -59,6 +62,17 @@ export class TransformersEngineWrapper {
return; // Exit early for TTS models
}

// Initialize image processor for multimodal models
if (modelConfig.modelType === 'multimodal') {
options.device = "webgpu";
// console.log('Loading multimodal model...');
this.imageProcessor = await AutoProcessor.from_pretrained(modelConfig.repo, pipelineOptions);
// console.log('Image processor loaded');
this.multimodalModel = await MultiModalityCausalLM.from_pretrained(modelConfig.repo, pipelineOptions);
// console.log('Multimodal model loaded');
return;
}

// For non-TTS models, create the appropriate pipeline
const pipelineType = modelConfig.pipeline as PipelineType;
this.transformersPipeline = await pipeline(pipelineType, modelConfig.repo, pipelineOptions);
Expand Down Expand Up @@ -168,4 +182,39 @@ export class TransformersEngineWrapper {
throw new Error('Feature extraction pipeline not initialized.');
}
}

async generateImage(input: { text: string }, options: any = {}) {
if (this.modelType !== 'multimodal') {
throw new Error('Multimodal model not initialized.');
}

if (!this.imageProcessor || !this.multimodalModel) {
throw new Error('Image processor or multimodal model not initialized.');
}

try {
const conversation = [{ 'role': 'user', 'content': input.text }];

// Process the input text with the image processor
const inputs = await this.imageProcessor(conversation, {
chat_template: "text_to_image",
...options
});

// Generate the image
const num_image_tokens = this.imageProcessor.num_image_tokens;

const outputs = await this.multimodalModel.generate({
...inputs,
min_new_tokens: num_image_tokens,
max_new_tokens: num_image_tokens,
do_sample: true,
});

return outputs;
} catch (error) {
console.error('Error generating image:', error);
throw error;
}
}
}
13 changes: 7 additions & 6 deletions src/engines/tts-engine.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { StyleTextToSpeech2Model, AutoTokenizer, RawAudio, Tensor } from "../libs/transformers/transformers";
import { StyleTextToSpeech2Model, AutoTokenizer, Tensor } from "../libs/transformers/transformers";
import { ModelConfig } from '../config/models/types';
import { phonemize } from "../libs/transformers/utils/phonemize";
import { getVoiceData, VOICES } from "../libs/transformers/utils/voices";
Expand All @@ -17,10 +17,11 @@ export class TTSEngine {
}

async loadModel(modelConfig: ModelConfig, options: any = {}) {
// console.log('Loading TTS model... ', modelConfig.repo, options);
try {
this.model = await StyleTextToSpeech2Model.from_pretrained(modelConfig.repo, {
progress_callback: options.onProgress,
dtype: options.dtype || "q4",
dtype: options.dtype || "fp32",
device: "webgpu",
});

Expand All @@ -47,7 +48,7 @@ export class TTSEngine {
}

try {
const language = voice.at(0); // "a" or "b"
const language = (voice.at(0)); // "a" or "b"
const phonemes = await phonemize(text, language);
// console.log('Phonemes:', phonemes); // Debug log

Expand All @@ -56,10 +57,10 @@ export class TTSEngine {
});

// Select voice style based on number of input tokens
const num_tokens = Math.max(
const num_tokens = Math.min(Math.max(
input_ids.dims.at(-1) - 2, // Without padding
0,
);
), 509);

// Load voice style
const data = await getVoiceData(voice);
Expand All @@ -69,7 +70,7 @@ export class TTSEngine {

// Prepare model inputs
const inputs = {
input_ids,
input_ids: input_ids,
style: new Tensor("float32", voiceData, [1, STYLE_DIM]),
speed: new Tensor("float32", [speed], [1]),
};
Expand Down
16 changes: 11 additions & 5 deletions src/libs/transformers/backends/onnx.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,23 @@ const supportedDevices: string[] = [];

/** @type {ONNXExecutionProviders[]} */
let defaultDevices: string[] = [];
// Simplified initialization - removed ORT_SYMBOL check since we're only using web runtime


// Then add WebNN support
if (apis.IS_WEBNN_AVAILABLE) {
supportedDevices.push('webnn-npu', 'webnn-gpu', 'webnn-cpu', 'webnn');
supportedDevices.push('webnn-gpu', 'webnn-cpu');
}

// Add WebGPU as an option if available
if (apis.IS_WEBGPU_AVAILABLE) {
supportedDevices.push('webgpu');
// Add WASM as fallback when WebGPU fails
supportedDevices.push('wasm', 'webgpu');
}

supportedDevices.push('wasm');
defaultDevices = ['wasm'];
// Remove the previous "Always keep WASM as fallback" section since we've integrated it above
if (defaultDevices.length === 0) {
defaultDevices = ['wasm'];
}

const InferenceSession = ONNX.InferenceSession;

Expand Down
2 changes: 1 addition & 1 deletion src/libs/transformers/base/image_processors_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function enforce_size_divisibility(size: [number, number], divisor: number) {
* @param {number[]} arr The coordinate for the center of the box and its width, height dimensions (center_x, center_y, width, height)
* @returns {number[]} The coodinates for the top-left and bottom-right corners of the box (top_left_x, top_left_y, bottom_right_x, bottom_right_y)
*/
function center_to_corners_format(centerX: number, centerY: number, width: number, height: number) {
export function center_to_corners_format(centerX: number, centerY: number, width: number, height: number) {
return [centerX - width / 2, centerY - height / 2, centerX + width / 2, centerY + height / 2];
}

Expand Down
11 changes: 11 additions & 0 deletions src/libs/transformers/base/processing_utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,17 @@ export class Processor extends Callable {
return this.tokenizer.batch_decode(...args);
}

/**
* @param {Parameters<PreTrainedTokenizer['decode']>} args
* @returns {ReturnType<PreTrainedTokenizer['decode']>}
*/
decode(...args: any[]) {
if (!this.tokenizer) {
throw new Error('Unable to decode without a tokenizer.');
}
return this.tokenizer.decode(...args);
}

/**
* Calls the feature_extractor function with the given input.
* @param {any} input The input to extract features from.
Expand Down
2 changes: 1 addition & 1 deletion src/libs/transformers/env.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import path from 'path';
import url from 'url';

const VERSION = '3.3.1';
const VERSION = '3.3.3';

// Check if various APIs are available (depends on environment)
const IS_BROWSER_ENV = typeof window !== 'undefined' && typeof window.document !== 'undefined';
Expand Down
9 changes: 6 additions & 3 deletions src/libs/transformers/generation/streamers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ export class TextStreamer extends BaseStreamer {
tokenizer: PreTrainedTokenizer;
skip_prompt: boolean;
callback_function: (x: string) => void;
token_callback_function: (x: bigint[]) => void;
token_callback_function: any;
decode_kwargs: any;
print_len: number;
next_tokens_are_prompt: boolean;
Expand All @@ -54,6 +54,7 @@ export class TextStreamer extends BaseStreamer {
skip_prompt = false,
callback_function = null,
token_callback_function = null,
skip_special_tokens = true,
decode_kwargs = {},
...kwargs
} = {},
Expand All @@ -63,7 +64,8 @@ export class TextStreamer extends BaseStreamer {
this.skip_prompt = skip_prompt;
this.callback_function = callback_function ?? stdout_write;
this.token_callback_function = token_callback_function;
this.decode_kwargs = { ...decode_kwargs, ...kwargs };
this.decode_kwargs = { skip_special_tokens, ...decode_kwargs, ...kwargs };


// variables used in the streaming process
this.token_cache = [];
Expand Down Expand Up @@ -186,9 +188,10 @@ export class WhisperTextStreamer extends TextStreamer {
) {
super(tokenizer, {
skip_prompt,
skip_special_tokens,
callback_function,
token_callback_function,
decode_kwargs: { skip_special_tokens, ...decode_kwargs },
decode_kwargs,
});
this.timestamp_begin = tokenizer.timestamp_begin;

Expand Down
Loading

0 comments on commit 562c262

Please sign in to comment.