Skip to content

Commit 15c6315

Browse files
eroellpre-commit-ci[bot]Zethson
authored
initial suggestions of array type handling on example of normalization methods (#835)
* initial suggestions of array type checks on example of scale_norm * singledispatch normalization functions and test them * try dask import * doc build fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * DRY Signed-off-by: zethson <lukas.heumos@posteo.net> * Fix tests Signed-off-by: zethson <lukas.heumos@posteo.net> * Fix tests Signed-off-by: zethson <lukas.heumos@posteo.net> * Simplify tests Signed-off-by: zethson <lukas.heumos@posteo.net> --------- Signed-off-by: zethson <lukas.heumos@posteo.net> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Lukas Heumos <lukas.heumos@posteo.net>
1 parent 8e00c45 commit 15c6315

File tree

5 files changed

+277
-92
lines changed

5 files changed

+277
-92
lines changed

.github/workflows/run_notebooks.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
name: Run Notebooks
22

33
on:
4-
- push
54
- pull_request
65

76
jobs:

ehrapy/_compat.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Since we might check whether an object is an instance of dask.array.Array
22
# without requiring dask installed in the environment.
3-
# This would become obsolete should dask become a requirement for ehrapy
4-
3+
from collections.abc import Callable
54

65
try:
76
import dask.array as da
@@ -11,6 +10,12 @@
1110
DASK_AVAILABLE = False
1211

1312

13+
def _raise_array_type_not_implemented(func: Callable, type_: type) -> NotImplementedError:
14+
raise NotImplementedError(
15+
f"{func.__name__} does not support array type {type_}. Must be of type {func.registry.keys()}." # type: ignore
16+
)
17+
18+
1419
def is_dask_array(array):
1520
if DASK_AVAILABLE:
1621
return isinstance(array, da.Array)

ehrapy/preprocessing/_normalization.py

Lines changed: 105 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
from __future__ import annotations
22

3+
from functools import singledispatch
34
from typing import TYPE_CHECKING
45

56
import numpy as np
67
import sklearn.preprocessing as sklearn_pp
78

8-
from ehrapy._compat import is_dask_array
9+
from ehrapy._compat import _raise_array_type_not_implemented
910

1011
try:
12+
import dask.array as da
1113
import dask_ml.preprocessing as daskml_pp
14+
15+
DASK_AVAILABLE = True
1216
except ImportError:
1317
daskml_pp = None
18+
DASK_AVAILABLE = False
19+
1420

1521
from ehrapy.anndata.anndata_ext import (
1622
assert_numeric_vars,
@@ -69,6 +75,23 @@ def _scale_func_group(
6975
return None
7076

7177

78+
@singledispatch
79+
def _scale_norm_function(arr):
80+
_raise_array_type_not_implemented(_scale_norm_function, type(arr))
81+
82+
83+
@_scale_norm_function.register
84+
def _(arr: np.ndarray, **kwargs):
85+
return sklearn_pp.StandardScaler(**kwargs).fit_transform
86+
87+
88+
if DASK_AVAILABLE:
89+
90+
@_scale_norm_function.register
91+
def _(arr: da.Array, **kwargs):
92+
return daskml_pp.StandardScaler(**kwargs).fit_transform
93+
94+
7295
def scale_norm(
7396
adata: AnnData,
7497
vars: str | Sequence[str] | None = None,
@@ -98,10 +121,7 @@ def scale_norm(
98121
>>> adata_norm = ep.pp.scale_norm(adata, copy=True)
99122
"""
100123

101-
if is_dask_array(adata.X):
102-
scale_func = daskml_pp.StandardScaler(**kwargs).fit_transform
103-
else:
104-
scale_func = sklearn_pp.StandardScaler(**kwargs).fit_transform
124+
scale_func = _scale_norm_function(adata.X, **kwargs)
105125

106126
return _scale_func_group(
107127
adata=adata,
@@ -113,6 +133,23 @@ def scale_norm(
113133
)
114134

115135

136+
@singledispatch
137+
def _minmax_norm_function(arr):
138+
_raise_array_type_not_implemented(_minmax_norm_function, type(arr))
139+
140+
141+
@_minmax_norm_function.register
142+
def _(arr: np.ndarray, **kwargs):
143+
return sklearn_pp.MinMaxScaler(**kwargs).fit_transform
144+
145+
146+
if DASK_AVAILABLE:
147+
148+
@_minmax_norm_function.register
149+
def _(arr: da.Array, **kwargs):
150+
return daskml_pp.MinMaxScaler(**kwargs).fit_transform
151+
152+
116153
def minmax_norm(
117154
adata: AnnData,
118155
vars: str | Sequence[str] | None = None,
@@ -143,10 +180,7 @@ def minmax_norm(
143180
>>> adata_norm = ep.pp.minmax_norm(adata, copy=True)
144181
"""
145182

146-
if is_dask_array(adata.X):
147-
scale_func = daskml_pp.MinMaxScaler(**kwargs).fit_transform
148-
else:
149-
scale_func = sklearn_pp.MinMaxScaler(**kwargs).fit_transform
183+
scale_func = _minmax_norm_function(adata.X, **kwargs)
150184

151185
return _scale_func_group(
152186
adata=adata,
@@ -158,6 +192,16 @@ def minmax_norm(
158192
)
159193

160194

195+
@singledispatch
196+
def _maxabs_norm_function(arr):
197+
_raise_array_type_not_implemented(_scale_norm_function, type(arr))
198+
199+
200+
@_maxabs_norm_function.register
201+
def _(arr: np.ndarray):
202+
return sklearn_pp.MaxAbsScaler().fit_transform
203+
204+
161205
def maxabs_norm(
162206
adata: AnnData,
163207
vars: str | Sequence[str] | None = None,
@@ -184,10 +228,8 @@ def maxabs_norm(
184228
>>> adata = ep.dt.mimic_2(encoded=True)
185229
>>> adata_norm = ep.pp.maxabs_norm(adata, copy=True)
186230
"""
187-
if is_dask_array(adata.X):
188-
raise NotImplementedError("MaxAbsScaler is not implemented in dask_ml.")
189-
else:
190-
scale_func = sklearn_pp.MaxAbsScaler().fit_transform
231+
232+
scale_func = _maxabs_norm_function(adata.X)
191233

192234
return _scale_func_group(
193235
adata=adata,
@@ -199,6 +241,23 @@ def maxabs_norm(
199241
)
200242

201243

244+
@singledispatch
245+
def _robust_scale_norm_function(arr, **kwargs):
246+
_raise_array_type_not_implemented(_robust_scale_norm_function, type(arr))
247+
248+
249+
@_robust_scale_norm_function.register
250+
def _(arr: np.ndarray, **kwargs):
251+
return sklearn_pp.RobustScaler(**kwargs).fit_transform
252+
253+
254+
if DASK_AVAILABLE:
255+
256+
@_robust_scale_norm_function.register
257+
def _(arr: da.Array, **kwargs):
258+
return daskml_pp.RobustScaler(**kwargs).fit_transform
259+
260+
202261
def robust_scale_norm(
203262
adata: AnnData,
204263
vars: str | Sequence[str] | None = None,
@@ -229,10 +288,8 @@ def robust_scale_norm(
229288
>>> adata = ep.dt.mimic_2(encoded=True)
230289
>>> adata_norm = ep.pp.robust_scale_norm(adata, copy=True)
231290
"""
232-
if is_dask_array(adata.X):
233-
scale_func = daskml_pp.RobustScaler(**kwargs).fit_transform
234-
else:
235-
scale_func = sklearn_pp.RobustScaler(**kwargs).fit_transform
291+
292+
scale_func = _robust_scale_norm_function(adata.X, **kwargs)
236293

237294
return _scale_func_group(
238295
adata=adata,
@@ -244,6 +301,23 @@ def robust_scale_norm(
244301
)
245302

246303

304+
@singledispatch
305+
def _quantile_norm_function(arr):
306+
_raise_array_type_not_implemented(_quantile_norm_function, type(arr))
307+
308+
309+
@_quantile_norm_function.register
310+
def _(arr: np.ndarray, **kwargs):
311+
return sklearn_pp.QuantileTransformer(**kwargs).fit_transform
312+
313+
314+
if DASK_AVAILABLE:
315+
316+
@_quantile_norm_function.register
317+
def _(arr: da.Array, **kwargs):
318+
return daskml_pp.QuantileTransformer(**kwargs).fit_transform
319+
320+
247321
def quantile_norm(
248322
adata: AnnData,
249323
vars: str | Sequence[str] | None = None,
@@ -273,10 +347,8 @@ def quantile_norm(
273347
>>> adata = ep.dt.mimic_2(encoded=True)
274348
>>> adata_norm = ep.pp.quantile_norm(adata, copy=True)
275349
"""
276-
if is_dask_array(adata.X):
277-
scale_func = daskml_pp.QuantileTransformer(**kwargs).fit_transform
278-
else:
279-
scale_func = sklearn_pp.QuantileTransformer(**kwargs).fit_transform
350+
351+
scale_func = _quantile_norm_function(adata.X, **kwargs)
280352

281353
return _scale_func_group(
282354
adata=adata,
@@ -288,6 +360,16 @@ def quantile_norm(
288360
)
289361

290362

363+
@singledispatch
364+
def _power_norm_function(arr, **kwargs):
365+
_raise_array_type_not_implemented(_power_norm_function, type(arr))
366+
367+
368+
@_power_norm_function.register
369+
def _(arr: np.ndarray, **kwargs):
370+
return sklearn_pp.PowerTransformer(**kwargs).fit_transform
371+
372+
291373
def power_norm(
292374
adata: AnnData,
293375
vars: str | Sequence[str] | None = None,
@@ -317,10 +399,8 @@ def power_norm(
317399
>>> adata = ep.dt.mimic_2(encoded=True)
318400
>>> adata_norm = ep.pp.power_norm(adata, copy=True)
319401
"""
320-
if is_dask_array(adata.X):
321-
raise NotImplementedError("dask-ml has no PowerTransformer, this is only available in scikit-learn")
322-
else:
323-
scale_func = sklearn_pp.PowerTransformer(**kwargs).fit_transform
402+
403+
scale_func = _power_norm_function(adata.X, **kwargs)
324404

325405
return _scale_func_group(
326406
adata=adata,

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ medcat = [
7272
"medcat",
7373
]
7474
dask = [
75-
"dask",
75+
"anndata[dask]",
7676
"dask-ml",
7777
]
7878
dev = [
@@ -136,7 +136,8 @@ filterwarnings = [
136136
"ignore:`flavor='seurat_v3'` expects raw count data, but non-integers were found:UserWarning",
137137
"ignore:All-NaN slice encountered:RuntimeWarning",
138138
"ignore:Observation names are not unique. To make them unique, call `.obs_names_make_unique`.:UserWarning",
139-
"ignore:Trying to modify attribute .var of view"
139+
"ignore:Trying to modify attribute `.var` of view, initializing view as actual.:anndata.ImplicitModificationWarning",
140+
"ignore:Transforming to str index.:anndata.ImplicitModificationWarning:"
140141
]
141142
minversion = 6.0
142143
norecursedirs = [ '.*', 'build', 'dist', '*.egg', 'data', '__pycache__']

0 commit comments

Comments
 (0)