Skip to content

Commit b606c1f

Browse files
authored
Add rliable plots (#176)
* Start integrating rliable * Add proba of improvement plot * Make rliable optional and update normalization * Update doc * Add sample efficiency plot * Avoid for loop and prevent potential bug * Update titles and warn user * Ensure backward compat * Fix backward compat
1 parent 75afd65 commit b606c1f

7 files changed

+326
-2
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
## Release 1.2.1a0 (WIP)
1+
## Release 1.2.1a4 (WIP)
22

33
### Breaking Changes
44
- Upgrade to panda-gym 1.1.1
55

66
### New Features
7+
- Added support for using rliable for performance comparison
78

89
### Bug fixes
910
- Fix training with Dict obs and channel last images

README.md

+24
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,30 @@ Plot evaluation reward curve for TQC, SAC and TD3 on the HalfCheetah and Ant PyB
7979
python scripts/all_plots.py -a sac td3 tqc --env HalfCheetah Ant -f rl-trained-agents/
8080
```
8181

82+
## Plot with the rliable library
83+
84+
The RL zoo integrates some of [rliable](https://agarwl.github.io/rliable/) library features.
85+
86+
First, you need to install [rliable](https://github.com/google-research/rliable).
87+
88+
Note: Python 3.7+ is required in that case.
89+
90+
Then export your results to a file using the `all_plots.py` script (see above):
91+
```
92+
python scripts/all_plots.py -a sac td3 tqc --env Half Ant -f logs/ -o logs/offpolicy
93+
```
94+
95+
You can now use the `plot_from_file.py` script with `--rliable`, `--versus` and `--iqm` arguments:
96+
```
97+
python scripts/plot_from_file.py -i logs/offpolicy.pkl --skip-timesteps --rliable --versus -l SAC TD3 TQC
98+
```
99+
100+
Note: you may need to edit `plot_from_file.py`, in particular the `env_key_to_env_id` dictionary
101+
and the `scripts/score_normalization.py` which stores min and max score for each environment.
102+
103+
Remark: plotting with the `--rliable` option is usually slow as confidence interval need to be computed using bootstrap sampling.
104+
105+
82106
## Custom Environment
83107

84108
The easiest way to add support for a custom environment is to edit `utils/import_envs.py` and register your environment here. Then, you need to add a section for it in the hyperparameters file (`hyperparams/algo.yml`).

requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ cloudpickle>=1.5.0
1414
atari-py==0.2.6
1515
plotly
1616
panda-gym>=1.1.1
17+
# rliable requires python 3.7+
18+
# rliable>=1.0.5

scripts/all_plots.py

+2
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@
195195
"std_error": std_error,
196196
"last_evals": last_evals,
197197
"std_error_last_eval": std_error_last_eval,
198+
"mean_per_eval": mean_per_eval,
198199
}
199200

200201
plt.plot(timesteps / divider, mean_, label=f"{algo}-{args.labels[folder_idx]}", linewidth=3)
@@ -203,6 +204,7 @@
203204
plt.legend()
204205

205206

207+
# Markdown Table
206208
writer = pytablewriter.MarkdownTableWriter()
207209
writer.table_name = "results_table"
208210

scripts/plot_from_file.py

+185
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
import argparse
2+
import itertools
23
import pickle
4+
import warnings
35

46
import numpy as np
57
import pandas as pd
68
import pytablewriter
79
import seaborn
810
from matplotlib import pyplot as plt
911

12+
try:
13+
from rliable import library as rly # pytype: disable=import-error
14+
from rliable import metrics, plot_utils # pytype: disable=import-error
15+
except ImportError:
16+
rly = None
17+
18+
from score_normalization import normalize_score
19+
1020

1121
# From https://github.com/mwaskom/seaborn/blob/master/seaborn/categorical.py
1222
def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5):
@@ -42,6 +52,10 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
4252
parser.add_argument("--fontsize", help="Font size", type=int, default=14)
4353
parser.add_argument("-l", "--labels", help="Custom labels", type=str, nargs="+")
4454
parser.add_argument("-b", "--boxplot", help="Enable boxplot", action="store_true", default=False)
55+
parser.add_argument("-r", "--rliable", help="Enable rliable plots", action="store_true", default=False)
56+
parser.add_argument("-vs", "--versus", help="Enable probability of improvement plot", action="store_true", default=False)
57+
parser.add_argument("-iqm", "--iqm", help="Enable IQM sample efficiency plot", action="store_true", default=False)
58+
parser.add_argument("-ci", "--ci-size", help="Confidence interval size (for rliable)", type=float, default=0.95)
4559
parser.add_argument("-latex", "--latex", help="Enable latex support", action="store_true", default=False)
4660
parser.add_argument("--merge", help="Merge with other results files", nargs="+", default=[], type=str)
4761

@@ -132,21 +146,192 @@ def restyle_boxplot(artist_dict, color, gray="#222222", linewidth=1, fliersize=5
132146

133147
# Convert to pandas dataframe, in order to use seaborn
134148
labels_df, envs_df, scores = [], [], []
149+
# Post-process to use it with rliable
150+
# algo: (n_runs, n_envs)
151+
normalized_score_dict = {}
152+
# algo: (n_runs, n_envs, n_eval)
153+
all_eval_normalized_scores_dict = {}
154+
# Convert env key to env id for normalization
155+
env_key_to_env_id = {
156+
"Half": "HalfCheetahBulletEnv-v0",
157+
"Ant": "AntBulletEnv-v0",
158+
"Hopper": "HopperBulletEnv-v0",
159+
"Walker": "Walker2DBulletEnv-v0",
160+
}
161+
# Backward compat
162+
skip_all_algos_dict = False
163+
135164
for key in keys:
165+
algo_scores, all_algo_scores = [], []
136166
for env in envs:
137167
if isinstance(results[env][key]["last_evals"], (np.float32, np.float64)):
138168
# No enough timesteps
139169
print(f"Skipping {env}-{key}")
140170
continue
171+
141172
for score in results[env][key]["last_evals"]:
142173
labels_df.append(labels[key])
143174
# convert to int if needed
144175
# labels_df.append(int(labels[key]))
145176
envs_df.append(env)
146177
scores.append(score)
147178

179+
algo_scores.append(results[env][key]["last_evals"])
180+
181+
# Backward compat: mean_per_eval key may not be present
182+
if "mean_per_eval" in results[env][key]:
183+
all_algo_scores.append(results[env][key]["mean_per_eval"])
184+
else:
185+
skip_all_algos_dict = True
186+
187+
# Normalize score, env key must match env_id
188+
if env in env_key_to_env_id:
189+
algo_scores[-1] = normalize_score(algo_scores[-1], env_key_to_env_id[env])
190+
if not skip_all_algos_dict:
191+
all_algo_scores[-1] = normalize_score(all_algo_scores[-1], env_key_to_env_id[env])
192+
elif env not in env_key_to_env_id and args.rliable:
193+
warnings.warn(f"{env} not found for normalizing scores, you should update `env_key_to_env_id`")
194+
195+
# Truncate to convert to matrix
196+
min_runs = min([len(algo_score) for algo_score in algo_scores])
197+
if min_runs > 0:
198+
algo_scores = [algo_score[:min_runs] for algo_score in algo_scores]
199+
# shape: (n_envs, n_runs) -> (n_runs, n_envs)
200+
normalized_score_dict[labels[key]] = np.array(algo_scores).T
201+
if not skip_all_algos_dict:
202+
all_algo_scores = [all_algo_score[:, :min_runs] for all_algo_score in all_algo_scores]
203+
# (n_envs, n_eval, n_runs) -> (n_runs, n_envs, n_eval)
204+
all_eval_normalized_scores_dict[labels[key]] = np.array(all_algo_scores).transpose((2, 0, 1))
205+
148206
data_frame = pd.DataFrame(data=dict(Method=labels_df, Environment=envs_df, Score=scores))
149207

208+
# Rliable plots, see https://github.com/google-research/rliable
209+
if args.rliable:
210+
211+
if rly is None:
212+
raise ImportError("You must install rliable package to use this feature. Note: Python 3.7+ is required in that case.")
213+
214+
print("Computing bootstrap CI ...")
215+
algorithms = list(labels.values())
216+
# Scores as a dictionary mapping algorithms to their normalized
217+
# score matrices, each of which is of size `(num_runs x num_envs)`.
218+
219+
aggregate_func = lambda x: np.array( # noqa: E731
220+
[
221+
metrics.aggregate_median(x),
222+
metrics.aggregate_iqm(x),
223+
metrics.aggregate_mean(x),
224+
metrics.aggregate_optimality_gap(x),
225+
]
226+
)
227+
aggregate_scores, aggregate_interval_estimates = rly.get_interval_estimates(
228+
normalized_score_dict,
229+
aggregate_func,
230+
# Default was 50000
231+
reps=2000, # Number of bootstrap replications.
232+
confidence_interval_size=args.ci_size, # Coverage of confidence interval. Defaults to 95%.
233+
)
234+
235+
fig, axes = plot_utils.plot_interval_estimates(
236+
aggregate_scores,
237+
aggregate_interval_estimates,
238+
metric_names=["Median", "IQM", "Mean", "Optimality Gap"],
239+
algorithms=algorithms,
240+
xlabel="Normalized Score",
241+
xlabel_y_coordinate=0.02,
242+
subfigure_width=5,
243+
row_height=1,
244+
max_ticks=4,
245+
interval_height=0.6,
246+
)
247+
fig.canvas.manager.set_window_title("Rliable metrics")
248+
# Adjust margin to see the x label
249+
plt.tight_layout()
250+
plt.subplots_adjust(bottom=0.2)
251+
252+
# Performance profiles
253+
# Normalized score thresholds
254+
normalized_score_thresholds = np.linspace(0.0, 1.5, 50)
255+
score_distributions, score_distributions_cis = rly.create_performance_profile(
256+
normalized_score_dict,
257+
normalized_score_thresholds,
258+
reps=2000,
259+
confidence_interval_size=args.ci_size,
260+
)
261+
# Plot score distributions
262+
fig, ax = plt.subplots(ncols=1, figsize=(7, 5))
263+
plot_utils.plot_performance_profiles(
264+
score_distributions,
265+
normalized_score_thresholds,
266+
performance_profile_cis=score_distributions_cis,
267+
colors=dict(zip(algorithms, seaborn.color_palette("colorblind"))),
268+
xlabel=r"Normalized Score $(\tau)$",
269+
ax=ax,
270+
)
271+
fig.canvas.manager.set_window_title("Performance profiles")
272+
plt.legend()
273+
274+
# Probability of improvement
275+
# Scores as a dictionary containing pairs of normalized score
276+
# matrices for pairs of algorithms we want to compare
277+
algorithm_pairs_keys = itertools.combinations(algorithms, 2)
278+
# algorithm_pairs = {.. , 'x,y': (score_x, score_y), ..}
279+
algorithm_pairs = {}
280+
for algo1, algo2 in algorithm_pairs_keys:
281+
algorithm_pairs[f"{algo1}, {algo2}"] = (normalized_score_dict[algo1], normalized_score_dict[algo2])
282+
283+
if args.versus:
284+
average_probabilities, average_prob_cis = rly.get_interval_estimates(
285+
algorithm_pairs,
286+
metrics.probability_of_improvement,
287+
reps=1000, # Default was 50000
288+
confidence_interval_size=args.ci_size,
289+
)
290+
plot_utils.plot_probability_of_improvement(
291+
average_probabilities,
292+
average_prob_cis,
293+
figsize=(10, 8),
294+
interval_height=0.6,
295+
)
296+
plt.gcf().canvas.manager.set_window_title("Probability of Improvement")
297+
plt.tight_layout()
298+
299+
if args.iqm:
300+
# Load scores as a dictionary mapping algorithms to their normalized
301+
# score matrices across all evaluations, each of which is of size
302+
# `(n_runs, n_envs, n_eval)` where scores are recorded every n steps.
303+
# Only compute CI for 1/4 of the evaluations and keep the first and last eval
304+
downsample_factor = 4
305+
n_evals = all_eval_normalized_scores_dict[algorithms[0]].shape[-1]
306+
eval_indices = np.arange(n_evals - 1)[::downsample_factor]
307+
eval_indices = np.concatenate((eval_indices, [n_evals - 1]))
308+
eval_indices_scores_dict = {
309+
algorithm: score[:, :, eval_indices] for algorithm, score in all_eval_normalized_scores_dict.items()
310+
}
311+
iqm = lambda scores: np.array( # noqa: E731
312+
[metrics.aggregate_iqm(scores[..., eval_idx]) for eval_idx in range(scores.shape[-1])]
313+
)
314+
iqm_scores, iqm_cis = rly.get_interval_estimates(
315+
eval_indices_scores_dict,
316+
iqm,
317+
reps=2000,
318+
confidence_interval_size=args.ci_size,
319+
)
320+
plot_utils.plot_sample_efficiency_curve(
321+
eval_indices + 1,
322+
iqm_scores,
323+
iqm_cis,
324+
algorithms=algorithms,
325+
# TODO: convert to timesteps using the timesteps
326+
xlabel=r"Number of Evaluations",
327+
ylabel="IQM Normalized Score",
328+
)
329+
plt.gcf().canvas.manager.set_window_title("IQM Normalized Score - Sample Efficiency Curve")
330+
plt.legend()
331+
plt.tight_layout()
332+
333+
plt.show()
334+
150335
# Plot final results with env as x axis
151336
plt.figure("Sensitivity plot", figsize=args.figsize)
152337
plt.title("Sensitivity plot", fontsize=args.fontsize)

0 commit comments

Comments
 (0)