-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Joe Fioti
authored and
Joe Fioti
committed
Jan 5, 2024
1 parent
a23e536
commit 9aaff41
Showing
7 changed files
with
87 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,4 +6,5 @@ Cargo.lock | |
*.npx | ||
*.npz | ||
/**/llama-7b-hf | ||
/**/mistral-7b-hf | ||
/**/mistral-7b-hf | ||
/**/setup_weights/target |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
use std::fs::File; | ||
|
||
use luminal::{op::Function, prelude::*}; | ||
use memmap2::MmapOptions; | ||
use metal_rs::{Device, MTLResourceOptions}; | ||
use safetensors::SafeTensors; | ||
|
||
/// Load the model in the same way dfdx-llama does | ||
pub struct MetalFp16SafetensorsLoader { | ||
paths: Vec<String>, | ||
} | ||
|
||
impl MetalFp16SafetensorsLoader { | ||
pub fn new<S: ToString>(paths: &[S]) -> Self { | ||
Self { | ||
paths: paths.iter().map(|s| s.to_string()).collect(), | ||
} | ||
} | ||
} | ||
|
||
impl Loader for MetalFp16SafetensorsLoader { | ||
fn load<M: SerializeModule>(self, model: &M, graph: &mut Graph) { | ||
for (weight_name, node_index) in state_dict(model) { | ||
if let Some(loading_node) = graph | ||
.graph | ||
.node_weight_mut(node_index) | ||
.and_then(|op| op.as_any_mut().downcast_mut::<Function>()) | ||
{ | ||
let file_paths = self.paths.clone(); | ||
loading_node.1 = Box::new(move |_| { | ||
for file_path in file_paths.iter() { | ||
let file = File::open(file_path).unwrap(); | ||
let buffer = unsafe { MmapOptions::new().map(&file).unwrap() }; | ||
let safetensors = SafeTensors::deserialize(&buffer).unwrap(); | ||
|
||
if let Ok(tensor_view) = safetensors.tensor(&weight_name.replace('/', ".")) | ||
{ | ||
let buffer = Device::system_default() | ||
.unwrap() | ||
.new_buffer_with_bytes_no_copy( | ||
tensor_view.data().as_ptr() as *const _, | ||
tensor_view.data().len() as u64, | ||
MTLResourceOptions::StorageModeShared, | ||
None, | ||
); | ||
return vec![Tensor { | ||
data: Box::new(buffer), | ||
}]; | ||
} | ||
} | ||
|
||
panic!("Tensor \"{weight_name}\" not found in files"); | ||
}); | ||
} | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters