From d370c4a7448156fa90226b7062a74e2dc14c071a Mon Sep 17 00:00:00 2001 From: Ivar Flakstad <69173633+ivarflakstad@users.noreply.github.com> Date: Wed, 20 Nov 2024 13:11:58 +0100 Subject: [PATCH] chore: update imports --- crates/ratchet-core/src/cpu/binary.rs | 5 +---- crates/ratchet-core/src/cpu/mod.rs | 9 +++------ crates/ratchet-core/src/cpu/norm.rs | 22 ++++++++-------------- crates/ratchet-core/src/cpu/reindex.rs | 5 +---- crates/ratchet-core/src/cpu/rope.rs | 1 - crates/ratchet-core/src/cpu/utils.rs | 6 +++--- 6 files changed, 16 insertions(+), 32 deletions(-) diff --git a/crates/ratchet-core/src/cpu/binary.rs b/crates/ratchet-core/src/cpu/binary.rs index e90a1d37..146d80f4 100644 --- a/crates/ratchet-core/src/cpu/binary.rs +++ b/crates/ratchet-core/src/cpu/binary.rs @@ -1,8 +1,5 @@ use crate::cpu::cpu_store_result; -use crate::{ - Binary, BinaryOp, CPUOperation, DType, OpGuards, Operation, OperationError, RVec, StorageView, - Tensor, TensorDType, -}; +use crate::{Binary, BinaryOp, CPUOperation, DType, OperationError, Tensor, TensorDType}; use core::marker::PhantomData; use half::{bf16, f16}; use num_traits::NumOps; diff --git a/crates/ratchet-core/src/cpu/mod.rs b/crates/ratchet-core/src/cpu/mod.rs index c5d8daf0..d07af976 100644 --- a/crates/ratchet-core/src/cpu/mod.rs +++ b/crates/ratchet-core/src/cpu/mod.rs @@ -6,17 +6,14 @@ pub mod rope; mod unary; mod utils; -use crate::cpu::unary::unary_apply_fn; use crate::{ - dequantize, Binary, BinaryOp, CPUBuffer, Cast, Concat, DType, IndexSelect, InvariantError, - LazyOp, OpGuards, Operation, OperationError, RVec, Shape, Storage, StorageView, Strides, - Tensor, TensorDType, Unary, UnaryOp, + dequantize, Cast, Concat, DType, IndexSelect, InvariantError, LazyOp, Operation, + OperationError, RVec, Shape, Tensor, TensorDType, }; use anyhow::anyhow; -use core::marker::PhantomData; use half::{bf16, f16}; -use num_traits::Float; use rope::cpu_rope; +use unary::unary_apply_fn; use utils::cpu_store_result; pub fn apply_operation(op: LazyOp, dst: Tensor) -> Result { diff --git a/crates/ratchet-core/src/cpu/norm.rs b/crates/ratchet-core/src/cpu/norm.rs index f28ee408..0a930d20 100644 --- a/crates/ratchet-core/src/cpu/norm.rs +++ b/crates/ratchet-core/src/cpu/norm.rs @@ -5,12 +5,12 @@ use crate::cpu::utils::cpu_store_result; use crate::reindex::broadcast_vector; use crate::{ shape, CPUOperation, DType, GroupNorm, InvariantError, Norm, NormOp, OperationError, Shape, - Strides, Tensor, TensorDType, + Tensor, TensorDType, }; use core::iter::Sum; use half::{bf16, f16}; use num::Float; -use num_traits::{AsPrimitive, NumOps}; +use num_traits::NumOps; impl CPUOperation for NormOp { fn apply_cpu(&self, dst: Tensor) -> Result { @@ -77,7 +77,7 @@ where let mean_dim = shape.numel() / shape[dim]; let mut result = vec![T::zero(); mean_dim]; let step = src.len() / mean_dim; - let n = T::from((step as f32)).unwrap(); + let n = T::from(step as f32).unwrap(); (0..src.len()) .step_by(step) @@ -100,8 +100,8 @@ where { let src_shape = input.shape(); let rank = input.rank(); - let M = src_shape[rank - 2]; let N = src_shape[rank - 1]; + let norm_shape = shape!(N); let input = input.to_vec::()?; let scale = scale.to_vec::()?; @@ -110,12 +110,9 @@ where None => None, }; - let dst_shape = dst.shape(); - let mut result = vec![T::zero(); dst.shape().numel()]; - let mut x = input.clone(); - let mut mu = mean(&x, src_shape, rank - 1); + let mu = mean(&x, src_shape, rank - 1); let mut mu2 = mu.clone(); square(&mut mu2); let mut x2 = input.clone(); @@ -136,11 +133,11 @@ where broadcast_vector(&x2, &mut v); mul(&mut x, &v); - let scale_b = broadcast(&scale, &shape!(N), src_shape); + let scale_b = broadcast(&scale, &norm_shape, src_shape); mul(&mut x, &scale_b); if let Some(bias) = bias { - let bias_b = broadcast(&bias, &shape!(N), src_shape); + let bias_b = broadcast(&bias, &norm_shape, src_shape); add(&mut x, &bias_b); } @@ -191,14 +188,11 @@ where { let src_shape = input.shape(); let rank = input.rank(); - let M = src_shape[rank - 2]; let N = src_shape[rank - 1]; let mut x = input.to_vec::()?; let scale = scale.to_vec::()?; - let dst_shape = dst.shape(); - let mut x2 = x.clone(); square(&mut x2); let mut x2 = mean(&x2, src_shape, rank - 1); @@ -218,7 +212,7 @@ where Ok(()) } -fn apply_group_norm(n: &GroupNorm, dst: Tensor) -> Result { +fn apply_group_norm(_n: &GroupNorm, dst: Tensor) -> Result { //let result = norm(&b.src.to_vec::()?, b.src.shape(), b.to()); //cpu_store_result(&dst, &result); Ok(dst) diff --git a/crates/ratchet-core/src/cpu/reindex.rs b/crates/ratchet-core/src/cpu/reindex.rs index 863578c0..9cb79d3f 100644 --- a/crates/ratchet-core/src/cpu/reindex.rs +++ b/crates/ratchet-core/src/cpu/reindex.rs @@ -1,7 +1,4 @@ -use super::utils::{ - cpu_store_result, TensorIterator, - TensorIterator::{Contiguous, Strided}, -}; +use super::utils::cpu_store_result; use crate::{ Broadcast, CPUOperation, DType, OperationError, Reindex, Shape, Slice, Strides, Tensor, TensorDType, diff --git a/crates/ratchet-core/src/cpu/rope.rs b/crates/ratchet-core/src/cpu/rope.rs index 0ba36bf8..a1d407f0 100644 --- a/crates/ratchet-core/src/cpu/rope.rs +++ b/crates/ratchet-core/src/cpu/rope.rs @@ -3,7 +3,6 @@ use crate::{ cpu::{cpu_store_result, gemm::gemm, reindex::slice}, shape, DType, OperationError, RoPE, Shape, Strides, Tensor, }; -use anyhow::anyhow; pub fn cpu_rope(op: RoPE, dst: Tensor) -> Result { match op.input().dt() { diff --git a/crates/ratchet-core/src/cpu/utils.rs b/crates/ratchet-core/src/cpu/utils.rs index f1f565ff..e3874291 100644 --- a/crates/ratchet-core/src/cpu/utils.rs +++ b/crates/ratchet-core/src/cpu/utils.rs @@ -1,5 +1,5 @@ use crate::{CPUBuffer, Shape, Storage, Strides, Tensor}; -use bytemuck::{Contiguous, NoUninit}; +use bytemuck::NoUninit; use std::ops::Range; pub fn cpu_store_result(dst: &Tensor, data: &[T]) { @@ -137,11 +137,11 @@ impl<'a> From<(&'a Shape, &'a Strides, usize)> for StridedIterator<'a> { #[cfg(test)] mod tests { use proptest::prelude::*; - use test_strategy::{proptest, Arbitrary}; + use test_strategy::proptest; use crate::{shape, Shape, Strides}; - use super::{StridedIterator, TensorIterator}; + use super::TensorIterator; #[derive(Debug)] struct IterProblem {