Skip to content

Commit aba9fb6

Browse files
committedJan 30, 2024
add example for custom registered preprocessing func
1 parent 3022ae3 commit aba9fb6

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed
 

‎examples/tuning/cta_svm/main.py

+25
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,39 @@
33
from typing import get_args
44

55
import wandb
6+
from sklearn.random_projection import GaussianRandomProjection
67

78
from dance import logger
89
from dance.datasets.singlemodality import CellTypeAnnotationDataset
910
from dance.modules.single_modality.cell_type_annotation.svm import SVM
1011
from dance.pipeline import PipelinePlaner
12+
from dance.registry import register_preprocessor
13+
from dance.transforms.base import BaseTransform
1114
from dance.typing import LogLevel
1215
from dance.utils import set_seed
1316

17+
18+
@register_preprocessor("feature", "cell") # NOTE: register any custom preprocessing function to be used for tuning
19+
class GaussRandProjFeature(BaseTransform):
20+
"""Custom preprocessing to extract cell feature via Gaussian random projection."""
21+
22+
_DISPLAY_ATTRS = ("n_components", "eps")
23+
24+
def __init__(self, n_components: int = 400, eps: float = 0.1, **kwargs):
25+
super().__init__(**kwargs)
26+
self.n_components = n_components
27+
self.eps = eps
28+
29+
def __call__(self, data):
30+
feat = data.get_feature(return_type="numpy")
31+
grp = GaussianRandomProjection(n_components=self.n_components, eps=self.eps)
32+
33+
self.logger.info(f"Start generateing cell feature via Gaussian random projection (d={self.n_components}).")
34+
data.data.obsm[self.out] = grp.fit_transform(feat)
35+
36+
return data
37+
38+
1439
if __name__ == "__main__":
1540
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
1641
parser.add_argument("--cache", action="store_true", help="Cache processed data.")

‎examples/tuning/cta_svm/pipeline_tuning_config.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@ pipeline:
66
- WeightedFeaturePCA
77
- CellPCA
88
- CellSVD
9+
- GaussRandProjFeature # Registered custom preprocessing func
910
params:
1011
n_components: 400
1112
out: feature.cell
13+
log_level: INFO
1214
default_params:
1315
WeightedFeaturePCA:
1416
split_name: train

0 commit comments

Comments
 (0)