From d95368890b97425b02bf2c913db49c3e88b387ad Mon Sep 17 00:00:00 2001 From: Ian Stenbit <3072903+ianstenbit@users.noreply.github.com> Date: Sat, 17 Jun 2023 11:51:55 -0600 Subject: [PATCH] Add support for sample_weights in CompileLoss (#370) * Add support for sample_weights in CompileLoss * is not None --- keras_core/trainers/compile_utils.py | 24 +++++++++++++++++++---- keras_core/trainers/compile_utils_test.py | 8 ++++++-- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/keras_core/trainers/compile_utils.py b/keras_core/trainers/compile_utils.py index 647e09ec94..6a60e09188 100644 --- a/keras_core/trainers/compile_utils.py +++ b/keras_core/trainers/compile_utils.py @@ -548,19 +548,35 @@ def build(self, y_true, y_pred): self.flat_loss_weights = flat_loss_weights self.built = True - def call(self, y_true, y_pred): + def __call__(self, y_true, y_pred, sample_weight=None): + with ops.name_scope(self.name): + return self.call(y_true, y_pred, sample_weight) + + def call(self, y_true, y_pred, sample_weight=None): if not self.built: self.build(y_true, y_pred) y_true = nest.flatten(y_true) y_pred = nest.flatten(y_pred) + + if sample_weight is not None: + sample_weight = nest.flatten(sample_weight) + else: + sample_weight = [None for _ in y_true] + loss_values = [] - for loss, y_t, y_p, w in zip( - self.flat_losses, y_true, y_pred, self.flat_loss_weights + for loss, y_t, y_p, loss_weight, sample_weight in zip( + self.flat_losses, + y_true, + y_pred, + self.flat_loss_weights, + sample_weight, ): if loss: - value = w * ops.cast(loss(y_t, y_p), dtype=backend.floatx()) + value = loss_weight * ops.cast( + loss(y_t, y_p, sample_weight), dtype=backend.floatx() + ) loss_values.append(value) if loss_values: total_loss = sum(loss_values) diff --git a/keras_core/trainers/compile_utils_test.py b/keras_core/trainers/compile_utils_test.py index 0fe4085458..95e4827993 100644 --- a/keras_core/trainers/compile_utils_test.py +++ b/keras_core/trainers/compile_utils_test.py @@ -270,6 +270,10 @@ def test_dict_output_case(self, broadcast): "a": np.array([[1.2, 1.1], [1.0, 0.9], [0.8, 0.7]]), "b": np.array([[0.6, 0.5], [0.4, 0.3], [0.2, 0.1]]), } + sample_weight = { + "a": np.array([1.0, 2.0, 3.0]), + "b": np.array([3.0, 2.0, 1.0]), + } compile_loss.build(y_true, y_pred) - value = compile_loss(y_true, y_pred) - self.assertAllClose(value, 0.953333, atol=1e-5) + value = compile_loss(y_true, y_pred, sample_weight) + self.assertAllClose(value, 1.266666, atol=1e-5)