diff --git a/train.py b/train.py index 94e4960cd..29b51bae3 100644 --- a/train.py +++ b/train.py @@ -176,8 +176,8 @@ def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_ # Remove timesteps that we didn't decode at, or are pads # pack_padded_sequence is an easy trick to do this - scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True) - targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True) + scores, *_ = pack_padded_sequence(scores, decode_lengths, batch_first=True) + targets, *_ = pack_padded_sequence(targets, decode_lengths, batch_first=True) # Calculate loss loss = criterion(scores, targets) @@ -267,8 +267,8 @@ def validate(val_loader, encoder, decoder, criterion): # Remove timesteps that we didn't decode at, or are pads # pack_padded_sequence is an easy trick to do this scores_copy = scores.clone() - scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True) - targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True) + scores, *_ = pack_padded_sequence(scores, decode_lengths, batch_first=True) + targets, *_ = pack_padded_sequence(targets, decode_lengths, batch_first=True) # Calculate loss loss = criterion(scores, targets)