Skip to content

Commit 4232f9d

Browse files
npitaraffin
andauthored
Rename the observations variable in the evaluation util to avoid shadowing (#1288)
* Rename the observations variable in the evaluation util to avoid shadowing This enables a callback in evaluate_policy to have access to the observation vector that is fed to the environment step function, which is currently shadowed by the output observation. * Update changelog * Add test * Move assignment outside of the loop --------- Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
1 parent 84f5511 commit 4232f9d

File tree

4 files changed

+37
-3
lines changed

4 files changed

+37
-3
lines changed

docs/misc/changelog.rst

+29-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,34 @@
33
Changelog
44
==========
55

6+
Release 1.8.1a0 (WIP)
7+
--------------------------
8+
9+
Breaking Changes:
10+
^^^^^^^^^^^^^^^^^
11+
- Renamed environment output observations in ``evaluate_policy`` to prevent shadowing the input observations during callbacks (@npit)
12+
13+
New Features:
14+
^^^^^^^^^^^^^
15+
16+
`SB3-Contrib`_
17+
^^^^^^^^^^^^^^
18+
19+
`RL Zoo`_
20+
^^^^^^^^^
21+
22+
Bug Fixes:
23+
^^^^^^^^^^
24+
25+
Deprecations:
26+
^^^^^^^^^^^^^
27+
28+
Others:
29+
^^^^^^^
30+
31+
Documentation:
32+
^^^^^^^^^^^^^^
33+
634

735
Release 1.8.0 (2023-04-07)
836
--------------------------
@@ -1271,4 +1299,4 @@ And all the contributors:
12711299
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
12721300
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi
12731301
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
1274-
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher
1302+
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit

stable_baselines3/common/evaluation.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def evaluate_policy(
8686
episode_starts = np.ones((env.num_envs,), dtype=bool)
8787
while (episode_counts < episode_count_targets).any():
8888
actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic)
89-
observations, rewards, dones, infos = env.step(actions)
89+
new_observations, rewards, dones, infos = env.step(actions)
9090
current_rewards += rewards
9191
current_lengths += 1
9292
for i in range(n_envs):
@@ -120,6 +120,8 @@ def evaluate_policy(
120120
current_rewards[i] = 0
121121
current_lengths[i] = 0
122122

123+
observations = new_observations
124+
123125
if render:
124126
env.render()
125127

stable_baselines3/version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.8.0
1+
1.8.1a0

tests/test_utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,10 @@ def test_evaluate_policy(direct_policy: bool):
183183

184184
def dummy_callback(locals_, _globals):
185185
locals_["model"].n_callback_calls += 1
186+
assert "observations" in locals_
187+
assert "new_observations" in locals_
188+
assert locals_["new_observations"] is not locals_["observations"]
189+
assert not np.allclose(locals_["new_observations"], locals_["observations"])
186190

187191
assert model.policy is not None
188192
policy = model.policy if direct_policy else model

0 commit comments

Comments
 (0)