Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
- do not fail parameter update validation for "invalid" keys (e.g. _freqs)
  • Loading branch information
davidkoski committed Mar 7, 2025
1 parent 0e40fa0 commit d93271b
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
24 changes: 18 additions & 6 deletions Source/MLXNN/Module.swift
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,11 @@ open class Module {
p.update(newArray)

case (.value(.parameters(let p)), .none):
throw UpdateError.keyNotFound(base: describeType(self), key: key)
if Self.parameterIsValid(key) {
throw UpdateError.keyNotFound(base: describeType(self), key: key)
} else {
// ignore it -- this isn't a parameter that requires update
}

case (.array(let array), .array(let values)):
for (i, (arrayItem, valueItem)) in zip(array, values).enumerated() {
Expand Down Expand Up @@ -930,6 +934,14 @@ public protocol UnaryLayer {

extension Module {

/// Return `true` if the given parameter name is valid -- should be considered for
/// validation, enumeration, etc.
///
/// Specifically this will filter out parameters with keys starting with `_`.
static public func parameterIsValid(_ key: String) -> Bool {
!key.hasPrefix("_")
}

/// Filter that will accept all values.
///
/// ### See Also
Expand Down Expand Up @@ -962,8 +974,8 @@ extension Module {
static public let filterValidParameters: @Sendable (Module, String, ModuleItem) -> Bool = {
(module: Module, key: String, item: ModuleItem) in
switch item {
case .array, .dictionary: !key.hasPrefix("_")
case .value(.parameters), .value(.module): !key.hasPrefix("_")
case .array, .dictionary: parameterIsValid(key)
case .value(.parameters), .value(.module): parameterIsValid(key)
default: false
}
}
Expand All @@ -978,8 +990,8 @@ extension Module {
static public let filterLocalParameters: @Sendable (Module, String, ModuleItem) -> Bool = {
(module: Module, key: String, item: ModuleItem) in
switch item {
case .array, .dictionary: !key.hasPrefix("_")
case .value(.parameters): !key.hasPrefix("_")
case .array, .dictionary: parameterIsValid(key)
case .value(.parameters): parameterIsValid(key)
default: false
}
}
Expand All @@ -996,7 +1008,7 @@ extension Module {
(module: Module, key: String, item: ModuleItem) in
switch item {
case .array, .dictionary, .value(.parameters), .value(.module):
!key.hasPrefix("_") && !module.noGrad.contains(key)
parameterIsValid(key) && !module.noGrad.contains(key)
default: false
}
}
Expand Down
23 changes: 23 additions & 0 deletions Tests/MLXTests/ModuleTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -770,4 +770,27 @@ class ModuleTests: XCTestCase {
XCTAssertTrue(pm.mlp.0 is QuantizedLinear)
XCTAssertTrue(pm.mlp.2 is QuantizedLinear)
}

func testModulesWithMLXArrayProperties() throws {
// https://github.com/ml-explore/mlx-swift-examples/issues/218

class SuScaledRotaryEmbedding: Module {
let _freqs: MLXArray

override init() {
_freqs = MLXArray(7)
}
}

let rope = SuScaledRotaryEmbedding()

// no parameters
XCTAssertEqual(rope.parameters().count, 0)

// but it can see the _freqs property
XCTAssertEqual(rope.items().count, 1)

// this should not throw because _freqs is not considered
try rope.update(parameters: .init(), verify: .all)
}
}

0 comments on commit d93271b

Please sign in to comment.