|
8 | 8 | from dance.settings import DANCEDIR
|
9 | 9 |
|
10 | 10 | sys.path.append(str(DANCEDIR))
|
11 |
| -from examples.atlas.get_result_web import check_exist, check_identical_strings, spilt_web |
| 11 | +from examples.atlas.get_result_web import check_exist, check_identical_strings, spilt_web, write_ans |
12 | 12 |
|
13 | 13 |
|
14 | 14 | # 测试 check_identical_strings 函数
|
@@ -90,72 +90,62 @@ def mock_settings(tmp_path, monkeypatch):
|
90 | 90 |
|
91 | 91 |
|
92 | 92 | def test_write_ans(mock_settings):
|
93 |
| - # 使用mock_settings而不是创建新的临时目录 |
94 | 93 | sweep_results_dir = mock_settings / "sweep_results"
|
95 | 94 | sweep_results_dir.mkdir(parents=True)
|
| 95 | + output_file = sweep_results_dir / "heart_ans.csv" |
96 | 96 |
|
97 |
| - # 创建测试数据 |
| 97 | + # 创建初始数据 |
98 | 98 | existing_data = pd.DataFrame({
|
99 |
| - 'Dataset_id': ['dataset1', 'dataset2', 'dataset3'], |
100 |
| - 'method1': ['url1', 'url2', 'url3'], |
101 |
| - 'method1_best_yaml': ['yaml1', 'yaml2', 'yaml3'], |
102 |
| - 'method1_best_res': [0.8, 0.9, 0.7] |
| 99 | + 'Dataset_id': ['dataset1', 'dataset2'], |
| 100 | + 'cta_actinn': ['url1', 'url2'], |
| 101 | + 'cta_actinn_best_yaml': ['yaml1', 'yaml2'], |
| 102 | + 'cta_actinn_best_res': [0.8, 0.7] |
103 | 103 | })
|
| 104 | + existing_data.to_csv(output_file) |
104 | 105 |
|
| 106 | + # 测试数据:包含较低分数和较高分数的情况 |
105 | 107 | new_data = pd.DataFrame({
|
106 |
| - 'Dataset_id': ['dataset2', 'dataset3', 'dataset4'], # 部分重叠的数据 |
107 |
| - 'method1': ['url2_new', 'url3_new', 'url4'], |
108 |
| - 'method1_best_yaml': ['yaml2_new', 'yaml3_new', 'yaml4'], |
109 |
| - 'method1_best_res': [0.9, 0.7, 0.85] # dataset2和dataset3的结果与现有数据相同 |
| 108 | + 'Dataset_id': ['dataset1', 'dataset2'], |
| 109 | + 'cta_actinn': ['url1_new', 'url2_new'], |
| 110 | + 'cta_actinn_best_yaml': ['yaml1_new', 'yaml2_new'], |
| 111 | + 'cta_actinn_best_res': [0.9, 0.6] # dataset1更高分数,dataset2更低分数 |
110 | 112 | })
|
111 | 113 |
|
112 |
| - # 写入现有数据 |
113 |
| - output_file = sweep_results_dir / "heart_ans.csv" |
114 |
| - existing_data.to_csv(output_file) |
115 |
| - |
116 |
| - # 测试写入新数据 |
117 |
| - from examples.atlas.get_result_web import write_ans |
118 | 114 | write_ans("heart", new_data, output_file)
|
119 | 115 |
|
120 |
| - # 读取合并后的结果 |
121 |
| - merged_df = pd.read_csv(output_file, index_col=0) |
122 |
| - |
123 | 116 | # 验证结果
|
124 |
| - assert len(merged_df) == 4 # 应该有4个唯一的Dataset_id |
125 |
| - assert 'dataset4' in merged_df.index # 新数据被添加 |
126 |
| - assert merged_df.loc['dataset2', 'method1'] == 'url2_new' # 更新了已存在的数据 |
127 |
| - |
128 |
| - # 测试结果冲突的情况 |
129 |
| - conflicting_data = pd.DataFrame({ |
130 |
| - 'Dataset_id': ['dataset1'], |
131 |
| - 'method1': ['url1_new'], |
132 |
| - 'method1_best_yaml': ['yaml1_new'], |
133 |
| - 'method1_best_res': [0.95] # 不同的结果值 |
134 |
| - }) |
| 117 | + result_df = pd.read_csv(output_file) |
| 118 | + |
| 119 | + # 验证高分数更新成功 |
| 120 | + dataset1_row = result_df[result_df['Dataset_id'] == 'dataset1'].iloc[0] |
| 121 | + assert dataset1_row['cta_actinn_best_res'] == 0.9 |
| 122 | + assert dataset1_row['cta_actinn'] == 'url1_new' |
| 123 | + assert dataset1_row['cta_actinn_best_yaml'] == 'yaml1_new' |
135 | 124 |
|
136 |
| - # 验证冲突数据会引发异常 |
137 |
| - with pytest.raises(ValueError, match="结果冲突"): |
138 |
| - write_ans("heart", conflicting_data) |
| 125 | + # 验证低分数保持不变 |
| 126 | + dataset2_row = result_df[result_df['Dataset_id'] == 'dataset2'].iloc[0] |
| 127 | + assert dataset2_row['cta_actinn_best_res'] == 0.7 |
| 128 | + assert dataset2_row['cta_actinn'] == 'url2' |
| 129 | + assert dataset2_row['cta_actinn_best_yaml'] == 'yaml2' |
139 | 130 |
|
140 | 131 |
|
141 | 132 | # 测试完全新的数据写入(文件不存在的情况)
|
142 | 133 | def test_write_ans_new_file(mock_settings):
|
143 | 134 | # 使用mock_settings而不是创建新的临时目录
|
144 | 135 | sweep_results_dir = mock_settings / "sweep_results"
|
145 | 136 | sweep_results_dir.mkdir(parents=True)
|
| 137 | + output_file = sweep_results_dir / "new_heart_ans.csv" |
146 | 138 |
|
147 | 139 | new_data = pd.DataFrame({
|
148 | 140 | 'Dataset_id': ['dataset1', 'dataset2'],
|
149 |
| - 'method1': ['url1', 'url2'], |
150 |
| - 'method1_best_yaml': ['yaml1', 'yaml2'], |
151 |
| - 'method1_best_res': [0.8, 0.9] |
| 141 | + 'cta_actinn': ['url1', 'url2'], |
| 142 | + 'cta_actinn_best_yaml': ['yaml1', 'yaml2'], |
| 143 | + 'cta_actinn_best_res': [0.8, 0.9] |
152 | 144 | })
|
153 | 145 |
|
154 | 146 | # 测试写入新文件
|
155 |
| - from examples.atlas.get_result_web import write_ans |
156 | 147 |
|
157 | 148 | # 验证文件被创建并包含正确的数据
|
158 |
| - output_file = sweep_results_dir / "heart_ans.csv" |
159 | 149 | write_ans("heart", new_data, output_file)
|
160 | 150 | assert output_file.exists()
|
161 | 151 |
|
|
0 commit comments