diff --git a/src/core/graph.rs b/src/core/graph.rs index 4cde2576..e8556379 100644 --- a/src/core/graph.rs +++ b/src/core/graph.rs @@ -16,7 +16,7 @@ use colored::Colorize; use itertools::Itertools; use petgraph::{graph::NodeIndex, stable_graph::StableGraph, visit::EdgeRef, Direction}; -use super::compiler_utils::{ToIds, ToIdsMut}; +use super::compiler_utils::ToIdsMut; pub type MainGraph = StableGraph, Dependency>; diff --git a/src/nn/embedding.rs b/src/nn/embedding.rs index e5f67ee0..0de6c97c 100644 --- a/src/nn/embedding.rs +++ b/src/nn/embedding.rs @@ -1,5 +1,3 @@ -use rand::{thread_rng, Rng}; - use crate::prelude::*; pub struct Embedding { @@ -8,17 +6,9 @@ pub struct Embedding { impl InitModule for Embedding { fn initialize(cx: &mut Graph) -> Self { - let s = Self { + Self { weight: cx.named_tensor("Embedding Weight"), - }; - // Init weight as uniform(-1, 1) - let mut rng = thread_rng(); - s.weight.set( - (0..(A * B)) - .map(|_| rng.gen_range(-1_f32..1_f32)) - .collect::>(), - ); - s + } } }