Skip to content

Commit

Permalink
Add support for sample_weights in CompileLoss (keras-team#370)
Browse files Browse the repository at this point in the history
* Add support for sample_weights in CompileLoss

* is not None
  • Loading branch information
ianstenbit authored Jun 17, 2023
1 parent 2ae265c commit d953688
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
24 changes: 20 additions & 4 deletions keras_core/trainers/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,19 +548,35 @@ def build(self, y_true, y_pred):
self.flat_loss_weights = flat_loss_weights
self.built = True

def call(self, y_true, y_pred):
def __call__(self, y_true, y_pred, sample_weight=None):
with ops.name_scope(self.name):
return self.call(y_true, y_pred, sample_weight)

def call(self, y_true, y_pred, sample_weight=None):
if not self.built:
self.build(y_true, y_pred)

y_true = nest.flatten(y_true)
y_pred = nest.flatten(y_pred)

if sample_weight is not None:
sample_weight = nest.flatten(sample_weight)
else:
sample_weight = [None for _ in y_true]

loss_values = []

for loss, y_t, y_p, w in zip(
self.flat_losses, y_true, y_pred, self.flat_loss_weights
for loss, y_t, y_p, loss_weight, sample_weight in zip(
self.flat_losses,
y_true,
y_pred,
self.flat_loss_weights,
sample_weight,
):
if loss:
value = w * ops.cast(loss(y_t, y_p), dtype=backend.floatx())
value = loss_weight * ops.cast(
loss(y_t, y_p, sample_weight), dtype=backend.floatx()
)
loss_values.append(value)
if loss_values:
total_loss = sum(loss_values)
Expand Down
8 changes: 6 additions & 2 deletions keras_core/trainers/compile_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,10 @@ def test_dict_output_case(self, broadcast):
"a": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]),
"b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]),
}
sample_weight = {
"a": np.array([1.0, 2.0, 3.0]),
"b": np.array([3.0, 2.0, 1.0]),
}
compile_loss.build(y_true, y_pred)
value = compile_loss(y_true, y_pred)
self.assertAllClose(value, 0.953333, atol=1e-5)
value = compile_loss(y_true, y_pred, sample_weight)
self.assertAllClose(value, 1.266666, atol=1e-5)

0 comments on commit d953688

Please sign in to comment.