diff --git a/Cargo.toml b/Cargo.toml index 8d0becb9..52086fe2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "crates/ratchet-core", + "crates/ratchet-downloader", "crates/ratchet-integration-tests", "crates/ratchet-loader", "crates/ratchet-models", @@ -28,6 +29,7 @@ derive-new = "0.6.0" log = "0.4.20" thiserror = "1.0.56" byteorder = "1.5.0" +wasm-bindgen-test = "0.3.34" [workspace.dev-dependencies] hf-hub = "0.3.0" diff --git a/crates/ratchet-core/Cargo.toml b/crates/ratchet-core/Cargo.toml index 97b1619b..cf0b67ff 100644 --- a/crates/ratchet-core/Cargo.toml +++ b/crates/ratchet-core/Cargo.toml @@ -31,6 +31,7 @@ glam = "0.25.0" pollster = "0.3.0" futures-intrusive = "0.5.0" anyhow = "1.0.79" +getrandom = { version = "0.2", features = ["js"] } # Needed for wasm support in `num` trait num = "0.4.1" rand_distr = { version = "0.4.3", optional = true } rand = { version = "0.8.4", optional = true } diff --git a/crates/ratchet-core/src/gpu/device.rs b/crates/ratchet-core/src/gpu/device.rs index fefe081e..0531e666 100644 --- a/crates/ratchet-core/src/gpu/device.rs +++ b/crates/ratchet-core/src/gpu/device.rs @@ -51,7 +51,7 @@ impl PartialEq for WgpuDevice { impl WgpuDevice { pub async fn new() -> Result { #[cfg(target_arch = "wasm32")] - let adapter = Self::select_adapter().await; + let adapter = Self::select_adapter().await?; #[cfg(not(target_arch = "wasm32"))] let adapter = Self::select_adapter()?; @@ -106,7 +106,7 @@ impl WgpuDevice { } #[cfg(target_arch = "wasm32")] - async fn select_adapter() -> Adapter { + async fn select_adapter() -> Result { let instance = wgpu::Instance::default(); let backends = wgpu::util::backend_bits_from_env().unwrap_or(wgpu::Backends::PRIMARY); instance @@ -116,10 +116,10 @@ impl WgpuDevice { force_fallback_adapter: false, }) .await - .map_err(|e| { - log::error!("Failed to create device: {:?}", e); - e - })? + .ok_or({ + log::error!("Failed to request adapter."); + DeviceError::AdapterRequestFailed + }) } #[cfg(not(target_arch = "wasm32"))] diff --git a/crates/ratchet-core/src/kernels.rs b/crates/ratchet-core/src/kernels.rs index 92408394..45908dfc 100644 --- a/crates/ratchet-core/src/kernels.rs +++ b/crates/ratchet-core/src/kernels.rs @@ -7,19 +7,19 @@ lazy_static! { m.insert( "qgemm_vec4", include_str!( - "/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/qgemm_vec4.wgsl" + "/Users/janschulte/code/ratchet/crates/ratchet-core/kernels/qgemm_vec4.wgsl" ), ); m.insert( "sgemm_scalar", include_str!( - "/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/sgemm_scalar.wgsl" + "/Users/janschulte/code/ratchet/crates/ratchet-core/kernels/sgemm_scalar.wgsl" ), ); m.insert( "add_scalar", include_str!( - "/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/add_scalar.wgsl" + "/Users/janschulte/code/ratchet/crates/ratchet-core/kernels/add_scalar.wgsl" ), ); m diff --git a/crates/ratchet-downloader/Cargo.toml b/crates/ratchet-downloader/Cargo.toml new file mode 100644 index 00000000..33fd6bcb --- /dev/null +++ b/crates/ratchet-downloader/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "ratchet-downloader" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[dependencies] +ratchet-loader = { path = "../ratchet-loader" } +wasm-bindgen = "0.2.84" +wasm-bindgen-futures = "0.4.39" +js-sys = "0.3.64" +gloo = "0.11.0" +wasm-streams = "0.4.0" +futures-util = { version = "^0.3.28", features = ["io", "sink"] } +winnow = "0.5.34" +circular = "0.3.0" +anyhow.workspace = true + +[dependencies.web-sys] +features = [ + 'console', + 'Headers', + 'Request', + 'RequestInit', + 'RequestMode', + 'Response', + 'ReadableStream', + 'ReadableStreamGetReaderOptions', + 'ReadableStreamReaderMode', + 'Window', + 'Navigator', + 'StorageManager', + 'CacheStorage' +] +version = "0.3.64" + +[dev-dependencies] +wasm-bindgen-test.workspace = true + +[lib] +crate-type = ["cdylib", "rlib"] diff --git a/crates/ratchet-downloader/src/fetch.rs b/crates/ratchet-downloader/src/fetch.rs new file mode 100644 index 00000000..bfa51da4 --- /dev/null +++ b/crates/ratchet-downloader/src/fetch.rs @@ -0,0 +1,35 @@ +use js_sys::{ArrayBuffer, Uint8Array, JSON}; + +use wasm_bindgen::{prelude::*, JsValue}; +use wasm_bindgen_futures::JsFuture; +use web_sys::{Request, RequestInit, RequestMode, Response}; + +fn to_error(value: JsValue) -> JsError { + JsError::new( + JSON::stringify(&value) + .map(|js_string| { + js_string + .as_string() + .unwrap_or(String::from("An unknown error occurred.")) + }) + .unwrap_or(String::from("An unknown error occurred.")) + .as_str(), + ) +} +pub(crate) async fn fetch(url: &str) -> Result { + let mut opts = RequestInit::new(); + opts.method("GET"); + opts.mode(RequestMode::Cors); + + let request = Request::new_with_str_and_init(&url, &opts).map_err(to_error)?; + + let window = web_sys::window().unwrap(); + let resp_value = JsFuture::from(window.fetch_with_request(&request)) + .await + .map_err(to_error)?; + + assert!(resp_value.is_instance_of::()); + let resp: Response = resp_value.dyn_into().unwrap(); + + Ok(resp) +} diff --git a/crates/ratchet-downloader/src/huggingface/mod.rs b/crates/ratchet-downloader/src/huggingface/mod.rs new file mode 100644 index 00000000..c426b23e --- /dev/null +++ b/crates/ratchet-downloader/src/huggingface/mod.rs @@ -0,0 +1 @@ +pub mod repo; diff --git a/crates/ratchet-downloader/src/huggingface/repo.rs b/crates/ratchet-downloader/src/huggingface/repo.rs new file mode 100644 index 00000000..70a8c16c --- /dev/null +++ b/crates/ratchet-downloader/src/huggingface/repo.rs @@ -0,0 +1,7 @@ +pub struct Repo { + pub id: String, + pub revision: String, + pub repo_type: String, +} + +impl Repo {} diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs new file mode 100644 index 00000000..477856e1 --- /dev/null +++ b/crates/ratchet-downloader/src/lib.rs @@ -0,0 +1,560 @@ +use js_sys::Uint8Array; +#[cfg(test)] +use wasm_bindgen_test::{wasm_bindgen_test, wasm_bindgen_test_configure}; + +use futures_util::{AsyncReadExt, StreamExt}; +use gloo::console::{debug, error as log_error}; +use js_sys::JsString; +use wasm_bindgen::{prelude::*, JsCast, JsValue}; +use wasm_bindgen_futures::JsFuture; +use wasm_streams::ReadableStream; +use web_sys::{console, ReadableStreamGetReaderOptions, ReadableStreamReaderMode}; +use winnow::{binary::bits::bytes, prelude::*, stream::Stream, Bytes, Partial}; +use winnow::{binary::u32, binary::u64, combinator::preceded, Parser}; + +mod fetch; +pub mod huggingface; + +#[cfg(test)] +wasm_bindgen_test_configure!(run_in_browser); + +#[wasm_bindgen] +pub fn js_error(message: String) -> JsError { + JsError::new(message.as_str()) +} + +type BytesStream<'i> = Partial<&'i Bytes>; + +pub struct Model { + url: String, +} + +impl Model { + fn from_hf(repo_id: String) -> Self { + Self { + url: format!("https://huggingface.co/{}/resolve/main", repo_id), + } + } + + fn from_hf_with_revision(repo_id: String, revision: String) -> Self { + Self { + url: format!("https://huggingface.co/{repo_id}/resolve/{revision}"), + } + } + + fn from_custom(url: String) -> Self { + Self { url } + } + + async fn open_stream(&self, file_name: String) -> Result<(), JsError> { + let file_url = format!("{}/{}", self.url, file_name); + let response = fetch::fetch(file_url.as_str()).await?; + + let raw_body = response + .body() + .ok_or(js_error(format!("Failed to load {}", file_name)))?; + + let mut body = ReadableStream::from_raw(raw_body); + let reader = body.get_byob_reader(); + let mut async_read = reader.into_async_read(); + + let mut buf = [0u8; 100]; + let result = async_read.read_exact(&mut buf).await?; + + let mut test = BytesStream::new(Bytes::new(&buf)); + + let g1 = &test.next_token(); + let g2 = &test.next_token(); + let u = &test.next_token(); + let f = &test.next_token(); + debug!("Done!:", format!("{:?}{:?}{:?}{:?}", g1, g2, u, f)); + + Ok(()) + } +} + +mod gguf { + use std::io::Seek; + + use crate::BytesStream; + use winnow::binary::{u32, u64, u8, Endianness}; + + use anyhow::anyhow; + use winnow::combinator::fail; + use winnow::error::Needed; + use winnow::error::{AddContext, ContextError, ErrMode, StrContext}; + use winnow::prelude; + use winnow::stream::{Offset, Stream}; + use winnow::token::take; + use winnow::Parser; + + #[derive(Clone, Debug)] + pub struct MetadataKv { + pub key: String, + pub metadata_value: MetadataValue, + } + #[derive(Clone, Debug)] + pub struct Header { + pub version: u32, + pub tensor_count: u64, + pub metadata_kv: Vec, + } + + #[derive(Clone, Debug)] + pub struct TensorInfo { + pub name: String, + pub dimensions: Vec, + pub ggml_type: GgmlType, + pub offset: u64, + } + + #[derive(Clone, Debug)] + pub enum MetadataValueType { + GgufMetadataValueTypeUint8, + GgufMetadataValueTypeInt8, + GgufMetadataValueTypeUint16, + GgufMetadataValueTypeInt16, + GgufMetadataValueTypeUint32, + GgufMetadataValueTypeInt32, + GgufMetadataValueTypeFloat32, + GgufMetadataValueTypeBool, + GgufMetadataValueTypeString, + GgufMetadataValueTypeArray, + GgufMetadataValueTypeUint64, + GgufMetadataValueTypeInt64, + GgufMetadataValueTypeFloat64, + } + + #[derive(Clone, Debug)] + pub enum MetadataValue { + GgufMetadataValueUint8(u8), + GgufMetadataValueInt8(i8), + GgufMetadataValueUint16(u16), + GgufMetadataValueInt16(i16), + GgufMetadataValueUint32(u32), + GgufMetadataValueInt32(i32), + GgufMetadataValueFloat32(f32), + GgufMetadataValueBool(bool), + GgufMetadataValueString(String), + GgufMetadataValueArray(Vec), + GgufMetadataValueUint64(u64), + GgufMetadataValueInt64(i64), + GgufMetadataValueFloat64(f64), + } + + #[derive(Clone, Debug)] + pub enum GgmlType { + GgmlTypeF32, + GgmlTypeF16, + GgmlTypeQ4_0, + GgmlTypeQ4_1, + GgmlTypeQ5_0, + GgmlTypeQ5_1, + GgmlTypeQ8_0, + GgmlTypeQ8_1, + // k-quantizations + GgmlTypeQ2K, + GgmlTypeQ3K, + GgmlTypeQ4K, + GgmlTypeQ5K, + GgmlTypeQ6K, + GgmlTypeQ8K, + GgmlTypeI8, + GgmlTypeI16, + GgmlTypeI32, + GgmlTypeCount, + } + + #[inline] + fn parse_magic_number(input: &mut BytesStream) -> winnow::PResult<()> { + // [TODO] Fix endianness + (71, 71, 85, 70).parse_next(input).map(|_magic_number| ()) + } + + #[inline] + fn parse_version(input: &mut BytesStream) -> winnow::PResult { + u32(Endianness::Little).parse_next(input) + } + + #[inline] + fn parse_tensor_count(input: &mut BytesStream) -> winnow::PResult { + u64(Endianness::Little).parse_next(input) + } + + fn parse_metadata_value_array(input: &mut BytesStream) -> winnow::PResult { + (parse_metadata_value_type, u64(Endianness::Little)) + .flat_map(|(metadata_value_type, length)| { + winnow::combinator::repeat( + length as usize, + parse_metadata_value(metadata_value_type), + ) + }) + .parse_next(input) + .map(MetadataValue::GgufMetadataValueArray) + } + + fn parse_metadata_value_type(input: &mut BytesStream) -> winnow::PResult { + u32(Endianness::Little) + .parse_next(input) + .and_then(|metadata_value_type| match metadata_value_type { + 0 => Ok(MetadataValueType::GgufMetadataValueTypeUint8), + 1 => Ok(MetadataValueType::GgufMetadataValueTypeInt8), + 2 => Ok(MetadataValueType::GgufMetadataValueTypeUint16), + 3 => Ok(MetadataValueType::GgufMetadataValueTypeInt16), + 4 => Ok(MetadataValueType::GgufMetadataValueTypeUint32), + 5 => Ok(MetadataValueType::GgufMetadataValueTypeInt32), + 6 => Ok(MetadataValueType::GgufMetadataValueTypeFloat32), + 7 => Ok(MetadataValueType::GgufMetadataValueTypeBool), + 8 => Ok(MetadataValueType::GgufMetadataValueTypeString), + 9 => Ok(MetadataValueType::GgufMetadataValueTypeArray), + 10 => Ok(MetadataValueType::GgufMetadataValueTypeUint64), + 11 => Ok(MetadataValueType::GgufMetadataValueTypeInt64), + 12 => Ok(MetadataValueType::GgufMetadataValueTypeFloat64), + other => Err(cut_error(input, "Unknown metadata value type.")), + }) + } + + #[inline] + fn parse_metadata_value<'i>( + metadata_value_type: MetadataValueType, + ) -> impl Parser, MetadataValue, ContextError> { + move |input: &mut BytesStream| match metadata_value_type { + MetadataValueType::GgufMetadataValueTypeUint8 => winnow::binary::u8 + .map(MetadataValue::GgufMetadataValueUint8) + .parse_next(input), + + MetadataValueType::GgufMetadataValueTypeInt8 => winnow::binary::i8 + .map(MetadataValue::GgufMetadataValueInt8) + .parse_next(input), + MetadataValueType::GgufMetadataValueTypeUint16 => { + winnow::binary::u16(Endianness::Little) + .map(MetadataValue::GgufMetadataValueUint16) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeInt16 => { + winnow::binary::i16(Endianness::Little) + .map(MetadataValue::GgufMetadataValueInt16) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeUint32 => { + winnow::binary::u32(Endianness::Little) + .map(MetadataValue::GgufMetadataValueUint32) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeInt32 => { + winnow::binary::i32(Endianness::Little) + .map(MetadataValue::GgufMetadataValueInt32) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeFloat32 => { + winnow::binary::f32(Endianness::Little) + .map(MetadataValue::GgufMetadataValueFloat32) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeBool => winnow::binary::i8 + .map(|b| { + if b == 0 { + MetadataValue::GgufMetadataValueBool(true) + } else { + MetadataValue::GgufMetadataValueBool(false) + } + }) + .parse_next(input), + MetadataValueType::GgufMetadataValueTypeString => parse_string + .map(MetadataValue::GgufMetadataValueString) + .parse_next(input), + MetadataValueType::GgufMetadataValueTypeArray => { + parse_metadata_value_array.parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeUint64 => { + winnow::binary::u64(Endianness::Little) + .map(MetadataValue::GgufMetadataValueUint64) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeInt64 => { + winnow::binary::i64(Endianness::Little) + .map(MetadataValue::GgufMetadataValueInt64) + .parse_next(input) + } + MetadataValueType::GgufMetadataValueTypeFloat64 => { + winnow::binary::f64(Endianness::Little) + .map(MetadataValue::GgufMetadataValueFloat64) + .parse_next(input) + } + } + } + + fn parse_metadata_value_single(input: &mut BytesStream) -> winnow::PResult { + parse_metadata_value_type + .flat_map(|metadata_value_type| parse_metadata_value(metadata_value_type)) + .parse_next(input) + } + + #[inline] + fn parse_metadata_kv_count(input: &mut BytesStream) -> winnow::PResult { + u64(Endianness::Little).parse_next(input) + } + + fn parse_string(input: &mut BytesStream) -> winnow::PResult { + u64(Endianness::Little) + .flat_map(|count| take(count)) + .parse_next(input) + .and_then(|bytes| { + String::from_utf8(bytes.to_vec()).map_err(|err| { + let error_msg = "Failed to parse string"; + cut_error(input, error_msg) + }) + }) + } + + fn parse_ggml_type(input: &mut BytesStream) -> winnow::PResult { + u32(Endianness::Little) + .parse_next(input) + .and_then(|metadata_value_type| match metadata_value_type { + 0 => Ok(GgmlType::GgmlTypeF32), + 1 => Ok(GgmlType::GgmlTypeF16), + 2 => Ok(GgmlType::GgmlTypeQ4_0), + 3 => Ok(GgmlType::GgmlTypeQ4_1), + // 4 & 5 have been removed + 6 => Ok(GgmlType::GgmlTypeQ5_0), + 7 => Ok(GgmlType::GgmlTypeQ5_1), + 8 => Ok(GgmlType::GgmlTypeQ8_0), + 9 => Ok(GgmlType::GgmlTypeQ8_1), + // k-quantizations + 10 => Ok(GgmlType::GgmlTypeQ2K), + 11 => Ok(GgmlType::GgmlTypeQ3K), + 12 => Ok(GgmlType::GgmlTypeQ4K), + 13 => Ok(GgmlType::GgmlTypeQ5K), + 14 => Ok(GgmlType::GgmlTypeQ6K), + 15 => Ok(GgmlType::GgmlTypeQ8K), + 16 => Ok(GgmlType::GgmlTypeI8), + 17 => Ok(GgmlType::GgmlTypeI16), + 18 => Ok(GgmlType::GgmlTypeI32), + 19 => Ok(GgmlType::GgmlTypeCount), + other => Err(cut_error(input, "Unknown metadata value type.")), + }) + } + + fn parse_tensor_info(input: &mut BytesStream) -> winnow::PResult { + (parse_string, u32(Endianness::Little)) + .flat_map(|(name, n_dimensions)| { + let dimensions_parser = + winnow::combinator::repeat(n_dimensions as usize, u64(Endianness::Little)); + + (dimensions_parser, parse_ggml_type, u64(Endianness::Little)).map( + move |(dimensions, ggml_type, offset)| TensorInfo { + name: name.clone(), + dimensions, + ggml_type, + offset, + }, + ) + }) + .parse_next(input) + } + + fn cut_error( + input: &mut winnow::Partial<&winnow::Bytes>, + error_msg: &'static str, + ) -> ErrMode { + println!("Error: {}", error_msg); + ErrMode::Cut(ContextError::new().add_context(input, StrContext::Label(error_msg))) + } + + #[inline] + fn parse_metadata_kv<'i>( + metadata_kv_count: u64, + ) -> impl Parser, MetadataKv, ContextError> { + move |input: &mut BytesStream| { + (parse_string, parse_metadata_value_single) + .parse_next(input) + .map(|(key, metadata_value)| MetadataKv { + key, + metadata_value, + }) + } + } + + fn parse_padding<'i>(padding: u64) -> impl Parser, (), ContextError> { + move |input: &mut BytesStream| { + winnow::combinator::repeat(padding as usize, u8) + .parse_next(input) + .map(|_: Vec| ()) + } + } + + pub fn parse_header(input: &mut BytesStream) -> winnow::PResult
{ + ( + parse_magic_number, + parse_version, + parse_tensor_count, + parse_metadata_kv_count, + ) + .flat_map(|(_gguf, version, tensor_count, metadata_kv_count)| { + winnow::combinator::repeat( + metadata_kv_count as usize, + parse_metadata_kv(metadata_kv_count), + ) + .map(move |metadata_kv| Header { + version, + tensor_count, + metadata_kv, + }) + }) + .parse_next(input) + } + + fn align_offset(alignment: u64, offset: u64) -> u64 { + return offset + (alignment - (offset % alignment)) % alignment; + } + + pub fn load_gguf(mut file: std::fs::File) -> anyhow::Result<(Header, Vec)> { + let buffer_size = 1_000_000; + let buffer_growth_factor = 2; + let mut buffer = circular::Buffer::with_capacity(buffer_size); + + let header = parse_with_buffer(&mut file, &mut buffer, parse_header, buffer_growth_factor)?; + + let alignment = header + .metadata_kv + .iter() + .find_map(|metadata_kv| match metadata_kv { + MetadataKv { + key, + metadata_value: MetadataValue::GgufMetadataValueUint32(v), + } if key.eq("general.alignment") => Some(v.clone()), + _ => None, + }) + // As per spec assume 32 if general.alignment is not present + .unwrap_or(32); + + let mut tensor_infos: Vec = vec![]; + for i in 0..header.tensor_count { + let tensor_info = parse_with_buffer( + &mut file, + &mut buffer, + parse_tensor_info, + buffer_growth_factor, + )?; + tensor_infos.push(tensor_info); + } + + let position = file.stream_position()?; + let padding = align_offset(alignment as u64, position) - position; + println!("calculated padding: {}", padding); + let padding_parser = parse_padding(padding); + let _ = parse_with_buffer(&mut file, &mut buffer, padding_parser, buffer_growth_factor)?; + Ok((header, tensor_infos)) + } + + fn parse_with_buffer( + file: &mut std::fs::File, + buffer: &mut circular::Buffer, + mut parser: fn(&mut winnow::Partial<&winnow::Bytes>) -> Result>, + // mut parser: impl Parser, O, ContextError> + buffer_growth_factor: usize, + ) -> anyhow::Result { + use std::io::Read; + let mut result: anyhow::Result = Err(anyhow!("Failed to read file.",)); + 'outer: loop { + if buffer.available_space() == 0 { + buffer.grow(buffer_growth_factor * buffer.capacity()); + } + let read = file.read(buffer.space())?; + + if read == 0 { + println!("Read 0"); + // Should be EOF since we always make sure there is `available_space` + assert_ne!(buffer.available_space(), 0); + assert_eq!( + buffer.available_data(), + 0, + "leftover data: {}", + String::from_utf8_lossy(buffer.data()) + ); + break 'outer; + } + buffer.fill(read); + + 'inner: loop { + let input = BytesStream::new(winnow::Bytes::new(buffer.data())); + + let parser_result = parser.parse_peek(input); + match parser_result { + Ok((remainder, parser_output)) => { + // Tell the buffer how much we read + let consumed = remainder.offset_from(&input); + buffer.consume(consumed); + result = Ok(parser_output); + break 'outer; + } + Err(ErrMode::Backtrack(e)) => { + let pos = buffer.position(); + return Err(anyhow::format_err!(e.to_string())); + } + Err(ErrMode::Cut(e)) => { + return Err(anyhow::format_err!(e.to_string())); + } + Err(ErrMode::Incomplete(_)) => { + let new_capacity = buffer_growth_factor * buffer.capacity(); + buffer.grow(new_capacity); + break 'inner; + } + } + } + } + result + } + + pub fn to_std_error( + error: winnow::error::ErrMode, + ) -> std::io::Error { + match error { + ErrMode::Backtrack(_) => std::io::Error::new(std::io::ErrorKind::Other, "Backtrack"), + ErrMode::Cut(_) => std::io::Error::new(std::io::ErrorKind::Other, "Cut"), + ErrMode::Incomplete(_) => std::io::Error::new(std::io::ErrorKind::Other, "Needed"), + } + } +} + +#[cfg(test)] +mod tests { + + use crate::{gguf, BytesStream}; + + #[test] + fn test_load_gguf() -> anyhow::Result<()> { + let file = std::fs::File::open("./test-data/TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")?; + + let result = gguf::load_gguf(file); + + match result { + Ok((header, tensor_info)) => { + assert_eq!(header.version, 3); + assert_eq!(header.tensor_count, 201) + } + Err(err) => println!("Got an error: {:#?}", err), + } + + Ok(()) + } +} + +#[cfg(test)] +#[wasm_bindgen_test] +async fn pass() -> Result<(), JsValue> { + use js_sys::JsString; + + let model = Model::from_custom("http://localhost:8888".to_string()); + let stream = model + .open_stream( + "TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf" + .to_string(), + ) + .await + .map_err(|err| { + log_error!(err); + JsString::from("Failed to download file") + })?; + Ok(()) +} diff --git a/crates/ratchet-integration-tests/Cargo.toml b/crates/ratchet-integration-tests/Cargo.toml index 3c44d58d..b3826de9 100644 --- a/crates/ratchet-integration-tests/Cargo.toml +++ b/crates/ratchet-integration-tests/Cargo.toml @@ -6,4 +6,4 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dev-dependencies] -wasm-bindgen-test = "0.3.34" +wasm-bindgen-test.workspace = true diff --git a/justfile b/justfile index 99357e51..b792abef 100644 --- a/justfile +++ b/justfile @@ -1,2 +1,11 @@ line-count: - cd ./crates/ratchet-core && scc -irs --exclude-file kernels + cd ./crates/ratchet-core && scc -irs --exclude-file kernels +install-pyo3: + env PYTHON_CONFIGURE_OPTS="--enable-shared" pyenv install --verbose 3.10.6 + echo "Please PYO3_PYTHON to your .bashrc or .zshrc" +wasm CRATE: + RUSTFLAGS=--cfg=web_sys_unstable_apis wasm-pack build --target web -d `pwd`/target/pkg/{{CRATE}} --out-name {{CRATE}} ./crates/{{CRATE}} --release +wasm-test CRATE: + RUSTFLAGS="--cfg=web_sys_unstable_apis -Z threads=8" wasm-pack test --chrome `pwd`/crates/{{CRATE}} +wasm-test-headless CRATE: + RUSTFLAGS="--cfg=web_sys_unstable_apis -Z threads=8" wasm-pack test --chrome `pwd`/crates/{{CRATE}}