diff --git a/src/core/serialization.rs b/src/core/serialization.rs index 62b5a5da..619daba6 100644 --- a/src/core/serialization.rs +++ b/src/core/serialization.rs @@ -120,11 +120,10 @@ impl Loader for SafeTensorLoader { { // Convert to fp32 let data: Vec = match tensor_view.dtype() { - Dtype::F32 => tensor_view - .data() - .chunks_exact(4) - .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]])) - .collect(), + Dtype::F32 => { + unsafe { std::mem::transmute::<_, &[f32]>(tensor_view.data()) } + .to_vec() + } Dtype::F16 => tensor_view .data() .chunks_exact(2) @@ -199,18 +198,8 @@ impl<'data> View for &'data Tensor { impl<'a> std::convert::From> for Tensor { fn from(value: safetensors::tensor::TensorView<'a>) -> Self { - let chunked = value.data().chunks_exact(std::mem::size_of::()); - Tensor { - data: Box::new( - chunked - .map(|chunk| unsafe { - std::mem::transmute::<[u8; 4], f32>([ - chunk[0], chunk[1], chunk[2], chunk[3], - ]) - }) - .collect::>(), - ), + data: Box::new(unsafe { std::mem::transmute::<_, &'a [f32]>(value.data()) }.to_vec()), } } }