Skip to content

Commit

Permalink
Assign operators
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Dec 26, 2023
1 parent 33b7f09 commit 941a8b9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 6 deletions.
5 changes: 2 additions & 3 deletions examples/llama/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ where
type Output = GraphTensor<Sh>;

fn forward(&self, input: GraphTensor<Sh>) -> Self::Output {
let gate = input.matmul(self.gate_proj.permute());
let gate = gate.swish();
let gate = input.matmul(self.gate_proj.permute()).swish();
let up = input.matmul(self.up_proj.permute()) * gate;
up.matmul(self.down_proj.permute())
}
Expand Down Expand Up @@ -218,7 +217,7 @@ impl<
let mut w = q
.matmul(k.permute())
.mul((HEAD_DIM as f64).sqrt().recip() as f32);
// We don't need to mask on a kv cached pass
// We only mask on a non-kv cache pass
if cache.is_none() {
let attention_mask = self.k_proj.graph().triu::<CurSeq, TotSeq>(1) * f16::MIN.to_f32();
w = w + attention_mask.expand();
Expand Down
41 changes: 38 additions & 3 deletions src/hl_ops/binary.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
use crate::op;
use crate::prelude::*;
use std::ops::AddAssign;
use std::ops::DivAssign;
use std::ops::MulAssign;
use std::ops::RemAssign;
use std::ops::SubAssign;
use std::ops::{Add, Div, Mul, Neg, Rem, Sub};

impl<S: Shape> Add<GraphTensor<S>> for GraphTensor<S> {
impl<S: Shape> Add for GraphTensor<S> {
type Output = GraphTensor<S>;

fn add(mut self, mut rhs: GraphTensor<S>) -> Self::Output {
Expand All @@ -19,15 +24,27 @@ impl<S: Shape> Add<GraphTensor<S>> for GraphTensor<S> {
}
}

impl<S: Shape> Sub<GraphTensor<S>> for GraphTensor<S> {
impl<S: Shape> AddAssign for GraphTensor<S> {
fn add_assign(&mut self, rhs: Self) {
*self = *self + rhs;
}
}

impl<S: Shape> Sub for GraphTensor<S> {
type Output = GraphTensor<S>;

fn sub(self, rhs: GraphTensor<S>) -> Self::Output {
self + -rhs
}
}

impl<S: Shape> Mul<GraphTensor<S>> for GraphTensor<S> {
impl<S: Shape> SubAssign for GraphTensor<S> {
fn sub_assign(&mut self, rhs: Self) {
*self = *self - rhs;
}
}

impl<S: Shape> Mul for GraphTensor<S> {
type Output = GraphTensor<S>;

fn mul(mut self, mut rhs: GraphTensor<S>) -> Self::Output {
Expand All @@ -44,6 +61,12 @@ impl<S: Shape> Mul<GraphTensor<S>> for GraphTensor<S> {
}
}

impl<S: Shape> MulAssign for GraphTensor<S> {
fn mul_assign(&mut self, rhs: Self) {
*self = *self * rhs;
}
}

#[allow(clippy::suspicious_arithmetic_impl)]
impl<S: Shape> Div<GraphTensor<S>> for GraphTensor<S> {
type Output = GraphTensor<S>;
Expand All @@ -53,6 +76,12 @@ impl<S: Shape> Div<GraphTensor<S>> for GraphTensor<S> {
}
}

impl<S: Shape> DivAssign for GraphTensor<S> {
fn div_assign(&mut self, rhs: Self) {
*self = *self / rhs;
}
}

impl<S: Shape> Rem<GraphTensor<S>> for GraphTensor<S> {
type Output = GraphTensor<S>;

Expand All @@ -70,6 +99,12 @@ impl<S: Shape> Rem<GraphTensor<S>> for GraphTensor<S> {
}
}

impl<S: Shape> RemAssign for GraphTensor<S> {
fn rem_assign(&mut self, rhs: Self) {
*self = *self % rhs;
}
}

impl<S: Shape> Neg for GraphTensor<S> {
type Output = GraphTensor<S>;

Expand Down

0 comments on commit 941a8b9

Please sign in to comment.