Skip to content

Commit ba8dda7

Browse files
committed
Fix rendering of multi-process test environment
1 parent 8e7de98 commit ba8dda7

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

lib/RLTrader.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -230,22 +230,24 @@ def test(self, model_epoch: int = 0, should_render: bool = True):
230230

231231
del train_provider
232232

233-
test_env = DummyVecEnv([make_env(test_provider, i) for i in range(1)])
233+
init_envs = DummyVecEnv([make_env(test_provider) for _ in range(self.n_envs)])
234234

235235
model_path = path.join('data', 'agents', f'{self.study_name}__{model_epoch}.pkl')
236-
model = self.Model.load(model_path, env=test_env)
236+
model = self.Model.load(model_path, env=init_envs)
237+
238+
test_env = DummyVecEnv([make_env(test_provider) for _ in range(1)])
237239

238240
self.logger.info(f'Testing model ({self.study_name}__{model_epoch})')
239241

240-
zero_completed_obs = np.zeros((self.n_envs,) + test_env.observation_space.shape)
242+
zero_completed_obs = np.zeros((self.n_envs,) + init_envs.observation_space.shape)
241243
zero_completed_obs[0, :] = test_env.reset()
242244

243245
state = None
244246
rewards = []
245247

246248
for _ in range(len(test_provider.data_frame)):
247249
action, state = model.predict(zero_completed_obs, state=state)
248-
obs, reward, _, __ = test_env.step([action])
250+
obs, reward, _, __ = test_env.step([action[0]])
249251

250252
zero_completed_obs[0, :] = obs
251253

0 commit comments

Comments
 (0)