-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Titanet-Large: Compute EER in every epoch #12881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
@nithinraok On the other hand, if i set |
@stevehuang52 do you have a sample manifest on how test manifest should be when |
Example of a line in the manifest file when |
When |
Thanks @stevehuang52 , i will try it. |
Yes, you can pass them as a list of manifests: |
Great. And from the |
Yes you can directly specify them in the yaml file |
Hi @stevehuang52
It seems ok now, but i get this error:
Seems like shape incompatibility issue. Do you think the manifest file is correct? Are the label shapes correct? Do i need to convert them to one-hot vectors? |
On the other hand, if i set |
CUDA OOM also depends on each of your audio sample length. Keep them <=3sec |
Hi @ukemamaster, regarding the macro-accuracy error, there's a bug in the model code, which will be fixed by this PR: #12908. Regarding OOM error, that's probably due to the lengths of audios, could you please share the statistics of your audio lengths? Also, it'll be helpful if you can share where in the code the OOM occurred, since the classification layer may take a lot of GPU memory if you have a huge number of speakers and long audios. As @nithinraok suggests, we normally use less than 3s audios during training with |
Hi @stevehuang52, Incorporating the PR, I still get the same error. Printing the If i re-initialize the
It proceeds to training correctly. BUT is it safe to do that? The method looks like: def pair_evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'):
audio_signal_1, audio_signal_len_1, audio_signal_2, audio_signal_len_2, labels, _ = batch
_, audio_emb1 = self.forward(input_signal=audio_signal_1, input_signal_length=audio_signal_len_1)
_, audio_emb2 = self.forward(input_signal=audio_signal_2, input_signal_length=audio_signal_len_2)
# convert binary labels to -1, 1
loss_labels = (labels.float() - 0.5) * 2
cosine_sim = torch.cosine_similarity(audio_emb1, audio_emb2)
loss_value = torch.nn.functional.mse_loss(cosine_sim, loss_labels)
logits = torch.stack([1 - cosine_sim, cosine_sim], dim=-1)
acc_top_k = self._accuracy(logits=logits, labels=labels)
################# re-initialize self._macro_accuracy
self._macro_accuracy = Accuracy(num_classes=2, top_k=1, average='macro', task='multiclass').to(labels.get_device())
correct_counts, total_counts = self._accuracy.correct_counts_k, self._accuracy.total_counts_k
self._macro_accuracy.update(preds=logits, target=labels)
stats = self._macro_accuracy._final_state()
output = {
f'{tag}_loss': loss_value,
f'{tag}_correct_counts': correct_counts,
f'{tag}_total_counts': total_counts,
f'{tag}_acc_micro_top_k': acc_top_k,
f'{tag}_acc_macro_stats': stats,
f"{tag}_scores": cosine_sim,
f"{tag}_labels": labels,
}
if tag == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output)
else:
self.validation_step_outputs.append(output)
else:
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output)
else:
self.test_step_outputs.append(output)
return output |
Hi @ukemamaster, re-initializing the metric in validation step is probably fine in this particular case. Meanwhile, I just updated the PR to use a separate metric class |
@stevehuang52 With the updated PR the training goes fine. One question regarding the EER: In the logs,
the |
@ukemamaster In the paired eval case, the val_loss is the MSE loss between predicted cosine similarity and the groundtruth labels converted to -1 and 1. For saving checkpoint based on EER value, you need to set If you only need to monitor the EER but still save checkpoints based on |
Thanks @stevehuang52. |
Hi @nithinraok To compute the test EER after every epoch, i have to set
is_audio_pair: true
in thetitanet-large.yaml
file. How should the corresponding test data manifest file look like?The text was updated successfully, but these errors were encountered: