1
1
import argparse
2
+ import os
3
+ import re
2
4
from pathlib import Path
3
5
4
6
import pandas as pd
7
+ import yaml
5
8
6
9
atlas_datasets = [
7
10
"01209dce-3575-4bed-b1df-129f57fbc031" , "055ca631-6ffb-40de-815e-b931e10718c0" ,
20
23
21
24
from get_result_web import get_sweep_url , spilt_web
22
25
26
+ from dance import logger
23
27
from dance .utils import try_import
24
28
25
29
file_root = str (Path (__file__ ).resolve ().parent .parent )
@@ -85,8 +89,30 @@ def is_match(config_str):
85
89
]
86
90
87
91
92
+ def is_matching_dict (yaml_str , target_dict ):
93
+
94
+ # 解析YAML字符串
95
+ yaml_config = yaml .safe_load (yaml_str )
96
+
97
+ # 构建期望的字典格式
98
+ expected_dict = {}
99
+ for i , item in enumerate (yaml_config ):
100
+ if item ['type' ] == 'misc' : # 跳过misc类型
101
+ continue
102
+ key = f"pipeline.{ i } .{ item ['type' ]} "
103
+ value = item ['target' ]
104
+ expected_dict [key ] = value
105
+
106
+ # 直接比较两个字典是否相等
107
+ return expected_dict == target_dict
108
+
109
+
88
110
def get_ans (query_dataset , method ):
89
- data = pd .read_csv (f"{ file_root } /tuning/{ method } /{ query_dataset } /results/atlas/best_test_acc.csv" )
111
+ result_path = f"{ file_root } /tuning/{ method } /{ query_dataset } /results/atlas/best_test_acc.csv"
112
+ if not os .path .exists (result_path ):
113
+ logger .warning (f"{ result_path } not exists" )
114
+ return None
115
+ data = pd .read_csv (result_path )
90
116
sweep_url = get_sweep_url (data )
91
117
_ , _ , sweep_id = spilt_web (sweep_url )
92
118
sweep = wandb .Api ().sweep (f"{ entity } /{ project } /{ sweep_id } " )
@@ -97,11 +123,35 @@ def get_ans(query_dataset, method):
97
123
return ans
98
124
99
125
126
+ def get_ans_from_cache (query_dataset , method ):
127
+ #1:get best method from step2 of atlas datasets
128
+ #2:search acc according to best method(需要注意的是,应该都是有值的,没有值的需要检查一下)
129
+ ans = pd .DataFrame (index = method , columns = [f"{ atlas_dataset } _from_cache" for atlas_dataset in atlas_datasets ])
130
+ sweep_url = re .search (r"step2:([^|]+)" , conf_data [conf_data ["Dataset_id" ] == query_dataset ][method ]).group (1 )
131
+ _ , _ , sweep_id = spilt_web (sweep_url )
132
+ sweep = wandb .Api ().sweep (f"{ entity } /{ project } /{ sweep_id } " )
133
+ for atlas_dataset in atlas_datasets :
134
+ best_yaml = conf_data [conf_data ["Dataset_id" ] == atlas_dataset ][f"{ method } _method" ]
135
+ match_run = None
136
+ for run in sweep .runs :
137
+ if is_matching_dict (best_yaml , run .config ):
138
+ if match_run is not None :
139
+
140
+ match_run = run
141
+
142
+ # for
143
+ # ans.loc[method, atlas_datasets[i]]
144
+
145
+
100
146
ans_all = {}
101
147
parser = argparse .ArgumentParser (formatter_class = argparse .ArgumentDefaultsHelpFormatter )
102
- parser .add_argument ("--methods" , default = ["cta_actinn" , "cta_scdeepsort" ], nargs = "+" )
148
+ parser .add_argument ("--methods" , default = ["cta_actinn" , "cta_scdeepsort" , "cta_singlecellnet" , "cta_celltypist" ],
149
+ nargs = "+" )
150
+ parser .add_argument ("--tissue" , type = str )
103
151
args = parser .parse_args ()
104
152
methods = args .methods
153
+ tissue = args .tissue
154
+ conf_data = pd .read_excel ("Cell Type Annotation Atlas.xlsx" , sheet_name = tissue )
105
155
if __name__ == "__main__" :
106
156
for query_dataset in query_datasets :
107
157
ans = []
0 commit comments