diff --git a/.gitignore b/.gitignore index 3471a1ae..db77ec21 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,4 @@ wandb test_automl/data test_automl/test.py +*.pkl diff --git a/test_automl/fun2code.py b/test_automl/fun2code.py index 7108dec4..01e4c657 100644 --- a/test_automl/fun2code.py +++ b/test_automl/fun2code.py @@ -8,7 +8,7 @@ #TODO register more functions fun2code_dict = { - "normalize_total": AnnDataTransform(sc.pp.normalize_total, target_sum=1e4), + "normalize_total": AnnDataTransform(sc.pp.normalize_total, target_sum=1e4, key_added="n_counts"), "log1p": AnnDataTransform(sc.pp.log1p, base=2), "scaleFeature": ScaleFeature(split_names="ALL", mode="standardize"), "scTransform": ScTransformR(mirror_index=1), diff --git a/test_automl/step2_clustering_scdcc.py b/test_automl/step2_clustering_scdcc.py index d83e482b..325ac0fc 100644 --- a/test_automl/step2_clustering_scdcc.py +++ b/test_automl/step2_clustering_scdcc.py @@ -7,8 +7,7 @@ from step2_config import get_transforms, log_in_wandb, setStep2 from dance import logger -from dance.datasets.singlemodality import CellTypeAnnotationDataset, ClusteringDataset -from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN +from dance.datasets.singlemodality import ClusteringDataset from dance.modules.single_modality.clustering.scdcc import ScDCC from dance.transforms.misc import Compose, SetConfig from dance.transforms.preprocess import generate_random_pair @@ -22,9 +21,9 @@ def train(config): aris = [] for seed in range(config.seed, config.seed + config.num_runs): set_seed(seed) - + dataset = "10X_PBMC" # Load data and perform necessary preprocessing - dataloader = ClusteringDataset("./test_automl/data", "10X_PBMC") + dataloader = ClusteringDataset("./test_automl/data", dataset=dataset) transforms = get_transforms(config=config, set_data_config=False, save_raw=True) if ("normalize" not in config.keys() or config.normalize != "normalize_total") or transforms is None: @@ -64,8 +63,8 @@ def train(config): # Build and train moodel model = ScDCC(input_dim=in_dim, z_dim=config.z_dim, n_clusters=n_clusters, encodeLayer=config.encodeLayer, decodeLayer=config.encodeLayer[::-1], sigma=config.sigma, gamma=config.gamma, - ml_weight=config.ml_weight, cl_weight=config.ml_weight, device=config.device, - pretrain_path=f"scdcc_{config.dataset}_pre.pkl") + ml_weight=config.ml_weight, cl_weight=config.ml_weight, device=device, + pretrain_path=f"scdcc_{dataset}_pre.pkl") model.fit(inputs, y, lr=config.lr, batch_size=config.batch_size, epochs=config.epochs, ml_ind1=ml_ind1, ml_ind2=ml_ind2, cl_ind1=cl_ind1, cl_ind2=cl_ind2, update_interval=config.update_interval, tol=config.tol, pt_batch_size=config.batch_size, pt_lr=config.pretrain_lr, @@ -77,7 +76,6 @@ def train(config): aris.append(score) print('scdcc') - print(config.dataset) print(f'aris: {aris}') print(f'aris: {np.mean(aris)} +/- {np.std(aris)}') return ({"scores": np.mean(aris)}) @@ -118,6 +116,12 @@ def startSweep(parameters_dict) -> Tuple[Dict[str, Any], Callable[..., Any]]: 'gamma': { 'value': 1.0 }, + 'lr': { + 'value': 0.01 + }, + 'pretrain_lr': { + 'value': 0.001 + }, 'ml_weight': { 'value': 1.0 }, @@ -135,6 +139,15 @@ def startSweep(parameters_dict) -> Tuple[Dict[str, Any], Callable[..., Any]]: }, 'ae_weight_file': { 'value': "AE_weights.pth.tar" + }, + 'pretrain_epochs': { + 'value': 50 + }, + 'epochs': { + 'value': 500 + }, + 'batch_size': { + 'value': 256 } }) diff --git a/test_automl/step2_config.py b/test_automl/step2_config.py index b6dd3d18..aa0407ba 100644 --- a/test_automl/step2_config.py +++ b/test_automl/step2_config.py @@ -1,9 +1,9 @@ import functools from itertools import combinations -import wandb from fun2code import fun2code_dict +import wandb from dance.transforms.misc import Compose, SetConfig #TODO register more functions and add more examples diff --git a/test_automl/step3_config.py b/test_automl/step3_config.py index 055bbb9f..edd361e6 100644 --- a/test_automl/step3_config.py +++ b/test_automl/step3_config.py @@ -3,10 +3,10 @@ import optuna import scanpy as sc -import wandb from fun2code import fun2code_dict from optuna.integration.wandb import WeightsAndBiasesCallback +import wandb from dance.transforms.cell_feature import CellPCA, CellSVD, WeightedFeaturePCA from dance.transforms.filter import FilterGenesPercentile, FilterGenesRegression from dance.transforms.interface import AnnDataTransform @@ -108,11 +108,13 @@ def normalize_total(method_name: str, trial: optuna.Trial): max_fraction = trial.suggest_float(method_name + "max_fraction", 0.04, 0.1) return AnnDataTransform(sc.pp.normalize_total, target_sum=trial.suggest_categorical(method_name + "target_sum", [1e4, 1e5, 1e6]), - exclude_highly_expressed=exclude_highly_expressed, max_fraction=max_fraction) + exclude_highly_expressed=exclude_highly_expressed, max_fraction=max_fraction, + key_added="n_counts") else: return AnnDataTransform(sc.pp.normalize_total, target_sum=trial.suggest_categorical(method_name + "target_sum", [1e4, 1e5, 1e6]), - exclude_highly_expressed=exclude_highly_expressed, max_fraction=max_fraction) + exclude_highly_expressed=exclude_highly_expressed, max_fraction=max_fraction, + key_added="n_counts") @set_method_name