From cd1b6c46c71f4a02c18ce45010877fc771627df6 Mon Sep 17 00:00:00 2001 From: Saurav Panda Date: Sat, 8 Feb 2025 13:45:39 -0800 Subject: [PATCH 1/2] fix: tts on webgpu --- package.json | 4 +- src/config/models/transformers-models.json | 15 +- src/config/models/types.ts | 1 + src/core/llm/index.ts | 19 + src/engines/mlc-engine-wrapper.ts | 15 +- src/engines/transformer-engine-wrapper.ts | 51 ++- src/engines/tts-engine.ts | 13 +- src/libs/transformers/backends/onnx.ts | 16 +- .../base/image_processors_utils.ts | 2 +- .../transformers/base/processing_utils.ts | 11 + src/libs/transformers/env.ts | 2 +- src/libs/transformers/generation/streamers.ts | 9 +- src/libs/transformers/models.ts | 381 ++++++++++-------- .../models/auto/image_processing_auto.ts | 17 +- .../models/auto/processing_auto.ts | 4 +- src/libs/transformers/models/processors.ts | 2 +- src/libs/transformers/ops/registry.ts | 6 +- src/libs/transformers/tokenizers.ts | 10 +- src/libs/transformers/utils/hub.ts | 2 +- src/libs/transformers/utils/tensor.ts | 42 +- 20 files changed, 401 insertions(+), 221 deletions(-) diff --git a/package.json b/package.json index 39a4b6f..9453a9c 100644 --- a/package.json +++ b/package.json @@ -36,9 +36,9 @@ }, "homepage": "https://github.com/sauravpanda/BrowserAI#readme", "dependencies": { - "@huggingface/jinja": "^0.3.2", + "@huggingface/jinja": "^0.3.3", "@mlc-ai/web-llm": "^0.2.78", - "onnxruntime-web": "1.21.0-dev.20250114-228dd16893", + "onnxruntime-web": "1.21.0-dev.20250206-d981b153d3", "phonemizer": "^1.2.1" }, "publishConfig": { diff --git a/src/config/models/transformers-models.json b/src/config/models/transformers-models.json index 772e7e4..fb56a5f 100644 --- a/src/config/models/transformers-models.json +++ b/src/config/models/transformers-models.json @@ -24,12 +24,23 @@ "engine": "transformers", "modelName": "kokoro-tts", "modelType": "text-to-speech", - "repo": "onnx-community/Kokoro-82M-ONNX", + "repo": "onnx-community/Kokoro-82M-v1.0-ONNX", "pipeline": "text-to-speech", - "defaultQuantization": "q4", + "defaultQuantization": "fp32", "quantizations": ["q4", "q8", "fp32", "fp16", "q4f16"], "defaultParams": { "language": "en" } + }, + "janus-1.3b": { + "engine": "transformers", + "modelName": "Janus-1.3B-ONNX", + "modelType": "multimodal", + "repo": "onnx-community/Janus-1.3B-ONNX", + "pipeline": "multimodal", + "defaultQuantization": "q4f16", + "quantizations": ["q4", "q8", "fp32", "fp16", "q4f16"], + "defaultParams": { + } } } \ No newline at end of file diff --git a/src/config/models/types.ts b/src/config/models/types.ts index 95c2454..03f5712 100644 --- a/src/config/models/types.ts +++ b/src/config/models/types.ts @@ -18,6 +18,7 @@ export type ModelType = | 'sentiment-analysis' | 'feature-extraction' | 'automatic-speech-recognition' + | 'multimodal' | 'text-to-speech'; export interface MLCConfig extends BaseModelConfig { diff --git a/src/core/llm/index.ts b/src/core/llm/index.ts index da657be..8d87ac6 100644 --- a/src/core/llm/index.ts +++ b/src/core/llm/index.ts @@ -179,4 +179,23 @@ export class BrowserAI { throw error; } } + + async generateImage(text: string, options: Record = {}): Promise { + if (!this.modelIdentifier) { + throw new Error('No model loaded. Please call loadModel first.'); + } + + if (this.currentModel?.modelType !== 'multimodal') { + throw new Error('Current model does not support multimodal inputs.'); + } + + if (this.engine instanceof TransformersEngineWrapper) { + const response = await this.engine.generateImage({ + text: text as string, + }, options); + return response; + } + + throw new Error('Current engine does not support multimodal generation'); + } } diff --git a/src/engines/mlc-engine-wrapper.ts b/src/engines/mlc-engine-wrapper.ts index 543b70b..83e343b 100644 --- a/src/engines/mlc-engine-wrapper.ts +++ b/src/engines/mlc-engine-wrapper.ts @@ -47,10 +47,14 @@ export class MLCEngineWrapper { throw new Error('MLC Engine not initialized.'); } - let messages = Array.isArray(input) ? input : []; + // Initialize messages array regardless of input type + let messages: Record[] = []; - // If input is a string, construct messages array - if (typeof input === 'string') { + // If input is an array, use it directly + if (Array.isArray(input)) { + messages = input; + } else if (typeof input === 'string') { + // If input is a string, construct messages array if (options.system_prompt) { messages.push({ role: 'system', content: options.system_prompt }); } @@ -65,9 +69,10 @@ export class MLCEngineWrapper { options.presence_penalty = options.presence_penalty || 0.5; if (options.stream) { options.stream_options = { include_usage: true }; - return this.mlcEngine.chat.completions.create({ messages, ...options }); + return this.mlcEngine.chat.completions.create({ messages: messages as any, ...options }); } - const result = await this.mlcEngine.chat.completions.create({ messages, ...options }); + console.log(messages); + const result = await this.mlcEngine.chat.completions.create({ messages: messages as any, ...options }); return result.choices[0].message.content; } diff --git a/src/engines/transformer-engine-wrapper.ts b/src/engines/transformer-engine-wrapper.ts index ef76aa7..ed5b27d 100644 --- a/src/engines/transformer-engine-wrapper.ts +++ b/src/engines/transformer-engine-wrapper.ts @@ -11,6 +11,7 @@ import { } from '../libs/transformers/transformers'; import { ModelConfig } from '../config/models/types'; import { TTSEngine } from './tts-engine'; +import { AutoProcessor, MultiModalityCausalLM } from '../libs/transformers/transformers'; export class TransformersEngineWrapper { private transformersPipeline: @@ -24,6 +25,8 @@ export class TransformersEngineWrapper { | null = null; private modelType: string | null = null; private ttsEngine: TTSEngine | null = null; + private imageProcessor: any | null = null; + private multimodalModel: any | null = null; constructor() { this.transformersPipeline = null; @@ -46,7 +49,7 @@ export class TransformersEngineWrapper { // Configure pipeline options with proper worker settings const pipelineOptions = { - progress_callback: options.onProgress, + progress_callback: options.onProgress || (() => {}), ...options }; @@ -59,6 +62,17 @@ export class TransformersEngineWrapper { return; // Exit early for TTS models } + // Initialize image processor for multimodal models + if (modelConfig.modelType === 'multimodal') { + options.device = "webgpu"; + // console.log('Loading multimodal model...'); + this.imageProcessor = await AutoProcessor.from_pretrained(modelConfig.repo, pipelineOptions); + // console.log('Image processor loaded'); + this.multimodalModel = await MultiModalityCausalLM.from_pretrained(modelConfig.repo, pipelineOptions); + // console.log('Multimodal model loaded'); + return; + } + // For non-TTS models, create the appropriate pipeline const pipelineType = modelConfig.pipeline as PipelineType; this.transformersPipeline = await pipeline(pipelineType, modelConfig.repo, pipelineOptions); @@ -168,4 +182,39 @@ export class TransformersEngineWrapper { throw new Error('Feature extraction pipeline not initialized.'); } } + + async generateImage(input: { text: string }, options: any = {}) { + if (this.modelType !== 'multimodal') { + throw new Error('Multimodal model not initialized.'); + } + + if (!this.imageProcessor || !this.multimodalModel) { + throw new Error('Image processor or multimodal model not initialized.'); + } + + try { + const conversation = [{ 'role': 'user', 'content': input.text }]; + + // Process the input text with the image processor + const inputs = await this.imageProcessor(conversation, { + chat_template: "text_to_image", + ...options + }); + + // Generate the image + const num_image_tokens = this.imageProcessor.num_image_tokens; + + const outputs = await this.multimodalModel.generate({ + ...inputs, + min_new_tokens: num_image_tokens, + max_new_tokens: num_image_tokens, + do_sample: true, + }); + + return outputs; + } catch (error) { + console.error('Error generating image:', error); + throw error; + } + } } diff --git a/src/engines/tts-engine.ts b/src/engines/tts-engine.ts index ab54613..0f70f9d 100644 --- a/src/engines/tts-engine.ts +++ b/src/engines/tts-engine.ts @@ -1,4 +1,4 @@ -import { StyleTextToSpeech2Model, AutoTokenizer, RawAudio, Tensor } from "../libs/transformers/transformers"; +import { StyleTextToSpeech2Model, AutoTokenizer, Tensor } from "../libs/transformers/transformers"; import { ModelConfig } from '../config/models/types'; import { phonemize } from "../libs/transformers/utils/phonemize"; import { getVoiceData, VOICES } from "../libs/transformers/utils/voices"; @@ -17,10 +17,11 @@ export class TTSEngine { } async loadModel(modelConfig: ModelConfig, options: any = {}) { + // console.log('Loading TTS model... ', modelConfig.repo, options); try { this.model = await StyleTextToSpeech2Model.from_pretrained(modelConfig.repo, { progress_callback: options.onProgress, - dtype: options.dtype || "q4", + dtype: options.dtype || "fp32", device: "webgpu", }); @@ -47,7 +48,7 @@ export class TTSEngine { } try { - const language = voice.at(0); // "a" or "b" + const language = (voice.at(0)); // "a" or "b" const phonemes = await phonemize(text, language); // console.log('Phonemes:', phonemes); // Debug log @@ -56,10 +57,10 @@ export class TTSEngine { }); // Select voice style based on number of input tokens - const num_tokens = Math.max( + const num_tokens = Math.min(Math.max( input_ids.dims.at(-1) - 2, // Without padding 0, - ); + ), 509); // Load voice style const data = await getVoiceData(voice); @@ -69,7 +70,7 @@ export class TTSEngine { // Prepare model inputs const inputs = { - input_ids, + input_ids: input_ids, style: new Tensor("float32", voiceData, [1, STYLE_DIM]), speed: new Tensor("float32", [speed], [1]), }; diff --git a/src/libs/transformers/backends/onnx.ts b/src/libs/transformers/backends/onnx.ts index ef6930e..a04fdc1 100644 --- a/src/libs/transformers/backends/onnx.ts +++ b/src/libs/transformers/backends/onnx.ts @@ -51,17 +51,23 @@ const supportedDevices: string[] = []; /** @type {ONNXExecutionProviders[]} */ let defaultDevices: string[] = []; -// Simplified initialization - removed ORT_SYMBOL check since we're only using web runtime + + +// Then add WebNN support if (apis.IS_WEBNN_AVAILABLE) { - supportedDevices.push('webnn-npu', 'webnn-gpu', 'webnn-cpu', 'webnn'); + supportedDevices.push('webnn-gpu', 'webnn-cpu'); } +// Add WebGPU as an option if available if (apis.IS_WEBGPU_AVAILABLE) { - supportedDevices.push('webgpu'); + // Add WASM as fallback when WebGPU fails + supportedDevices.push('wasm', 'webgpu'); } -supportedDevices.push('wasm'); -defaultDevices = ['wasm']; +// Remove the previous "Always keep WASM as fallback" section since we've integrated it above +if (defaultDevices.length === 0) { + defaultDevices = ['wasm']; +} const InferenceSession = ONNX.InferenceSession; diff --git a/src/libs/transformers/base/image_processors_utils.ts b/src/libs/transformers/base/image_processors_utils.ts index 837cbfe..1d63e7c 100644 --- a/src/libs/transformers/base/image_processors_utils.ts +++ b/src/libs/transformers/base/image_processors_utils.ts @@ -61,7 +61,7 @@ function enforce_size_divisibility(size: [number, number], divisor: number) { * @param {number[]} arr The coordinate for the center of the box and its width, height dimensions (center_x, center_y, width, height) * @returns {number[]} The coodinates for the top-left and bottom-right corners of the box (top_left_x, top_left_y, bottom_right_x, bottom_right_y) */ -function center_to_corners_format(centerX: number, centerY: number, width: number, height: number) { +export function center_to_corners_format(centerX: number, centerY: number, width: number, height: number) { return [centerX - width / 2, centerY - height / 2, centerX + width / 2, centerY + height / 2]; } diff --git a/src/libs/transformers/base/processing_utils.ts b/src/libs/transformers/base/processing_utils.ts index 23375a3..583d95a 100644 --- a/src/libs/transformers/base/processing_utils.ts +++ b/src/libs/transformers/base/processing_utils.ts @@ -99,6 +99,17 @@ export class Processor extends Callable { return this.tokenizer.batch_decode(...args); } + /** + * @param {Parameters} args + * @returns {ReturnType} + */ + decode(...args: any[]) { + if (!this.tokenizer) { + throw new Error('Unable to decode without a tokenizer.'); + } + return this.tokenizer.decode(...args); + } + /** * Calls the feature_extractor function with the given input. * @param {any} input The input to extract features from. diff --git a/src/libs/transformers/env.ts b/src/libs/transformers/env.ts index 17cf9b2..bc31d49 100644 --- a/src/libs/transformers/env.ts +++ b/src/libs/transformers/env.ts @@ -25,7 +25,7 @@ import path from 'path'; import url from 'url'; -const VERSION = '3.3.1'; +const VERSION = '3.3.3'; // Check if various APIs are available (depends on environment) const IS_BROWSER_ENV = typeof window !== 'undefined' && typeof window.document !== 'undefined'; diff --git a/src/libs/transformers/generation/streamers.ts b/src/libs/transformers/generation/streamers.ts index 931ced7..9c69a11 100644 --- a/src/libs/transformers/generation/streamers.ts +++ b/src/libs/transformers/generation/streamers.ts @@ -42,7 +42,7 @@ export class TextStreamer extends BaseStreamer { tokenizer: PreTrainedTokenizer; skip_prompt: boolean; callback_function: (x: string) => void; - token_callback_function: (x: bigint[]) => void; + token_callback_function: any; decode_kwargs: any; print_len: number; next_tokens_are_prompt: boolean; @@ -54,6 +54,7 @@ export class TextStreamer extends BaseStreamer { skip_prompt = false, callback_function = null, token_callback_function = null, + skip_special_tokens = true, decode_kwargs = {}, ...kwargs } = {}, @@ -63,7 +64,8 @@ export class TextStreamer extends BaseStreamer { this.skip_prompt = skip_prompt; this.callback_function = callback_function ?? stdout_write; this.token_callback_function = token_callback_function; - this.decode_kwargs = { ...decode_kwargs, ...kwargs }; + this.decode_kwargs = { skip_special_tokens, ...decode_kwargs, ...kwargs }; + // variables used in the streaming process this.token_cache = []; @@ -186,9 +188,10 @@ export class WhisperTextStreamer extends TextStreamer { ) { super(tokenizer, { skip_prompt, + skip_special_tokens, callback_function, token_callback_function, - decode_kwargs: { skip_special_tokens, ...decode_kwargs }, + decode_kwargs, }); this.timestamp_begin = tokenizer.timestamp_begin; diff --git a/src/libs/transformers/models.ts b/src/libs/transformers/models.ts index 9726d92..2d47e5f 100644 --- a/src/libs/transformers/models.ts +++ b/src/libs/transformers/models.ts @@ -220,7 +220,7 @@ async function getSession(pretrained_model_name_or_path: string, fileName: strin } else if (selectedDevice.startsWith('webnn') && !session_options.freeDimensionOverrides) { console.warn( 'WebNN does not currently support dynamic shapes and requires `free_dimension_overrides` to be set in config.json as a field within "transformers.js_config". ' + - 'When `free_dimension_overrides` is not set, you may experience significant performance degradation.', + 'When `free_dimension_overrides` is not set, you may experience significant performance degradation.', ); } @@ -387,6 +387,12 @@ function validateInputs(session: any, inputs: any) { * @private */ async function sessionRun(session: any, inputs: any) { + console.log(`Running session ${session.path || 'unknown'}:`, { + inputNames: session.inputNames, + inputShapes: Object.fromEntries( + Object.entries(inputs).map(([k, v]) => [k, (v as any).dims]) + ) + }); const checkedInputs = validateInputs(session, inputs); try { // pass the original ort tensor @@ -394,6 +400,9 @@ async function sessionRun(session: any, inputs: any) { Object.entries(checkedInputs).map(([k, v]) => [k, (v as Tensor).ort_tensor]), ); let output = await session.run(ortFeed); + console.log('Session run successful:', { + outputNames: Object.keys(output) + }); output = replaceTensors(output); return output; } catch (e) { @@ -411,6 +420,15 @@ async function sessionRun(session: any, inputs: any) { ]), ); + console.error('Session run failed:', { + e, + session: session.path, + inputNames: session.inputNames, + inputShapes: Object.fromEntries( + Object.entries(inputs).map(([k, v]) => [k, (v as any).dims]) + ) + }); + // This usually occurs when the inputs are of the wrong type. console.error(`An error occurred during model execution: "${e}".`); console.error('Inputs given to model:', formatted); @@ -527,11 +545,28 @@ async function encoderForward(self: any, model_inputs: any) { if (session.inputNames.includes('token_type_ids') && !encoderFeeds.token_type_ids) { // Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it, // but they weren't created by the tokenizer. - encoderFeeds.token_type_ids = new Tensor( - 'int64', - new BigInt64Array(encoderFeeds.input_ids.data.length), - encoderFeeds.input_ids.dims, - ); + // encoderFeeds.token_type_ids = new Tensor( + // 'int64', + // new BigInt64Array(encoderFeeds.input_ids.data.length), + // encoderFeeds.input_ids.dims, + // ); + + if (!encoderFeeds.input_ids) { + throw new Error('Both `input_ids` and `token_type_ids` are missing in the model inputs.'); + } + // Assign default `token_type_ids` (all zeroes) to the `encoderFeeds` if the model expects it, + // but they weren't created by the tokenizer. + encoderFeeds.token_type_ids = zeros_like(encoderFeeds.input_ids); + } + + if (session.inputNames.includes('pixel_mask') && !encoderFeeds.pixel_mask) { + if (!encoderFeeds.pixel_values) { + throw new Error('Both `pixel_values` and `pixel_mask` are missing in the model inputs.'); + } + // Assign default `pixel_mask` (all ones) to the `encoderFeeds` if the model expects it, + // but they weren't created by the processor. + const dims = encoderFeeds.pixel_values.dims; + encoderFeeds.pixel_mask = ones([dims[0], dims[2], dims[3]]); } // console.log("encoder forward running here: ", session, encoderFeeds) return await sessionRun(session, encoderFeeds); @@ -547,6 +582,10 @@ async function encoderForward(self: any, model_inputs: any) { async function decoderForward(self: any, model_inputs: any, is_encoder_decoder = false) { const session = self.sessions[is_encoder_decoder ? 'decoder_model_merged' : 'model']; + // Add debugging + console.log('Decoder inputs:', model_inputs); + console.log('Session config:', session.config); + const { past_key_values, ...new_model_inputs } = model_inputs; if (session.inputNames.includes('use_cache_branch')) { @@ -1561,7 +1600,7 @@ export class PreTrainedModel extends Callable { if (inputs) { throw new Error( '`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. ' + - 'Make sure to either pass {inputs} or {input_name}=...', + 'Make sure to either pass {inputs} or {input_name}=...', ); } } else { @@ -1760,7 +1799,7 @@ export class PreTrainedModel extends Callable { let input_ids_length = input_ids.dims.at(-1); if (generation_config && (generation_config as any).max_new_tokens !== null) { - (generation_config as any).max_length = input_ids_length + (generation_config as any).max_new_tokens; + (generation_config as any).max_length = input_ids_length + (generation_config as any).max_new_tokens; } // input_ids_length = model_inputs[model_input_name].dims.at(1); @@ -2006,7 +2045,7 @@ export class PreTrainedModel extends Callable { if (!this.config.num_image_tokens) { console.warn( 'The number of image tokens was not set in the model configuration. ' + - `Setting it to the number of features detected by the vision encoder (${features.dims[1]}).`, + `Setting it to the number of features detected by the vision encoder (${features.dims[1]}).`, ); // @ts-expect-error TS2339 this.config.num_image_tokens = features.dims[1]; @@ -2022,7 +2061,7 @@ export class PreTrainedModel extends Callable { ////////////////////////////////////////////////// // Base model output class -export class ModelOutput {} +export class ModelOutput { } /** * Base class for model's outputs, with potential hidden states and attentions. @@ -2056,12 +2095,12 @@ export class BaseModelOutput extends ModelOutput { ////////////////////////////////////////////////// // Audio Spectrogram Transformer (AST) models -export class ASTPreTrainedModel extends PreTrainedModel {} +export class ASTPreTrainedModel extends PreTrainedModel { } /** * The bare AST Model transformer outputting raw hidden-states without any specific head on top. */ -export class ASTModel extends ASTPreTrainedModel {} +export class ASTModel extends ASTPreTrainedModel { } ////////////////////////////////////////////////// @@ -2082,7 +2121,7 @@ export class WhisperPreTrainedModel extends PreTrainedModel { /** * WhisperModel class for training Whisper models without a language model head. */ -export class WhisperModel extends WhisperPreTrainedModel {} +export class WhisperModel extends WhisperPreTrainedModel { } /** * WhisperForConditionalGeneration class for generating conditional outputs from Whisper models. @@ -2160,7 +2199,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { */ async generate({ inputs = null, - generation_config= null, + generation_config = null, logits_processor = null, stopping_criteria = null, @@ -2191,7 +2230,7 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { if (!generation_config.alignment_heads) { throw new Error( 'Model generation config has no `alignment_heads`, token-level timestamps not available. ' + - 'See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config.', + 'See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config.', ); } @@ -2243,13 +2282,13 @@ export class WhisperForConditionalGeneration extends WhisperPreTrainedModel { if (!generate_outputs.cross_attentions) { throw new Error( 'Model outputs must contain cross attentions to extract timestamps. ' + - 'This is most likely because the model was not exported with `output_attentions=True`.', + 'This is most likely because the model was not exported with `output_attentions=True`.', ); } if (num_frames == null) { console.warn( '`num_frames` has not been set, meaning the entire audio will be analyzed. ' + - 'This may lead to inaccurate token-level timestamps for short audios (< 30 seconds).', + 'This may lead to inaccurate token-level timestamps for short audios (< 30 seconds).', ); } @@ -2360,9 +2399,9 @@ export class MoonshinePreTrainedModel extends PreTrainedModel { /** * MoonshineModel class for training Moonshine models without a language model head. */ -export class MoonshineModel extends MoonshinePreTrainedModel {} +export class MoonshineModel extends MoonshinePreTrainedModel { } -export class MoonshineForConditionalGeneration extends MoonshinePreTrainedModel {} +export class MoonshineForConditionalGeneration extends MoonshinePreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -2421,8 +2460,8 @@ export class LlavaForConditionalGeneration extends LlavaPreTrainedModel { } ////////////////////////////////////////////////// -export class LlavaOnevisionForConditionalGeneration extends LlavaForConditionalGeneration {} // NOTE: extends LlavaForConditionalGeneration -export class Moondream1ForConditionalGeneration extends LlavaForConditionalGeneration {} // NOTE: extends LlavaForConditionalGeneration +export class LlavaOnevisionForConditionalGeneration extends LlavaForConditionalGeneration { } // NOTE: extends LlavaForConditionalGeneration +export class Moondream1ForConditionalGeneration extends LlavaForConditionalGeneration { } // NOTE: extends LlavaForConditionalGeneration export class Florence2PreTrainedModel extends PreTrainedModel { forward_params = [ @@ -2688,7 +2727,7 @@ export class Phi3VForCausalLM extends Phi3VPreTrainedModel { } ////////////////////////////////////////////////// -export class CLIPPreTrainedModel extends PreTrainedModel {} +export class CLIPPreTrainedModel extends PreTrainedModel { } /** * CLIP Text and Vision Model with a projection layers on top @@ -2733,7 +2772,7 @@ export class CLIPPreTrainedModel extends PreTrainedModel {} * // } * ``` */ -export class CLIPModel extends CLIPPreTrainedModel {} +export class CLIPModel extends CLIPPreTrainedModel { } /** * The text model from CLIP without any head or projection on top. @@ -2840,7 +2879,7 @@ export class CLIPVisionModelWithProjection extends CLIPPreTrainedModel { ////////////////////////////////////////////////// // SigLIP models -export class SiglipPreTrainedModel extends PreTrainedModel {} +export class SiglipPreTrainedModel extends PreTrainedModel { } /** * SigLIP Text and Vision Model with a projection layers on top @@ -2885,7 +2924,7 @@ export class SiglipPreTrainedModel extends PreTrainedModel {} * // } * ``` */ -export class SiglipModel extends SiglipPreTrainedModel {} +export class SiglipModel extends SiglipPreTrainedModel { } /** * The text model from SigLIP without any head or projection on top. @@ -2962,14 +3001,14 @@ export class SiglipVisionModel extends CLIPPreTrainedModel { } ////////////////////////////////////////////////// // ChineseCLIP models -export class ChineseCLIPPreTrainedModel extends PreTrainedModel {} +export class ChineseCLIPPreTrainedModel extends PreTrainedModel { } -export class ChineseCLIPModel extends ChineseCLIPPreTrainedModel {} +export class ChineseCLIPModel extends ChineseCLIPPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // JinaCLIP models -export class JinaCLIPPreTrainedModel extends PreTrainedModel {} +export class JinaCLIPPreTrainedModel extends PreTrainedModel { } export class JinaCLIPModel extends JinaCLIPPreTrainedModel { async forward(model_inputs: any) { @@ -3036,14 +3075,14 @@ export class JinaCLIPVisionModel extends JinaCLIPPreTrainedModel { ////////////////////////////////////////////////// // GPT2 models -export class GPT2PreTrainedModel extends PreTrainedModel {} +export class GPT2PreTrainedModel extends PreTrainedModel { } -export class GPT2Model extends GPT2PreTrainedModel {} +export class GPT2Model extends GPT2PreTrainedModel { } /** * GPT-2 language model head on top of the GPT-2 base model. This model is suitable for text generation tasks. */ -export class GPT2LMHeadModel extends GPT2PreTrainedModel {} +export class GPT2LMHeadModel extends GPT2PreTrainedModel { } // export class GPT2ForSequenceClassification extends GPT2PreTrainedModel { // TODO // } @@ -3051,65 +3090,65 @@ export class GPT2LMHeadModel extends GPT2PreTrainedModel {} ////////////////////////////////////////////////// // JAIS models -export class JAISPreTrainedModel extends PreTrainedModel {} +export class JAISPreTrainedModel extends PreTrainedModel { } /** * The bare JAIS Model transformer outputting raw hidden-states without any specific head on top. */ -export class JAISModel extends JAISPreTrainedModel {} +export class JAISModel extends JAISPreTrainedModel { } /** * The JAIS Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). */ -export class JAISLMHeadModel extends JAISPreTrainedModel {} +export class JAISLMHeadModel extends JAISPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // GPTNeo models -export class GPTNeoPreTrainedModel extends PreTrainedModel {} -export class GPTNeoModel extends GPTNeoPreTrainedModel {} +export class GPTNeoPreTrainedModel extends PreTrainedModel { } +export class GPTNeoModel extends GPTNeoPreTrainedModel { } -export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel {} +export class GPTNeoForCausalLM extends GPTNeoPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // GPTNeoX models -export class GPTNeoXPreTrainedModel extends PreTrainedModel {} -export class GPTNeoXModel extends GPTNeoXPreTrainedModel {} +export class GPTNeoXPreTrainedModel extends PreTrainedModel { } +export class GPTNeoXModel extends GPTNeoXPreTrainedModel { } -export class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel {} +export class GPTNeoXForCausalLM extends GPTNeoXPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // GPT-J models -export class GPTJPreTrainedModel extends PreTrainedModel {} +export class GPTJPreTrainedModel extends PreTrainedModel { } -export class GPTJModel extends GPTJPreTrainedModel {} +export class GPTJModel extends GPTJPreTrainedModel { } -export class GPTJForCausalLM extends GPTJPreTrainedModel {} +export class GPTJForCausalLM extends GPTJPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // GPTBigCode models -export class GPTBigCodePreTrainedModel extends PreTrainedModel {} +export class GPTBigCodePreTrainedModel extends PreTrainedModel { } -export class GPTBigCodeModel extends GPTBigCodePreTrainedModel {} +export class GPTBigCodeModel extends GPTBigCodePreTrainedModel { } -export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel {} +export class GPTBigCodeForCausalLM extends GPTBigCodePreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // CodeGen models -export class CodeGenPreTrainedModel extends PreTrainedModel {} +export class CodeGenPreTrainedModel extends PreTrainedModel { } /** * CodeGenModel is a class representing a code generation model without a language model head. */ -export class CodeGenModel extends CodeGenPreTrainedModel {} +export class CodeGenModel extends CodeGenPreTrainedModel { } /** * CodeGenForCausalLM is a class that represents a code generation model based on the GPT-2 architecture. It extends the `CodeGenPreTrainedModel` class. */ -export class CodeGenForCausalLM extends CodeGenPreTrainedModel {} +export class CodeGenForCausalLM extends CodeGenPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -3118,48 +3157,48 @@ export class CodeGenForCausalLM extends CodeGenPreTrainedModel {} /** * The bare LLama Model outputting raw hidden-states without any specific head on top. */ -export class LlamaPreTrainedModel extends PreTrainedModel {} +export class LlamaPreTrainedModel extends PreTrainedModel { } /** * The bare LLaMA Model outputting raw hidden-states without any specific head on top. */ -export class LlamaModel extends LlamaPreTrainedModel {} +export class LlamaModel extends LlamaPreTrainedModel { } -export class LlamaForCausalLM extends LlamaPreTrainedModel {} +export class LlamaForCausalLM extends LlamaPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // EXAONE models -export class ExaonePreTrainedModel extends PreTrainedModel {} -export class ExaoneModel extends ExaonePreTrainedModel {} -export class ExaoneForCausalLM extends ExaonePreTrainedModel {} +export class ExaonePreTrainedModel extends PreTrainedModel { } +export class ExaoneModel extends ExaonePreTrainedModel { } +export class ExaoneForCausalLM extends ExaonePreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // MobileLLM models -export class MobileLLMPreTrainedModel extends PreTrainedModel {} -export class MobileLLMModel extends MobileLLMPreTrainedModel {} -export class MobileLLMForCausalLM extends MobileLLMPreTrainedModel {} +export class MobileLLMPreTrainedModel extends PreTrainedModel { } +export class MobileLLMModel extends MobileLLMPreTrainedModel { } +export class MobileLLMForCausalLM extends MobileLLMPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // OLMo models -export class OlmoPreTrainedModel extends PreTrainedModel {} -export class OlmoModel extends OlmoPreTrainedModel {} -export class OlmoForCausalLM extends OlmoPreTrainedModel {} +export class OlmoPreTrainedModel extends PreTrainedModel { } +export class OlmoModel extends OlmoPreTrainedModel { } +export class OlmoForCausalLM extends OlmoPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // OLMo2 models -export class Olmo2PreTrainedModel extends PreTrainedModel {} -export class Olmo2Model extends Olmo2PreTrainedModel {} -export class Olmo2ForCausalLM extends Olmo2PreTrainedModel {} +export class Olmo2PreTrainedModel extends PreTrainedModel { } +export class Olmo2Model extends Olmo2PreTrainedModel { } +export class Olmo2ForCausalLM extends Olmo2PreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // Granite models -export class GranitePreTrainedModel extends PreTrainedModel {} -export class GraniteModel extends GranitePreTrainedModel {} -export class GraniteForCausalLM extends GranitePreTrainedModel {} +export class GranitePreTrainedModel extends PreTrainedModel { } +export class GraniteModel extends GranitePreTrainedModel { } +export class GraniteForCausalLM extends GranitePreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -3168,10 +3207,10 @@ export class GraniteForCausalLM extends GranitePreTrainedModel {} /** * The bare Cohere Model outputting raw hidden-states without any specific head on top. */ -export class CoherePreTrainedModel extends PreTrainedModel {} -export class CohereModel extends CoherePreTrainedModel {} +export class CoherePreTrainedModel extends PreTrainedModel { } +export class CohereModel extends CoherePreTrainedModel { } -export class CohereForCausalLM extends CoherePreTrainedModel {} +export class CohereForCausalLM extends CoherePreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -3180,13 +3219,13 @@ export class CohereForCausalLM extends CoherePreTrainedModel {} /** * The bare Gemma Model outputting raw hidden-states without any specific head on top. */ -export class GemmaPreTrainedModel extends PreTrainedModel {} +export class GemmaPreTrainedModel extends PreTrainedModel { } /** * The bare Gemma Model outputting raw hidden-states without any specific head on top. */ -export class GemmaModel extends GemmaPreTrainedModel {} +export class GemmaModel extends GemmaPreTrainedModel { } -export class GemmaForCausalLM extends GemmaPreTrainedModel {} +export class GemmaForCausalLM extends GemmaPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -3195,20 +3234,20 @@ export class GemmaForCausalLM extends GemmaPreTrainedModel {} /** * The bare Gemma2 Model outputting raw hidden-states without any specific head on top. */ -export class Gemma2PreTrainedModel extends PreTrainedModel {} +export class Gemma2PreTrainedModel extends PreTrainedModel { } /** * The bare Gemma2 Model outputting raw hidden-states without any specific head on top. */ -export class Gemma2Model extends Gemma2PreTrainedModel {} +export class Gemma2Model extends Gemma2PreTrainedModel { } -export class Gemma2ForCausalLM extends Gemma2PreTrainedModel {} +export class Gemma2ForCausalLM extends Gemma2PreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// -export class OpenELMPreTrainedModel extends PreTrainedModel {} -export class OpenELMModel extends OpenELMPreTrainedModel {} +export class OpenELMPreTrainedModel extends PreTrainedModel { } +export class OpenELMModel extends OpenELMPreTrainedModel { } -export class OpenELMForCausalLM extends OpenELMPreTrainedModel {} +export class OpenELMForCausalLM extends OpenELMPreTrainedModel { } ////////////////////////////////////////////////// // Qwen2 models @@ -3216,36 +3255,36 @@ export class OpenELMForCausalLM extends OpenELMPreTrainedModel {} /** * The bare Qwen2 Model outputting raw hidden-states without any specific head on top. */ -export class Qwen2PreTrainedModel extends PreTrainedModel {} +export class Qwen2PreTrainedModel extends PreTrainedModel { } /** * The bare Qwen2 Model outputting raw hidden-states without any specific head on top. */ -export class Qwen2Model extends Qwen2PreTrainedModel {} +export class Qwen2Model extends Qwen2PreTrainedModel { } -export class Qwen2ForCausalLM extends Qwen2PreTrainedModel {} +export class Qwen2ForCausalLM extends Qwen2PreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // Phi models -export class PhiPreTrainedModel extends PreTrainedModel {} +export class PhiPreTrainedModel extends PreTrainedModel { } /** * The bare Phi Model outputting raw hidden-states without any specific head on top. */ -export class PhiModel extends PhiPreTrainedModel {} +export class PhiModel extends PhiPreTrainedModel { } -export class PhiForCausalLM extends PhiPreTrainedModel {} +export class PhiForCausalLM extends PhiPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // Phi3 models -export class Phi3PreTrainedModel extends PreTrainedModel {} +export class Phi3PreTrainedModel extends PreTrainedModel { } /** * The bare Phi3 Model outputting raw hidden-states without any specific head on top. */ -export class Phi3Model extends Phi3PreTrainedModel {} +export class Phi3Model extends Phi3PreTrainedModel { } -export class Phi3ForCausalLM extends Phi3PreTrainedModel {} +export class Phi3ForCausalLM extends Phi3PreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -3253,70 +3292,70 @@ export class Phi3ForCausalLM extends Phi3PreTrainedModel {} /** * The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). */ -export class BloomPreTrainedModel extends PreTrainedModel {} +export class BloomPreTrainedModel extends PreTrainedModel { } /** * The bare Bloom Model transformer outputting raw hidden-states without any specific head on top. */ -export class BloomModel extends BloomPreTrainedModel {} +export class BloomModel extends BloomPreTrainedModel { } /** * The Bloom Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). */ -export class BloomForCausalLM extends BloomPreTrainedModel {} +export class BloomForCausalLM extends BloomPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // MPT models -export class MptPreTrainedModel extends PreTrainedModel {} +export class MptPreTrainedModel extends PreTrainedModel { } /** * The bare Mpt Model transformer outputting raw hidden-states without any specific head on top. */ -export class MptModel extends MptPreTrainedModel {} +export class MptModel extends MptPreTrainedModel { } /** * The MPT Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). */ -export class MptForCausalLM extends MptPreTrainedModel {} +export class MptForCausalLM extends MptPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // OPT models -export class OPTPreTrainedModel extends PreTrainedModel {} +export class OPTPreTrainedModel extends PreTrainedModel { } /** * The bare OPT Model outputting raw hidden-states without any specific head on top. */ -export class OPTModel extends OPTPreTrainedModel {} +export class OPTModel extends OPTPreTrainedModel { } /** * The OPT Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings). */ -export class OPTForCausalLM extends OPTPreTrainedModel {} +export class OPTForCausalLM extends OPTPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// -export class VitPosePreTrainedModel extends PreTrainedModel {} +export class VitPosePreTrainedModel extends PreTrainedModel { } /** * The VitPose model with a pose estimation head on top. */ -export class VitPoseForPoseEstimation extends VitPosePreTrainedModel {} +export class VitPoseForPoseEstimation extends VitPosePreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// -export class ViTMAEPreTrainedModel extends PreTrainedModel {} -export class ViTMAEModel extends ViTMAEPreTrainedModel {} +export class ViTMAEPreTrainedModel extends PreTrainedModel { } +export class ViTMAEModel extends ViTMAEPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// -export class GroupViTPreTrainedModel extends PreTrainedModel {} -export class GroupViTModel extends GroupViTPreTrainedModel {} +export class GroupViTPreTrainedModel extends PreTrainedModel { } +export class GroupViTModel extends GroupViTPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// -export class VitMattePreTrainedModel extends PreTrainedModel {} +export class VitMattePreTrainedModel extends PreTrainedModel { } /** * ViTMatte framework leveraging any vision backbone e.g. for ADE20k, CityScapes. @@ -3380,12 +3419,12 @@ export class VitMatteForImageMatting extends VitMattePreTrainedModel { ////////////////////////////////////////////////// ////////////////////////////////////////////////// -export class Swin2SRPreTrainedModel extends PreTrainedModel {} +export class Swin2SRPreTrainedModel extends PreTrainedModel { } /** * The bare Swin2SR Model transformer outputting raw hidden-states without any specific head on top. */ -export class Swin2SRModel extends Swin2SRPreTrainedModel {} +export class Swin2SRModel extends Swin2SRPreTrainedModel { } /** * Swin2SR Model transformer with an upsampler head on top for image super resolution and restoration. @@ -3419,12 +3458,12 @@ export class Swin2SRModel extends Swin2SRPreTrainedModel {} * // } * ``` */ -export class Swin2SRForImageSuperResolution extends Swin2SRPreTrainedModel {} +export class Swin2SRForImageSuperResolution extends Swin2SRPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// -export class SamPreTrainedModel extends PreTrainedModel {} +export class SamPreTrainedModel extends PreTrainedModel { } /** * Segment Anything Model (SAM) for generating segmentation masks, given an input image @@ -3574,7 +3613,7 @@ export class SamImageSegmentationOutput extends ModelOutput { ////////////////////////////////////////////////// // Wav2Vec2 models -export class Wav2Vec2PreTrainedModel extends PreTrainedModel {} +export class Wav2Vec2PreTrainedModel extends PreTrainedModel { } /** * The bare Wav2Vec2 Model transformer outputting raw hidden-states without any specific head on top. @@ -3602,7 +3641,7 @@ export class Wav2Vec2PreTrainedModel extends PreTrainedModel {} * // } * ``` */ -export class Wav2Vec2Model extends Wav2Vec2PreTrainedModel {} +export class Wav2Vec2Model extends Wav2Vec2PreTrainedModel { } export class Wav2Vec2ForCTC extends Wav2Vec2PreTrainedModel { /** @@ -3619,12 +3658,12 @@ export class Wav2Vec2ForCTC extends Wav2Vec2PreTrainedModel { ////////////////////////////////////////////////// // PyAnnote models -export class PyAnnotePreTrainedModel extends PreTrainedModel {} +export class PyAnnotePreTrainedModel extends PreTrainedModel { } /** * The bare PyAnnote Model transformer outputting raw hidden-states without any specific head on top. */ -export class PyAnnoteModel extends PyAnnotePreTrainedModel {} +export class PyAnnoteModel extends PyAnnotePreTrainedModel { } /** * PyAnnote Model with a frame classification head on top for tasks like Speaker Diarization. @@ -3697,18 +3736,18 @@ export class PyAnnoteForAudioFrameClassification extends PyAnnotePreTrainedModel ////////////////////////////////////////////////// // WeSpeakerResNet models -export class WeSpeakerResNetPreTrainedModel extends PreTrainedModel {} -export class WeSpeakerResNetModel extends WeSpeakerResNetPreTrainedModel {} +export class WeSpeakerResNetPreTrainedModel extends PreTrainedModel { } +export class WeSpeakerResNetModel extends WeSpeakerResNetPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // UniSpeech models -export class UniSpeechPreTrainedModel extends PreTrainedModel {} +export class UniSpeechPreTrainedModel extends PreTrainedModel { } /** * The bare UniSpeech Model transformer outputting raw hidden-states without any specific head on top. */ -export class UniSpeechModel extends UniSpeechPreTrainedModel {} +export class UniSpeechModel extends UniSpeechPreTrainedModel { } /** * UniSpeech Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). @@ -3728,12 +3767,12 @@ export class UniSpeechForCTC extends UniSpeechPreTrainedModel { ////////////////////////////////////////////////// // UniSpeechSat models -export class UniSpeechSatPreTrainedModel extends PreTrainedModel {} +export class UniSpeechSatPreTrainedModel extends PreTrainedModel { } /** * The bare UniSpeechSat Model transformer outputting raw hidden-states without any specific head on top. */ -export class UniSpeechSatModel extends UniSpeechSatPreTrainedModel {} +export class UniSpeechSatModel extends UniSpeechSatPreTrainedModel { } /** * UniSpeechSat Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). @@ -3766,12 +3805,12 @@ export class UniSpeechSatForAudioFrameClassification extends UniSpeechSatPreTrai ////////////////////////////////////////////////// // Wav2Vec2Bert models -export class Wav2Vec2BertPreTrainedModel extends PreTrainedModel {} +export class Wav2Vec2BertPreTrainedModel extends PreTrainedModel { } /** * The bare Wav2Vec2Bert Model transformer outputting raw hidden-states without any specific head on top. */ -export class Wav2Vec2BertModel extends Wav2Vec2BertPreTrainedModel {} +export class Wav2Vec2BertModel extends Wav2Vec2BertPreTrainedModel { } /** * Wav2Vec2Bert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). @@ -3791,7 +3830,7 @@ export class Wav2Vec2BertForCTC extends Wav2Vec2BertPreTrainedModel { ////////////////////////////////////////////////// // Hubert models -export class HubertPreTrainedModel extends PreTrainedModel {} +export class HubertPreTrainedModel extends PreTrainedModel { } /** * The bare Hubert Model transformer outputting raw hidden-states without any specific head on top. @@ -3819,7 +3858,7 @@ export class HubertPreTrainedModel extends PreTrainedModel {} * // } * ``` */ -export class HubertModel extends Wav2Vec2PreTrainedModel {} +export class HubertModel extends Wav2Vec2PreTrainedModel { } /** * Hubert Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). @@ -3842,7 +3881,7 @@ export class HubertForCTC extends Wav2Vec2PreTrainedModel { /** * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. */ -export class WavLMPreTrainedModel extends PreTrainedModel {} +export class WavLMPreTrainedModel extends PreTrainedModel { } /** * The bare WavLM Model transformer outputting raw hidden-states without any specific head on top. @@ -3870,7 +3909,7 @@ export class WavLMPreTrainedModel extends PreTrainedModel {} * // } * ``` */ -export class WavLMModel extends WavLMPreTrainedModel {} +export class WavLMModel extends WavLMPreTrainedModel { } /** * WavLM Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). @@ -3982,12 +4021,12 @@ export class WavLMForAudioFrameClassification extends WavLMPreTrainedModel { /** * An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. */ -export class SpeechT5PreTrainedModel extends PreTrainedModel {} +export class SpeechT5PreTrainedModel extends PreTrainedModel { } /** * The bare SpeechT5 Encoder-Decoder Model outputting raw hidden-states without any specific pre- or post-nets. */ -export class SpeechT5Model extends SpeechT5PreTrainedModel {} +export class SpeechT5Model extends SpeechT5PreTrainedModel { } /** * SpeechT5 Model with a speech encoder and a text decoder. @@ -4029,7 +4068,7 @@ export class SpeechT5Model extends SpeechT5PreTrainedModel {} * // } * ``` */ -export class SpeechT5ForSpeechToText extends SpeechT5PreTrainedModel {} +export class SpeechT5ForSpeechToText extends SpeechT5PreTrainedModel { } /** * SpeechT5 Model with a text encoder and a speech decoder. @@ -4143,12 +4182,12 @@ export class SpeechT5HifiGan extends PreTrainedModel { ////////////////////////////////////////////////// // TrOCR models -export class TrOCRPreTrainedModel extends PreTrainedModel {} +export class TrOCRPreTrainedModel extends PreTrainedModel { } /** * The TrOCR Decoder with a language modeling head. */ -export class TrOCRForCausalLM extends TrOCRPreTrainedModel {} +export class TrOCRForCausalLM extends TrOCRPreTrainedModel { } ////////////////////////////////////////////////// @@ -4157,11 +4196,11 @@ export class TrOCRForCausalLM extends TrOCRPreTrainedModel {} /** * The bare Mistral Model outputting raw hidden-states without any specific head on top. */ -export class MistralPreTrainedModel extends PreTrainedModel {} +export class MistralPreTrainedModel extends PreTrainedModel { } -export class MistralModel extends MistralPreTrainedModel {} +export class MistralModel extends MistralPreTrainedModel { } -export class MistralForCausalLM extends MistralPreTrainedModel {} +export class MistralForCausalLM extends MistralPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -4169,11 +4208,11 @@ export class MistralForCausalLM extends MistralPreTrainedModel {} /** * The bare Starcoder2 Model outputting raw hidden-states without any specific head on top. */ -export class Starcoder2PreTrainedModel extends PreTrainedModel {} +export class Starcoder2PreTrainedModel extends PreTrainedModel { } -export class Starcoder2Model extends Starcoder2PreTrainedModel {} +export class Starcoder2Model extends Starcoder2PreTrainedModel { } -export class Starcoder2ForCausalLM extends Starcoder2PreTrainedModel {} +export class Starcoder2ForCausalLM extends Starcoder2PreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -4181,18 +4220,18 @@ export class Starcoder2ForCausalLM extends Starcoder2PreTrainedModel {} /** * The bare Falcon Model outputting raw hidden-states without any specific head on top. */ -export class FalconPreTrainedModel extends PreTrainedModel {} +export class FalconPreTrainedModel extends PreTrainedModel { } -export class FalconModel extends FalconPreTrainedModel {} +export class FalconModel extends FalconPreTrainedModel { } -export class FalconForCausalLM extends FalconPreTrainedModel {} +export class FalconForCausalLM extends FalconPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // CLAP models -export class ClapPreTrainedModel extends PreTrainedModel {} +export class ClapPreTrainedModel extends PreTrainedModel { } -export class ClapModel extends ClapPreTrainedModel {} +export class ClapModel extends ClapPreTrainedModel { } /** * CLAP Text Model with a projection layer on top (a linear layer on top of the pooled output). @@ -4271,7 +4310,7 @@ export class ClapAudioModelWithProjection extends ClapPreTrainedModel { ////////////////////////////////////////////////// // VITS models -export class VitsPreTrainedModel extends PreTrainedModel {} +export class VitsPreTrainedModel extends PreTrainedModel { } /** * The complete VITS model, for text-to-speech synthesis. @@ -4311,32 +4350,32 @@ export class VitsModel extends VitsPreTrainedModel { ////////////////////////////////////////////////// // StableLm models -export class StableLmPreTrainedModel extends PreTrainedModel {} +export class StableLmPreTrainedModel extends PreTrainedModel { } /** * The bare StableLm Model transformer outputting raw hidden-states without any specific head on top. */ -export class StableLmModel extends StableLmPreTrainedModel {} +export class StableLmModel extends StableLmPreTrainedModel { } /** * StableLm Model with a `language modeling` head on top for Causal Language Modeling (with past). */ -export class StableLmForCausalLM extends StableLmPreTrainedModel {} +export class StableLmForCausalLM extends StableLmPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// // Musicgen models -export class MusicgenPreTrainedModel extends PreTrainedModel {} +export class MusicgenPreTrainedModel extends PreTrainedModel { } /** * The bare Musicgen decoder model outputting raw hidden-states without any specific head on top. */ -export class MusicgenModel extends MusicgenPreTrainedModel {} +export class MusicgenModel extends MusicgenPreTrainedModel { } /** * The MusicGen decoder model with a language modelling head on top. */ -export class MusicgenForCausalLM extends MusicgenPreTrainedModel {} +export class MusicgenForCausalLM extends MusicgenPreTrainedModel { } /** * The composite MusicGen model with a text encoder, audio encoder and Musicgen decoder, @@ -4462,17 +4501,17 @@ export class MusicgenForConditionalGeneration extends PreTrainedModel { ////////////////////////////////////////////////// // Decision Transformer models -export class DecisionTransformerPreTrainedModel extends PreTrainedModel {} +export class DecisionTransformerPreTrainedModel extends PreTrainedModel { } /** * The model builds upon the GPT2 architecture to perform autoregressive prediction of actions in an offline RL setting. * Refer to the paper for more details: https://arxiv.org/abs/2106.01345 */ -export class DecisionTransformerModel extends DecisionTransformerPreTrainedModel {} +export class DecisionTransformerModel extends DecisionTransformerPreTrainedModel { } ////////////////////////////////////////////////// -export class MultiModalityPreTrainedModel extends PreTrainedModel {} +export class MultiModalityPreTrainedModel extends PreTrainedModel { } export class MultiModalityCausalLM extends MultiModalityPreTrainedModel { forward_params = [ // prepare_inputs_embeds @@ -4598,7 +4637,7 @@ export class MgpstrModelOutput extends ModelOutput { } } -export class MgpstrPreTrainedModel extends PreTrainedModel {} +export class MgpstrPreTrainedModel extends PreTrainedModel { } /** * MGP-STR Model transformer with three classification heads on top @@ -4615,17 +4654,17 @@ export class MgpstrForSceneTextRecognition extends MgpstrPreTrainedModel { ////////////////////////////////////////////////// // PatchTST Transformer models -export class PatchTSTPreTrainedModel extends PreTrainedModel {} +export class PatchTSTPreTrainedModel extends PreTrainedModel { } /** * The bare PatchTST Model outputting raw hidden-states without any specific head. */ -export class PatchTSTModel extends PatchTSTPreTrainedModel {} +export class PatchTSTModel extends PatchTSTPreTrainedModel { } /** * The PatchTST for prediction model. */ -export class PatchTSTForPrediction extends PatchTSTPreTrainedModel {} +export class PatchTSTForPrediction extends PatchTSTPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -4636,17 +4675,17 @@ export class StyleTextToSpeech2Model extends StyleTextToSpeech2PreTrainedModel { ////////////////////////////////////////////////// // PatchTSMixer Transformer models -export class PatchTSMixerPreTrainedModel extends PreTrainedModel {} +export class PatchTSMixerPreTrainedModel extends PreTrainedModel { } /** * The bare PatchTSMixer Model outputting raw hidden-states without any specific head. */ -export class PatchTSMixerModel extends PatchTSMixerPreTrainedModel {} +export class PatchTSMixerModel extends PatchTSMixerPreTrainedModel { } /** * The PatchTSMixer for prediction model. */ -export class PatchTSMixerForPrediction extends PatchTSMixerPreTrainedModel {} +export class PatchTSMixerForPrediction extends PatchTSMixerPreTrainedModel { } ////////////////////////////////////////////////// ////////////////////////////////////////////////// @@ -4658,18 +4697,18 @@ export class PatchTSMixerForPrediction extends PatchTSMixerPreTrainedModel {} * which is used to instantiate pretrained models. */ interface ModelOptions { - config?: PretrainedConfig | null; - cache_dir?: string | null; - local_files_only?: boolean; - revision?: string; - model_file_name?: string | null; - subfolder?: string; - device?: string | null; - dtype?: string | null; - use_external_data_format?: boolean | null; - session_options?: any; - progress_callback?: ProgressCallback | null; - } + config?: PretrainedConfig | null; + cache_dir?: string | null; + local_files_only?: boolean; + revision?: string; + model_file_name?: string | null; + subfolder?: string; + device?: string | null; + dtype?: string | null; + use_external_data_format?: boolean | null; + session_options?: any; + progress_callback?: ProgressCallback | null; +} export class PretrainedMixin { /** * Mapping from model type to model class. @@ -4859,7 +4898,7 @@ const MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES = new Map([ ['llava_onevision', ['LlavaOnevisionForConditionalGeneration', LlavaOnevisionForConditionalGeneration]], ['moondream1', ['Moondream1ForConditionalGeneration', Moondream1ForConditionalGeneration]], ['florence2', ['Florence2ForConditionalGeneration', Florence2ForConditionalGeneration]], -// ['qwen2-vl', ['Qwen2VLForConditionalGeneration', Qwen2VLForConditionalGeneration]], + // ['qwen2-vl', ['Qwen2VLForConditionalGeneration', Qwen2VLForConditionalGeneration]], ['idefics3', ['Idefics3ForConditionalGeneration', Idefics3ForConditionalGeneration]], ['paligemma', ['PaliGemmaForConditionalGeneration', PaliGemmaForConditionalGeneration]], ]); @@ -5104,7 +5143,7 @@ export class SequenceClassifierOutput extends ModelOutput { */ logits: Tensor; attentions: Record; - constructor({ logits, ...attentions }: { logits: Tensor; [key: string]: Tensor }) { + constructor({ logits, ...attentions }: { logits: Tensor;[key: string]: Tensor }) { super(); this.logits = logits; this.attentions = attentions; diff --git a/src/libs/transformers/models/auto/image_processing_auto.ts b/src/libs/transformers/models/auto/image_processing_auto.ts index a7c11df..98f8ec1 100644 --- a/src/libs/transformers/models/auto/image_processing_auto.ts +++ b/src/libs/transformers/models/auto/image_processing_auto.ts @@ -1,6 +1,13 @@ import { GITHUB_ISSUE_URL, IMAGE_PROCESSOR_NAME } from '../../utils/constants'; import { getModelJSON } from '../../utils/hub'; import { ImageProcessor } from '../../base/image_processors_utils'; +import { VLMImageProcessor } from '../janus/image_processing_janus'; + +// Map of processor types to their implementations +const PROCESSOR_MAPPING = { + 'VLMImageProcessor': VLMImageProcessor, + 'ImageProcessor': ImageProcessor, +}; export class AutoImageProcessor { /** @type {typeof ImageProcessor.from_pretrained} */ @@ -9,15 +16,15 @@ export class AutoImageProcessor { // Determine image processor class const key = preprocessorConfig.image_processor_type ?? preprocessorConfig.feature_extractor_type; - let image_processor_class: typeof ImageProcessor; + let image_processor_class = ImageProcessor; - if (key !== undefined) { - // Only log a warning if the class is not found and the key is set. + if (key && key in PROCESSOR_MAPPING) { + image_processor_class = PROCESSOR_MAPPING[key as keyof typeof PROCESSOR_MAPPING]; + } else if (key !== undefined) { console.warn( - `Image processor type '${key}' not found, assuming base ImageProcessor. Please report this at ${GITHUB_ISSUE_URL}.`, + `Image processor type '${key}' not found, assuming base ImageProcessor. Please report this at ${GITHUB_ISSUE_URL}.` ); } - image_processor_class = ImageProcessor; // Instantiate image processor return new image_processor_class(preprocessorConfig); diff --git a/src/libs/transformers/models/auto/processing_auto.ts b/src/libs/transformers/models/auto/processing_auto.ts index c7f6fcf..efac49a 100644 --- a/src/libs/transformers/models/auto/processing_auto.ts +++ b/src/libs/transformers/models/auto/processing_auto.ts @@ -36,9 +36,9 @@ import * as AllFeatureExtractors from '../feature_extractors.js'; * // } * ``` */ -export class AutoProcessor { +export class AutoProcessor extends Processor { /** @type {typeof Processor.from_pretrained} */ - static async from_pretrained(pretrained_model_name_or_path: string, options = {}) { + static async from_pretrained(pretrained_model_name_or_path: string, options = {}): Promise { // TODO: first check for processor.json const preprocessorConfig = await getModelJSON(pretrained_model_name_or_path, IMAGE_PROCESSOR_NAME, true, options); diff --git a/src/libs/transformers/models/processors.ts b/src/libs/transformers/models/processors.ts index 4a653e2..96b104f 100644 --- a/src/libs/transformers/models/processors.ts +++ b/src/libs/transformers/models/processors.ts @@ -1,7 +1,7 @@ export * from './florence2/processing_florence2'; +export * from './idefics3/processing_idefics3'; export * from './mgp_str/processing_mgp_str'; export * from './moonshine/processing_moonshine'; -export * from './idefics3/processing_idefics3'; export * from './janus/processing_janus'; export * from './jina_clip/processing_jina_clip'; export * from './phi3_v/processing_phi3_v'; diff --git a/src/libs/transformers/ops/registry.ts b/src/libs/transformers/ops/registry.ts index c693554..fa039cb 100644 --- a/src/libs/transformers/ops/registry.ts +++ b/src/libs/transformers/ops/registry.ts @@ -18,12 +18,16 @@ const wrap = async ( names: string | string[], ) => { const session = await createInferenceSession(new Uint8Array(session_bytes), session_options, {}); + + let chain = Promise.resolve(); + return /** @type {any} */ async (/** @type {Record} */ inputs: Record) => { const proxied = isONNXProxy(); const ortFeed = Object.fromEntries( Object.entries(inputs).map(([k, v]) => [k, (proxied ? v.clone() : v).ort_tensor]), ); - const outputs = await session.run(ortFeed as any); + // When running in-browser via WASM, we need to chain calls to session.run to avoid "Error: Session already started" + const outputs = await (chain.then(() => session.run(ortFeed as any))); if (Array.isArray(names)) { return names.map((n) => new Tensor(outputs[n])); diff --git a/src/libs/transformers/tokenizers.ts b/src/libs/transformers/tokenizers.ts index 9e5e068..57fbd4f 100644 --- a/src/libs/transformers/tokenizers.ts +++ b/src/libs/transformers/tokenizers.ts @@ -1,6 +1,6 @@ import { Callable } from './utils/generic'; -import { reverseDictionary, escapeRegExp, isIntegralNumber, mergeArrays, len } from './utils/core'; +import { reverseDictionary, escapeRegExp, isIntegralNumber, mergeArrays, len, ProgressCallback } from './utils/core'; import { getModelJSON } from './utils/hub'; @@ -12,6 +12,7 @@ import { PriorityQueue, TokenLattice, CharTrie } from './utils/data-structures'; import { Template } from '@huggingface/jinja'; import { PretrainedOptions } from './utils/hub'; import { WHISPER_LANGUAGE_MAPPING } from './models/whisper/common_whisper'; +import { PretrainedConfig } from './configs'; type TokenizerProperties = { legacy?: boolean; @@ -4456,6 +4457,13 @@ export class AutoTokenizer { local_files_only = false, revision = 'main', legacy = undefined, + }: { + progress_callback?: null | ProgressCallback; + config?: null | PretrainedConfig; + cache_dir?: null | string; + local_files_only?: boolean; + revision?: string; + legacy?: boolean; } = {}, ) { const [tokenizerJSON, tokenizerConfig] = await loadTokenizer(pretrained_model_name_or_path, { diff --git a/src/libs/transformers/utils/hub.ts b/src/libs/transformers/utils/hub.ts index 42b5721..6fee2d6 100755 --- a/src/libs/transformers/utils/hub.ts +++ b/src/libs/transformers/utils/hub.ts @@ -9,7 +9,7 @@ import path from 'path'; import { env } from '../env.js'; import { dispatchCallback, ProgressCallback } from './core'; -import { PretrainedConfig } from '../configs.js'; +import { PretrainedConfig } from '../configs'; export interface PretrainedOptions { progress_callback?: null | ProgressCallback; config?: null | PretrainedConfig; diff --git a/src/libs/transformers/utils/tensor.ts b/src/libs/transformers/utils/tensor.ts index 9844b1f..7964fbd 100644 --- a/src/libs/transformers/utils/tensor.ts +++ b/src/libs/transformers/utils/tensor.ts @@ -97,12 +97,28 @@ export class Tensor { if (isONNXTensor(args[0])) { this.ort_tensor = /** @type {ONNXTensor} */ args[0]; } else { - // Create new tensor - this.ort_tensor = new ONNXTensor( - args[0] as keyof typeof DataTypeMap, - args[1] as Exclude, - args[2], - ); + // Add debugging + // console.log('Creating new tensor:', { + // type: args[0], + // dataLength: (args[1] as DataArray).length, + // dims: args[2] + // }); + + try { + this.ort_tensor = new ONNXTensor( + args[0] as keyof typeof DataTypeMap, + args[1] as Exclude, + args[2], + ); + } catch (error) { + console.error('Failed to create ONNXTensor:', { + error, + type: args[0], + dataLength: (args[1] as DataArray).length, + dims: args[2] + }); + throw error; + } } return new Proxy(this, { @@ -571,10 +587,10 @@ export class Tensor { * NOTE: The returned tensor shares the storage with the input tensor, so changing the contents of one will change the contents of the other. * If you would like a copy, use `tensor.clone()` before squeezing. * - * @param {number} [dim=null] If given, the input will be squeezed only in the specified dimensions. + * @param {number|number[]} [dim=null] If given, the input will be squeezed only in the specified dimensions. * @returns {Tensor} The squeezed tensor */ - squeeze(dim = null) { + squeeze(dim: number | number[] | null = null) { return new Tensor(this.type, this.data, calc_squeeze_dims(this.dims, dim)); } @@ -750,7 +766,7 @@ export class Tensor { const val = min(this.data as any)[0]; return new Tensor(this.type, [val], [/* scalar */]); } - const [type, result, resultDims] = reduce_helper((a, b) => Math.min(a, b), this, dim, keepdim, Infinity); + const [type, result, resultDims] = reduce_helper((a: any, b: any) => Math.min(a, b), this, dim, keepdim, Infinity); return new Tensor(type, result, resultDims); } max(dim = null, keepdim = false) { @@ -759,7 +775,7 @@ export class Tensor { const val = max(this.data as any)[0]; return new Tensor(this.type, [val], [/* scalar */]); } - const [type, result, resultDims] = reduce_helper((a, b) => Math.max(a, b), this, dim, keepdim, -Infinity); + const [type, result, resultDims] = reduce_helper((a: any, b: any) => Math.max(a, b), this, dim, keepdim, -Infinity); return new Tensor(type, result, resultDims); } @@ -1258,7 +1274,7 @@ export function stack(tensors: Tensor[], dim: number = 0) { * @param {boolean} keepdim whether the output tensor has dim retained or not. * @returns {[DataType, any, number[]]} The reduced tensor data. */ -function reduce_helper(callbackfn: (previousValue: number, currentValue: number, currentIndex?: number| null, resultIndex?: number| null) => number, input: Tensor, dim: number | null = null, keepdim = false, initialValue: number | null = null) { +function reduce_helper(callbackfn: any, input: Tensor, dim: number | number[] | null = null, keepdim = false, initialValue: number | null = null) { const inputData = input.data; const inputDims = input.dims; @@ -1329,7 +1345,7 @@ export function std_mean(input: Tensor, dim: number | null = null, correction = const meanTensorData = meanTensor.data; // Compute squared sum - const [type, result, resultDims] = reduce_helper((a: number, b: number, i: any, j: any) => a + (b - meanTensorData[j]) ** 2, input, dim, keepdim); + const [type, result, resultDims] = reduce_helper((a: any, b: any, i: any, j: any) => a + (b - meanTensorData[j]) ** 2, input, dim, keepdim); // Square root of the squared sum for (let i = 0; i < result.length; ++i) { @@ -1369,7 +1385,7 @@ export function mean(input: Tensor, dim: number | null = null, keepdim: boolean // Negative indexing dim = safeIndex(dim, inputDims.length); - const [type, result, resultDims] = reduce_helper((a: number, b: number) => a + b, input, dim, keepdim); + const [type, result, resultDims] = reduce_helper((a: any, b: any) => a + b, input, dim, keepdim); return new Tensor(type, result, resultDims); } From 18d3709648bc25b409871dc7cbd6c00016b9852d Mon Sep 17 00:00:00 2001 From: Saurav Panda Date: Sat, 8 Feb 2025 13:46:50 -0800 Subject: [PATCH 2/2] new version with tts updates --- examples/tts-demo/package.json | 2 +- package.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/tts-demo/package.json b/examples/tts-demo/package.json index 87f5fc0..42a3eb5 100644 --- a/examples/tts-demo/package.json +++ b/examples/tts-demo/package.json @@ -10,7 +10,7 @@ "preview": "vite preview" }, "dependencies": { - "@browserai/browserai": "^1.0.18", + "@browserai/browserai": "^1.0.24", "@emotion/styled": "^11.14.0", "react": "^18.3.1", "react-dom": "^18.3.1" diff --git a/package.json b/package.json index 9453a9c..a8697d4 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@browserai/browserai", - "version": "1.0.23", + "version": "1.0.24", "private": false, "description": "A library for running AI models directly in the browser", "main": "dist/index.js",