Skip to content

Commit a2b8610

Browse files
committed
Fix TF hanging issue in multiprocessing pools.
1 parent e8c8eab commit a2b8610

File tree

2 files changed

+22
-29
lines changed

2 files changed

+22
-29
lines changed

cli.py

+15-17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
2-
import multiprocessing
2+
3+
from multiprocessing.pool import ThreadPool
34

45
from lib.RLTrader import RLTrader
56
from lib.cli.RLTraderCLI import RLTraderCLI
@@ -11,32 +12,29 @@
1112
args = trader_cli.get_args()
1213

1314

14-
def run_concurrent_optimize():
15-
trader = RLTrader(**vars(args))
16-
trader.optimize(args.trials)
17-
15+
def run_optimize(params):
16+
trader_args, logger = params
1817

19-
def concurrent_optimize():
20-
processes = []
21-
for i in range(args.parallel_jobs):
22-
processes.append(multiprocessing.Process(target=run_concurrent_optimize, args=()))
18+
trader = RLTrader(**vars(trader_args), logger=logger)
19+
trader.optimize(trader_args.trials)
2320

24-
print(processes)
2521

26-
for p in processes:
27-
p.start()
22+
def optimize_concurrent(trader_args, logger):
23+
n_processes = trader_args.parallel_jobs
2824

29-
for p in processes:
30-
p.join()
25+
opt_pool = ThreadPool(processes=n_processes)
26+
opt_pool.map(run_optimize, [((trader_args, logger)) for _ in range(n_processes)])
3127

3228

3329
if __name__ == '__main__':
3430
logger = init_logger(__name__, show_debug=args.debug)
35-
trader = RLTrader(**vars(args), logger=logger)
3631

3732
if args.command == 'optimize':
38-
concurrent_optimize()
39-
elif args.command == 'train':
33+
optimize_concurrent(args, logger)
34+
35+
trader = RLTrader(**vars(args), logger=logger)
36+
37+
if args.command == 'train':
4038
trader.train(n_epochs=args.epochs)
4139
elif args.command == 'test':
4240
trader.test(model_epoch=args.model_epoch, should_render=args.no_render)

optimize.py

+7-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
import multiprocessing
1+
import os
22
import numpy as np
33

4+
from multiprocessing.pool import ThreadPool
5+
46
from lib.RLTrader import RLTrader
57

68
np.warnings.filterwarnings('ignore')
@@ -12,18 +14,11 @@ def optimize_code(params):
1214

1315

1416
if __name__ == '__main__':
15-
n_process = multiprocessing.cpu_count()
16-
params = {'n_envs': n_process}
17-
18-
processes = []
19-
for i in range(n_process):
20-
processes.append(multiprocessing.Process(target=optimize_code, args=(params,)))
21-
22-
for p in processes:
23-
p.start()
17+
n_processes = 6 # os.cpu_count()
18+
params = {'n_envs': n_processes}
2419

25-
for p in processes:
26-
p.join()
20+
opt_pool = ThreadPool(processes=n_processes)
21+
opt_pool.map(optimize_code, [params for _ in range(n_processes)])
2722

2823
trader = RLTrader(**params)
2924
trader.train(test_trained_model=True, render_trained_model=True)

0 commit comments

Comments
 (0)