Skip to content

Commit c8410e0

Browse files
authored
Updated meds_reader to 0.1.9 (#51)
* Updated meds_reader to 0.1.9 * Switched back from RunTimeError to Exception for loading the meds data * Changed patient_id to subject_id in meds_utils.py file * Used fully qualified names for the imports instead of using their relative paths * Added the missing positional parameter for _create_cehrbert_data_from_meds * Restored the AttType cehr_bert that's automatically updated by the code formatter * Restored the enum type and MedsToCehrBertConversion class check
1 parent 3d09645 commit c8410e0

File tree

6 files changed

+96
-27
lines changed

6 files changed

+96
-27
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies = [
3434
"femr==0.2.0",
3535
"Jinja2==3.1.3",
3636
"meds==0.3.3",
37-
"meds_reader==0.1.1",
37+
"meds_reader==0.1.9",
3838
"networkx==3.2.1",
3939
"numpy==1.24.3",
4040
"packaging==23.2",

src/cehrbert/data_generators/hf_data_generator/hf_dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
from datasets import Dataset, DatasetDict, IterableDataset, IterableDatasetDict
44

5-
from ...data_generators.hf_data_generator.hf_dataset_mapping import (
5+
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import (
66
DatasetMapping,
77
HFFineTuningMapping,
88
HFTokenizationMapping,
99
SortPatientSequenceMapping,
1010
)
11-
from ...models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
12-
from ...runners.hf_runner_argument_dataclass import DataTrainingArguments
11+
from cehrbert.models.hf_models.tokenization_hf_cehrbert import CehrBertTokenizer
12+
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments
1313

1414
CEHRBERT_COLUMNS = [
1515
"concept_ids",

src/cehrbert/data_generators/hf_data_generator/meds_to_cehrbert_conversion_rules/meds_to_cehrbert_micmic4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import re
22
from typing import List
33

4-
from ....data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
4+
from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
55
EventConversionRule,
66
MedsToCehrBertConversion,
77
)

src/cehrbert/data_generators/hf_data_generator/meds_utils.py

Lines changed: 89 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
import pandas as pd
1111
from datasets import Dataset, DatasetDict, Split
1212

13-
from ...data_generators.hf_data_generator.hf_dataset import apply_cehrbert_dataset_mapping
14-
from ...data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping, birth_codes
15-
from ...data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
13+
from cehrbert.data_generators.hf_data_generator.hf_dataset import apply_cehrbert_dataset_mapping
14+
from cehrbert.data_generators.hf_data_generator.hf_dataset_mapping import MedToCehrBertDatasetMapping, birth_codes
15+
from cehrbert.data_generators.hf_data_generator.meds_to_cehrbert_conversion_rules.meds_to_cehrbert_base import (
1616
MedsToCehrBertConversion,
1717
)
18-
from ...med_extension.schema_extension import CehrBertPatient, Event, Visit
19-
from ...runners.hf_runner_argument_dataclass import DataTrainingArguments, MedsToCehrBertConversionType
18+
from cehrbert.med_extension.schema_extension import CehrBertPatient, Event, Visit
19+
from cehrbert.runners.hf_runner_argument_dataclass import DataTrainingArguments, MedsToCehrBertConversionType
2020

2121
UNKNOWN_VALUE = "Unknown"
2222
DEFAULT_ED_CONCEPT_ID = "9203"
@@ -39,19 +39,48 @@ def get_meds_to_cehrbert_conversion_cls(
3939
raise RuntimeError(f"{meds_to_cehrbert_conversion_type} is not a valid MedsToCehrBertConversionType")
4040

4141

42-
def get_patient_split(meds_reader_db_path: str) -> Dict[str, List[int]]:
43-
patient_split = pd.read_parquet(os.path.join(meds_reader_db_path, "metadata/patient_splits.parquet"))
44-
result = {str(group): records["patient_id"].tolist() for group, records in patient_split.groupby("split")}
42+
def get_subject_split(meds_reader_db_path: str) -> Dict[str, List[int]]:
43+
patient_split = pd.read_parquet(os.path.join(meds_reader_db_path, "metadata/subject_splits.parquet"))
44+
result = {str(group): records["subject_id"].tolist() for group, records in patient_split.groupby("split")}
4545
return result
4646

4747

4848
class PatientBlock:
49+
"""
50+
Represents a block of medical events for a single patient visit, including.
51+
52+
inferred visit type and various admission and discharge statuses.
53+
54+
Attributes:
55+
visit_id (int): The unique ID of the visit.
56+
events (List[meds_reader.Event]): A list of medical events associated with this visit.
57+
min_time (datetime): The earliest event time in the visit.
58+
max_time (datetime): The latest event time in the visit.
59+
conversion (MedsToCehrBertConversion): Conversion object for mapping event codes to CEHR-BERT.
60+
has_ed_admission (bool): Whether the visit includes an emergency department (ED) admission event.
61+
has_admission (bool): Whether the visit includes an admission event.
62+
has_discharge (bool): Whether the visit includes a discharge event.
63+
visit_type (str): The inferred type of visit, such as inpatient, ED, or outpatient.
64+
"""
65+
4966
def __init__(
5067
self,
5168
events: List[meds_reader.Event],
5269
visit_id: int,
5370
conversion: MedsToCehrBertConversion,
5471
):
72+
"""
73+
Initializes a PatientBlock instance, inferring the visit type based on the events and caching.
74+
75+
admission and discharge status.
76+
77+
Args:
78+
events (List[meds_reader.Event]): The medical events associated with the visit.
79+
visit_id (int): The unique ID of the visit.
80+
conversion (MedsToCehrBertConversion): Conversion object for mapping event codes to CEHR-BERT.
81+
82+
Attributes are initialized to store visit metadata and calculate admission/discharge statuses.
83+
"""
5584
self.visit_id = visit_id
5685
self.events = events
5786
self.min_time = events[0].time
@@ -73,28 +102,51 @@ def __init__(
73102
self.visit_type = DEFAULT_OUTPATIENT_CONCEPT_ID
74103

75104
def _has_ed_admission(self) -> bool:
76-
"""Make this configurable in the future."""
105+
"""
106+
Determines if the visit includes an emergency department (ED) admission event.
107+
108+
Returns:
109+
bool: True if an ED admission event is found, False otherwise.
110+
"""
77111
for event in self.events:
78112
for matching_rule in self.conversion.get_ed_admission_matching_rules():
79113
if re.match(matching_rule, event.code):
80114
return True
81115
return False
82116

83117
def _has_admission(self) -> bool:
118+
"""
119+
Determines if the visit includes a hospital admission event.
120+
121+
Returns:
122+
bool: True if an admission event is found, False otherwise.
123+
"""
84124
for event in self.events:
85125
for matching_rule in self.conversion.get_admission_matching_rules():
86126
if re.match(matching_rule, event.code):
87127
return True
88128
return False
89129

90130
def _has_discharge(self) -> bool:
131+
"""
132+
Determines if the visit includes a discharge event.
133+
134+
Returns:
135+
bool: True if a discharge event is found, False otherwise.
136+
"""
91137
for event in self.events:
92138
for matching_rule in self.conversion.get_discharge_matching_rules():
93139
if re.match(matching_rule, event.code):
94140
return True
95141
return False
96142

97143
def get_discharge_facility(self) -> Optional[str]:
144+
"""
145+
Extracts the discharge facility code from the discharge event, if present.
146+
147+
Returns:
148+
Optional[str]: The sanitized discharge facility code, or None if no discharge event is found.
149+
"""
98150
if self._has_discharge():
99151
for event in self.events:
100152
for matching_rule in self.conversion.get_discharge_matching_rules():
@@ -105,12 +157,22 @@ def get_discharge_facility(self) -> Optional[str]:
105157
return None
106158

107159
def _convert_event(self, event) -> List[Event]:
160+
"""
161+
Converts a medical event into a list of CEHR-BERT-compatible events, potentially parsing.
162+
163+
numeric values from text-based events.
164+
165+
Args:
166+
event (meds_reader.Event): The medical event to be converted.
167+
168+
Returns:
169+
List[Event]: A list of converted events, possibly numeric, based on the original event's code and value.
170+
"""
108171
code = event.code
109172
time = getattr(event, "time", None)
110173
text_value = getattr(event, "text_value", None)
111174
numeric_value = getattr(event, "numeric_value", None)
112-
# We try to parse the numeric values from the text value, in other words,
113-
# we try to construct numeric events from the event with a text value
175+
114176
if numeric_value is None and text_value is not None:
115177
conversion_rule = self.conversion.get_text_event_to_numeric_events_rule(code)
116178
if conversion_rule:
@@ -140,14 +202,20 @@ def _convert_event(self, event) -> List[Event]:
140202
]
141203

142204
def get_meds_events(self) -> Iterable[Event]:
205+
"""
206+
Retrieves all medication events for the visit, converting each raw event if necessary.
207+
208+
Returns:
209+
Iterable[Event]: A list of CEHR-BERT-compatible medication events for the visit.
210+
"""
143211
events = []
144212
for e in self.events:
145213
events.extend(self._convert_event(e))
146214
return events
147215

148216

149217
def convert_one_patient(
150-
patient: meds_reader.Patient,
218+
patient: meds_reader.Subject,
151219
conversion: MedsToCehrBertConversion,
152220
default_visit_id: int = 1,
153221
prediction_time: datetime = None,
@@ -296,10 +364,10 @@ def convert_one_patient(
296364
age_at_index -= 1
297365

298366
# birth_datetime can not be None
299-
assert birth_datetime is not None, f"patient_id: {patient.patient_id} does not have a valid birth_datetime"
367+
assert birth_datetime is not None, f"patient_id: {patient.subject_id} does not have a valid birth_datetime"
300368

301369
return CehrBertPatient(
302-
patient_id=patient.patient_id,
370+
patient_id=patient.subject_id,
303371
birth_datetime=birth_datetime,
304372
visits=visits,
305373
race=race if race else UNKNOWN_VALUE,
@@ -346,7 +414,7 @@ def _meds_to_cehrbert_generator(
346414
) -> CehrBertPatient:
347415
conversion = get_meds_to_cehrbert_conversion_cls(meds_to_cehrbert_conversion_type)
348416
for shard in shards:
349-
with meds_reader.PatientDatabase(path_to_db) as patient_database:
417+
with meds_reader.SubjectDatabase(path_to_db) as patient_database:
350418
for patient_id, prediction_time, label in shard:
351419
patient = patient_database[patient_id]
352420
yield convert_one_patient(patient, conversion, default_visit_id, prediction_time, label)
@@ -363,20 +431,21 @@ def _create_cehrbert_data_from_meds(
363431
if data_args.cohort_folder:
364432
cohort = pd.read_parquet(os.path.join(data_args.cohort_folder, split))
365433
for cohort_row in cohort.itertuples():
366-
patient_id = cohort_row.patient_id
434+
subject_id = cohort_row.subject_id
367435
prediction_time = cohort_row.prediction_time
368436
label = int(cohort_row.boolean_value)
369-
batches.append((patient_id, prediction_time, label))
437+
batches.append((subject_id, prediction_time, label))
370438
else:
371-
patient_split = get_patient_split(data_args.data_folder)
372-
for patient_id in patient_split[split]:
373-
batches.append((patient_id, None, None))
439+
patient_split = get_subject_split(data_args.data_folder)
440+
for subject_id in patient_split[split]:
441+
batches.append((subject_id, None, None))
374442

375443
split_batches = np.array_split(np.asarray(batches), data_args.preprocessing_num_workers)
376444
batch_func = functools.partial(
377445
_meds_to_cehrbert_generator,
378446
path_to_db=data_args.data_folder,
379447
default_visit_id=default_visit_id,
448+
meds_to_cehrbert_conversion_type=data_args.meds_to_cehrbert_conversion_type,
380449
)
381450
dataset = Dataset.from_generator(
382451
batch_func,

src/cehrbert/runners/hf_cehrbert_pretrain_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def main():
179179
dataset = load_from_disk(meds_extension_path)
180180
if data_args.streaming:
181181
dataset = dataset.to_iterable_dataset(num_shards=training_args.dataloader_num_workers)
182-
except RuntimeError as e:
182+
except FileNotFoundError as e:
183183
LOG.exception(e)
184184
dataset = create_dataset_from_meds_reader(data_args, is_pretraining=True)
185185
if not data_args.streaming:

src/cehrbert/spark_apps/decorators/patient_event_decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class AttType(Enum):
1616
DAY = "day"
1717
WEEK = "week"
1818
MONTH = "month"
19-
CEHR_BERT = "cehrbert"
19+
CEHR_BERT = "cehr_bert"
2020
MIX = "mix"
2121
NONE = "none"
2222

0 commit comments

Comments
 (0)