Skip to content

Commit 924df37

Browse files
committed
Fix a slurry of bugs..
1 parent d1d19de commit 924df37

File tree

6 files changed

+19
-12
lines changed

6 files changed

+19
-12
lines changed

cli.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def run_optimize(args, logger):
2020
from lib.RLTrader import RLTrader
2121

2222
trader = RLTrader(**vars(args), logger=logger, reward_strategy=reward_strategy)
23-
trader.optimize(n_trials=args.trials, n_prune_evals_per_trial=args.prune_evals, n_tests_per_eval=args.eval_tests)
23+
trader.optimize(n_trials=args.trials)
2424

2525

2626
if __name__ == '__main__':

data/params.db

28 KB
Binary file not shown.

lib/RLTrader.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ def optimize_params(self, trial, n_prune_evals_per_trial: int = 2, n_tests_per_e
185185

186186
return -1 * last_reward
187187

188-
def optimize(self, n_trials: int = 20, **optimize_params):
188+
def optimize(self, n_trials: int = 20):
189189
try:
190-
self.optuna_study.optimize(self.optimize_params, n_trials=n_trials, n_jobs=1, **optimize_params)
190+
self.optuna_study.optimize(self.optimize_params, n_trials=n_trials, n_jobs=1)
191191
except KeyboardInterrupt:
192192
pass
193193

@@ -278,7 +278,7 @@ def test(self, model_epoch: int = 0, render_env: bool = True, render_report: boo
278278
if done:
279279
net_worths = pd.DataFrame({
280280
'Date': info[0]['timestamps'],
281-
'Balance': info[0]['networths'],
281+
'Balance': info[0]['net_worths'],
282282
})
283283

284284
net_worths.set_index('Date', drop=True, inplace=True)

lib/env/TradingEnv.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def _get_trade(self, action: int):
9494

9595
def _take_action(self, action: int):
9696
amount_asset_to_buy, amount_asset_to_sell = self._get_trade(action)
97+
9798
asset_bought, asset_sold, purchase_cost, sale_revenue = self.trade_strategy.trade(buy_amount=amount_asset_to_buy,
9899
sell_amount=amount_asset_to_sell,
99100
balance=self.balance,
@@ -104,15 +105,20 @@ def _take_action(self, action: int):
104105
self.asset_held += asset_bought
105106
self.balance -= purchase_cost
106107

107-
self.trades.append({'step': self.current_step, 'amount': asset_bought,
108-
'total': purchase_cost, 'type': 'buy'})
108+
self.trades.append({'step': self.current_step,
109+
'amount': asset_bought,
110+
'total': purchase_cost,
111+
'type': 'buy'})
109112
elif asset_sold:
110113
self.asset_held -= asset_sold
111114
self.balance += sale_revenue
115+
112116
self.reward_strategy.reset_reward()
113117

114-
self.trades.append({'step': self.current_step, 'amount': asset_sold,
115-
'total': sale_revenue, 'type': 'sell'})
118+
self.trades.append({'step': self.current_step,
119+
'amount': asset_sold,
120+
'total': sale_revenue,
121+
'type': 'sell'})
116122

117123
current_net_worth = round(self.balance + self.asset_held * self._current_price(), self.base_precision)
118124
self.net_worths.append(current_net_worth)
@@ -132,7 +138,7 @@ def _done(self):
132138

133139
def _reward(self):
134140
reward = self.reward_strategy.get_reward(current_step=self.current_step,
135-
current_price=self._current_price(),
141+
current_price=self._current_price,
136142
observations=self.observations,
137143
account_history=self.account_history,
138144
net_worths=self.net_worths)
@@ -214,7 +220,8 @@ def step(self, action):
214220
obs = self._next_observation()
215221
reward = self._reward()
216222
done = self._done()
217-
return obs, reward, done, {'networths': self.net_worths, 'timestamps': self.timestamps}
223+
224+
return obs, reward, done, {'net_worths': self.net_worths, 'timestamps': self.timestamps}
218225

219226
def render(self, mode='human'):
220227
if mode == 'system':

lib/env/reward/WeightedUnrealizedProfit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ def get_reward(self,
3636
if account_history['asset_sold'].values[-1] > 0:
3737
reward = self.calc_reward(account_history['sale_revenue'].values[-1])
3838
else:
39-
reward = self.calc_reward(account_history['asset_held'].values[-1] * current_price)
39+
reward = self.calc_reward(account_history['asset_held'].values[-1] * current_price())
4040

4141
return reward

lib/env/trade/SimulatedTradeStrategy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def trade(self,
3030
commission = self.commissionPercent / 100
3131
slippage = np.random.uniform(0, self.maxSlippagePercent) / 100
3232

33-
asset_bought, asset_sold, purchase_cost, sale_revenue = 0, 0, 0, 0
33+
asset_bought, asset_sold, purchase_cost, sale_revenue = buy_amount, sell_amount, 0, 0
3434

3535
if buy_amount > 0 and balance >= self.min_cost_limit:
3636
price_adjustment = (1 + commission) * (1 + slippage)

0 commit comments

Comments
 (0)