Skip to content

Commit 5a8a162

Browse files
committed
update ans
1 parent 00380dc commit 5a8a162

File tree

1 file changed

+52
-2
lines changed

1 file changed

+52
-2
lines changed

examples/atlas/sc_similarity_examples/sim_query_atlas.py

+52-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import argparse
2+
import os
3+
import re
24
from pathlib import Path
35

46
import pandas as pd
7+
import yaml
58

69
atlas_datasets = [
710
"01209dce-3575-4bed-b1df-129f57fbc031", "055ca631-6ffb-40de-815e-b931e10718c0",
@@ -20,6 +23,7 @@
2023

2124
from get_result_web import get_sweep_url, spilt_web
2225

26+
from dance import logger
2327
from dance.utils import try_import
2428

2529
file_root = str(Path(__file__).resolve().parent.parent)
@@ -85,8 +89,30 @@ def is_match(config_str):
8589
]
8690

8791

92+
def is_matching_dict(yaml_str, target_dict):
93+
94+
# 解析YAML字符串
95+
yaml_config = yaml.safe_load(yaml_str)
96+
97+
# 构建期望的字典格式
98+
expected_dict = {}
99+
for i, item in enumerate(yaml_config):
100+
if item['type'] == 'misc': # 跳过misc类型
101+
continue
102+
key = f"pipeline.{i}.{item['type']}"
103+
value = item['target']
104+
expected_dict[key] = value
105+
106+
# 直接比较两个字典是否相等
107+
return expected_dict == target_dict
108+
109+
88110
def get_ans(query_dataset, method):
89-
data = pd.read_csv(f"{file_root}/tuning/{method}/{query_dataset}/results/atlas/best_test_acc.csv")
111+
result_path = f"{file_root}/tuning/{method}/{query_dataset}/results/atlas/best_test_acc.csv"
112+
if not os.path.exists(result_path):
113+
logger.warning(f"{result_path} not exists")
114+
return None
115+
data = pd.read_csv(result_path)
90116
sweep_url = get_sweep_url(data)
91117
_, _, sweep_id = spilt_web(sweep_url)
92118
sweep = wandb.Api().sweep(f"{entity}/{project}/{sweep_id}")
@@ -97,11 +123,35 @@ def get_ans(query_dataset, method):
97123
return ans
98124

99125

126+
def get_ans_from_cache(query_dataset, method):
127+
#1:get best method from step2 of atlas datasets
128+
#2:search acc according to best method(需要注意的是,应该都是有值的,没有值的需要检查一下)
129+
ans = pd.DataFrame(index=method, columns=[f"{atlas_dataset}_from_cache" for atlas_dataset in atlas_datasets])
130+
sweep_url = re.search(r"step2:([^|]+)", conf_data[conf_data["Dataset_id"] == query_dataset][method]).group(1)
131+
_, _, sweep_id = spilt_web(sweep_url)
132+
sweep = wandb.Api().sweep(f"{entity}/{project}/{sweep_id}")
133+
for atlas_dataset in atlas_datasets:
134+
best_yaml = conf_data[conf_data["Dataset_id"] == atlas_dataset][f"{method}_method"]
135+
match_run = None
136+
for run in sweep.runs:
137+
if is_matching_dict(best_yaml, run.config):
138+
if match_run is not None:
139+
140+
match_run = run
141+
142+
# for
143+
# ans.loc[method, atlas_datasets[i]]
144+
145+
100146
ans_all = {}
101147
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
102-
parser.add_argument("--methods", default=["cta_actinn", "cta_scdeepsort"], nargs="+")
148+
parser.add_argument("--methods", default=["cta_actinn", "cta_scdeepsort", "cta_singlecellnet", "cta_celltypist"],
149+
nargs="+")
150+
parser.add_argument("--tissue", type=str)
103151
args = parser.parse_args()
104152
methods = args.methods
153+
tissue = args.tissue
154+
conf_data = pd.read_excel("Cell Type Annotation Atlas.xlsx", sheet_name=tissue)
105155
if __name__ == "__main__":
106156
for query_dataset in query_datasets:
107157
ans = []

0 commit comments

Comments
 (0)