Skip to content

Commit

Permalink
removed embedding init
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Jan 4, 2024
1 parent 9b81ef2 commit e7c78e9
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/core/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use colored::Colorize;
use itertools::Itertools;
use petgraph::{graph::NodeIndex, stable_graph::StableGraph, visit::EdgeRef, Direction};

use super::compiler_utils::{ToIds, ToIdsMut};
use super::compiler_utils::ToIdsMut;

pub type MainGraph = StableGraph<Box<dyn Operator>, Dependency>;

Expand Down
14 changes: 2 additions & 12 deletions src/nn/embedding.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use rand::{thread_rng, Rng};

use crate::prelude::*;

pub struct Embedding<const N: usize, const DIM: usize> {
Expand All @@ -8,17 +6,9 @@ pub struct Embedding<const N: usize, const DIM: usize> {

impl<const A: usize, const B: usize> InitModule for Embedding<A, B> {
fn initialize(cx: &mut Graph) -> Self {
let s = Self {
Self {
weight: cx.named_tensor("Embedding Weight"),
};
// Init weight as uniform(-1, 1)
let mut rng = thread_rng();
s.weight.set(
(0..(A * B))
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
);
s
}
}
}

Expand Down

0 comments on commit e7c78e9

Please sign in to comment.