Skip to content

Commit

Permalink
Whisper model loads
Browse files Browse the repository at this point in the history
  • Loading branch information
jafioti committed May 11, 2024
1 parent 25f4247 commit 92bf747
Show file tree
Hide file tree
Showing 13 changed files with 115,423 additions and 139 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ rustc-hash = "1.1.0"
uuid = { version = "1.7.0", features = ["v4"] }
as-any = "0.3.1"
egg = "0.9.5"
symbolic_expressions = "5.0.3"

[dev-dependencies]
dfdx = { version = "0.13", features = ["f16"] }
Expand Down
12 changes: 7 additions & 5 deletions crates/luminal_metal/src/elementwise_fusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,17 +378,18 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
.insert(i, Regex::new(&format!(r"input{i}([^0-9]|$)")).unwrap());
input_regexes.get(&i).unwrap()
};
let (ind, val) = (ind_exp.clone().simplify(), val_exp.clone().simplify());
*subexp = re
.replace_all(
subexp,
&if *val_exp != true {
&if val != true {
format!(
"({} != 0 ? (float)input{i}[{}] : 0.0)$1",
expr_to_metal_string(val_exp),
expr_to_metal_string(ind_exp)
expr_to_metal_string(&val),
expr_to_metal_string(&ind)
)
} else {
format!("(float)input{i}[{}]$1", expr_to_metal_string(ind_exp))
format!("(float)input{i}[{}]$1", expr_to_metal_string(&ind))
},
)
.to_string();
Expand All @@ -407,7 +408,8 @@ impl<T: MetalFloat> Compiler for ElementwiseFusionCompiler<T> {
)
},
)
.0;
.0
.simplify();
if val_exp != true {
*subexp = format!(
"(({} != 0) ? {subexp} : 0.0)",
Expand Down
42 changes: 40 additions & 2 deletions crates/luminal_nn/src/convolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pub struct Conv1D<
const DILATION: usize = 0,
> {
pub weight: GraphTensor<R3<CH_OUT, CH_IN, KERNEL>>,
pub bias: Option<GraphTensor<R1<CH_OUT>>>,
}

impl<
Expand All @@ -30,6 +31,35 @@ impl<
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
),
bias: None,
}
}
}

impl<
const CH_IN: usize,
const CH_OUT: usize,
const KERNEL: usize,
const STRIDE: usize,
const DILATION: usize,
> Conv1D<CH_IN, CH_OUT, KERNEL, STRIDE, DILATION>
{
pub fn initialize_bias(cx: &mut Graph) -> Self {
// Init weight as uniform(-1, 1)
let mut rng = thread_rng();
Self {
weight: cx.named_tensor("Weight").set(
(0..(CH_IN * CH_OUT * KERNEL))
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
),
bias: Some(
cx.named_tensor("Bias").set(
(0..CH_OUT)
.map(|_| rng.gen_range(-1_f32..1_f32))
.collect::<Vec<_>>(),
),
),
}
}
}
Expand All @@ -44,6 +74,9 @@ impl<
{
fn serialize(&self, s: &mut luminal::module::Serializer) {
s.tensor("weight", self.weight);
if let Some(bias) = self.bias {
s.tensor("bias", bias);
}
}
}

Expand Down Expand Up @@ -127,7 +160,8 @@ impl<
PhantomData<DimOut>,
),
) -> Self::Output {
self.weight
let mut o = self
.weight
.dyn_reshape::<(Const<CH_OUT>, Dyn<'-'>)>(vec![CH_OUT.into(), (CH_IN * KERNEL).into()])
.expand::<(Batch1, Batch2, Const<CH_OUT>, Dyn<'-'>), _>()
.matmul(
Expand All @@ -142,7 +176,11 @@ impl<
(CH_IN * KERNEL).into(),
DimOut::size(),
]),
)
);
if let Some(b) = self.bias {
o += b.expand();
}
o
}
}

Expand Down
7 changes: 7 additions & 0 deletions examples/whisper/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,10 @@ luminal_nn = {path="../../crates/luminal_nn"}
luminal_metal = { path = "../../crates/luminal_metal" }
num-traits = "0.2.18"
num_cpus = "1.16.0"
byteorder = "1.5.0"
memmap2 = "0.9.4"
tokenizers = "0.15.2"
itertools = "0.12.1"
metal-rs = { version = "0.27.0", package = "metal", features = [
"mps",
] }
7 changes: 7 additions & 0 deletions examples/whisper/setup/setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/usr/bin/env bash
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )

echo "Downloading Model and Tokenizer..."
curl --location https://huggingface.co/openai/whisper-tiny/resolve/main/tokenizer.json?download=true --output $SCRIPT_DIR/tokenizer.json
curl --location https://huggingface.co/FL33TW00D-HF/whisper-tiny/resolve/main/tiny_f32.gguf?download=true --output $SCRIPT_DIR/whisper-tiny.gguf
echo "Done!"
Loading

0 comments on commit 92bf747

Please sign in to comment.