Skip to content

Commit b37be22

Browse files
committed
Unify behavior of self.losses in compute_loss across backends.
1 parent 04cad40 commit b37be22

File tree

3 files changed

+35
-4
lines changed

3 files changed

+35
-4
lines changed

keras/backend/jax/trainer.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ def compute_loss_and_updates(
5151
return_losses=True,
5252
**kwargs,
5353
)
54+
if losses:
55+
# Make forward pass losses available to compute_loss.
56+
self._losses_override.clear()
57+
self._losses_override = losses
5458

5559
loss, variables = self.stateless_compute_loss(
5660
trainable_variables,
@@ -61,14 +65,12 @@ def compute_loss_and_updates(
6165
y_pred=y_pred,
6266
sample_weight=sample_weight,
6367
)
68+
if losses:
69+
self._losses_override.clear()
6470
(trainable_variables, non_trainable_variables, metrics_variables) = (
6571
variables
6672
)
6773

68-
# Sum forward pass losses
69-
if losses:
70-
loss += ops.sum(losses)
71-
7274
# Handle loss scaling
7375
unscaled_loss = loss
7476
if training and self.optimizer is not None:

keras/layers/layer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ def __init__(
274274
self._trainable = trainable
275275
self._losses = []
276276
self._loss_ids = set()
277+
self._losses_override = []
277278

278279
self._call_signature = inspect.signature(self.call)
279280
call_signature_parameters = [
@@ -1091,6 +1092,8 @@ def _get_regularization_losses(self):
10911092
@property
10921093
def losses(self):
10931094
"""List of scalar losses from `add_loss`, regularizers and sublayers."""
1095+
if self._losses_override:
1096+
return self._losses_override
10941097
losses = self._get_own_losses()
10951098
for layer in self._flatten_layers(include_self=False):
10961099
losses.extend(layer._get_own_losses())

keras/trainers/trainer_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1319,3 +1319,29 @@ def compute_loss(
13191319
self.assertAlmostEqual(
13201320
history.history["custom"][0], history.history["loss"][0] * 4
13211321
)
1322+
1323+
@pytest.mark.requires_trainable_backend
1324+
def test_fwd_pass_loss_presence_in_compute_loss(self):
1325+
1326+
class MyModel(keras.Model):
1327+
def __init__(self):
1328+
super().__init__()
1329+
self.custom_metric = keras.metrics.Mean(name="custom")
1330+
self.dense = keras.layers.Dense(2, activity_regularizer="l2")
1331+
1332+
def call(self, x):
1333+
return self.dense(x)
1334+
1335+
def compute_loss(
1336+
self, x=None, y=None, y_pred=None, sample_weight=None
1337+
):
1338+
loss = super().compute_loss(x, y, y_pred, sample_weight)
1339+
self.custom_metric.update_state(sum(self.losses))
1340+
return loss
1341+
1342+
model = MyModel()
1343+
model.compile(optimizer="sgd", loss="mse")
1344+
x = np.ones((32, 4))
1345+
y = np.ones((32, 2)) * 2
1346+
history = model.fit(x, y)
1347+
self.assertGreater(history.history["custom"][0], 0.0)

0 commit comments

Comments
 (0)