10
10
import pandas as pd
11
11
from datasets import Dataset , DatasetDict , Split
12
12
13
- from .. .data_generators .hf_data_generator .hf_dataset import apply_cehrbert_dataset_mapping
14
- from .. .data_generators .hf_data_generator .hf_dataset_mapping import MedToCehrBertDatasetMapping , birth_codes
15
- from .. .data_generators .hf_data_generator .meds_to_cehrbert_conversion_rules .meds_to_cehrbert_base import (
13
+ from cehrbert .data_generators .hf_data_generator .hf_dataset import apply_cehrbert_dataset_mapping
14
+ from cehrbert .data_generators .hf_data_generator .hf_dataset_mapping import MedToCehrBertDatasetMapping , birth_codes
15
+ from cehrbert .data_generators .hf_data_generator .meds_to_cehrbert_conversion_rules .meds_to_cehrbert_base import (
16
16
MedsToCehrBertConversion ,
17
17
)
18
- from .. .med_extension .schema_extension import CehrBertPatient , Event , Visit
19
- from .. .runners .hf_runner_argument_dataclass import DataTrainingArguments , MedsToCehrBertConversionType
18
+ from cehrbert .med_extension .schema_extension import CehrBertPatient , Event , Visit
19
+ from cehrbert .runners .hf_runner_argument_dataclass import DataTrainingArguments , MedsToCehrBertConversionType
20
20
21
21
UNKNOWN_VALUE = "Unknown"
22
22
DEFAULT_ED_CONCEPT_ID = "9203"
@@ -39,19 +39,48 @@ def get_meds_to_cehrbert_conversion_cls(
39
39
raise RuntimeError (f"{ meds_to_cehrbert_conversion_type } is not a valid MedsToCehrBertConversionType" )
40
40
41
41
42
- def get_patient_split (meds_reader_db_path : str ) -> Dict [str , List [int ]]:
43
- patient_split = pd .read_parquet (os .path .join (meds_reader_db_path , "metadata/patient_splits .parquet" ))
44
- result = {str (group ): records ["patient_id " ].tolist () for group , records in patient_split .groupby ("split" )}
42
+ def get_subject_split (meds_reader_db_path : str ) -> Dict [str , List [int ]]:
43
+ patient_split = pd .read_parquet (os .path .join (meds_reader_db_path , "metadata/subject_splits .parquet" ))
44
+ result = {str (group ): records ["subject_id " ].tolist () for group , records in patient_split .groupby ("split" )}
45
45
return result
46
46
47
47
48
48
class PatientBlock :
49
+ """
50
+ Represents a block of medical events for a single patient visit, including.
51
+
52
+ inferred visit type and various admission and discharge statuses.
53
+
54
+ Attributes:
55
+ visit_id (int): The unique ID of the visit.
56
+ events (List[meds_reader.Event]): A list of medical events associated with this visit.
57
+ min_time (datetime): The earliest event time in the visit.
58
+ max_time (datetime): The latest event time in the visit.
59
+ conversion (MedsToCehrBertConversion): Conversion object for mapping event codes to CEHR-BERT.
60
+ has_ed_admission (bool): Whether the visit includes an emergency department (ED) admission event.
61
+ has_admission (bool): Whether the visit includes an admission event.
62
+ has_discharge (bool): Whether the visit includes a discharge event.
63
+ visit_type (str): The inferred type of visit, such as inpatient, ED, or outpatient.
64
+ """
65
+
49
66
def __init__ (
50
67
self ,
51
68
events : List [meds_reader .Event ],
52
69
visit_id : int ,
53
70
conversion : MedsToCehrBertConversion ,
54
71
):
72
+ """
73
+ Initializes a PatientBlock instance, inferring the visit type based on the events and caching.
74
+
75
+ admission and discharge status.
76
+
77
+ Args:
78
+ events (List[meds_reader.Event]): The medical events associated with the visit.
79
+ visit_id (int): The unique ID of the visit.
80
+ conversion (MedsToCehrBertConversion): Conversion object for mapping event codes to CEHR-BERT.
81
+
82
+ Attributes are initialized to store visit metadata and calculate admission/discharge statuses.
83
+ """
55
84
self .visit_id = visit_id
56
85
self .events = events
57
86
self .min_time = events [0 ].time
@@ -73,28 +102,51 @@ def __init__(
73
102
self .visit_type = DEFAULT_OUTPATIENT_CONCEPT_ID
74
103
75
104
def _has_ed_admission (self ) -> bool :
76
- """Make this configurable in the future."""
105
+ """
106
+ Determines if the visit includes an emergency department (ED) admission event.
107
+
108
+ Returns:
109
+ bool: True if an ED admission event is found, False otherwise.
110
+ """
77
111
for event in self .events :
78
112
for matching_rule in self .conversion .get_ed_admission_matching_rules ():
79
113
if re .match (matching_rule , event .code ):
80
114
return True
81
115
return False
82
116
83
117
def _has_admission (self ) -> bool :
118
+ """
119
+ Determines if the visit includes a hospital admission event.
120
+
121
+ Returns:
122
+ bool: True if an admission event is found, False otherwise.
123
+ """
84
124
for event in self .events :
85
125
for matching_rule in self .conversion .get_admission_matching_rules ():
86
126
if re .match (matching_rule , event .code ):
87
127
return True
88
128
return False
89
129
90
130
def _has_discharge (self ) -> bool :
131
+ """
132
+ Determines if the visit includes a discharge event.
133
+
134
+ Returns:
135
+ bool: True if a discharge event is found, False otherwise.
136
+ """
91
137
for event in self .events :
92
138
for matching_rule in self .conversion .get_discharge_matching_rules ():
93
139
if re .match (matching_rule , event .code ):
94
140
return True
95
141
return False
96
142
97
143
def get_discharge_facility (self ) -> Optional [str ]:
144
+ """
145
+ Extracts the discharge facility code from the discharge event, if present.
146
+
147
+ Returns:
148
+ Optional[str]: The sanitized discharge facility code, or None if no discharge event is found.
149
+ """
98
150
if self ._has_discharge ():
99
151
for event in self .events :
100
152
for matching_rule in self .conversion .get_discharge_matching_rules ():
@@ -105,12 +157,22 @@ def get_discharge_facility(self) -> Optional[str]:
105
157
return None
106
158
107
159
def _convert_event (self , event ) -> List [Event ]:
160
+ """
161
+ Converts a medical event into a list of CEHR-BERT-compatible events, potentially parsing.
162
+
163
+ numeric values from text-based events.
164
+
165
+ Args:
166
+ event (meds_reader.Event): The medical event to be converted.
167
+
168
+ Returns:
169
+ List[Event]: A list of converted events, possibly numeric, based on the original event's code and value.
170
+ """
108
171
code = event .code
109
172
time = getattr (event , "time" , None )
110
173
text_value = getattr (event , "text_value" , None )
111
174
numeric_value = getattr (event , "numeric_value" , None )
112
- # We try to parse the numeric values from the text value, in other words,
113
- # we try to construct numeric events from the event with a text value
175
+
114
176
if numeric_value is None and text_value is not None :
115
177
conversion_rule = self .conversion .get_text_event_to_numeric_events_rule (code )
116
178
if conversion_rule :
@@ -140,14 +202,20 @@ def _convert_event(self, event) -> List[Event]:
140
202
]
141
203
142
204
def get_meds_events (self ) -> Iterable [Event ]:
205
+ """
206
+ Retrieves all medication events for the visit, converting each raw event if necessary.
207
+
208
+ Returns:
209
+ Iterable[Event]: A list of CEHR-BERT-compatible medication events for the visit.
210
+ """
143
211
events = []
144
212
for e in self .events :
145
213
events .extend (self ._convert_event (e ))
146
214
return events
147
215
148
216
149
217
def convert_one_patient (
150
- patient : meds_reader .Patient ,
218
+ patient : meds_reader .Subject ,
151
219
conversion : MedsToCehrBertConversion ,
152
220
default_visit_id : int = 1 ,
153
221
prediction_time : datetime = None ,
@@ -296,10 +364,10 @@ def convert_one_patient(
296
364
age_at_index -= 1
297
365
298
366
# birth_datetime can not be None
299
- assert birth_datetime is not None , f"patient_id: { patient .patient_id } does not have a valid birth_datetime"
367
+ assert birth_datetime is not None , f"patient_id: { patient .subject_id } does not have a valid birth_datetime"
300
368
301
369
return CehrBertPatient (
302
- patient_id = patient .patient_id ,
370
+ patient_id = patient .subject_id ,
303
371
birth_datetime = birth_datetime ,
304
372
visits = visits ,
305
373
race = race if race else UNKNOWN_VALUE ,
@@ -346,7 +414,7 @@ def _meds_to_cehrbert_generator(
346
414
) -> CehrBertPatient :
347
415
conversion = get_meds_to_cehrbert_conversion_cls (meds_to_cehrbert_conversion_type )
348
416
for shard in shards :
349
- with meds_reader .PatientDatabase (path_to_db ) as patient_database :
417
+ with meds_reader .SubjectDatabase (path_to_db ) as patient_database :
350
418
for patient_id , prediction_time , label in shard :
351
419
patient = patient_database [patient_id ]
352
420
yield convert_one_patient (patient , conversion , default_visit_id , prediction_time , label )
@@ -363,20 +431,21 @@ def _create_cehrbert_data_from_meds(
363
431
if data_args .cohort_folder :
364
432
cohort = pd .read_parquet (os .path .join (data_args .cohort_folder , split ))
365
433
for cohort_row in cohort .itertuples ():
366
- patient_id = cohort_row .patient_id
434
+ subject_id = cohort_row .subject_id
367
435
prediction_time = cohort_row .prediction_time
368
436
label = int (cohort_row .boolean_value )
369
- batches .append ((patient_id , prediction_time , label ))
437
+ batches .append ((subject_id , prediction_time , label ))
370
438
else :
371
- patient_split = get_patient_split (data_args .data_folder )
372
- for patient_id in patient_split [split ]:
373
- batches .append ((patient_id , None , None ))
439
+ patient_split = get_subject_split (data_args .data_folder )
440
+ for subject_id in patient_split [split ]:
441
+ batches .append ((subject_id , None , None ))
374
442
375
443
split_batches = np .array_split (np .asarray (batches ), data_args .preprocessing_num_workers )
376
444
batch_func = functools .partial (
377
445
_meds_to_cehrbert_generator ,
378
446
path_to_db = data_args .data_folder ,
379
447
default_visit_id = default_visit_id ,
448
+ meds_to_cehrbert_conversion_type = data_args .meds_to_cehrbert_conversion_type ,
380
449
)
381
450
dataset = Dataset .from_generator (
382
451
batch_func ,
0 commit comments