|
3 | 3 | from typing import get_args
|
4 | 4 |
|
5 | 5 | import wandb
|
| 6 | +from sklearn.random_projection import GaussianRandomProjection |
6 | 7 |
|
7 | 8 | from dance import logger
|
8 | 9 | from dance.datasets.singlemodality import CellTypeAnnotationDataset
|
9 | 10 | from dance.modules.single_modality.cell_type_annotation.svm import SVM
|
10 | 11 | from dance.pipeline import PipelinePlaner
|
| 12 | +from dance.registry import register_preprocessor |
| 13 | +from dance.transforms.base import BaseTransform |
11 | 14 | from dance.typing import LogLevel
|
12 | 15 | from dance.utils import set_seed
|
13 | 16 |
|
| 17 | + |
| 18 | +@register_preprocessor("feature", "cell") # NOTE: register any custom preprocessing function to be used for tuning |
| 19 | +class GaussRandProjFeature(BaseTransform): |
| 20 | + """Custom preprocessing to extract cell feature via Gaussian random projection.""" |
| 21 | + |
| 22 | + _DISPLAY_ATTRS = ("n_components", "eps") |
| 23 | + |
| 24 | + def __init__(self, n_components: int = 400, eps: float = 0.1, **kwargs): |
| 25 | + super().__init__(**kwargs) |
| 26 | + self.n_components = n_components |
| 27 | + self.eps = eps |
| 28 | + |
| 29 | + def __call__(self, data): |
| 30 | + feat = data.get_feature(return_type="numpy") |
| 31 | + grp = GaussianRandomProjection(n_components=self.n_components, eps=self.eps) |
| 32 | + |
| 33 | + self.logger.info(f"Start generateing cell feature via Gaussian random projection (d={self.n_components}).") |
| 34 | + data.data.obsm[self.out] = grp.fit_transform(feat) |
| 35 | + |
| 36 | + return data |
| 37 | + |
| 38 | + |
14 | 39 | if __name__ == "__main__":
|
15 | 40 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
16 | 41 | parser.add_argument("--cache", action="store_true", help="Cache processed data.")
|
|
0 commit comments