1
1
import argparse
2
+ from functools import partial
2
3
import json
3
4
import os
4
5
from pathlib import Path
@@ -119,9 +120,27 @@ def spilt_web(url: str):
119
120
print ("No match found" )
120
121
121
122
123
+ def get_metric (run ,metric_col ):
124
+ """Extract metric value from wandb run.
125
+
126
+ Parameters
127
+ ----------
128
+ run : wandb.Run
129
+ Weights & Biases run object
130
+
131
+ Returns
132
+ -------
133
+ float
134
+ Metric value or negative infinity if metric not found
135
+ """
136
+ if metric_col not in run .summary :
137
+ return float ('-inf' ) # Return -inf for missing metrics to handle in comparisons
138
+ return run .summary [metric_col ]
139
+
140
+
122
141
def get_best_method (urls , metric_col = "test_acc" ):
123
142
"""Find the best performing method across multiple wandb sweeps.
124
-
143
+
125
144
Parameters
126
145
----------
127
146
urls : list
@@ -142,29 +161,31 @@ def get_best_method(urls, metric_col="test_acc"):
142
161
all_best_step_name = None
143
162
step_names = ["step2" , "step3_0" , "step3_1" , "step3_2" ]
144
163
145
- def get_metric (run ):
146
- if metric_col not in run .summary :
147
- return float ('-inf' ) #TODO 根据metric_col的名称来判断
148
- else :
149
- return run .summary [metric_col ]
150
-
164
+ # Track run statistics
151
165
run_states = {"all_total_runs" : 0 , "all_finished_runs" : 0 }
166
+
152
167
for step_name , url in zip (step_names , urls ):
153
168
_ , _ , sweep_id = spilt_web (url )
154
169
sweep = wandb .Api (timeout = 1000 ).sweep (f"{ entity } /{ project } /{ sweep_id } " )
170
+
171
+ # Update run statistics
172
+ finished_runs = [run for run in sweep .runs if run .state == "finished" ]
155
173
run_states .update ({
156
174
f"{ step_name } _total_runs" : len (sweep .runs ),
157
- f"{ step_name } _finished_runs" : len ([ run for run in sweep . runs if run . state == "finished" ] )
175
+ f"{ step_name } _finished_runs" : len (finished_runs )
158
176
})
159
177
run_states ["all_total_runs" ] += run_states [f"{ step_name } _total_runs" ]
160
178
run_states ["all_finished_runs" ] += run_states [f"{ step_name } _finished_runs" ]
179
+
180
+ # Find best run based on optimization goal
161
181
goal = sweep .config ["metric" ]["goal" ]
162
- if goal == "maximize" :
163
- best_run = max (sweep .runs , key = get_metric )
164
- elif goal == "minimize" :
165
- best_run = min (sweep .runs , key = get_metric )
166
- else :
167
- raise RuntimeError ("choose goal in ['minimize','maximize']" )
182
+ best_run = max (sweep .runs , key = partial (get_metric , metric_col = metric_col )) if goal == "maximize" else \
183
+ min (sweep .runs , key = partial (get_metric , metric_col = metric_col )) if goal == "minimize" else \
184
+ None
185
+
186
+ if best_run is None :
187
+ raise RuntimeError ("Optimization goal must be either 'minimize' or 'maximize'" )
188
+
168
189
if metric_col not in best_run .summary :
169
190
continue
170
191
if all_best_run is None :
@@ -301,78 +322,91 @@ def get_new_ans(tissue):
301
322
302
323
303
324
def write_ans (tissue , new_df , output_file = None ):
304
- """Process and write results for a specific tissue type to CSV."""
325
+ """Process and write results for a specific tissue type to CSV.
326
+
327
+ Handles merging of new results with existing data, including conflict detection
328
+ for metric values.
329
+
330
+ Parameters
331
+ ----------
332
+ tissue : str
333
+ Tissue type being processed
334
+ new_df : pd.DataFrame
335
+ New results to be written
336
+ output_file : str, optional
337
+ Output file path. Defaults to 'sweep_results/{tissue}_ans.csv'
338
+ """
305
339
if output_file is None :
306
340
output_file = f"sweep_results/{ tissue } _ans.csv"
307
341
308
- # 重置索引,确保Dataset_id成为一个普通列
309
- if 'Dataset_id' in new_df .columns :
310
- new_df = new_df .reset_index (drop = True )
311
- else :
312
- logger .warning ("Dataset_id not in new_df.columns" )
342
+ if 'Dataset_id' not in new_df .columns :
343
+ logger .warning ("Dataset_id column missing in input DataFrame" )
313
344
return
314
- # 处理新数据,合并相同Dataset_id的非NA值
315
- new_df_processed = pd .DataFrame ()
316
345
346
+ # Reset index to ensure Dataset_id is a regular column
347
+ new_df = new_df .reset_index (drop = True )
348
+
349
+ # Process new data by merging rows with same Dataset_id
350
+ new_df_processed = pd .DataFrame ()
317
351
for dataset_id in new_df ['Dataset_id' ].unique ():
318
352
row_data = {'Dataset_id' : dataset_id }
319
353
subset = new_df [new_df ['Dataset_id' ] == dataset_id ]
320
354
for col in new_df .columns :
321
355
if col != 'Dataset_id' :
322
- values = subset [col ].dropna ().unique ()
323
- if len (values ) > 0 :
324
- row_data [col ] = values [0 ]
356
+ non_null_values = subset [col ].dropna ().unique ()
357
+ if len (non_null_values ) > 0 :
358
+ row_data [col ] = non_null_values [0 ]
325
359
new_df_processed = pd .concat ([new_df_processed , pd .DataFrame ([row_data ])])
326
360
327
361
if os .path .exists (output_file ):
328
- # 读取现有数据,不将任何列设置为索引
362
+ # Read existing data without setting any column as index
329
363
existing_df = pd .read_csv (output_file )
330
364
331
- # 清理可能存在的Unnamed列
365
+ # Clean up any potential 'Unnamed' columns
332
366
existing_df = existing_df .loc [:, ~ existing_df .columns .str .contains ('^Unnamed' )]
333
367
334
- # 创建合并后的DataFrame
368
+ # Create merged DataFrame
335
369
merged_df = existing_df .copy ()
336
370
337
- # 添加新数据中的列(如果不存在)
371
+ # Add columns from new data if they don't exist
338
372
for col in new_df_processed .columns :
339
373
if col not in merged_df .columns :
340
374
merged_df [col ] = pd .NA
341
375
342
- # 对每个Dataset_id进行合并和冲突检查
376
+ # Merge and check conflicts for each Dataset_id
343
377
for _ , new_row in new_df_processed .iterrows ():
344
378
dataset_id = new_row ['Dataset_id' ]
345
379
existing_row = merged_df [merged_df ['Dataset_id' ] == dataset_id ]
346
380
347
381
if len (existing_row ) > 0 :
348
- # 检查每一列的值
382
+ # Check values for each column
349
383
for col in new_df_processed .columns :
350
384
if col == 'Dataset_id' :
351
385
continue
352
386
new_value = new_row [col ]
353
387
existing_value = existing_row [col ].iloc [0 ] if len (existing_row ) > 0 else pd .NA
354
388
355
- # 只对_best_res结尾的列进行冲突检查
389
+ # Only check for conflicts in columns ending with _best_res
356
390
if str (col ).endswith ("_best_res" ):
357
391
if pd .notna (new_value ) and pd .notna (existing_value ):
358
392
if abs (float (new_value ) - float (existing_value )) > 1e-10 :
359
- print (f"结果冲突 : Dataset { dataset_id } , Column { col } \n "
360
- f"现有值 : { existing_value } \n 新值 : { new_value } " )
393
+ print (f"Result conflict : Dataset { dataset_id } , Column { col } \n "
394
+ f"Existing value : { existing_value } \n New value : { new_value } " )
361
395
else :
362
- print (f"提示: 发现重复值 Dataset { dataset_id } , Column { col } \n "
363
- f"现有值和新值都是 : { new_value } " )
396
+ print (f"Note: Duplicate value found for Dataset { dataset_id } , Column { col } \n "
397
+ f"Both existing and new values are : { new_value } " )
364
398
365
- # 如果新值不是NaN,更新该值
399
+ # Update value if new value is not NaN
366
400
if pd .notna (new_value ):
367
401
merged_df .loc [merged_df ['Dataset_id' ] == dataset_id , col ] = new_value
368
402
else :
369
- # 如果是新的Dataset_id,直接添加整行
403
+ # If it's a new Dataset_id, add the entire row
370
404
merged_df = pd .concat ([merged_df , pd .DataFrame ([new_row ])], ignore_index = True )
371
405
372
- # 保存合并后的数据,不包含索引
406
+ # Save merged data without index
373
407
merged_df .to_csv (output_file , index = False )
374
408
else :
375
- # 如果文件不存在,直接保存处理后的新数据,不包含索引
409
+ # If file doesn't exist, save processed new data directly without index
376
410
new_df_processed .to_csv (output_file , index = False )
377
411
378
412
0 commit comments