We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 54159f9 commit 659ce64Copy full SHA for 659ce64
src/cehrbert/models/hf_models/hf_cehrbert.py
@@ -528,7 +528,7 @@ def forward(
528
if labels is not None:
529
# Skip the MLM predictions for label concepts if include_value_prediction is disabled
530
if not self.config.include_value_prediction:
531
- labels = torch.where(concept_value_masks, -100, labels)
+ labels = torch.where(concept_value_masks.to(torch.bool), -100, labels)
532
loss_fct = nn.CrossEntropyLoss()
533
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
534
total_loss = masked_lm_loss
0 commit comments