Skip to content

Commit

Permalink
Fix #218
Browse files Browse the repository at this point in the history
- OpenELM had optional layers that were always created
- see #214
  • Loading branch information
davidkoski committed Mar 7, 2025
1 parent 06c825a commit 8d30ead
Showing 1 changed file with 13 additions and 15 deletions.
28 changes: 13 additions & 15 deletions Libraries/MLXLLM/Models/OpenELM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ func makeDivisible(_ v: Float, divisor: Int = 8, minValue: Float? = nil) -> Int
}

private class MultiHeadCausalAttention: Module {
var args: OpenElmConfiguration
let scale: Float
let heads: Int
let headDim: Int
Expand All @@ -36,18 +35,17 @@ private class MultiHeadCausalAttention: Module {
@ModuleInfo(key: "qkv_proj") var qkvProj: Linear
@ModuleInfo(key: "out_proj") var outProj: Linear

@ModuleInfo(key: "q_norm") var qNorm: RMSNorm
@ModuleInfo(key: "k_norm") var kNorm: RMSNorm
@ModuleInfo(key: "q_norm") var qNorm: RMSNorm?
@ModuleInfo(key: "k_norm") var kNorm: RMSNorm?

let rope: RoPE

public init(_ args: OpenElmConfiguration, layerId: Int) {
self.args = args
self.headDim = args.headDimensions
let modelDim = args.modelDim

self.heads = self.args.numQueryHeads[layerId]
self.kvHeads = self.args.kvHeads[layerId]
self.heads = args.numQueryHeads[layerId]
self.kvHeads = args.kvHeads[layerId]
self.scale = pow(Float(headDim), -0.5)

let opSize = (heads + (kvHeads * 2)) * headDim
Expand All @@ -74,7 +72,7 @@ private class MultiHeadCausalAttention: Module {
var keys = qkvSplit[1]
var values = qkvSplit[2]

if args.normalizeQkProjections {
if let qNorm, let kNorm {
queries = qNorm(queries)
keys = kNorm(keys)
}
Expand Down Expand Up @@ -181,27 +179,27 @@ public class OpenELMModel: Module, LLMModel, KVCacheDimensionProvider {
public let vocabularySize: Int
public let kvHeads: [Int]

let shareInputOutputLayers: Bool
let transformer: OpenELMModelInner

@ModuleInfo(key: "lm_head") var lmHead: Linear
@ModuleInfo(key: "lm_head") var lmHead: Linear?

public init(_ args: OpenElmConfiguration) {
self.vocabularySize = args.vocabularySize
self.kvHeads = args.kvHeads

self.transformer = OpenELMModelInner(args)
self.shareInputOutputLayers = args.shareInputOutputLayers
self._lmHead.wrappedValue = Linear(
args.numTransformerLayers, args.vocabularySize, bias: false)
if !args.shareInputOutputLayers {
self._lmHead.wrappedValue = Linear(
args.numTransformerLayers, args.vocabularySize, bias: false)
}
}

public func callAsFunction(_ inputs: MLXArray, cache: [KVCache]?) -> MLXArray {
var out = transformer(inputs, cache: cache)
if shareInputOutputLayers {
out = matmul(out, transformer.embedTokens.weight.T)
} else {
if let lmHead {
out = lmHead(out)
} else {
out = matmul(out, transformer.embedTokens.weight.T)
}

return out
Expand Down

0 comments on commit 8d30ead

Please sign in to comment.