|
| 1 | +#normalize_per_cell是一定要选的,因为需要n_counts |
| 2 | +import os |
| 3 | +from typing import Any, Callable, Dict, Tuple |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +from step2_config import get_transforms, log_in_wandb, setStep2 |
| 8 | + |
| 9 | +from dance import logger |
| 10 | +from dance.datasets.singlemodality import CellTypeAnnotationDataset, ClusteringDataset |
| 11 | +from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN |
| 12 | +from dance.modules.single_modality.clustering.scdcc import ScDCC |
| 13 | +from dance.transforms.misc import Compose, SetConfig |
| 14 | +from dance.transforms.preprocess import generate_random_pair |
| 15 | +from dance.utils import set_seed |
| 16 | + |
| 17 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 18 | + |
| 19 | + |
| 20 | +@log_in_wandb(config=None) |
| 21 | +def train(config): |
| 22 | + aris = [] |
| 23 | + for seed in range(config.seed, config.seed + config.num_runs): |
| 24 | + set_seed(seed) |
| 25 | + |
| 26 | + # Load data and perform necessary preprocessing |
| 27 | + dataloader = ClusteringDataset("./test_automl/data", "10X_PBMC") |
| 28 | + |
| 29 | + transforms = get_transforms(config=config, set_data_config=False, save_raw=True) |
| 30 | + if ("normalize" not in config.keys() or config.normalize != "normalize_total") or transforms is None: |
| 31 | + logger.warning("skip transforms") |
| 32 | + return {"scores": 0} |
| 33 | + transforms.append( |
| 34 | + SetConfig({ |
| 35 | + "feature_channel": [None, None, "n_counts"], |
| 36 | + "feature_channel_type": ["X", "raw_X", "obs"], |
| 37 | + "label_channel": "Group" |
| 38 | + })) |
| 39 | + preprocessing_pipeline = Compose(*transforms, log_level="INFO") |
| 40 | + data = dataloader.load_data(transform=preprocessing_pipeline, cache=config.cache) |
| 41 | + |
| 42 | + # inputs: x, x_raw, n_counts |
| 43 | + inputs, y = data.get_train_data() |
| 44 | + n_clusters = len(np.unique(y)) |
| 45 | + in_dim = inputs[0].shape[1] |
| 46 | + |
| 47 | + # Generate random pairs |
| 48 | + if not os.path.exists(config.label_cells_files): |
| 49 | + indx = np.arange(len(y)) |
| 50 | + np.random.shuffle(indx) |
| 51 | + label_cell_indx = indx[0:int(np.ceil(config.label_cells * len(y)))] |
| 52 | + else: |
| 53 | + label_cell_indx = np.loadtxt(config.label_cells_files, dtype=np.int) |
| 54 | + |
| 55 | + if config.n_pairwise > 0: |
| 56 | + ml_ind1, ml_ind2, cl_ind1, cl_ind2, error_num = generate_random_pair(y, label_cell_indx, config.n_pairwise, |
| 57 | + config.n_pairwise_error) |
| 58 | + print("Must link paris: %d" % ml_ind1.shape[0]) |
| 59 | + print("Cannot link paris: %d" % cl_ind1.shape[0]) |
| 60 | + print("Number of error pairs: %d" % error_num) |
| 61 | + else: |
| 62 | + ml_ind1, ml_ind2, cl_ind1, cl_ind2 = np.array([]), np.array([]), np.array([]), np.array([]) |
| 63 | + |
| 64 | + # Build and train moodel |
| 65 | + model = ScDCC(input_dim=in_dim, z_dim=config.z_dim, n_clusters=n_clusters, encodeLayer=config.encodeLayer, |
| 66 | + decodeLayer=config.encodeLayer[::-1], sigma=config.sigma, gamma=config.gamma, |
| 67 | + ml_weight=config.ml_weight, cl_weight=config.ml_weight, device=config.device, |
| 68 | + pretrain_path=f"scdcc_{config.dataset}_pre.pkl") |
| 69 | + model.fit(inputs, y, lr=config.lr, batch_size=config.batch_size, epochs=config.epochs, ml_ind1=ml_ind1, |
| 70 | + ml_ind2=ml_ind2, cl_ind1=cl_ind1, cl_ind2=cl_ind2, update_interval=config.update_interval, |
| 71 | + tol=config.tol, pt_batch_size=config.batch_size, pt_lr=config.pretrain_lr, |
| 72 | + pt_epochs=config.pretrain_epochs) |
| 73 | + |
| 74 | + # Evaluate model predictions |
| 75 | + score = model.score(None, y) |
| 76 | + print(f"{score=:.4f}") |
| 77 | + aris.append(score) |
| 78 | + |
| 79 | + print('scdcc') |
| 80 | + print(config.dataset) |
| 81 | + print(f'aris: {aris}') |
| 82 | + print(f'aris: {np.mean(aris)} +/- {np.std(aris)}') |
| 83 | + return ({"scores": np.mean(aris)}) |
| 84 | + |
| 85 | + |
| 86 | +def startSweep(parameters_dict) -> Tuple[Dict[str, Any], Callable[..., Any]]: |
| 87 | + parameters_dict.update({ |
| 88 | + 'seed': { |
| 89 | + 'value': 0 |
| 90 | + }, |
| 91 | + 'num_runs': { |
| 92 | + 'value': 1 |
| 93 | + }, |
| 94 | + 'cache': { |
| 95 | + 'value': True |
| 96 | + }, |
| 97 | + 'label_cells_files': { |
| 98 | + 'value': 'label_10X_PBMC.txt' |
| 99 | + }, |
| 100 | + 'label_cells': { |
| 101 | + 'value': 0.1 |
| 102 | + }, |
| 103 | + 'n_pairwise': { |
| 104 | + 'value': 0 |
| 105 | + }, |
| 106 | + 'n_pairwise_error': { |
| 107 | + 'value': 0 |
| 108 | + }, |
| 109 | + 'z_dim': { |
| 110 | + 'value': 32 |
| 111 | + }, |
| 112 | + 'encodeLayer': { |
| 113 | + 'value': [256, 64] |
| 114 | + }, |
| 115 | + 'sigma': { |
| 116 | + 'value': 2.5 |
| 117 | + }, |
| 118 | + 'gamma': { |
| 119 | + 'value': 1.0 |
| 120 | + }, |
| 121 | + 'ml_weight': { |
| 122 | + 'value': 1.0 |
| 123 | + }, |
| 124 | + 'cl_weight': { |
| 125 | + 'value': 1.0 |
| 126 | + }, |
| 127 | + 'update_interval': { |
| 128 | + 'value': 1.0 |
| 129 | + }, |
| 130 | + 'tol': { |
| 131 | + 'value': 0.00001 |
| 132 | + }, |
| 133 | + 'ae_weights': { |
| 134 | + 'value': None |
| 135 | + }, |
| 136 | + 'ae_weight_file': { |
| 137 | + 'value': "AE_weights.pth.tar" |
| 138 | + } |
| 139 | + }) |
| 140 | + |
| 141 | + sweep_config = {'method': 'grid'} |
| 142 | + sweep_config['parameters'] = parameters_dict |
| 143 | + metric = {'name': 'scores', 'goal': 'maximize'} |
| 144 | + |
| 145 | + sweep_config['metric'] = metric |
| 146 | + return sweep_config, train #Return function configuration and training function |
| 147 | + |
| 148 | + |
| 149 | +if __name__ == "__main__": |
| 150 | + """get_function_combinations.""" |
| 151 | + function_list = setStep2(startSweep, original_list=["gene_filter", "cell_filter", "normalize"]) |
| 152 | + for func in function_list: |
| 153 | + func() |
0 commit comments