diff --git a/Cargo.toml b/Cargo.toml index 8d0becb9..52086fe2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] members = [ "crates/ratchet-core", + "crates/ratchet-downloader", "crates/ratchet-integration-tests", "crates/ratchet-loader", "crates/ratchet-models", @@ -28,6 +29,7 @@ derive-new = "0.6.0" log = "0.4.20" thiserror = "1.0.56" byteorder = "1.5.0" +wasm-bindgen-test = "0.3.34" [workspace.dev-dependencies] hf-hub = "0.3.0" diff --git a/crates/ratchet-core/Cargo.toml b/crates/ratchet-core/Cargo.toml index 97b1619b..cf0b67ff 100644 --- a/crates/ratchet-core/Cargo.toml +++ b/crates/ratchet-core/Cargo.toml @@ -31,6 +31,7 @@ glam = "0.25.0" pollster = "0.3.0" futures-intrusive = "0.5.0" anyhow = "1.0.79" +getrandom = { version = "0.2", features = ["js"] } # Needed for wasm support in `num` trait num = "0.4.1" rand_distr = { version = "0.4.3", optional = true } rand = { version = "0.8.4", optional = true } diff --git a/crates/ratchet-core/src/gpu/device.rs b/crates/ratchet-core/src/gpu/device.rs index fefe081e..0531e666 100644 --- a/crates/ratchet-core/src/gpu/device.rs +++ b/crates/ratchet-core/src/gpu/device.rs @@ -51,7 +51,7 @@ impl PartialEq for WgpuDevice { impl WgpuDevice { pub async fn new() -> Result { #[cfg(target_arch = "wasm32")] - let adapter = Self::select_adapter().await; + let adapter = Self::select_adapter().await?; #[cfg(not(target_arch = "wasm32"))] let adapter = Self::select_adapter()?; @@ -106,7 +106,7 @@ impl WgpuDevice { } #[cfg(target_arch = "wasm32")] - async fn select_adapter() -> Adapter { + async fn select_adapter() -> Result { let instance = wgpu::Instance::default(); let backends = wgpu::util::backend_bits_from_env().unwrap_or(wgpu::Backends::PRIMARY); instance @@ -116,10 +116,10 @@ impl WgpuDevice { force_fallback_adapter: false, }) .await - .map_err(|e| { - log::error!("Failed to create device: {:?}", e); - e - })? + .ok_or({ + log::error!("Failed to request adapter."); + DeviceError::AdapterRequestFailed + }) } #[cfg(not(target_arch = "wasm32"))] diff --git a/crates/ratchet-downloader/Cargo.toml b/crates/ratchet-downloader/Cargo.toml new file mode 100644 index 00000000..df9e5075 --- /dev/null +++ b/crates/ratchet-downloader/Cargo.toml @@ -0,0 +1,34 @@ +[package] +name = "ratchet-downloader" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[dependencies] +ratchet-loader = { path = "../ratchet-loader" } +wasm-bindgen = "0.2.84" +wasm-bindgen-futures = "0.4.39" +js-sys = "0.3.64" +gloo = "0.11.0" +reqwest = "0.11.23" + +[dependencies.web-sys] +features = [ + 'console', + 'Headers', + 'Request', + 'RequestInit', + 'RequestMode', + 'Response', + 'Window', + 'Navigator', + 'StorageManager', + 'CacheStorage' +] +version = "0.3.64" + +[dev-dependencies] +wasm-bindgen-test.workspace = true + +[lib] +crate-type = ["cdylib", "rlib"] diff --git a/crates/ratchet-downloader/src/fetch.rs b/crates/ratchet-downloader/src/fetch.rs new file mode 100644 index 00000000..bfa51da4 --- /dev/null +++ b/crates/ratchet-downloader/src/fetch.rs @@ -0,0 +1,35 @@ +use js_sys::{ArrayBuffer, Uint8Array, JSON}; + +use wasm_bindgen::{prelude::*, JsValue}; +use wasm_bindgen_futures::JsFuture; +use web_sys::{Request, RequestInit, RequestMode, Response}; + +fn to_error(value: JsValue) -> JsError { + JsError::new( + JSON::stringify(&value) + .map(|js_string| { + js_string + .as_string() + .unwrap_or(String::from("An unknown error occurred.")) + }) + .unwrap_or(String::from("An unknown error occurred.")) + .as_str(), + ) +} +pub(crate) async fn fetch(url: &str) -> Result { + let mut opts = RequestInit::new(); + opts.method("GET"); + opts.mode(RequestMode::Cors); + + let request = Request::new_with_str_and_init(&url, &opts).map_err(to_error)?; + + let window = web_sys::window().unwrap(); + let resp_value = JsFuture::from(window.fetch_with_request(&request)) + .await + .map_err(to_error)?; + + assert!(resp_value.is_instance_of::()); + let resp: Response = resp_value.dyn_into().unwrap(); + + Ok(resp) +} diff --git a/crates/ratchet-downloader/src/huggingface/mod.rs b/crates/ratchet-downloader/src/huggingface/mod.rs new file mode 100644 index 00000000..c426b23e --- /dev/null +++ b/crates/ratchet-downloader/src/huggingface/mod.rs @@ -0,0 +1 @@ +pub mod repo; diff --git a/crates/ratchet-downloader/src/huggingface/repo.rs b/crates/ratchet-downloader/src/huggingface/repo.rs new file mode 100644 index 00000000..70a8c16c --- /dev/null +++ b/crates/ratchet-downloader/src/huggingface/repo.rs @@ -0,0 +1,7 @@ +pub struct Repo { + pub id: String, + pub revision: String, + pub repo_type: String, +} + +impl Repo {} diff --git a/crates/ratchet-downloader/src/lib.rs b/crates/ratchet-downloader/src/lib.rs new file mode 100644 index 00000000..7297f570 --- /dev/null +++ b/crates/ratchet-downloader/src/lib.rs @@ -0,0 +1,61 @@ +#[cfg(test)] +use wasm_bindgen_test::{wasm_bindgen_test, wasm_bindgen_test_configure}; + +use gloo::console::error as log_error; +use wasm_bindgen::{prelude::*, JsValue}; + +mod fetch; +pub mod huggingface; + +#[cfg(test)] +wasm_bindgen_test_configure!(run_in_browser); + +pub struct Model { + url: String, +} + +impl Model { + fn from_hf(repo_id: String) -> Self { + Self { + url: format!("https://huggingface.co/{}/resolve/main", repo_id), + } + } + + fn from_hf_with_revision(repo_id: String, revision: String) -> Self { + Self { + url: format!("https://huggingface.co/{repo_id}/resolve/{revision}"), + } + } + + fn from_custom(url: String) -> Self { + Self { url } + } + + async fn get(&self, file_name: String) -> Result<(), JsError> { + let file_url = format!("{}/{}", self.url, file_name); + // let response = fetch::fetch(file_url.as_str()).await?; + + let res = reqwest::Client::new() + .get(file_url) + // .header("Accept", "application/vnd.github.v3+json") + .send() + .await?; + Ok(()) + } +} + +#[cfg(test)] +#[wasm_bindgen_test] +async fn pass() -> Result<(), JsValue> { + use js_sys::JsString; + + let model = Model::from_hf("jantxu/ratchet-test".to_string()); + let file = model + .get("model.safetensors".to_string()) + .await + .map_err(|err| { + log_error!(err); + JsString::from("Failed to download file") + }); + Ok(()) +} diff --git a/crates/ratchet-integration-tests/Cargo.toml b/crates/ratchet-integration-tests/Cargo.toml index 3c44d58d..b3826de9 100644 --- a/crates/ratchet-integration-tests/Cargo.toml +++ b/crates/ratchet-integration-tests/Cargo.toml @@ -6,4 +6,4 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dev-dependencies] -wasm-bindgen-test = "0.3.34" +wasm-bindgen-test.workspace = true diff --git a/justfile b/justfile index 99357e51..025f505d 100644 --- a/justfile +++ b/justfile @@ -1,2 +1,9 @@ line-count: - cd ./crates/ratchet-core && scc -irs --exclude-file kernels + cd ./crates/ratchet-core && scc -irs --exclude-file kernels +install-pyo3: + env PYTHON_CONFIGURE_OPTS="--enable-shared" pyenv install --verbose 3.10.6 + echo "Please PYO3_PYTHON to your .bashrc or .zshrc" +wasm CRATE: + RUSTFLAGS=--cfg=web_sys_unstable_apis wasm-pack build --target web -d `pwd`/target/pkg/{{CRATE}} --out-name {{CRATE}} ./crates/{{CRATE}} --release +wasm-test CRATE: + RUSTFLAGS=--cfg=web_sys_unstable_apis wasm-pack test --chrome `pwd`/crates/{{CRATE}}