Skip to content

Commit 1f242f4

Browse files
committed
update test
1 parent 8037b1c commit 1f242f4

File tree

1 file changed

+29
-39
lines changed

1 file changed

+29
-39
lines changed

tests/test_get_result_web.py

+29-39
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dance.settings import DANCEDIR
99

1010
sys.path.append(str(DANCEDIR))
11-
from examples.atlas.get_result_web import check_exist, check_identical_strings, spilt_web
11+
from examples.atlas.get_result_web import check_exist, check_identical_strings, spilt_web, write_ans
1212

1313

1414
# 测试 check_identical_strings 函数
@@ -90,72 +90,62 @@ def mock_settings(tmp_path, monkeypatch):
9090

9191

9292
def test_write_ans(mock_settings):
93-
# 使用mock_settings而不是创建新的临时目录
9493
sweep_results_dir = mock_settings / "sweep_results"
9594
sweep_results_dir.mkdir(parents=True)
95+
output_file = sweep_results_dir / "heart_ans.csv"
9696

97-
# 创建测试数据
97+
# 创建初始数据
9898
existing_data = pd.DataFrame({
99-
'Dataset_id': ['dataset1', 'dataset2', 'dataset3'],
100-
'method1': ['url1', 'url2', 'url3'],
101-
'method1_best_yaml': ['yaml1', 'yaml2', 'yaml3'],
102-
'method1_best_res': [0.8, 0.9, 0.7]
99+
'Dataset_id': ['dataset1', 'dataset2'],
100+
'cta_actinn': ['url1', 'url2'],
101+
'cta_actinn_best_yaml': ['yaml1', 'yaml2'],
102+
'cta_actinn_best_res': [0.8, 0.7]
103103
})
104+
existing_data.to_csv(output_file)
104105

106+
# 测试数据:包含较低分数和较高分数的情况
105107
new_data = pd.DataFrame({
106-
'Dataset_id': ['dataset2', 'dataset3', 'dataset4'], # 部分重叠的数据
107-
'method1': ['url2_new', 'url3_new', 'url4'],
108-
'method1_best_yaml': ['yaml2_new', 'yaml3_new', 'yaml4'],
109-
'method1_best_res': [0.9, 0.7, 0.85] # dataset2和dataset3的结果与现有数据相同
108+
'Dataset_id': ['dataset1', 'dataset2'],
109+
'cta_actinn': ['url1_new', 'url2_new'],
110+
'cta_actinn_best_yaml': ['yaml1_new', 'yaml2_new'],
111+
'cta_actinn_best_res': [0.9, 0.6] # dataset1更高分数,dataset2更低分数
110112
})
111113

112-
# 写入现有数据
113-
output_file = sweep_results_dir / "heart_ans.csv"
114-
existing_data.to_csv(output_file)
115-
116-
# 测试写入新数据
117-
from examples.atlas.get_result_web import write_ans
118114
write_ans("heart", new_data, output_file)
119115

120-
# 读取合并后的结果
121-
merged_df = pd.read_csv(output_file, index_col=0)
122-
123116
# 验证结果
124-
assert len(merged_df) == 4 # 应该有4个唯一的Dataset_id
125-
assert 'dataset4' in merged_df.index # 新数据被添加
126-
assert merged_df.loc['dataset2', 'method1'] == 'url2_new' # 更新了已存在的数据
127-
128-
# 测试结果冲突的情况
129-
conflicting_data = pd.DataFrame({
130-
'Dataset_id': ['dataset1'],
131-
'method1': ['url1_new'],
132-
'method1_best_yaml': ['yaml1_new'],
133-
'method1_best_res': [0.95] # 不同的结果值
134-
})
117+
result_df = pd.read_csv(output_file)
118+
119+
# 验证高分数更新成功
120+
dataset1_row = result_df[result_df['Dataset_id'] == 'dataset1'].iloc[0]
121+
assert dataset1_row['cta_actinn_best_res'] == 0.9
122+
assert dataset1_row['cta_actinn'] == 'url1_new'
123+
assert dataset1_row['cta_actinn_best_yaml'] == 'yaml1_new'
135124

136-
# 验证冲突数据会引发异常
137-
with pytest.raises(ValueError, match="结果冲突"):
138-
write_ans("heart", conflicting_data)
125+
# 验证低分数保持不变
126+
dataset2_row = result_df[result_df['Dataset_id'] == 'dataset2'].iloc[0]
127+
assert dataset2_row['cta_actinn_best_res'] == 0.7
128+
assert dataset2_row['cta_actinn'] == 'url2'
129+
assert dataset2_row['cta_actinn_best_yaml'] == 'yaml2'
139130

140131

141132
# 测试完全新的数据写入(文件不存在的情况)
142133
def test_write_ans_new_file(mock_settings):
143134
# 使用mock_settings而不是创建新的临时目录
144135
sweep_results_dir = mock_settings / "sweep_results"
145136
sweep_results_dir.mkdir(parents=True)
137+
output_file = sweep_results_dir / "new_heart_ans.csv"
146138

147139
new_data = pd.DataFrame({
148140
'Dataset_id': ['dataset1', 'dataset2'],
149-
'method1': ['url1', 'url2'],
150-
'method1_best_yaml': ['yaml1', 'yaml2'],
151-
'method1_best_res': [0.8, 0.9]
141+
'cta_actinn': ['url1', 'url2'],
142+
'cta_actinn_best_yaml': ['yaml1', 'yaml2'],
143+
'cta_actinn_best_res': [0.8, 0.9]
152144
})
153145

154146
# 测试写入新文件
155-
from examples.atlas.get_result_web import write_ans
156147

157148
# 验证文件被创建并包含正确的数据
158-
output_file = sweep_results_dir / "heart_ans.csv"
159149
write_ans("heart", new_data, output_file)
160150
assert output_file.exists()
161151

0 commit comments

Comments
 (0)