File tree Expand file tree Collapse file tree 1 file changed +6
-2
lines changed
src/cehrbert/data_generators/hf_data_generator Expand file tree Collapse file tree 1 file changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -436,8 +436,12 @@ def _create_cehrbert_data_from_meds(
436
436
assert split in ["held_out" , "train" , "tuning" ]
437
437
batches = []
438
438
if data_args .cohort_folder :
439
- cohort = pd .read_parquet (os .path .join (os .path .expanduser (data_args .cohort_folder ), split ))
440
- for cohort_row in cohort .itertuples ():
439
+ # Load the entire cohort
440
+ cohort = pd .read_parquet (os .path .expanduser (data_args .cohort_folder ))
441
+ patient_split = get_subject_split (os .path .expanduser (data_args .data_folder ))
442
+ subject_ids = patient_split [split ]
443
+ cohort_split = cohort [cohort .subject_id .isin (subject_ids )]
444
+ for cohort_row in cohort_split .itertuples ():
441
445
subject_id = cohort_row .subject_id
442
446
prediction_time = cohort_row .prediction_time
443
447
label = int (cohort_row .boolean_value )
You can’t perform that action at this time.
0 commit comments