1
1
from __future__ import annotations
2
2
3
+ from functools import singledispatch
3
4
from typing import TYPE_CHECKING
4
5
5
6
import numpy as np
6
7
import sklearn .preprocessing as sklearn_pp
7
8
8
- from ehrapy ._compat import is_dask_array
9
+ from ehrapy ._compat import _raise_array_type_not_implemented
9
10
10
11
try :
12
+ import dask .array as da
11
13
import dask_ml .preprocessing as daskml_pp
14
+
15
+ DASK_AVAILABLE = True
12
16
except ImportError :
13
17
daskml_pp = None
18
+ DASK_AVAILABLE = False
19
+
14
20
15
21
from ehrapy .anndata .anndata_ext import (
16
22
assert_numeric_vars ,
@@ -69,6 +75,23 @@ def _scale_func_group(
69
75
return None
70
76
71
77
78
+ @singledispatch
79
+ def _scale_norm_function (arr ):
80
+ _raise_array_type_not_implemented (_scale_norm_function , type (arr ))
81
+
82
+
83
+ @_scale_norm_function .register
84
+ def _ (arr : np .ndarray , ** kwargs ):
85
+ return sklearn_pp .StandardScaler (** kwargs ).fit_transform
86
+
87
+
88
+ if DASK_AVAILABLE :
89
+
90
+ @_scale_norm_function .register
91
+ def _ (arr : da .Array , ** kwargs ):
92
+ return daskml_pp .StandardScaler (** kwargs ).fit_transform
93
+
94
+
72
95
def scale_norm (
73
96
adata : AnnData ,
74
97
vars : str | Sequence [str ] | None = None ,
@@ -98,10 +121,7 @@ def scale_norm(
98
121
>>> adata_norm = ep.pp.scale_norm(adata, copy=True)
99
122
"""
100
123
101
- if is_dask_array (adata .X ):
102
- scale_func = daskml_pp .StandardScaler (** kwargs ).fit_transform
103
- else :
104
- scale_func = sklearn_pp .StandardScaler (** kwargs ).fit_transform
124
+ scale_func = _scale_norm_function (adata .X , ** kwargs )
105
125
106
126
return _scale_func_group (
107
127
adata = adata ,
@@ -113,6 +133,23 @@ def scale_norm(
113
133
)
114
134
115
135
136
+ @singledispatch
137
+ def _minmax_norm_function (arr ):
138
+ _raise_array_type_not_implemented (_minmax_norm_function , type (arr ))
139
+
140
+
141
+ @_minmax_norm_function .register
142
+ def _ (arr : np .ndarray , ** kwargs ):
143
+ return sklearn_pp .MinMaxScaler (** kwargs ).fit_transform
144
+
145
+
146
+ if DASK_AVAILABLE :
147
+
148
+ @_minmax_norm_function .register
149
+ def _ (arr : da .Array , ** kwargs ):
150
+ return daskml_pp .MinMaxScaler (** kwargs ).fit_transform
151
+
152
+
116
153
def minmax_norm (
117
154
adata : AnnData ,
118
155
vars : str | Sequence [str ] | None = None ,
@@ -143,10 +180,7 @@ def minmax_norm(
143
180
>>> adata_norm = ep.pp.minmax_norm(adata, copy=True)
144
181
"""
145
182
146
- if is_dask_array (adata .X ):
147
- scale_func = daskml_pp .MinMaxScaler (** kwargs ).fit_transform
148
- else :
149
- scale_func = sklearn_pp .MinMaxScaler (** kwargs ).fit_transform
183
+ scale_func = _minmax_norm_function (adata .X , ** kwargs )
150
184
151
185
return _scale_func_group (
152
186
adata = adata ,
@@ -158,6 +192,16 @@ def minmax_norm(
158
192
)
159
193
160
194
195
+ @singledispatch
196
+ def _maxabs_norm_function (arr ):
197
+ _raise_array_type_not_implemented (_scale_norm_function , type (arr ))
198
+
199
+
200
+ @_maxabs_norm_function .register
201
+ def _ (arr : np .ndarray ):
202
+ return sklearn_pp .MaxAbsScaler ().fit_transform
203
+
204
+
161
205
def maxabs_norm (
162
206
adata : AnnData ,
163
207
vars : str | Sequence [str ] | None = None ,
@@ -184,10 +228,8 @@ def maxabs_norm(
184
228
>>> adata = ep.dt.mimic_2(encoded=True)
185
229
>>> adata_norm = ep.pp.maxabs_norm(adata, copy=True)
186
230
"""
187
- if is_dask_array (adata .X ):
188
- raise NotImplementedError ("MaxAbsScaler is not implemented in dask_ml." )
189
- else :
190
- scale_func = sklearn_pp .MaxAbsScaler ().fit_transform
231
+
232
+ scale_func = _maxabs_norm_function (adata .X )
191
233
192
234
return _scale_func_group (
193
235
adata = adata ,
@@ -199,6 +241,23 @@ def maxabs_norm(
199
241
)
200
242
201
243
244
+ @singledispatch
245
+ def _robust_scale_norm_function (arr , ** kwargs ):
246
+ _raise_array_type_not_implemented (_robust_scale_norm_function , type (arr ))
247
+
248
+
249
+ @_robust_scale_norm_function .register
250
+ def _ (arr : np .ndarray , ** kwargs ):
251
+ return sklearn_pp .RobustScaler (** kwargs ).fit_transform
252
+
253
+
254
+ if DASK_AVAILABLE :
255
+
256
+ @_robust_scale_norm_function .register
257
+ def _ (arr : da .Array , ** kwargs ):
258
+ return daskml_pp .RobustScaler (** kwargs ).fit_transform
259
+
260
+
202
261
def robust_scale_norm (
203
262
adata : AnnData ,
204
263
vars : str | Sequence [str ] | None = None ,
@@ -229,10 +288,8 @@ def robust_scale_norm(
229
288
>>> adata = ep.dt.mimic_2(encoded=True)
230
289
>>> adata_norm = ep.pp.robust_scale_norm(adata, copy=True)
231
290
"""
232
- if is_dask_array (adata .X ):
233
- scale_func = daskml_pp .RobustScaler (** kwargs ).fit_transform
234
- else :
235
- scale_func = sklearn_pp .RobustScaler (** kwargs ).fit_transform
291
+
292
+ scale_func = _robust_scale_norm_function (adata .X , ** kwargs )
236
293
237
294
return _scale_func_group (
238
295
adata = adata ,
@@ -244,6 +301,23 @@ def robust_scale_norm(
244
301
)
245
302
246
303
304
+ @singledispatch
305
+ def _quantile_norm_function (arr ):
306
+ _raise_array_type_not_implemented (_quantile_norm_function , type (arr ))
307
+
308
+
309
+ @_quantile_norm_function .register
310
+ def _ (arr : np .ndarray , ** kwargs ):
311
+ return sklearn_pp .QuantileTransformer (** kwargs ).fit_transform
312
+
313
+
314
+ if DASK_AVAILABLE :
315
+
316
+ @_quantile_norm_function .register
317
+ def _ (arr : da .Array , ** kwargs ):
318
+ return daskml_pp .QuantileTransformer (** kwargs ).fit_transform
319
+
320
+
247
321
def quantile_norm (
248
322
adata : AnnData ,
249
323
vars : str | Sequence [str ] | None = None ,
@@ -273,10 +347,8 @@ def quantile_norm(
273
347
>>> adata = ep.dt.mimic_2(encoded=True)
274
348
>>> adata_norm = ep.pp.quantile_norm(adata, copy=True)
275
349
"""
276
- if is_dask_array (adata .X ):
277
- scale_func = daskml_pp .QuantileTransformer (** kwargs ).fit_transform
278
- else :
279
- scale_func = sklearn_pp .QuantileTransformer (** kwargs ).fit_transform
350
+
351
+ scale_func = _quantile_norm_function (adata .X , ** kwargs )
280
352
281
353
return _scale_func_group (
282
354
adata = adata ,
@@ -288,6 +360,16 @@ def quantile_norm(
288
360
)
289
361
290
362
363
+ @singledispatch
364
+ def _power_norm_function (arr , ** kwargs ):
365
+ _raise_array_type_not_implemented (_power_norm_function , type (arr ))
366
+
367
+
368
+ @_power_norm_function .register
369
+ def _ (arr : np .ndarray , ** kwargs ):
370
+ return sklearn_pp .PowerTransformer (** kwargs ).fit_transform
371
+
372
+
291
373
def power_norm (
292
374
adata : AnnData ,
293
375
vars : str | Sequence [str ] | None = None ,
@@ -317,10 +399,8 @@ def power_norm(
317
399
>>> adata = ep.dt.mimic_2(encoded=True)
318
400
>>> adata_norm = ep.pp.power_norm(adata, copy=True)
319
401
"""
320
- if is_dask_array (adata .X ):
321
- raise NotImplementedError ("dask-ml has no PowerTransformer, this is only available in scikit-learn" )
322
- else :
323
- scale_func = sklearn_pp .PowerTransformer (** kwargs ).fit_transform
402
+
403
+ scale_func = _power_norm_function (adata .X , ** kwargs )
324
404
325
405
return _scale_func_group (
326
406
adata = adata ,
0 commit comments