Skip to content

Commit

Permalink
Merge branch 'jafioti:main' into feature-conv3d
Browse files Browse the repository at this point in the history
  • Loading branch information
NewBornRustacean authored May 2, 2024
2 parents aa939cb + bb97aab commit 9f615d3
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
4 changes: 3 additions & 1 deletion examples/train_math_net/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ edition = "2021"

[features]
metal = ["dep:luminal_metal"]
cuda = ["dep:luminal_cuda"]

[dependencies]
luminal = {path="../.."}
luminal_training = {path="../../crates/luminal_training"}
luminal_nn = {path="../../crates/luminal_nn"}
rand = "0.8.5"
luminal_metal = {path="../../crates/luminal_metal", optional=true}
luminal_metal = {path="../../crates/luminal_metal", optional=true}
luminal_cuda = {path="../../crates/luminal_cuda", optional=true}
16 changes: 16 additions & 0 deletions examples/train_math_net/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ fn main() {
let mut weights = params(&model);
let grads = cx.compile(Autograd::new(&weights, loss), ());
let (mut new_weights, lr) = sgd_on_graph(&mut cx, &weights, &grads);
cx.keep_tensors(&new_weights);
cx.keep_tensors(&weights);
lr.set(1e-1);

#[cfg(all(not(feature = "metal"), not(feature = "cuda")))]
cx.compile(
GenericCompiler::default(),
(
Expand All @@ -46,6 +49,19 @@ fn main() {
),
);

#[cfg(feature = "cuda")]
cx.compile(
luminal_cuda::CudaCompiler::<f32>::default(),
(
&mut input,
&mut target,
&mut loss,
&mut output,
&mut weights,
&mut new_weights,
),
);

let mut rng = thread_rng();
let (mut loss_avg, mut acc_avg) = (ExponentialAverage::new(1.0), ExponentialAverage::new(0.0));
let mut iter = 0;
Expand Down

0 comments on commit 9f615d3

Please sign in to comment.