Skip to content

Commit 659ce64

Browse files
committed
convert concept_value_masks to torch.bool before using it in torch.where
1 parent 54159f9 commit 659ce64

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

Diff for: src/cehrbert/models/hf_models/hf_cehrbert.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ def forward(
528528
if labels is not None:
529529
# Skip the MLM predictions for label concepts if include_value_prediction is disabled
530530
if not self.config.include_value_prediction:
531-
labels = torch.where(concept_value_masks, -100, labels)
531+
labels = torch.where(concept_value_masks.to(torch.bool), -100, labels)
532532
loss_fct = nn.CrossEntropyLoss()
533533
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
534534
total_loss = masked_lm_loss

0 commit comments

Comments
 (0)