Skip to content

Commit

Permalink
update step2 cluster
Browse files Browse the repository at this point in the history
  • Loading branch information
xingzhongyu committed Jan 20, 2024
1 parent ba32644 commit 56e1291
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 12 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ wandb

test_automl/data
test_automl/test.py
*.pkl
2 changes: 1 addition & 1 deletion test_automl/fun2code.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

#TODO register more functions
fun2code_dict = {
"normalize_total": AnnDataTransform(sc.pp.normalize_total, target_sum=1e4),
"normalize_total": AnnDataTransform(sc.pp.normalize_total, target_sum=1e4, key_added="n_counts"),
"log1p": AnnDataTransform(sc.pp.log1p, base=2),
"scaleFeature": ScaleFeature(split_names="ALL", mode="standardize"),
"scTransform": ScTransformR(mirror_index=1),
Expand Down
27 changes: 20 additions & 7 deletions test_automl/step2_clustering_scdcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@
from step2_config import get_transforms, log_in_wandb, setStep2

from dance import logger
from dance.datasets.singlemodality import CellTypeAnnotationDataset, ClusteringDataset
from dance.modules.single_modality.cell_type_annotation.actinn import ACTINN
from dance.datasets.singlemodality import ClusteringDataset
from dance.modules.single_modality.clustering.scdcc import ScDCC
from dance.transforms.misc import Compose, SetConfig
from dance.transforms.preprocess import generate_random_pair
Expand All @@ -22,9 +21,9 @@ def train(config):
aris = []
for seed in range(config.seed, config.seed + config.num_runs):
set_seed(seed)

dataset = "10X_PBMC"
# Load data and perform necessary preprocessing
dataloader = ClusteringDataset("./test_automl/data", "10X_PBMC")
dataloader = ClusteringDataset("./test_automl/data", dataset=dataset)

transforms = get_transforms(config=config, set_data_config=False, save_raw=True)
if ("normalize" not in config.keys() or config.normalize != "normalize_total") or transforms is None:
Expand Down Expand Up @@ -64,8 +63,8 @@ def train(config):
# Build and train moodel
model = ScDCC(input_dim=in_dim, z_dim=config.z_dim, n_clusters=n_clusters, encodeLayer=config.encodeLayer,
decodeLayer=config.encodeLayer[::-1], sigma=config.sigma, gamma=config.gamma,
ml_weight=config.ml_weight, cl_weight=config.ml_weight, device=config.device,
pretrain_path=f"scdcc_{config.dataset}_pre.pkl")
ml_weight=config.ml_weight, cl_weight=config.ml_weight, device=device,
pretrain_path=f"scdcc_{dataset}_pre.pkl")
model.fit(inputs, y, lr=config.lr, batch_size=config.batch_size, epochs=config.epochs, ml_ind1=ml_ind1,
ml_ind2=ml_ind2, cl_ind1=cl_ind1, cl_ind2=cl_ind2, update_interval=config.update_interval,
tol=config.tol, pt_batch_size=config.batch_size, pt_lr=config.pretrain_lr,
Expand All @@ -77,7 +76,6 @@ def train(config):
aris.append(score)

print('scdcc')
print(config.dataset)
print(f'aris: {aris}')
print(f'aris: {np.mean(aris)} +/- {np.std(aris)}')
return ({"scores": np.mean(aris)})
Expand Down Expand Up @@ -118,6 +116,12 @@ def startSweep(parameters_dict) -> Tuple[Dict[str, Any], Callable[..., Any]]:
'gamma': {
'value': 1.0
},
'lr': {
'value': 0.01
},
'pretrain_lr': {
'value': 0.001
},
'ml_weight': {
'value': 1.0
},
Expand All @@ -135,6 +139,15 @@ def startSweep(parameters_dict) -> Tuple[Dict[str, Any], Callable[..., Any]]:
},
'ae_weight_file': {
'value': "AE_weights.pth.tar"
},
'pretrain_epochs': {
'value': 50
},
'epochs': {
'value': 500
},
'batch_size': {
'value': 256
}
})

Expand Down
2 changes: 1 addition & 1 deletion test_automl/step2_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import functools
from itertools import combinations

import wandb
from fun2code import fun2code_dict

import wandb
from dance.transforms.misc import Compose, SetConfig

#TODO register more functions and add more examples
Expand Down
8 changes: 5 additions & 3 deletions test_automl/step3_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@

import optuna
import scanpy as sc
import wandb
from fun2code import fun2code_dict
from optuna.integration.wandb import WeightsAndBiasesCallback

import wandb
from dance.transforms.cell_feature import CellPCA, CellSVD, WeightedFeaturePCA
from dance.transforms.filter import FilterGenesPercentile, FilterGenesRegression
from dance.transforms.interface import AnnDataTransform
Expand Down Expand Up @@ -108,11 +108,13 @@ def normalize_total(method_name: str, trial: optuna.Trial):
max_fraction = trial.suggest_float(method_name + "max_fraction", 0.04, 0.1)
return AnnDataTransform(sc.pp.normalize_total,
target_sum=trial.suggest_categorical(method_name + "target_sum", [1e4, 1e5, 1e6]),
exclude_highly_expressed=exclude_highly_expressed, max_fraction=max_fraction)
exclude_highly_expressed=exclude_highly_expressed, max_fraction=max_fraction,
key_added="n_counts")
else:
return AnnDataTransform(sc.pp.normalize_total,
target_sum=trial.suggest_categorical(method_name + "target_sum", [1e4, 1e5, 1e6]),
exclude_highly_expressed=exclude_highly_expressed, max_fraction=max_fraction)
exclude_highly_expressed=exclude_highly_expressed, max_fraction=max_fraction,
key_added="n_counts")


@set_method_name
Expand Down

0 comments on commit 56e1291

Please sign in to comment.