diff --git a/examples/train_math_net/Cargo.toml b/examples/train_math_net/Cargo.toml index 98fc83f4..a17b9b07 100644 --- a/examples/train_math_net/Cargo.toml +++ b/examples/train_math_net/Cargo.toml @@ -7,10 +7,12 @@ edition = "2021" [features] metal = ["dep:luminal_metal"] +cuda = ["dep:luminal_cuda"] [dependencies] luminal = {path="../.."} luminal_training = {path="../../crates/luminal_training"} luminal_nn = {path="../../crates/luminal_nn"} rand = "0.8.5" -luminal_metal = {path="../../crates/luminal_metal", optional=true} \ No newline at end of file +luminal_metal = {path="../../crates/luminal_metal", optional=true} +luminal_cuda = {path="../../crates/luminal_cuda", optional=true} \ No newline at end of file diff --git a/examples/train_math_net/src/main.rs b/examples/train_math_net/src/main.rs index 8e083903..34a9394e 100644 --- a/examples/train_math_net/src/main.rs +++ b/examples/train_math_net/src/main.rs @@ -19,8 +19,11 @@ fn main() { let mut weights = params(&model); let grads = cx.compile(Autograd::new(&weights, loss), ()); let (mut new_weights, lr) = sgd_on_graph(&mut cx, &weights, &grads); + cx.keep_tensors(&new_weights); + cx.keep_tensors(&weights); lr.set(1e-1); + #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] cx.compile( GenericCompiler::default(), ( @@ -46,6 +49,19 @@ fn main() { ), ); + #[cfg(feature = "cuda")] + cx.compile( + luminal_cuda::CudaCompiler::::default(), + ( + &mut input, + &mut target, + &mut loss, + &mut output, + &mut weights, + &mut new_weights, + ), + ); + let mut rng = thread_rng(); let (mut loss_avg, mut acc_avg) = (ExponentialAverage::new(1.0), ExponentialAverage::new(0.0)); let mut iter = 0;