Skip to content

Commit 4c4db5f

Browse files
aGuyLearningeroellZethson
authored
Datatype Support in Quality Control and Impute (#865)
* Enhancement: Add Dask support for explicit imputation * Enhancement: Add Dask support for quality control metrics and imputation tests * Fix test for imputation to handle Dask arrays without raising errors * Refactor quality control metrics functions to streamline computation and improve readability * added expected error * Remove unused Dask import from quality control module * simplify missing value computation * Rename parameter 'arr' to 'mtx' in _compute_obs_metrics no longer creates copy * daskify qc_metrics * Add fixture for array types and update imputation tests for dask arrays * Refactor _compute_var_metrics to prevent modification of the original data matrix and add a test for encoding mode integrity * Add parameterized tests for array types in miceforest imputation * Update missing values handling to include array type in error message and refine parameterized tests for miceforest imputation * Fix array type handling in missing values computation and update test for miceforest imputation * Implement array type handling in load_dataframe function and update tests for miceforest imputation * Remove parameterization for array types in miceforest numerical data imputation test * Update tests/preprocessing/test_quality_control.py Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> * Update tests/preprocessing/test_quality_control.py Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> * revert deepcopy changes * Fix test to ensure original matrix is not modified after encoding * Remove unused parameters from observation and variable metrics computation functions * Add sparse.csr_matrix to explicit impute array types test case * Parameterize quality control metrics tests to support multiple array types * Remove unused imports from test_quality_control.py * encode blocks dask function * Add pytest fixtures for observation and variable data in tests * Update tests/preprocessing/test_quality_control.py Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> * Update tests/preprocessing/test_quality_control.py Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com> * support dask explicit impute all object types --------- Co-authored-by: eroell <eljas.roellin@ikmail.com> Co-authored-by: Lukas Heumos <lukas.heumos@posteo.net> Co-authored-by: Eljas Roellin <65244425+eroell@users.noreply.github.com>
1 parent 324a978 commit 4c4db5f

File tree

5 files changed

+277
-97
lines changed

5 files changed

+277
-97
lines changed

ehrapy/preprocessing/_imputation.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import warnings
44
from collections.abc import Iterable
5+
from functools import singledispatch
56
from typing import TYPE_CHECKING, Literal
67

78
import numpy as np
@@ -11,7 +12,7 @@
1112
from sklearn.impute import SimpleImputer
1213

1314
from ehrapy import settings
14-
from ehrapy._compat import _check_module_importable
15+
from ehrapy._compat import _check_module_importable, _raise_array_type_not_implemented
1516
from ehrapy._progress import spinner
1617
from ehrapy.anndata import check_feature_types
1718
from ehrapy.anndata.anndata_ext import (
@@ -23,6 +24,13 @@
2324
if TYPE_CHECKING:
2425
from anndata import AnnData
2526

27+
try:
28+
import dask.array as da
29+
30+
DASK_AVAILABLE = True
31+
except ImportError:
32+
DASK_AVAILABLE = False
33+
2634

2735
@spinner("Performing explicit impute")
2836
def explicit_impute(
@@ -76,7 +84,9 @@ def explicit_impute(
7684
imputation_value = _extract_impute_value(replacement, column_name)
7785
# only replace if an explicit value got passed or could be extracted from replacement
7886
if imputation_value:
79-
_replace_explicit(adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings)
87+
adata.X[:, idx : idx + 1] = _replace_explicit(
88+
adata.X[:, idx : idx + 1], imputation_value, impute_empty_strings
89+
)
8090
else:
8191
logger.warning(f"No replace value passed and found for var [not bold green]{column_name}.")
8292
else:
@@ -87,13 +97,33 @@ def explicit_impute(
8797
return adata if copy else None
8898

8999

90-
def _replace_explicit(arr: np.ndarray, replacement: str | int, impute_empty_strings: bool) -> None:
100+
@singledispatch
101+
def _replace_explicit(arr, replacement: str | int, impute_empty_strings: bool) -> None:
102+
_raise_array_type_not_implemented(_replace_explicit, type(arr))
103+
104+
105+
@_replace_explicit.register
106+
def _(arr: np.ndarray, replacement: str | int, impute_empty_strings: bool) -> np.ndarray:
91107
"""Replace one column or whole X with a value where missing values are stored."""
92108
if not impute_empty_strings: # pragma: no cover
93109
impute_conditions = pd.isnull(arr)
94110
else:
95111
impute_conditions = np.logical_or(pd.isnull(arr), arr == "")
96112
arr[impute_conditions] = replacement
113+
return arr
114+
115+
116+
if DASK_AVAILABLE:
117+
118+
@_replace_explicit.register(da.Array)
119+
def _(arr: da.Array, replacement: str | int, impute_empty_strings: bool) -> da.Array:
120+
"""Replace one column or whole X with a value where missing values are stored."""
121+
if not impute_empty_strings: # pragma: no cover
122+
impute_conditions = da.isnull(arr)
123+
else:
124+
impute_conditions = da.logical_or(da.isnull(arr), arr == "")
125+
arr[impute_conditions] = replacement
126+
return arr
97127

98128

99129
def _extract_impute_value(replacement: dict[str, str | int], column_name: str) -> str | int | None:
@@ -469,12 +499,22 @@ def mice_forest_impute(
469499
return adata if copy else None
470500

471501

502+
@singledispatch
503+
def load_dataframe(arr, columns, index):
504+
_raise_array_type_not_implemented(load_dataframe, type(arr))
505+
506+
507+
@load_dataframe.register
508+
def _(arr: np.ndarray, columns, index):
509+
return pd.DataFrame(arr, columns=columns, index=index)
510+
511+
472512
def _miceforest_impute(
473513
adata, var_names, save_all_iterations_data, random_state, inplace, iterations, variable_parameters, verbose
474514
) -> None:
475515
import miceforest as mf
476516

477-
data_df = pd.DataFrame(adata.X, columns=adata.var_names, index=adata.obs_names)
517+
data_df = load_dataframe(adata.X, columns=adata.var_names, index=adata.obs_names)
478518
data_df = data_df.apply(pd.to_numeric, errors="coerce")
479519

480520
if isinstance(var_names, Iterable) and all(isinstance(item, str) for item in var_names):

ehrapy/preprocessing/_quality_control.py

Lines changed: 60 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import copy
4+
from functools import singledispatch
45
from pathlib import Path
56
from typing import TYPE_CHECKING, Literal
67

@@ -9,6 +10,7 @@
910
from lamin_utils import logger
1011
from thefuzz import process
1112

13+
from ehrapy._compat import _raise_array_type_not_implemented
1214
from ehrapy.anndata import anndata_to_df
1315
from ehrapy.preprocessing._encoding import _get_encoded_features
1416

@@ -17,6 +19,13 @@
1719

1820
from anndata import AnnData
1921

22+
try:
23+
import dask.array as da
24+
25+
DASK_AVAILABLE = True
26+
except ImportError:
27+
DASK_AVAILABLE = False
28+
2029

2130
def qc_metrics(
2231
adata: AnnData, qc_vars: Collection[str] = (), layer: str = None
@@ -55,55 +64,57 @@ def qc_metrics(
5564
>>> obs_qc, var_qc = ep.pp.qc_metrics(adata)
5665
>>> obs_qc["missing_values_pct"].plot(kind="hist", bins=20)
5766
"""
58-
obs_metrics = _obs_qc_metrics(adata, layer, qc_vars)
59-
var_metrics = _var_qc_metrics(adata, layer)
6067

61-
adata.obs[obs_metrics.columns] = obs_metrics
68+
mtx = adata.X if layer is None else adata.layers[layer]
69+
var_metrics = _compute_var_metrics(mtx, adata)
70+
obs_metrics = _compute_obs_metrics(mtx, adata, qc_vars=qc_vars, log1p=True)
71+
6272
adata.var[var_metrics.columns] = var_metrics
73+
adata.obs[obs_metrics.columns] = obs_metrics
6374

6475
return obs_metrics, var_metrics
6576

6677

67-
def _missing_values(
68-
arr: np.ndarray, mode: Literal["abs", "pct"] = "abs", df_type: Literal["obs", "var"] = "obs"
69-
) -> np.ndarray:
70-
"""Calculates the absolute or relative amount of missing values.
78+
@singledispatch
79+
def _compute_missing_values(mtx, axis):
80+
_raise_array_type_not_implemented(_compute_missing_values, type(mtx))
7181

72-
Args:
73-
arr: Numpy array containing a data row which is a subset of X (mtx).
74-
mode: Whether to calculate absolute or percentage of missing values.
75-
df_type: Whether to calculate the proportions for obs or var. One of 'obs' or 'var'.
7682

77-
Returns:
78-
Absolute or relative amount of missing values.
79-
"""
80-
num_missing = pd.isnull(arr).sum()
81-
if mode == "abs":
82-
return num_missing
83-
elif mode == "pct":
84-
total_elements = arr.shape[0] if df_type == "obs" else len(arr)
85-
return (num_missing / total_elements) * 100
83+
@_compute_missing_values.register
84+
def _(mtx: np.ndarray, axis) -> np.ndarray:
85+
return pd.isnull(mtx).sum(axis)
86+
8687

88+
if DASK_AVAILABLE:
8789

88-
def _obs_qc_metrics(
89-
adata: AnnData, layer: str = None, qc_vars: Collection[str] = (), log1p: bool = True
90-
) -> pd.DataFrame:
90+
@_compute_missing_values.register
91+
def _(mtx: da.Array, axis) -> np.ndarray:
92+
return da.isnull(mtx).sum(axis).compute()
93+
94+
95+
def _compute_obs_metrics(
96+
mtx,
97+
adata: AnnData,
98+
*,
99+
qc_vars: Collection[str] = (),
100+
log1p: bool = True,
101+
):
91102
"""Calculates quality control metrics for observations.
92103
93104
See :func:`~ehrapy.preprocessing._quality_control.calculate_qc_metrics` for a list of calculated metrics.
94105
95106
Args:
107+
mtx: Data array.
96108
adata: Annotated data matrix.
97-
layer: Layer containing the actual data matrix.
98109
qc_vars: A list of previously calculated QC metrics to calculate summary statistics for.
99110
log1p: Whether to apply log1p normalization for the QC metrics. Only used with parameter 'qc_vars'.
100111
101112
Returns:
102113
A Pandas DataFrame with the calculated metrics.
103114
"""
115+
104116
obs_metrics = pd.DataFrame(index=adata.obs_names)
105117
var_metrics = pd.DataFrame(index=adata.var_names)
106-
mtx = adata.X if layer is None else adata.layers[layer]
107118

108119
if "encoding_mode" in adata.var:
109120
for original_values_categorical in _get_encoded_features(adata):
@@ -120,8 +131,8 @@ def _obs_qc_metrics(
120131
)
121132
)
122133

123-
obs_metrics["missing_values_abs"] = np.apply_along_axis(_missing_values, 1, mtx, mode="abs")
124-
obs_metrics["missing_values_pct"] = np.apply_along_axis(_missing_values, 1, mtx, mode="pct", df_type="obs")
134+
obs_metrics["missing_values_abs"] = _compute_missing_values(mtx, axis=1)
135+
obs_metrics["missing_values_pct"] = (obs_metrics["missing_values_abs"] / mtx.shape[1]) * 100
125136

126137
# Specific QC metrics
127138
for qc_var in qc_vars:
@@ -136,10 +147,19 @@ def _obs_qc_metrics(
136147
return obs_metrics
137148

138149

139-
def _var_qc_metrics(adata: AnnData, layer: str | None = None) -> pd.DataFrame:
140-
var_metrics = pd.DataFrame(index=adata.var_names)
141-
mtx = adata.X if layer is None else adata.layers[layer]
150+
def _compute_var_metrics(
151+
mtx,
152+
adata: AnnData,
153+
):
154+
"""Compute variable metrics for quality control.
155+
156+
Args:
157+
mtx: Data array.
158+
adata: Annotated data matrix.
159+
"""
160+
142161
categorical_indices = np.ndarray([0], dtype=int)
162+
var_metrics = pd.DataFrame(index=adata.var_names)
143163

144164
if "encoding_mode" in adata.var.keys():
145165
for original_values_categorical in _get_encoded_features(adata):
@@ -157,32 +177,35 @@ def _var_qc_metrics(adata: AnnData, layer: str | None = None) -> pd.DataFrame:
157177
mtx[:, index].shape[1],
158178
)
159179
categorical_indices = np.concatenate([categorical_indices, index])
180+
160181
non_categorical_indices = np.ones(mtx.shape[1], dtype=bool)
161182
non_categorical_indices[categorical_indices] = False
162-
var_metrics["missing_values_abs"] = np.apply_along_axis(_missing_values, 0, mtx, mode="abs")
163-
var_metrics["missing_values_pct"] = np.apply_along_axis(_missing_values, 0, mtx, mode="pct", df_type="var")
183+
184+
var_metrics["missing_values_abs"] = _compute_missing_values(mtx, axis=0)
185+
var_metrics["missing_values_pct"] = (var_metrics["missing_values_abs"] / mtx.shape[0]) * 100
164186

165187
var_metrics["mean"] = np.nan
166188
var_metrics["median"] = np.nan
167189
var_metrics["standard_deviation"] = np.nan
168190
var_metrics["min"] = np.nan
169191
var_metrics["max"] = np.nan
192+
var_metrics["iqr_outliers"] = np.nan
170193

171194
try:
172195
var_metrics.loc[non_categorical_indices, "mean"] = np.nanmean(
173-
np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0
196+
mtx[:, non_categorical_indices].astype(np.float64), axis=0
174197
)
175198
var_metrics.loc[non_categorical_indices, "median"] = np.nanmedian(
176-
np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0
199+
mtx[:, non_categorical_indices].astype(np.float64), axis=0
177200
)
178201
var_metrics.loc[non_categorical_indices, "standard_deviation"] = np.nanstd(
179-
np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0
202+
mtx[:, non_categorical_indices].astype(np.float64), axis=0
180203
)
181204
var_metrics.loc[non_categorical_indices, "min"] = np.nanmin(
182-
np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0
205+
mtx[:, non_categorical_indices].astype(np.float64), axis=0
183206
)
184207
var_metrics.loc[non_categorical_indices, "max"] = np.nanmax(
185-
np.array(mtx[:, non_categorical_indices], dtype=np.float64), axis=0
208+
mtx[:, non_categorical_indices].astype(np.float64), axis=0
186209
)
187210

188211
# Calculate IQR and define IQR outliers

tests/conftest.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import TYPE_CHECKING
55

66
import numpy as np
7+
import pandas as pd
78
import pytest
89
from anndata import AnnData
910
from matplotlib.testing.compare import compare_images
@@ -29,6 +30,54 @@ def rng():
2930
return np.random.default_rng(seed=42)
3031

3132

33+
@pytest.fixture
34+
def obs_data():
35+
return {
36+
"disease": ["cancer", "tumor"],
37+
"country": ["Germany", "switzerland"],
38+
"sex": ["male", "female"],
39+
}
40+
41+
42+
@pytest.fixture
43+
def var_data():
44+
return {
45+
"alive": ["yes", "no", "maybe"],
46+
"hospital": ["hospital 1", "hospital 2", "hospital 1"],
47+
"crazy": ["yes", "yes", "yes"],
48+
}
49+
50+
51+
@pytest.fixture
52+
def missing_values_adata(obs_data, var_data):
53+
return AnnData(
54+
X=np.array([[0.21, np.nan, 41.42], [np.nan, np.nan, 7.234]], dtype=np.float32),
55+
obs=pd.DataFrame(data=obs_data),
56+
var=pd.DataFrame(data=var_data, index=["Acetaminophen", "hospital", "crazy"]),
57+
)
58+
59+
60+
@pytest.fixture
61+
def lab_measurements_simple_adata(obs_data, var_data):
62+
X = np.array([[73, 0.02, 1.00], [148, 0.25, 3.55]], dtype=np.float32)
63+
return AnnData(
64+
X=X,
65+
obs=pd.DataFrame(data=obs_data),
66+
var=pd.DataFrame(data=var_data, index=["Acetaminophen", "Acetoacetic acid", "Beryllium, toxic"]),
67+
)
68+
69+
70+
@pytest.fixture
71+
def lab_measurements_layer_adata(obs_data, var_data):
72+
X = np.array([[73, 0.02, 1.00], [148, 0.25, 3.55]], dtype=np.float32)
73+
return AnnData(
74+
X=X,
75+
obs=pd.DataFrame(data=obs_data),
76+
var=pd.DataFrame(data=var_data, index=["Acetaminophen", "Acetoacetic acid", "Beryllium, toxic"]),
77+
layers={"layer_copy": X},
78+
)
79+
80+
3281
@pytest.fixture
3382
def mimic_2():
3483
adata = ep.dt.mimic_2()
@@ -152,10 +201,10 @@ def asarray(a):
152201
return np.asarray(a)
153202

154203

155-
def as_dense_dask_array(a):
204+
def as_dense_dask_array(a, chunk_size=1000):
156205
import dask.array as da
157206

158-
return da.asarray(a)
207+
return da.from_array(a, chunks=chunk_size)
159208

160209

161210
ARRAY_TYPES = (asarray, as_dense_dask_array)

0 commit comments

Comments
 (0)