|
1 | 1 | import numpy as np
|
2 |
| -import multiprocessing |
| 2 | + |
| 3 | +from multiprocessing.pool import ThreadPool |
3 | 4 |
|
4 | 5 | from lib.RLTrader import RLTrader
|
5 | 6 | from lib.cli.RLTraderCLI import RLTraderCLI
|
|
11 | 12 | args = trader_cli.get_args()
|
12 | 13 |
|
13 | 14 |
|
14 |
| -def run_concurrent_optimize(): |
15 |
| - trader = RLTrader(**vars(args)) |
16 |
| - trader.optimize(args.trials) |
17 |
| - |
| 15 | +def run_optimize(params): |
| 16 | + trader_args, logger = params |
18 | 17 |
|
19 |
| -def concurrent_optimize(): |
20 |
| - processes = [] |
21 |
| - for i in range(args.parallel_jobs): |
22 |
| - processes.append(multiprocessing.Process(target=run_concurrent_optimize, args=())) |
| 18 | + trader = RLTrader(**vars(trader_args), logger=logger) |
| 19 | + trader.optimize(trader_args.trials) |
23 | 20 |
|
24 |
| - print(processes) |
25 | 21 |
|
26 |
| - for p in processes: |
27 |
| - p.start() |
| 22 | +def optimize_concurrent(trader_args, logger): |
| 23 | + n_processes = trader_args.parallel_jobs |
28 | 24 |
|
29 |
| - for p in processes: |
30 |
| - p.join() |
| 25 | + opt_pool = ThreadPool(processes=n_processes) |
| 26 | + opt_pool.map(run_optimize, [((trader_args, logger)) for _ in range(n_processes)]) |
31 | 27 |
|
32 | 28 |
|
33 | 29 | if __name__ == '__main__':
|
34 | 30 | logger = init_logger(__name__, show_debug=args.debug)
|
35 |
| - trader = RLTrader(**vars(args), logger=logger) |
36 | 31 |
|
37 | 32 | if args.command == 'optimize':
|
38 |
| - concurrent_optimize() |
39 |
| - elif args.command == 'train': |
| 33 | + optimize_concurrent(args, logger) |
| 34 | + |
| 35 | + trader = RLTrader(**vars(args), logger=logger) |
| 36 | + |
| 37 | + if args.command == 'train': |
40 | 38 | trader.train(n_epochs=args.epochs)
|
41 | 39 | elif args.command == 'test':
|
42 | 40 | trader.test(model_epoch=args.model_epoch, should_render=args.no_render)
|
|
0 commit comments