Skip to content

Commit

Permalink
Finish tensor info loading
Browse files Browse the repository at this point in the history
  • Loading branch information
sigma-andex committed Jan 27, 2024
1 parent f3abaeb commit b95ac40
Showing 1 changed file with 30 additions and 21 deletions.
51 changes: 30 additions & 21 deletions crates/ratchet-downloader/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#![feature(seek_stream_len)]
use js_sys::Uint8Array;
#[cfg(test)]
use wasm_bindgen_test::{wasm_bindgen_test, wasm_bindgen_test_configure};
Expand Down Expand Up @@ -77,6 +78,8 @@ impl Model {
}

mod gguf {
use std::io::Seek;

use crate::BytesStream;
use winnow::binary::{u32, u64, u8, Endianness};

Expand All @@ -101,6 +104,7 @@ mod gguf {
pub metadata_kv: Vec<MetadataKv>,
}

#[derive(Clone, Debug)]
pub struct TensorInfo {
pub name: String,
pub dimensions: Vec<u64>,
Expand Down Expand Up @@ -396,31 +400,37 @@ mod gguf {
.parse_next(input)
}

pub fn load_gguf(mut file: std::fs::File) -> anyhow::Result<Header> {
pub fn load_gguf(mut file: std::fs::File) -> anyhow::Result<(Header, Vec<TensorInfo>)> {
let buffer_size = 1_000_000;
let min_buffer_growth = 1_000_000;
let buffer_growth_factor = 2;
let mut buffer = circular::Buffer::with_capacity(buffer_size);

let mut parser = parse_header;

let res = parse_with_buffer(file, buffer, parser, buffer_growth_factor)?;
res
let header = parse_with_buffer(&mut file, &mut buffer, parse_header, buffer_growth_factor)?;
let mut tensor_infos: Vec<TensorInfo> = vec![];
for i in 0..header.tensor_count {
let tensor_info = parse_with_buffer(
&mut file,
&mut buffer,
parse_tensor_info,
buffer_growth_factor,
)?;
tensor_infos.push(tensor_info);
}
Ok((header, tensor_infos))
}

fn parse_with_buffer(
mut file: std::fs::File,
mut buffer: circular::Buffer,
mut parser: fn(
&mut winnow::Partial<&winnow::Bytes>,
) -> Result<Header, ErrMode<ContextError>>,
fn parse_with_buffer<O>(
file: &mut std::fs::File,
buffer: &mut circular::Buffer,
mut parser: fn(&mut winnow::Partial<&winnow::Bytes>) -> Result<O, ErrMode<ContextError>>,
buffer_growth_factor: usize,
) -> Result<Result<Header, anyhow::Error>, anyhow::Error> {
) -> anyhow::Result<O> {
use std::io::Read;
let mut result: anyhow::Result<Header> = Err(anyhow!(
"An unknown error occurred while parsing the header.",
));
let mut result: anyhow::Result<O> = Err(anyhow!("Failed to read file.",));
'outer: loop {
if buffer.available_space() == 0 {
buffer.grow(buffer_growth_factor * buffer.capacity());
}
let read = file.read(buffer.space())?;

if read == 0 {
Expand Down Expand Up @@ -464,8 +474,7 @@ mod gguf {
}
}
}
let res = result;
Ok(res)
result
}

pub fn to_std_error(
Expand All @@ -485,14 +494,14 @@ mod tests {
use crate::{gguf, BytesStream};

#[test]
fn test_parse_header() -> anyhow::Result<()> {
fn test_load_gguf() -> anyhow::Result<()> {
let mut file = std::fs::File::open("./test-data/TheBloke_TinyLlama-1.1B-Chat-v1.0-GGUF/tinyllama-1.1b-chat-v1.0.Q4_K_M.gguf")?;

let result = gguf::load_gguf(file);

match result {
Ok(header) => {
// println!("{:#?}", header.metadata_kv);
Ok((header, tensor_info)) => {
println!("{:#?}", tensor_info);
assert_eq!(header.version, 3);
assert_eq!(header.tensor_count, 201)
}
Expand Down

0 comments on commit b95ac40

Please sign in to comment.