Skip to content

Commit fa6d0ae

Browse files
authored
used the existing subject_split to split the cohort automatically instead of doing it manually (#63)
1 parent 1517c3d commit fa6d0ae

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/cehrbert/data_generators/hf_data_generator/meds_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,8 +436,12 @@ def _create_cehrbert_data_from_meds(
436436
assert split in ["held_out", "train", "tuning"]
437437
batches = []
438438
if data_args.cohort_folder:
439-
cohort = pd.read_parquet(os.path.join(os.path.expanduser(data_args.cohort_folder), split))
440-
for cohort_row in cohort.itertuples():
439+
# Load the entire cohort
440+
cohort = pd.read_parquet(os.path.expanduser(data_args.cohort_folder))
441+
patient_split = get_subject_split(os.path.expanduser(data_args.data_folder))
442+
subject_ids = patient_split[split]
443+
cohort_split = cohort[cohort.subject_id.isin(subject_ids)]
444+
for cohort_row in cohort_split.itertuples():
441445
subject_id = cohort_row.subject_id
442446
prediction_time = cohort_row.prediction_time
443447
label = int(cohort_row.boolean_value)

0 commit comments

Comments
 (0)