diff --git a/crates/ratchet-core/src/ops/norm/mod.rs b/crates/ratchet-core/src/ops/norm/mod.rs index 9ed488ec..8acda16f 100644 --- a/crates/ratchet-core/src/ops/norm/mod.rs +++ b/crates/ratchet-core/src/ops/norm/mod.rs @@ -413,7 +413,6 @@ mod tests { let ln_prg = r#" import torch import torch.nn.functional as F - def layer_norm(input, scale, bias): (input, scale, bias) = (torch.from_numpy(input), torch.from_numpy(scale), torch.from_numpy(bias)) return F.layer_norm(input, (input.shape[-1],), weight=scale, bias=bias).numpy() @@ -427,7 +426,6 @@ def manual_rms_norm(input, scale): input = input * torch.rsqrt(variance + 1e-5) return (scale * input).numpy() "#; - let prg = match var { NormVariant::LayerNorm => ln_prg, NormVariant::RMSNorm => rms_prg, @@ -489,7 +487,7 @@ def manual_rms_norm(input, scale): #[test] fn debug_norm() { - let device = Device::request_device(DeviceRequest::GPU).unwrap(); + let device = Device::request_device(DeviceRequest::CPU).unwrap(); let prob = NormProblem { var: NormVariant::LayerNorm, B: 2, @@ -501,9 +499,14 @@ def manual_rms_norm(input, scale): } #[proptest(cases = 64)] - fn test_norm(prob: NormProblem) { + fn test_norm_gpu(prob: NormProblem) { let device = Device::request_device(DeviceRequest::GPU).unwrap(); - println!("prob = {:#?}", prob); + run_norm_trial(&device, prob).unwrap(); + } + + #[proptest(cases = 64)] + fn test_norm_cpu(prob: NormProblem) { + let device = Device::request_device(DeviceRequest::CPU).unwrap(); run_norm_trial(&device, prob).unwrap(); } } diff --git a/crates/ratchet-core/src/ops/unary.rs b/crates/ratchet-core/src/ops/unary.rs index e8818477..346e053e 100644 --- a/crates/ratchet-core/src/ops/unary.rs +++ b/crates/ratchet-core/src/ops/unary.rs @@ -409,7 +409,7 @@ def {}(a): } fn run_unary_trial(prob: UnaryProblem, device: Device) -> anyhow::Result<()> { - let UnaryProblem { op, B, M, N } = prob; + let UnaryProblem { op, B, M, N: _ } = prob; let a = Tensor::randn::(shape![B, M], Device::CPU); let args = match op {