Skip to content

Commit

Permalink
reinterpret entire array at once
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Jan 4, 2024
1 parent 88ed1de commit 5bc2477
Showing 1 changed file with 5 additions and 16 deletions.
21 changes: 5 additions & 16 deletions src/core/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,10 @@ impl Loader for SafeTensorLoader {
{
// Convert to fp32
let data: Vec<f32> = match tensor_view.dtype() {
Dtype::F32 => tensor_view
.data()
.chunks_exact(4)
.map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]))
.collect(),
Dtype::F32 => {
unsafe { std::mem::transmute::<_, &[f32]>(tensor_view.data()) }
.to_vec()
}
Dtype::F16 => tensor_view
.data()
.chunks_exact(2)
Expand Down Expand Up @@ -199,18 +198,8 @@ impl<'data> View for &'data Tensor {

impl<'a> std::convert::From<safetensors::tensor::TensorView<'a>> for Tensor {
fn from(value: safetensors::tensor::TensorView<'a>) -> Self {
let chunked = value.data().chunks_exact(std::mem::size_of::<f32>());

Tensor {
data: Box::new(
chunked
.map(|chunk| unsafe {
std::mem::transmute::<[u8; 4], f32>([
chunk[0], chunk[1], chunk[2], chunk[3],
])
})
.collect::<Vec<f32>>(),
),
data: Box::new(unsafe { std::mem::transmute::<_, &'a [f32]>(value.data()) }.to_vec()),
}
}
}
Expand Down

0 comments on commit 5bc2477

Please sign in to comment.