Skip to content

Commit 396bdf8

Browse files
committed
Divergence heatmap tweak: more s.f. where figure can fit it
1 parent 728bb1b commit 396bdf8

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

src/evaluating_rewards/analysis/plot_divergence_heatmap.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
"""CLI script to plot heatmap of divergence between pairs of reward models."""
1616

17+
import functools
1718
import itertools
1819
import os
1920
from typing import Any, Iterable, Mapping, Optional
@@ -189,6 +190,7 @@ def hopper():
189190
for prefix, suffix in itertools.product(activities, MUJOCO_STANDARD_ORDER)
190191
]
191192
heatmap_kwargs["after_plot"] = horizontal_ticks
193+
heatmap_kwargs["fmt"] = functools.partial(visualize.short_e, precision=0)
192194
del activities
193195

194196

src/evaluating_rewards/analysis/visualize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def compute_mask(
349349
return res
350350

351351

352-
short_fmt = functools.partial(short_e, precision=0)
352+
short_fmt = functools.partial(short_e, precision=1)
353353

354354

355355
def compact_heatmaps(

0 commit comments

Comments
 (0)