From f04866755fe79089f8e34688d01c8c75cfb56a7b Mon Sep 17 00:00:00 2001 From: travolin Date: Mon, 25 Nov 2024 11:14:35 -0800 Subject: [PATCH] Update model to use cuda when available --- crates/spyglass-llm/src/bin.rs | 11 ++++++++--- crates/spyglass-llm/src/model.rs | 8 ++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/crates/spyglass-llm/src/bin.rs b/crates/spyglass-llm/src/bin.rs index 75675dae9..932e0a1ec 100644 --- a/crates/spyglass-llm/src/bin.rs +++ b/crates/spyglass-llm/src/bin.rs @@ -51,9 +51,14 @@ pub async fn main() -> Result<(), anyhow::Error> { } }); - let mut client = - LlmClient::new("assets/models/llm/llama3/Llama-3.2-3B-Instruct.Q5_K_M.gguf".into())?; - client.chat(&prompt, Some(tx)).await?; + match LlmClient::new("assets/models/llm/llama3/Llama-3.2-3B-Instruct.Q5_K_M.gguf".into()) { + Ok(mut client) => { + client.chat(&prompt, Some(tx)).await?; + } + Err(error) => { + log::error!("Error loading model {error}"); + } + } Ok(()) } diff --git a/crates/spyglass-llm/src/model.rs b/crates/spyglass-llm/src/model.rs index 01310ec05..a06f9c784 100644 --- a/crates/spyglass-llm/src/model.rs +++ b/crates/spyglass-llm/src/model.rs @@ -39,6 +39,14 @@ impl LLMModel { candle::Device::Cpu } } + } else if candle::utils::cuda_is_available() { + match Device::new_cuda(0) { + Ok(dev) => dev, + Err(err) => { + log::warn!("Using CPU fallback. Unable to create CudaDevice: {err}"); + candle::Device::Cpu + } + } } else { candle::Device::Cpu };