From a94a8391cfb3ba9957f0c0f7ba13f2e468a3246c Mon Sep 17 00:00:00 2001 From: Tony Hung Date: Mon, 3 Aug 2020 14:55:06 -0400 Subject: [PATCH] Update train.py Fixed crash when trying to unpack values from pack_padded_sequence function --- train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 94e4960cd..29b51bae3 100644 --- a/train.py +++ b/train.py @@ -176,8 +176,8 @@ def train(train_loader, encoder, decoder, criterion, encoder_optimizer, decoder_ # Remove timesteps that we didn't decode at, or are pads # pack_padded_sequence is an easy trick to do this - scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True) - targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True) + scores, *_ = pack_padded_sequence(scores, decode_lengths, batch_first=True) + targets, *_ = pack_padded_sequence(targets, decode_lengths, batch_first=True) # Calculate loss loss = criterion(scores, targets) @@ -267,8 +267,8 @@ def validate(val_loader, encoder, decoder, criterion): # Remove timesteps that we didn't decode at, or are pads # pack_padded_sequence is an easy trick to do this scores_copy = scores.clone() - scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True) - targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True) + scores, *_ = pack_padded_sequence(scores, decode_lengths, batch_first=True) + targets, *_ = pack_padded_sequence(targets, decode_lengths, batch_first=True) # Calculate loss loss = criterion(scores, targets)