|
1 | 1 | import argparse
|
| 2 | +import itertools |
2 | 3 | import pickle
|
| 4 | +import warnings |
3 | 5 |
|
4 | 6 | import numpy as np
|
5 | 7 | import pandas as pd
|
6 | 8 | import pytablewriter
|
7 | 9 | import seaborn
|
8 | 10 | from matplotlib import pyplot as plt
|
9 | 11 |
|
| 12 | +try: |
| 13 | + from rliable import library as rly # pytype: disable=import-error |
| 14 | + from rliable import metrics, plot_utils # pytype: disable=import-error |
| 15 | +except ImportError: |
| 16 | + rly = None |
| 17 | + |
| 18 | +from score_normalization import normalize_score |
| 19 | + |
10 | 20 |
|
11 | 21 | # From https://github.com/mwaskom/seaborn/blob/master/seaborn/categorical.py
|
12 | 22 | def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5):
|
@@ -42,6 +52,10 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
|
42 | 52 | parser.add_argument("--fontsize", help="Font size", type=int, default=14)
|
43 | 53 | parser.add_argument("-l", "--labels", help="Custom labels", type=str, nargs="+")
|
44 | 54 | parser.add_argument("-b", "--boxplot", help="Enable boxplot", action="store_true", default=False)
|
| 55 | +parser.add_argument("-r", "--rliable", help="Enable rliable plots", action="store_true", default=False) |
| 56 | +parser.add_argument("-vs", "--versus", help="Enable probability of improvement plot", action="store_true", default=False) |
| 57 | +parser.add_argument("-iqm", "--iqm", help="Enable IQM sample efficiency plot", action="store_true", default=False) |
| 58 | +parser.add_argument("-ci", "--ci-size", help="Confidence interval size (for rliable)", type=float, default=0.95) |
45 | 59 | parser.add_argument("-latex", "--latex", help="Enable latex support", action="store_true", default=False)
|
46 | 60 | parser.add_argument("--merge", help="Merge with other results files", nargs="+", default=[], type=str)
|
47 | 61 |
|
@@ -132,21 +146,192 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
|
132 | 146 |
|
133 | 147 | # Convert to pandas dataframe, in order to use seaborn
|
134 | 148 | labels_df, envs_df, scores = [], [], []
|
| 149 | +# Post-process to use it with rliable |
| 150 | +# algo: (n_runs, n_envs) |
| 151 | +normalized_score_dict = {} |
| 152 | +# algo: (n_runs, n_envs, n_eval) |
| 153 | +all_eval_normalized_scores_dict = {} |
| 154 | +# Convert env key to env id for normalization |
| 155 | +env_key_to_env_id = { |
| 156 | + "Half": "HalfCheetahBulletEnv-v0", |
| 157 | + "Ant": "AntBulletEnv-v0", |
| 158 | + "Hopper": "HopperBulletEnv-v0", |
| 159 | + "Walker": "Walker2DBulletEnv-v0", |
| 160 | +} |
| 161 | +# Backward compat |
| 162 | +skip_all_algos_dict = False |
| 163 | + |
135 | 164 | for key in keys:
|
| 165 | + algo_scores, all_algo_scores = [], [] |
136 | 166 | for env in envs:
|
137 | 167 | if isinstance(results[env][key]["last_evals"], (np.float32, np.float64)):
|
138 | 168 | # No enough timesteps
|
139 | 169 | print(f"Skipping {env}-{key}")
|
140 | 170 | continue
|
| 171 | + |
141 | 172 | for score in results[env][key]["last_evals"]:
|
142 | 173 | labels_df.append(labels[key])
|
143 | 174 | # convert to int if needed
|
144 | 175 | # labels_df.append(int(labels[key]))
|
145 | 176 | envs_df.append(env)
|
146 | 177 | scores.append(score)
|
147 | 178 |
|
| 179 | + algo_scores.append(results[env][key]["last_evals"]) |
| 180 | + |
| 181 | + # Backward compat: mean_per_eval key may not be present |
| 182 | + if "mean_per_eval" in results[env][key]: |
| 183 | + all_algo_scores.append(results[env][key]["mean_per_eval"]) |
| 184 | + else: |
| 185 | + skip_all_algos_dict = True |
| 186 | + |
| 187 | + # Normalize score, env key must match env_id |
| 188 | + if env in env_key_to_env_id: |
| 189 | + algo_scores[-1] = normalize_score(algo_scores[-1], env_key_to_env_id[env]) |
| 190 | + if not skip_all_algos_dict: |
| 191 | + all_algo_scores[-1] = normalize_score(all_algo_scores[-1], env_key_to_env_id[env]) |
| 192 | + elif env not in env_key_to_env_id and args.rliable: |
| 193 | + warnings.warn(f"{env} not found for normalizing scores, you should update `env_key_to_env_id`") |
| 194 | + |
| 195 | + # Truncate to convert to matrix |
| 196 | + min_runs = min([len(algo_score) for algo_score in algo_scores]) |
| 197 | + if min_runs > 0: |
| 198 | + algo_scores = [algo_score[:min_runs] for algo_score in algo_scores] |
| 199 | + # shape: (n_envs, n_runs) -> (n_runs, n_envs) |
| 200 | + normalized_score_dict[labels[key]] = np.array(algo_scores).T |
| 201 | + if not skip_all_algos_dict: |
| 202 | + all_algo_scores = [all_algo_score[:, :min_runs] for all_algo_score in all_algo_scores] |
| 203 | + # (n_envs, n_eval, n_runs) -> (n_runs, n_envs, n_eval) |
| 204 | + all_eval_normalized_scores_dict[labels[key]] = np.array(all_algo_scores).transpose((2, 0, 1)) |
| 205 | + |
148 | 206 | data_frame = pd.DataFrame(data=dict(Method=labels_df, Environment=envs_df, Score=scores))
|
149 | 207 |
|
| 208 | +# Rliable plots, see https://github.com/google-research/rliable |
| 209 | +if args.rliable: |
| 210 | + |
| 211 | + if rly is None: |
| 212 | + raise ImportError("You must install rliable package to use this feature. Note: Python 3.7+ is required in that case.") |
| 213 | + |
| 214 | + print("Computing bootstrap CI ...") |
| 215 | + algorithms = list(labels.values()) |
| 216 | + # Scores as a dictionary mapping algorithms to their normalized |
| 217 | + # score matrices, each of which is of size `(num_runs x num_envs)`. |
| 218 | + |
| 219 | + aggregate_func = lambda x: np.array( # noqa: E731 |
| 220 | + [ |
| 221 | + metrics.aggregate_median(x), |
| 222 | + metrics.aggregate_iqm(x), |
| 223 | + metrics.aggregate_mean(x), |
| 224 | + metrics.aggregate_optimality_gap(x), |
| 225 | + ] |
| 226 | + ) |
| 227 | + aggregate_scores, aggregate_interval_estimates = rly.get_interval_estimates( |
| 228 | + normalized_score_dict, |
| 229 | + aggregate_func, |
| 230 | + # Default was 50000 |
| 231 | + reps=2000, # Number of bootstrap replications. |
| 232 | + confidence_interval_size=args.ci_size, # Coverage of confidence interval. Defaults to 95%. |
| 233 | + ) |
| 234 | + |
| 235 | + fig, axes = plot_utils.plot_interval_estimates( |
| 236 | + aggregate_scores, |
| 237 | + aggregate_interval_estimates, |
| 238 | + metric_names=["Median", "IQM", "Mean", "Optimality Gap"], |
| 239 | + algorithms=algorithms, |
| 240 | + xlabel="Normalized Score", |
| 241 | + xlabel_y_coordinate=0.02, |
| 242 | + subfigure_width=5, |
| 243 | + row_height=1, |
| 244 | + max_ticks=4, |
| 245 | + interval_height=0.6, |
| 246 | + ) |
| 247 | + fig.canvas.manager.set_window_title("Rliable metrics") |
| 248 | + # Adjust margin to see the x label |
| 249 | + plt.tight_layout() |
| 250 | + plt.subplots_adjust(bottom=0.2) |
| 251 | + |
| 252 | + # Performance profiles |
| 253 | + # Normalized score thresholds |
| 254 | + normalized_score_thresholds = np.linspace(0.0, 1.5, 50) |
| 255 | + score_distributions, score_distributions_cis = rly.create_performance_profile( |
| 256 | + normalized_score_dict, |
| 257 | + normalized_score_thresholds, |
| 258 | + reps=2000, |
| 259 | + confidence_interval_size=args.ci_size, |
| 260 | + ) |
| 261 | + # Plot score distributions |
| 262 | + fig, ax = plt.subplots(ncols=1, figsize=(7, 5)) |
| 263 | + plot_utils.plot_performance_profiles( |
| 264 | + score_distributions, |
| 265 | + normalized_score_thresholds, |
| 266 | + performance_profile_cis=score_distributions_cis, |
| 267 | + colors=dict(zip(algorithms, seaborn.color_palette("colorblind"))), |
| 268 | + xlabel=r"Normalized Score $(\tau)$", |
| 269 | + ax=ax, |
| 270 | + ) |
| 271 | + fig.canvas.manager.set_window_title("Performance profiles") |
| 272 | + plt.legend() |
| 273 | + |
| 274 | + # Probability of improvement |
| 275 | + # Scores as a dictionary containing pairs of normalized score |
| 276 | + # matrices for pairs of algorithms we want to compare |
| 277 | + algorithm_pairs_keys = itertools.combinations(algorithms, 2) |
| 278 | + # algorithm_pairs = {.. , 'x,y': (score_x, score_y), ..} |
| 279 | + algorithm_pairs = {} |
| 280 | + for algo1, algo2 in algorithm_pairs_keys: |
| 281 | + algorithm_pairs[f"{algo1}, {algo2}"] = (normalized_score_dict[algo1], normalized_score_dict[algo2]) |
| 282 | + |
| 283 | + if args.versus: |
| 284 | + average_probabilities, average_prob_cis = rly.get_interval_estimates( |
| 285 | + algorithm_pairs, |
| 286 | + metrics.probability_of_improvement, |
| 287 | + reps=1000, # Default was 50000 |
| 288 | + confidence_interval_size=args.ci_size, |
| 289 | + ) |
| 290 | + plot_utils.plot_probability_of_improvement( |
| 291 | + average_probabilities, |
| 292 | + average_prob_cis, |
| 293 | + figsize=(10, 8), |
| 294 | + interval_height=0.6, |
| 295 | + ) |
| 296 | + plt.gcf().canvas.manager.set_window_title("Probability of Improvement") |
| 297 | + plt.tight_layout() |
| 298 | + |
| 299 | + if args.iqm: |
| 300 | + # Load scores as a dictionary mapping algorithms to their normalized |
| 301 | + # score matrices across all evaluations, each of which is of size |
| 302 | + # `(n_runs, n_envs, n_eval)` where scores are recorded every n steps. |
| 303 | + # Only compute CI for 1/4 of the evaluations and keep the first and last eval |
| 304 | + downsample_factor = 4 |
| 305 | + n_evals = all_eval_normalized_scores_dict[algorithms[0]].shape[-1] |
| 306 | + eval_indices = np.arange(n_evals - 1)[::downsample_factor] |
| 307 | + eval_indices = np.concatenate((eval_indices, [n_evals - 1])) |
| 308 | + eval_indices_scores_dict = { |
| 309 | + algorithm: score[:, :, eval_indices] for algorithm, score in all_eval_normalized_scores_dict.items() |
| 310 | + } |
| 311 | + iqm = lambda scores: np.array( # noqa: E731 |
| 312 | + [metrics.aggregate_iqm(scores[..., eval_idx]) for eval_idx in range(scores.shape[-1])] |
| 313 | + ) |
| 314 | + iqm_scores, iqm_cis = rly.get_interval_estimates( |
| 315 | + eval_indices_scores_dict, |
| 316 | + iqm, |
| 317 | + reps=2000, |
| 318 | + confidence_interval_size=args.ci_size, |
| 319 | + ) |
| 320 | + plot_utils.plot_sample_efficiency_curve( |
| 321 | + eval_indices + 1, |
| 322 | + iqm_scores, |
| 323 | + iqm_cis, |
| 324 | + algorithms=algorithms, |
| 325 | + # TODO: convert to timesteps using the timesteps |
| 326 | + xlabel=r"Number of Evaluations", |
| 327 | + ylabel="IQM Normalized Score", |
| 328 | + ) |
| 329 | + plt.gcf().canvas.manager.set_window_title("IQM Normalized Score - Sample Efficiency Curve") |
| 330 | + plt.legend() |
| 331 | + plt.tight_layout() |
| 332 | + |
| 333 | + plt.show() |
| 334 | + |
150 | 335 | # Plot final results with env as x axis
|
151 | 336 | plt.figure("Sensitivity plot", figsize=args.figsize)
|
152 | 337 | plt.title("Sensitivity plot", fontsize=args.fontsize)
|
|
0 commit comments