Skip to content

Commit

Permalink
Fix omniglot double normalization issue (#201)
Browse files Browse the repository at this point in the history
  • Loading branch information
steremma authored Nov 11, 2020
1 parent c402014 commit 9af5ee7
Showing 1 changed file with 0 additions and 2 deletions.
2 changes: 0 additions & 2 deletions examples/vision/maml_omniglot.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ def fast_adapt(batch, learner, loss, adaptation_steps, shots, ways, device):
# Adapt the model
for step in range(adaptation_steps):
train_error = loss(learner(adaptation_data), adaptation_labels)
train_error /= len(adaptation_data)
learner.adapt(train_error)

# Evaluate the adapted model
predictions = learner(evaluation_data)
valid_error = loss(predictions, evaluation_labels)
valid_error /= len(evaluation_data)
valid_accuracy = accuracy(predictions, evaluation_labels)
return valid_error, valid_accuracy

Expand Down

0 comments on commit 9af5ee7

Please sign in to comment.