Skip to content

Commit 96fe34a

Browse files
committed
step2 cluster
1 parent 057fea8 commit 96fe34a

9 files changed

+218
-25
lines changed

test_automl/fun2code.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dance.transforms.cell_feature import CellPCA, CellSVD, WeightedFeaturePCA
44
from dance.transforms.filter import FilterGenesPercentile, FilterGenesRegression
55
from dance.transforms.interface import AnnDataTransform
6+
from dance.transforms.misc import SaveRaw
67
from dance.transforms.normalize import ScaleFeature, ScTransformR
78

89
#TODO register more functions
@@ -19,5 +20,6 @@
1920
"cell_svd": CellSVD(),
2021
"cell_weighted_pca": WeightedFeaturePCA(split_name="train"),
2122
"cell_pca": CellPCA(),
22-
# "filter_cell_by_count":AnnDataTransform(sc.pp.filter_cells,min_genes=1)
23+
"filter_cell_by_count": AnnDataTransform(sc.pp.filter_cells, min_genes=1),
24+
"save_raw": SaveRaw()
2325
} #funcion 2 code

test_automl/step2_cell_type_annotation_actinn_example.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
import numpy as np
44
import torch
5-
from step2_config import get_preprocessing_pipeline, log_in_wandb, setStep2
5+
from step2_config import get_transforms, log_in_wandb, setStep2
66

7+
from dance import logger
78
from dance.datasets.singlemodality import CellTypeAnnotationDataset
89
from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN
10+
from dance.transforms.misc import Compose
911
from dance.utils import set_seed
1012

1113
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -15,9 +17,11 @@
1517
def train(config):
1618

1719
model = ACTINN(hidden_dims=config.hidden_dims, lambd=config.lambd, device=device)
18-
preprocessing_pipeline = get_preprocessing_pipeline(config=config)
19-
if preprocessing_pipeline is None:
20+
transforms = get_transforms(config=config)
21+
if transforms is None:
22+
logger.warning("skip transforms")
2023
return {"scores": 0}
24+
preprocessing_pipeline = Compose(*transforms, log_level="INFO")
2125
train_dataset = [753, 3285]
2226
test_dataset = [2695]
2327
tissue = "Brain"
@@ -75,6 +79,6 @@ def startSweep(parameters_dict) -> Tuple[Dict[str, Any], Callable[..., Any]]:
7579

7680
if __name__ == "__main__":
7781
"""get_function_combinations."""
78-
function_list = setStep2(startSweep, original_list=["normalize_total", "gene_filter", "gene_dim_reduction"])
82+
function_list = setStep2(startSweep, original_list=["normalize", "gene_filter", "gene_dim_reduction"])
7983
for func in function_list:
8084
func()

test_automl/step2_clustering_scdcc.py

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#normalize_per_cell是一定要选的,因为需要n_counts
2+
import os
3+
from typing import Any, Callable, Dict, Tuple
4+
5+
import numpy as np
6+
import torch
7+
from step2_config import get_transforms, log_in_wandb, setStep2
8+
9+
from dance import logger
10+
from dance.datasets.singlemodality import CellTypeAnnotationDataset, ClusteringDataset
11+
from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN
12+
from dance.modules.single_modality.clustering.scdcc import ScDCC
13+
from dance.transforms.misc import Compose, SetConfig
14+
from dance.transforms.preprocess import generate_random_pair
15+
from dance.utils import set_seed
16+
17+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18+
19+
20+
@log_in_wandb(config=None)
21+
def train(config):
22+
aris = []
23+
for seed in range(config.seed, config.seed + config.num_runs):
24+
set_seed(seed)
25+
26+
# Load data and perform necessary preprocessing
27+
dataloader = ClusteringDataset("./test_automl/data", "10X_PBMC")
28+
29+
transforms = get_transforms(config=config, set_data_config=False, save_raw=True)
30+
if ("normalize" not in config.keys() or config.normalize != "normalize_total") or transforms is None:
31+
logger.warning("skip transforms")
32+
return {"scores": 0}
33+
transforms.append(
34+
SetConfig({
35+
"feature_channel": [None, None, "n_counts"],
36+
"feature_channel_type": ["X", "raw_X", "obs"],
37+
"label_channel": "Group"
38+
}))
39+
preprocessing_pipeline = Compose(*transforms, log_level="INFO")
40+
data = dataloader.load_data(transform=preprocessing_pipeline, cache=config.cache)
41+
42+
# inputs: x, x_raw, n_counts
43+
inputs, y = data.get_train_data()
44+
n_clusters = len(np.unique(y))
45+
in_dim = inputs[0].shape[1]
46+
47+
# Generate random pairs
48+
if not os.path.exists(config.label_cells_files):
49+
indx = np.arange(len(y))
50+
np.random.shuffle(indx)
51+
label_cell_indx = indx[0:int(np.ceil(config.label_cells * len(y)))]
52+
else:
53+
label_cell_indx = np.loadtxt(config.label_cells_files, dtype=np.int)
54+
55+
if config.n_pairwise > 0:
56+
ml_ind1, ml_ind2, cl_ind1, cl_ind2, error_num = generate_random_pair(y, label_cell_indx, config.n_pairwise,
57+
config.n_pairwise_error)
58+
print("Must link paris: %d" % ml_ind1.shape[0])
59+
print("Cannot link paris: %d" % cl_ind1.shape[0])
60+
print("Number of error pairs: %d" % error_num)
61+
else:
62+
ml_ind1, ml_ind2, cl_ind1, cl_ind2 = np.array([]), np.array([]), np.array([]), np.array([])
63+
64+
# Build and train moodel
65+
model = ScDCC(input_dim=in_dim, z_dim=config.z_dim, n_clusters=n_clusters, encodeLayer=config.encodeLayer,
66+
decodeLayer=config.encodeLayer[::-1], sigma=config.sigma, gamma=config.gamma,
67+
ml_weight=config.ml_weight, cl_weight=config.ml_weight, device=config.device,
68+
pretrain_path=f"scdcc_{config.dataset}_pre.pkl")
69+
model.fit(inputs, y, lr=config.lr, batch_size=config.batch_size, epochs=config.epochs, ml_ind1=ml_ind1,
70+
ml_ind2=ml_ind2, cl_ind1=cl_ind1, cl_ind2=cl_ind2, update_interval=config.update_interval,
71+
tol=config.tol, pt_batch_size=config.batch_size, pt_lr=config.pretrain_lr,
72+
pt_epochs=config.pretrain_epochs)
73+
74+
# Evaluate model predictions
75+
score = model.score(None, y)
76+
print(f"{score=:.4f}")
77+
aris.append(score)
78+
79+
print('scdcc')
80+
print(config.dataset)
81+
print(f'aris: {aris}')
82+
print(f'aris: {np.mean(aris)} +/- {np.std(aris)}')
83+
return ({"scores": np.mean(aris)})
84+
85+
86+
def startSweep(parameters_dict) -> Tuple[Dict[str, Any], Callable[..., Any]]:
87+
parameters_dict.update({
88+
'seed': {
89+
'value': 0
90+
},
91+
'num_runs': {
92+
'value': 1
93+
},
94+
'cache': {
95+
'value': True
96+
},
97+
'label_cells_files': {
98+
'value': 'label_10X_PBMC.txt'
99+
},
100+
'label_cells': {
101+
'value': 0.1
102+
},
103+
'n_pairwise': {
104+
'value': 0
105+
},
106+
'n_pairwise_error': {
107+
'value': 0
108+
},
109+
'z_dim': {
110+
'value': 32
111+
},
112+
'encodeLayer': {
113+
'value': [256, 64]
114+
},
115+
'sigma': {
116+
'value': 2.5
117+
},
118+
'gamma': {
119+
'value': 1.0
120+
},
121+
'ml_weight': {
122+
'value': 1.0
123+
},
124+
'cl_weight': {
125+
'value': 1.0
126+
},
127+
'update_interval': {
128+
'value': 1.0
129+
},
130+
'tol': {
131+
'value': 0.00001
132+
},
133+
'ae_weights': {
134+
'value': None
135+
},
136+
'ae_weight_file': {
137+
'value': "AE_weights.pth.tar"
138+
}
139+
})
140+
141+
sweep_config = {'method': 'grid'}
142+
sweep_config['parameters'] = parameters_dict
143+
metric = {'name': 'scores', 'goal': 'maximize'}
144+
145+
sweep_config['metric'] = metric
146+
return sweep_config, train #Return function configuration and training function
147+
148+
149+
if __name__ == "__main__":
150+
"""get_function_combinations."""
151+
function_list = setStep2(startSweep, original_list=["gene_filter", "cell_filter", "normalize"])
152+
for func in function_list:
153+
func()

test_automl/step2_config.py

+17-9
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import functools
22
from itertools import combinations
33

4-
import wandb
54
from fun2code import fun2code_dict
65

6+
import wandb
77
from dance.transforms.misc import Compose, SetConfig
88

99
#TODO register more functions and add more examples
@@ -17,6 +17,9 @@
1717
},
1818
"gene_dim_reduction": {
1919
"values": ["cell_svd", "cell_weighted_pca", "cell_pca"]
20+
},
21+
"cell_filter": {
22+
"values": ["filter_cell_by_count"]
2023
}
2124
} #Functions registered in the preprocessing process
2225

@@ -25,29 +28,34 @@ def getFunConfig(selected_keys=None):
2528
"""Get the config that needs to be optimized and the number of rounds."""
2629
global pipline2fun_dict
2730
pipline2fun_dict_subset = {key: pipline2fun_dict[key] for key in selected_keys}
31+
print(pipline2fun_dict)
2832
count = 1
2933
for _, pipline_values in pipline2fun_dict_subset.items():
3034
count *= len(pipline_values['values'])
3135
return pipline2fun_dict_subset, count
3236

3337

34-
def get_preprocessing_pipeline(config=None):
38+
def get_transforms(config=None, set_data_config=True, save_raw=False):
3539
"""Obtain the Compose of the preprocessing function according to the preprocessing
3640
process."""
3741
if ("normalize" not in config.keys() or config.normalize
3842
!= "log1p") and ("gene_filter" in config.keys() and config.gene_filter == "highly_variable_genes"):
3943

4044
return None
4145
transforms = []
42-
transforms.append(fun2code_dict[config.normalize]) if "normalize" in config.keys() else None
4346
transforms.append(fun2code_dict[config.gene_filter]) if "gene_filter" in config.keys() else None
47+
transforms.append(fun2code_dict[config.cell_filter]) if "cell_filter" in config.keys() else None
48+
if save_raw:
49+
transforms.append(fun2code_dict["save_raw"])
50+
transforms.append(fun2code_dict[config.normalize]) if "normalize" in config.keys() else None
4451
transforms.append(fun2code_dict[config.gene_dim_reduction]) if "gene_dim_reduction" in config.keys() else None
45-
data_config = {"label_channel": "cell_type"}
46-
if "gene_dim_reduction" in config.keys():
47-
data_config.update({"feature_channel": fun2code_dict[config.gene_dim_reduction].name})
48-
transforms.append(SetConfig(data_config))
49-
preprocessing_pipeline = Compose(*transforms, log_level="INFO")
50-
return preprocessing_pipeline
52+
53+
if set_data_config:
54+
data_config = {"label_channel": "cell_type"}
55+
if "gene_dim_reduction" in config.keys():
56+
data_config.update({"feature_channel": fun2code_dict[config.gene_dim_reduction].name})
57+
transforms.append(SetConfig(data_config))
58+
return transforms
5159

5260

5361
def sweepDecorator(selected_keys=None, project="pytorch-cell_type_annotation_ACTINN"):

test_automl/step2_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def test_get_preprocessing_pipeline():
2+
pass #不一定需要,因为主要都是装饰器函数

test_automl/step3_cell_type_annotation_actinn_example.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
import numpy as np
22
import optuna
33
import torch
4-
from step3_config import get_optimizer, get_preprocessing_pipeline
4+
from step3_config import get_optimizer, get_transforms
55

6+
from dance import logger
67
from dance.datasets.singlemodality import CellTypeAnnotationDataset
78
from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN
9+
from dance.transforms.misc import Compose
810
from dance.utils import set_seed
911

1012
fun_list = ["log1p", "filter_gene_by_count"]
@@ -30,7 +32,11 @@ def objective(trial):
3032
species = "mouse"
3133
dataloader = CellTypeAnnotationDataset(train_dataset=train_dataset, test_dataset=test_dataset, tissue=tissue,
3234
species=species, data_dir="./test_automl/data")
33-
preprocessing_pipeline = get_preprocessing_pipeline(trial=trial, fun_list=fun_list)
35+
transforms = get_transforms(trial=trial, fun_list=fun_list)
36+
if transforms is None:
37+
logger.warning("skip transforms")
38+
return {"scores": 0}
39+
preprocessing_pipeline = Compose(*transforms, log_level="INFO")
3440
data = dataloader.load_data(transform=preprocessing_pipeline, cache=True)
3541

3642
# Obtain training and testing data

test_automl/step3_clustering_scdcc.py

Whitespace-only changes.

test_automl/step3_config.py

+25-9
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
import optuna
55
import scanpy as sc
6-
import wandb
76
from fun2code import fun2code_dict
87
from optuna.integration.wandb import WeightsAndBiasesCallback
98

9+
import wandb
1010
from dance.transforms.cell_feature import CellPCA, CellSVD, WeightedFeaturePCA
1111
from dance.transforms.filter import FilterGenesPercentile, FilterGenesRegression
1212
from dance.transforms.interface import AnnDataTransform
@@ -115,6 +115,20 @@ def normalize_total(method_name: str, trial: optuna.Trial):
115115
exclude_highly_expressed=exclude_highly_expressed, max_fraction=max_fraction)
116116

117117

118+
@set_method_name
119+
def filter_cell_by_count(method_name: str, trial: optuna.Trial):
120+
method = trial.suggest_categorical(method_name + "method", ['min_counts', 'min_genes', 'max_counts', 'max_genes'])
121+
if method == "min_counts":
122+
num = trial.suggest_int(method_name + "num", 2, 10)
123+
if method == "min_genes":
124+
num = trial.suggest_int(method_name + "num", 2, 10)
125+
if method == "max_counts":
126+
num = trial.suggest_int(method_name + "num", 500, 1000)
127+
if method == "max_genes":
128+
num = trial.suggest_int(method_name + "num", 500, 1000)
129+
return AnnDataTransform(sc.pp.filter_cells, **{method: num})
130+
131+
118132
# # 获取当前文件中的所有函数
119133
# functions = [(name,obj) for name, obj in inspect.getmembers(
120134
# sys.modules[__name__]) if inspect.isfunction(obj)]
@@ -127,20 +141,22 @@ def normalize_total(method_name: str, trial: optuna.Trial):
127141
# setattr(__name__, name, set_method_name(function))
128142

129143

130-
def get_preprocessing_pipeline(trial, fun_list):
144+
def get_transforms(trial, fun_list, set_data_config=True):
131145
"""Obtain the Compose of the preprocessing function according to the preprocessing
132146
function."""
133147
transforms = []
134148
for f_str in fun_list:
135149
fun_i = eval(f_str)
136150
transforms.append(fun_i(trial))
137-
data_config = {"label_channel": "cell_type"}
138-
feature_name = {"cell_svd", "cell_weighted_pca", "cell_pca"} & set(fun_list)
139-
if feature_name:
140-
data_config.update({"feature_channel": fun2code_dict[feature_name].name})
141-
transforms.append(SetConfig(data_config))
142-
preprocessing_pipeline = Compose(*transforms, log_level="INFO")
143-
return preprocessing_pipeline
151+
if "highly_variable_genes" in fun_list and "log1p" not in fun_list[:fun_list.index('"highly_variable_genes"')]:
152+
return None
153+
if set_data_config:
154+
data_config = {"label_channel": "cell_type"}
155+
feature_name = {"cell_svd", "cell_weighted_pca", "cell_pca"} & set(fun_list)
156+
if feature_name:
157+
data_config.update({"feature_channel": fun2code_dict[feature_name].name})
158+
transforms.append(SetConfig(data_config))
159+
return transforms
144160

145161

146162
def log_in_wandb(wandbc=None):

test_automl/step3_test.py

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
def test_get_preprocessing_pipeline():
2+
pass #不一定需要,因为主要都是装饰器函数

0 commit comments

Comments
 (0)