Skip to content

Commit ffc84b2

Browse files
committed
Merge remote-tracking branch 'origin/celltype_annotation_automl' into celltype_annotation_automl
2 parents c611304 + 0ee8d40 commit ffc84b2

File tree

3 files changed

+33
-30
lines changed

3 files changed

+33
-30
lines changed

examples/atlas/get_result_web.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
2-
from functools import partial
32
import json
43
import os
4+
from functools import partial
55
from pathlib import Path
66

77
import numpy as np
@@ -120,18 +120,19 @@ def spilt_web(url: str):
120120
print("No match found")
121121

122122

123-
def get_metric(run,metric_col):
123+
def get_metric(run, metric_col):
124124
"""Extract metric value from wandb run.
125-
125+
126126
Parameters
127127
----------
128128
run : wandb.Run
129129
Weights & Biases run object
130-
130+
131131
Returns
132132
-------
133133
float
134134
Metric value or negative infinity if metric not found
135+
135136
"""
136137
if metric_col not in run.summary:
137138
return float('-inf') # Return -inf for missing metrics to handle in comparisons
@@ -140,7 +141,7 @@ def get_metric(run,metric_col):
140141

141142
def get_best_method(urls, metric_col="test_acc"):
142143
"""Find the best performing method across multiple wandb sweeps.
143-
144+
144145
Parameters
145146
----------
146147
urls : list
@@ -163,11 +164,11 @@ def get_best_method(urls, metric_col="test_acc"):
163164

164165
# Track run statistics
165166
run_states = {"all_total_runs": 0, "all_finished_runs": 0}
166-
167+
167168
for step_name, url in zip(step_names, urls):
168169
_, _, sweep_id = spilt_web(url)
169170
sweep = wandb.Api(timeout=1000).sweep(f"{entity}/{project}/{sweep_id}")
170-
171+
171172
# Update run statistics
172173
finished_runs = [run for run in sweep.runs if run.state == "finished"]
173174
run_states.update({
@@ -182,10 +183,10 @@ def get_best_method(urls, metric_col="test_acc"):
182183
best_run = max(sweep.runs, key=partial(get_metric, metric_col=metric_col)) if goal == "maximize" else \
183184
min(sweep.runs, key=partial(get_metric, metric_col=metric_col)) if goal == "minimize" else \
184185
None
185-
186+
186187
if best_run is None:
187188
raise RuntimeError("Optimization goal must be either 'minimize' or 'maximize'")
188-
189+
189190
if metric_col not in best_run.summary:
190191
continue
191192
if all_best_run is None:
@@ -323,10 +324,10 @@ def get_new_ans(tissue):
323324

324325
def write_ans(tissue, new_df, output_file=None):
325326
"""Process and write results for a specific tissue type to CSV.
326-
327+
327328
Handles merging of new results with existing data, including conflict detection
328329
for metric values.
329-
330+
330331
Parameters
331332
----------
332333
tissue : str
@@ -335,6 +336,7 @@ def write_ans(tissue, new_df, output_file=None):
335336
New results to be written
336337
output_file : str, optional
337338
Output file path. Defaults to 'sweep_results/{tissue}_ans.csv'
339+
338340
"""
339341
if output_file is None:
340342
output_file = f"sweep_results/{tissue}_ans.csv"
@@ -345,7 +347,7 @@ def write_ans(tissue, new_df, output_file=None):
345347

346348
# Reset index to ensure Dataset_id is a regular column
347349
new_df = new_df.reset_index(drop=True)
348-
350+
349351
# Process new data by merging rows with same Dataset_id
350352
new_df_processed = pd.DataFrame()
351353
for dataset_id in new_df['Dataset_id'].unique():

examples/atlas/sc_similarity_examples/sim_query_atlas.py

+18-17
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,19 @@ def is_match(config_str):
7676

7777
def is_matching_dict(yaml_str, target_dict):
7878
"""Compare YAML configuration with target dictionary.
79-
79+
8080
Parameters
8181
----------
8282
yaml_str : str
8383
YAML configuration string to parse
8484
target_dict : dict
8585
Target dictionary to compare against
86-
86+
8787
Returns
8888
-------
8989
bool
9090
True if dictionaries match, False otherwise
91+
9192
"""
9293
# Parse YAML string
9394
yaml_config = yaml.safe_load(yaml_str)
@@ -107,18 +108,19 @@ def is_matching_dict(yaml_str, target_dict):
107108

108109
def get_ans(query_dataset, method):
109110
"""Get test accuracy results for a given dataset and method.
110-
111+
111112
Parameters
112113
----------
113114
query_dataset : str
114115
Dataset identifier
115116
method : str
116117
Method name to analyze
117-
118+
118119
Returns
119120
-------
120121
pandas.DataFrame or None
121122
DataFrame containing test accuracy results, None if results don't exist
123+
122124
"""
123125
result_path = f"{file_root}/tuning/{method}/{query_dataset}/results/atlas/best_test_acc.csv"
124126
if not os.path.exists(result_path):
@@ -137,33 +139,33 @@ def get_ans(query_dataset, method):
137139

138140
def get_ans_from_cache(query_dataset, method):
139141
"""Get cached test accuracy results for atlas datasets.
140-
142+
141143
Parameters
142144
----------
143145
query_dataset : str
144146
Query dataset identifier
145147
method : str
146148
Method name to analyze
147-
149+
148150
Returns
149151
-------
150152
pandas.DataFrame
151153
DataFrame containing test accuracy results from cache
154+
152155
"""
153156
# Get best method from step2 of atlas datasets
154157
# Search accuracy according to best method (all values should exist)
155-
ans = pd.DataFrame(index=[method],
156-
columns=[f"{atlas_dataset}_from_cache" for atlas_dataset in atlas_datasets])
157-
158+
ans = pd.DataFrame(index=[method], columns=[f"{atlas_dataset}_from_cache" for atlas_dataset in atlas_datasets])
159+
158160
sweep_url = re.search(r"step2:([^|]+)",
159-
conf_data[conf_data["dataset_id"] == query_dataset][method].iloc[0]).group(1)
161+
conf_data[conf_data["dataset_id"] == query_dataset][method].iloc[0]).group(1)
160162
_, _, sweep_id = spilt_web(sweep_url)
161163
sweep = wandb.Api().sweep(f"{entity}/{project}/{sweep_id}")
162-
164+
163165
for atlas_dataset in atlas_datasets:
164166
best_yaml = conf_data[conf_data["dataset_id"] == atlas_dataset][f"{method}_best_yaml"].iloc[0]
165167
match_run = None
166-
168+
167169
# Find matching run configuration
168170
for run in sweep.runs:
169171
if isinstance(best_yaml, float) and np.isnan(best_yaml):
@@ -172,14 +174,13 @@ def get_ans_from_cache(query_dataset, method):
172174
if match_run is not None:
173175
raise ValueError("Multiple matching runs found when only one expected")
174176
match_run = run
175-
177+
176178
if match_run is None:
177179
logger.warning(f"No matching configuration found for {atlas_dataset} with method {method}")
178180
else:
179-
ans.loc[method, f"{atlas_dataset}_from_cache"] = (
180-
match_run.summary["test_acc"] if "test_acc" in match_run.summary else np.nan
181-
)
182-
181+
ans.loc[method, f"{atlas_dataset}_from_cache"] = (match_run.summary["test_acc"]
182+
if "test_acc" in match_run.summary else np.nan)
183+
183184
return ans
184185

185186

examples/atlas/test_get_result_web.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_write_ans(tmp_path):
108108
# 测试冲突情况的处理
109109
write_ans(tissue, conflict_df, output_file=output_file)
110110
final_df = pd.read_csv(output_file)
111-
111+
112112
# 验证新值被更新
113113
assert final_df[final_df['Dataset_id'] == 'dataset1']['method1_best_res'].iloc[0] == 0.7
114114

0 commit comments

Comments
 (0)