Skip to content

Commit 9d4ad1c

Browse files
committed
Replace deco with standard multiprocessing pool
1 parent afacd02 commit 9d4ad1c

6 files changed

+26
-71
lines changed

README.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,17 @@ python ./cli.py --help
6363
or simply run the project with default options:
6464

6565
```bash
66-
python ./cli.py optimize-train-test
66+
python ./cli.py optimize
6767
```
6868

6969
If you have a standard set of configs you want to run the trader against, you can specify a config file to load configuration from. Rename config/config.ini.dist to config/config.ini and run
7070

7171
```bash
72-
python ./cli.py --from-config config/config.ini optimize-train-test
72+
python ./cli.py --from-config config/config.ini optimize
7373
```
7474

7575
```bash
76-
python ./cli.py optimize-train-test
76+
python ./cli.py optimize
7777
```
7878

7979
### Testing with vagrant
@@ -92,7 +92,7 @@ Note: With vagrant you cannot take full advantage of your GPU, so is mainly for
9292
If you want to run everything within a docker container, then just use:
9393

9494
```bash
95-
./run-with-docker (cpu|gpu) (yes|no) optimize-train-test
95+
./run-with-docker (cpu|gpu) (yes|no) optimize
9696
```
9797

9898
- cpu - start the container using CPU requirements
@@ -101,7 +101,7 @@ If you want to run everything within a docker container, then just use:
101101
Note: in case using yes as second argument, use
102102

103103
```bash
104-
python ./ cli.py --params-db-path "postgres://rl_trader:rl_trader@localhost" optimize-train-test
104+
python ./ cli.py --params-db-path "postgres://rl_trader:rl_trader@localhost" optimize
105105
```
106106

107107
The database and it's data are pesisted under `data/postgres` locally.

cli.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
2-
3-
from deco import concurrent
2+
import multiprocessing
43

54
from lib.RLTrader import RLTrader
65
from lib.cli.RLTraderCLI import RLTraderCLI
@@ -12,27 +11,34 @@
1211
args = trader_cli.get_args()
1312

1413

15-
@concurrent(processes=args.parallel_jobs)
16-
def run_concurrent_optimize(trader: RLTrader, args):
17-
trader.optimize(args.trials, args.trials, args.parallel_jobs)
14+
def run_concurrent_optimize():
15+
trader = RLTrader(**vars(args))
16+
trader.optimize(args.trials)
17+
18+
19+
def concurrent_optimize():
20+
processes = []
21+
for i in range(args.parallel_jobs):
22+
processes.append(multiprocessing.Process(target=run_concurrent_optimize, args=()))
23+
24+
print(processes)
25+
26+
for p in processes:
27+
p.start()
28+
29+
for p in processes:
30+
p.join()
1831

1932

2033
if __name__ == '__main__':
2134
logger = init_logger(__name__, show_debug=args.debug)
2235
trader = RLTrader(**vars(args), logger=logger)
2336

2437
if args.command == 'optimize':
25-
run_concurrent_optimize(trader, args)
38+
concurrent_optimize()
2639
elif args.command == 'train':
2740
trader.train(n_epochs=args.epochs)
2841
elif args.command == 'test':
2942
trader.test(model_epoch=args.model_epoch, should_render=args.no_render)
30-
elif args.command == 'optimize-train-test':
31-
run_concurrent_optimize(trader, args)
32-
trader.train(
33-
n_epochs=args.train_epochs,
34-
test_trained_model=args.no_test,
35-
render_trained_model=args.no_render
36-
)
3743
elif args.command == 'update-static-data':
3844
download_data_async()

lib/RLTrader.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from os import path
66
from typing import Dict
77

8-
from deco import concurrent
98
from stable_baselines.common.base_class import BaseRLModel
10-
from stable_baselines.common.policies import BasePolicy, MlpPolicy
9+
from stable_baselines.common.policies import BasePolicy, MlpLnLstmPolicy
1110
from stable_baselines.common.vec_env import DummyVecEnv, SubprocVecEnv
1211
from stable_baselines.common import set_global_seeds
1312
from stable_baselines import PPO2
@@ -31,7 +30,7 @@ class RLTrader:
3130
data_provider = None
3231
study_name = None
3332

34-
def __init__(self, modelClass: BaseRLModel = PPO2, policyClass: BasePolicy = MlpPolicy, exchange_args: Dict = {}, **kwargs):
33+
def __init__(self, modelClass: BaseRLModel = PPO2, policyClass: BasePolicy = MlpLnLstmPolicy, exchange_args: Dict = {}, **kwargs):
3534
self.logger = kwargs.get('logger', init_logger(__name__, show_debug=kwargs.get('show_debug', True)))
3635

3736
self.Model = modelClass
@@ -162,7 +161,6 @@ def optimize_params(self, trial, n_prune_evals_per_trial: int = 2, n_tests_per_e
162161

163162
return -1 * last_reward
164163

165-
@concurrent
166164
def optimize(self, n_trials: int = 100, n_parallel_jobs: int = 1, *optimize_params):
167165
try:
168166
self.optuna_study.optimize(

lib/cli/RLTraderCLI.py

-6
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,6 @@ def __init__(self):
4444

4545
subparsers = self.parser.add_subparsers(help='Command', dest="command")
4646

47-
opt_train_test_parser = subparsers.add_parser('optimize-train-test', description='Optimize train and test')
48-
opt_train_test_parser.add_argument('--trials', type=int, default=20, help='Number of trials')
49-
opt_train_test_parser.add_argument('--train-epochs', type=int, default=10, help='Train for how many epochs')
50-
opt_train_test_parser.add_argument('--no-render', action='store_false', help='Should render the model')
51-
opt_train_test_parser.add_argument('--no-test', action='store_false', help='Should test the model')
52-
5347
optimize_parser = subparsers.add_parser('optimize', description='Optimize model parameters')
5448
optimize_parser.add_argument('--trials', type=int, default=1, help='Number of trials')
5549

requirements.base.txt

-1
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,4 @@ statsmodels==0.10.0rc2
1010
empyrical
1111
ccxt
1212
psycopg2
13-
deco
1413
configparser

update_data.py

-42
This file was deleted.

0 commit comments

Comments
 (0)