Skip to content

Commit c611304

Browse files
committed
Merge remote-tracking branch 'origin/celltype_annotation_automl' into celltype_annotation_automl
2 parents 3e24fcd + af633a2 commit c611304

File tree

4 files changed

+150
-68
lines changed

4 files changed

+150
-68
lines changed

examples/atlas/get_result_web.py

+74-40
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import argparse
2+
from functools import partial
23
import json
34
import os
45
from pathlib import Path
@@ -119,9 +120,27 @@ def spilt_web(url: str):
119120
print("No match found")
120121

121122

123+
def get_metric(run,metric_col):
124+
"""Extract metric value from wandb run.
125+
126+
Parameters
127+
----------
128+
run : wandb.Run
129+
Weights & Biases run object
130+
131+
Returns
132+
-------
133+
float
134+
Metric value or negative infinity if metric not found
135+
"""
136+
if metric_col not in run.summary:
137+
return float('-inf') # Return -inf for missing metrics to handle in comparisons
138+
return run.summary[metric_col]
139+
140+
122141
def get_best_method(urls, metric_col="test_acc"):
123142
"""Find the best performing method across multiple wandb sweeps.
124-
143+
125144
Parameters
126145
----------
127146
urls : list
@@ -142,29 +161,31 @@ def get_best_method(urls, metric_col="test_acc"):
142161
all_best_step_name = None
143162
step_names = ["step2", "step3_0", "step3_1", "step3_2"]
144163

145-
def get_metric(run):
146-
if metric_col not in run.summary:
147-
return float('-inf') #TODO 根据metric_col的名称来判断
148-
else:
149-
return run.summary[metric_col]
150-
164+
# Track run statistics
151165
run_states = {"all_total_runs": 0, "all_finished_runs": 0}
166+
152167
for step_name, url in zip(step_names, urls):
153168
_, _, sweep_id = spilt_web(url)
154169
sweep = wandb.Api(timeout=1000).sweep(f"{entity}/{project}/{sweep_id}")
170+
171+
# Update run statistics
172+
finished_runs = [run for run in sweep.runs if run.state == "finished"]
155173
run_states.update({
156174
f"{step_name}_total_runs": len(sweep.runs),
157-
f"{step_name}_finished_runs": len([run for run in sweep.runs if run.state == "finished"])
175+
f"{step_name}_finished_runs": len(finished_runs)
158176
})
159177
run_states["all_total_runs"] += run_states[f"{step_name}_total_runs"]
160178
run_states["all_finished_runs"] += run_states[f"{step_name}_finished_runs"]
179+
180+
# Find best run based on optimization goal
161181
goal = sweep.config["metric"]["goal"]
162-
if goal == "maximize":
163-
best_run = max(sweep.runs, key=get_metric)
164-
elif goal == "minimize":
165-
best_run = min(sweep.runs, key=get_metric)
166-
else:
167-
raise RuntimeError("choose goal in ['minimize','maximize']")
182+
best_run = max(sweep.runs, key=partial(get_metric, metric_col=metric_col)) if goal == "maximize" else \
183+
min(sweep.runs, key=partial(get_metric, metric_col=metric_col)) if goal == "minimize" else \
184+
None
185+
186+
if best_run is None:
187+
raise RuntimeError("Optimization goal must be either 'minimize' or 'maximize'")
188+
168189
if metric_col not in best_run.summary:
169190
continue
170191
if all_best_run is None:
@@ -301,78 +322,91 @@ def get_new_ans(tissue):
301322

302323

303324
def write_ans(tissue, new_df, output_file=None):
304-
"""Process and write results for a specific tissue type to CSV."""
325+
"""Process and write results for a specific tissue type to CSV.
326+
327+
Handles merging of new results with existing data, including conflict detection
328+
for metric values.
329+
330+
Parameters
331+
----------
332+
tissue : str
333+
Tissue type being processed
334+
new_df : pd.DataFrame
335+
New results to be written
336+
output_file : str, optional
337+
Output file path. Defaults to 'sweep_results/{tissue}_ans.csv'
338+
"""
305339
if output_file is None:
306340
output_file = f"sweep_results/{tissue}_ans.csv"
307341

308-
# 重置索引,确保Dataset_id成为一个普通列
309-
if 'Dataset_id' in new_df.columns:
310-
new_df = new_df.reset_index(drop=True)
311-
else:
312-
logger.warning("Dataset_id not in new_df.columns")
342+
if 'Dataset_id' not in new_df.columns:
343+
logger.warning("Dataset_id column missing in input DataFrame")
313344
return
314-
# 处理新数据,合并相同Dataset_id的非NA值
315-
new_df_processed = pd.DataFrame()
316345

346+
# Reset index to ensure Dataset_id is a regular column
347+
new_df = new_df.reset_index(drop=True)
348+
349+
# Process new data by merging rows with same Dataset_id
350+
new_df_processed = pd.DataFrame()
317351
for dataset_id in new_df['Dataset_id'].unique():
318352
row_data = {'Dataset_id': dataset_id}
319353
subset = new_df[new_df['Dataset_id'] == dataset_id]
320354
for col in new_df.columns:
321355
if col != 'Dataset_id':
322-
values = subset[col].dropna().unique()
323-
if len(values) > 0:
324-
row_data[col] = values[0]
356+
non_null_values = subset[col].dropna().unique()
357+
if len(non_null_values) > 0:
358+
row_data[col] = non_null_values[0]
325359
new_df_processed = pd.concat([new_df_processed, pd.DataFrame([row_data])])
326360

327361
if os.path.exists(output_file):
328-
# 读取现有数据,不将任何列设置为索引
362+
# Read existing data without setting any column as index
329363
existing_df = pd.read_csv(output_file)
330364

331-
# 清理可能存在的Unnamed列
365+
# Clean up any potential 'Unnamed' columns
332366
existing_df = existing_df.loc[:, ~existing_df.columns.str.contains('^Unnamed')]
333367

334-
# 创建合并后的DataFrame
368+
# Create merged DataFrame
335369
merged_df = existing_df.copy()
336370

337-
# 添加新数据中的列(如果不存在)
371+
# Add columns from new data if they don't exist
338372
for col in new_df_processed.columns:
339373
if col not in merged_df.columns:
340374
merged_df[col] = pd.NA
341375

342-
# 对每个Dataset_id进行合并和冲突检查
376+
# Merge and check conflicts for each Dataset_id
343377
for _, new_row in new_df_processed.iterrows():
344378
dataset_id = new_row['Dataset_id']
345379
existing_row = merged_df[merged_df['Dataset_id'] == dataset_id]
346380

347381
if len(existing_row) > 0:
348-
# 检查每一列的值
382+
# Check values for each column
349383
for col in new_df_processed.columns:
350384
if col == 'Dataset_id':
351385
continue
352386
new_value = new_row[col]
353387
existing_value = existing_row[col].iloc[0] if len(existing_row) > 0 else pd.NA
354388

355-
# 只对_best_res结尾的列进行冲突检查
389+
# Only check for conflicts in columns ending with _best_res
356390
if str(col).endswith("_best_res"):
357391
if pd.notna(new_value) and pd.notna(existing_value):
358392
if abs(float(new_value) - float(existing_value)) > 1e-10:
359-
print(f"结果冲突: Dataset {dataset_id}, Column {col}\n"
360-
f"现有值: {existing_value}\n新值: {new_value}")
393+
print(f"Result conflict: Dataset {dataset_id}, Column {col}\n"
394+
f"Existing value: {existing_value}\nNew value: {new_value}")
361395
else:
362-
print(f"提示: 发现重复值 Dataset {dataset_id}, Column {col}\n"
363-
f"现有值和新值都是: {new_value}")
396+
print(f"Note: Duplicate value found for Dataset {dataset_id}, Column {col}\n"
397+
f"Both existing and new values are: {new_value}")
364398

365-
# 如果新值不是NaN,更新该值
399+
# Update value if new value is not NaN
366400
if pd.notna(new_value):
367401
merged_df.loc[merged_df['Dataset_id'] == dataset_id, col] = new_value
368402
else:
369-
# 如果是新的Dataset_id,直接添加整行
403+
# If it's a new Dataset_id, add the entire row
370404
merged_df = pd.concat([merged_df, pd.DataFrame([new_row])], ignore_index=True)
371405

372-
# 保存合并后的数据,不包含索引
406+
# Save merged data without index
373407
merged_df.to_csv(output_file, index=False)
374408
else:
375-
# 如果文件不存在,直接保存处理后的新数据,不包含索引
409+
# If file doesn't exist, save processed new data directly without index
376410
new_df_processed.to_csv(output_file, index=False)
377411

378412

examples/atlas/sc_similarity_examples/example_usage_anndata.py

-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from torch.utils.data import TensorDataset
1313

1414
from dance.atlas.sc_similarity.anndata_similarity import AnnDataSimilarity, get_anndata
15-
from dance.otdd.pytorch.distance import DatasetDistance
1615
from dance.utils import set_seed
1716

1817
# target_files = [

examples/atlas/sc_similarity_examples/sim_query_atlas.py

+64-17
Original file line numberDiff line numberDiff line change
@@ -75,24 +75,51 @@ def is_match(config_str):
7575

7676

7777
def is_matching_dict(yaml_str, target_dict):
78-
79-
# 解析YAML字符串
78+
"""Compare YAML configuration with target dictionary.
79+
80+
Parameters
81+
----------
82+
yaml_str : str
83+
YAML configuration string to parse
84+
target_dict : dict
85+
Target dictionary to compare against
86+
87+
Returns
88+
-------
89+
bool
90+
True if dictionaries match, False otherwise
91+
"""
92+
# Parse YAML string
8093
yaml_config = yaml.safe_load(yaml_str)
8194

82-
# 构建期望的字典格式
95+
# Build expected dictionary format
8396
expected_dict = {}
8497
for i, item in enumerate(yaml_config):
85-
if item['type'] in ['misc', 'graph.cell'] or item['target'] == 'SCNFeature': # 跳过misc类型
98+
# Skip misc and graph.cell types, or SCNFeature targets
99+
if item['type'] in ['misc', 'graph.cell'] or item['target'] == 'SCNFeature':
86100
continue
87101
key = f"pipeline.{i}.{item['type']}"
88102
value = item['target']
89103
expected_dict[key] = value
90104

91-
# 直接比较两个字典是否相等
92105
return expected_dict == target_dict
93106

94107

95108
def get_ans(query_dataset, method):
109+
"""Get test accuracy results for a given dataset and method.
110+
111+
Parameters
112+
----------
113+
query_dataset : str
114+
Dataset identifier
115+
method : str
116+
Method name to analyze
117+
118+
Returns
119+
-------
120+
pandas.DataFrame or None
121+
DataFrame containing test accuracy results, None if results don't exist
122+
"""
96123
result_path = f"{file_root}/tuning/{method}/{query_dataset}/results/atlas/best_test_acc.csv"
97124
if not os.path.exists(result_path):
98125
logger.warning(f"{result_path} not exists")
@@ -109,30 +136,50 @@ def get_ans(query_dataset, method):
109136

110137

111138
def get_ans_from_cache(query_dataset, method):
112-
#1:get best method from step2 of atlas datasets
113-
#2:search acc according to best method(需要注意的是,应该都是有值的,没有值的需要检查一下)
114-
ans = pd.DataFrame(index=[method], columns=[f"{atlas_dataset}_from_cache" for atlas_dataset in atlas_datasets])
115-
print(conf_data[conf_data["dataset_id"] == query_dataset][method])
139+
"""Get cached test accuracy results for atlas datasets.
140+
141+
Parameters
142+
----------
143+
query_dataset : str
144+
Query dataset identifier
145+
method : str
146+
Method name to analyze
147+
148+
Returns
149+
-------
150+
pandas.DataFrame
151+
DataFrame containing test accuracy results from cache
152+
"""
153+
# Get best method from step2 of atlas datasets
154+
# Search accuracy according to best method (all values should exist)
155+
ans = pd.DataFrame(index=[method],
156+
columns=[f"{atlas_dataset}_from_cache" for atlas_dataset in atlas_datasets])
157+
116158
sweep_url = re.search(r"step2:([^|]+)",
117-
conf_data[conf_data["dataset_id"] == query_dataset][method].iloc[0]).group(1)
159+
conf_data[conf_data["dataset_id"] == query_dataset][method].iloc[0]).group(1)
118160
_, _, sweep_id = spilt_web(sweep_url)
119161
sweep = wandb.Api().sweep(f"{entity}/{project}/{sweep_id}")
162+
120163
for atlas_dataset in atlas_datasets:
121164
best_yaml = conf_data[conf_data["dataset_id"] == atlas_dataset][f"{method}_best_yaml"].iloc[0]
122165
match_run = None
166+
167+
# Find matching run configuration
123168
for run in sweep.runs:
124-
if type(best_yaml) == float and np.isnan(best_yaml):
169+
if isinstance(best_yaml, float) and np.isnan(best_yaml):
125170
continue
126171
if is_matching_dict(best_yaml, run.config):
127172
if match_run is not None:
128-
raise ValueError("match_run只能被赋值一次")
129-
else:
130-
match_run = run
173+
raise ValueError("Multiple matching runs found when only one expected")
174+
match_run = run
175+
131176
if match_run is None:
132-
logger.warning(f"{atlas_dataset}{method}下未找到匹配")
177+
logger.warning(f"No matching configuration found for {atlas_dataset} with method {method}")
133178
else:
134-
ans.loc[method, f"{atlas_dataset}_from_cache"] = match_run.summary[
135-
"test_acc"] if "test_acc" in match_run.summary else np.nan
179+
ans.loc[method, f"{atlas_dataset}_from_cache"] = (
180+
match_run.summary["test_acc"] if "test_acc" in match_run.summary else np.nan
181+
)
182+
136183
return ans
137184

138185

examples/atlas/test_get_result_web.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -88,27 +88,29 @@ def test_write_ans(tmp_path):
8888
# 测试正常更新
8989
write_ans(tissue, new_df, output_file=output_file)
9090
# 读取更新后的文件
91-
updated_df = pd.read_csv(output_file, index_col=0)
92-
print(updated_df)
91+
updated_df = pd.read_csv(output_file)
9392

9493
# 验证结果
9594
assert len(updated_df) == 3 # 应该有3个不同的数据集
96-
assert 'dataset3' in updated_df.index.values
97-
assert updated_df.loc['dataset2', 'method1_best_res'] == 0.9
98-
assert updated_df.loc['dataset2', 'method2_best_res'] == 0.95
95+
assert 'dataset3' in updated_df['Dataset_id'].values
96+
assert updated_df[updated_df['Dataset_id'] == 'dataset2']['method1_best_res'].iloc[0] == 0.9
97+
assert updated_df[updated_df['Dataset_id'] == 'dataset2']['method2_best_res'].iloc[0] == 0.95
9998

100-
# 测试结果冲突
99+
# 测试结果冲突情况(现在应该更新值而不是抛出错误)
101100
conflict_data = [{
102101
'Dataset_id': 'dataset1',
103102
'method1': 'url1_new',
104103
'method1_best_yaml': 'yaml1_new',
105-
'method1_best_res': 0.7 # 与现有的0.8不同,应该引发冲突
104+
'method1_best_res': 0.7 # 与现有的0.8不同,应该打印警告而不是报错
106105
}]
107106
conflict_df = pd.DataFrame(conflict_data)
108107

109-
# 验证冲突检测
110-
with pytest.raises(ValueError, match="结果冲突"):
111-
write_ans(tissue, conflict_df, output_file=output_file)
108+
# 测试冲突情况的处理
109+
write_ans(tissue, conflict_df, output_file=output_file)
110+
final_df = pd.read_csv(output_file)
111+
112+
# 验证新值被更新
113+
assert final_df[final_df['Dataset_id'] == 'dataset1']['method1_best_res'].iloc[0] == 0.7
112114

113115

114116
def test_write_ans_new_file(tmp_path):

0 commit comments

Comments
 (0)