Skip to content

Commit 20baac2

Browse files
committed
Allow multi-proc env to be rendered.
1 parent 9d4ad1c commit 20baac2

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

lib/RLTrader.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -213,23 +213,28 @@ def test(self, model_epoch: int = 0, should_render: bool = True):
213213

214214
del train_provider
215215

216-
test_env = SubprocVecEnv([make_env(test_provider, i) for i in range(self.n_envs)])
216+
test_env = DummyVecEnv([make_env(test_provider, i) for i in range(1)])
217217

218218
model_path = path.join('data', 'agents', f'{self.study_name}__{model_epoch}.pkl')
219219
model = self.Model.load(model_path, env=test_env)
220220

221221
self.logger.info(f'Testing model ({self.study_name}__{model_epoch})')
222222

223+
zero_completed_obs = np.zeros((self.n_envs,) + test_env.observation_space.shape)
224+
zero_completed_obs[0, :] = test_env.reset()
225+
223226
state = None
224-
obs, rewards = test_env.reset(), []
227+
rewards = []
225228

226229
for _ in range(len(test_provider.data_frame)):
227-
action, state = model.predict(obs, state=state)
230+
action, state = model.predict(zero_completed_obs, state=state)
228231
obs, reward, _, __ = test_env.step(action)
229232

233+
zero_completed_obs[0, :] = obs
234+
230235
rewards.append(reward)
231236

232-
if should_render and self.n_envs == 1:
237+
if should_render:
233238
test_env.render(mode='human')
234239

235240
self.logger.info(

optimize.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def optimize_code(params):
1212

1313

1414
if __name__ == '__main__':
15-
n_process = multiprocessing.cpu_count() - 4
15+
n_process = multiprocessing.cpu_count()
1616
params = {}
1717

1818
processes = []

0 commit comments

Comments
 (0)