Skip to content

Commit

Permalink
Small changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Joe Fioti authored and Joe Fioti committed Dec 26, 2023
1 parent 941a8b9 commit 9e3bea8
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 22 deletions.
12 changes: 3 additions & 9 deletions examples/llama/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,16 +199,10 @@ impl<
// Add KV cache
let k = cache
.0
.contiguous()
.concat_along::<(Batch, Const<NUM_HEADS>, TotSeq, Const<HEAD_DIM>), Axis<2>, _>(
k.contiguous(),
);
.concat_along::<(Batch, Const<NUM_HEADS>, TotSeq, Const<HEAD_DIM>), Axis<2>, _>(k);
let v = cache
.1
.contiguous()
.concat_along::<(Batch, Const<NUM_HEADS>, TotSeq, Const<HEAD_DIM>), Axis<2>, _>(
v.contiguous(),
);
.concat_along::<(Batch, Const<NUM_HEADS>, TotSeq, Const<HEAD_DIM>), Axis<2>, _>(v);
(k, v)
} else {
(k.realize(), v.realize())
Expand All @@ -220,7 +214,7 @@ impl<
// We only mask on a non-kv cache pass
if cache.is_none() {
let attention_mask = self.k_proj.graph().triu::<CurSeq, TotSeq>(1) * f16::MIN.to_f32();
w = w + attention_mask.expand();
w += attention_mask.expand();
}
w = w.softmax::<3>();

Expand Down
10 changes: 9 additions & 1 deletion src/core/shape/symbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,19 @@ where
}

/// A symbolic expression
#[derive(Clone, Default)]
#[derive(Clone)]
pub struct GenericExpression<S: ExpressionStorage> {
pub terms: S,
}

impl<S: ExpressionStorage> Default for GenericExpression<S> {
fn default() -> Self {
let mut s = S::default();
s.push(Term::Num(0));
Self { terms: s }
}
}

impl<S: Copy + ExpressionStorage> Copy for GenericExpression<S> {}

impl<S: PartialEq + ExpressionStorage> PartialEq for GenericExpression<S> {
Expand Down
10 changes: 1 addition & 9 deletions src/hl_ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::ops::DivAssign;
use std::ops::MulAssign;
use std::ops::RemAssign;
use std::ops::SubAssign;
use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
use std::ops::{Add, Div, Mul, Rem, Sub};

impl<S: Shape> Add for GraphTensor<S> {
type Output = GraphTensor<S>;
Expand Down Expand Up @@ -105,14 +105,6 @@ impl<S: Shape> RemAssign for GraphTensor<S> {
}
}

impl<S: Shape> Neg for GraphTensor<S> {
type Output = GraphTensor<S>;

fn neg(self) -> Self::Output {
self * -1.0
}
}

impl<S: Shape> Add<f32> for GraphTensor<S> {
type Output = GraphTensor<S>;

Expand Down
4 changes: 2 additions & 2 deletions src/hl_ops/movement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,9 @@ impl<S: Shape> GraphTensor<S> {
) -> GraphTensor<Dst> {
let dim = Ax::as_array()[0] as usize;
// Create padding
let mut a_padding = self.shape.padding;
let mut a_padding = vec![(Expression::default(), Expression::default()); self.shape.len()];
a_padding[dim].1 = rhs.shape.shape()[dim].clone().into();
let mut b_padding = rhs.shape.padding;
let mut b_padding = vec![(Expression::default(), Expression::default()); rhs.shape.len()];
b_padding[dim].0 = self.shape.shape()[dim].clone().into();
// Pad and add
self.pad(&a_padding) + rhs.pad(&b_padding)
Expand Down
10 changes: 9 additions & 1 deletion src/hl_ops/unary.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
use std::ops::{Add, Mul};
use std::ops::{Add, Mul, Neg};

use crate::{op, prelude::*};

impl<S: Shape> Neg for GraphTensor<S> {
type Output = GraphTensor<S>;

fn neg(self) -> Self::Output {
self * -1.0
}
}

impl<S: Shape> GraphTensor<S> {
/// Base 2 log
pub fn log_2(self) -> GraphTensor<S> {
Expand Down

0 comments on commit 9e3bea8

Please sign in to comment.