diff --git a/inferno/extensions/criteria/core.py b/inferno/extensions/criteria/core.py index 733f4332..0e7eb861 100755 --- a/inferno/extensions/criteria/core.py +++ b/inferno/extensions/criteria/core.py @@ -8,7 +8,7 @@ class Criteria(nn.Module): """Aggregate multiple criteria to one.""" - def __init__(self, *criteria): + def __init__(self, *criteria, weights=None): super(Criteria, self).__init__() if len(criteria) == 1 and isinstance(criteria[0], (list, tuple)): criteria = list(criteria[0]) @@ -19,6 +19,12 @@ def __init__(self, *criteria): "Criterion must be a torch module." self.criteria = criteria + if not weights: + weights = (1,) * len(criteria) + assert len(weights) == len(criteria), \ + "weight must be given for every criterion" + self.weights = weights + def forward(self, prediction, target): assert isinstance(prediction, (list, tuple)), \ "`prediction` must be a list or a tuple, got {} instead."\ @@ -30,8 +36,9 @@ def forward(self, prediction, target): "Number of predictions must equal the number of targets. " \ "Got {} predictions but {} targets.".format(len(prediction), len(target)) # Compute losses - losses = [criterion(prediction, target) - for _prediction, _target, criterion in zip(prediction, target, self.criteria)] + losses = [weight * criterion(_prediction, _target) + for weight, _prediction, _target, criterion + in zip(self.weights, prediction, target, self.criteria)] # Aggegate losses loss = reduce(lambda x, y: x + y, losses) # Done