Skip to content

Commit e7a6db8

Browse files
authored
Merge pull request freqtrade#10173 from freqtrade/fix/mutable_defaults
Fix mutable defaults, enable bugbear ruff rule also for freqAI code
2 parents cdb7fa8 + 717f17a commit e7a6db8

7 files changed

+29
-20
lines changed

freqtrade/freqai/RL/BaseEnvironment.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,20 @@ class BaseEnvironment(gym.Env):
4646

4747
def __init__(
4848
self,
49-
df: DataFrame = DataFrame(),
50-
prices: DataFrame = DataFrame(),
51-
reward_kwargs: dict = {},
49+
*,
50+
df: DataFrame,
51+
prices: DataFrame,
52+
reward_kwargs: dict,
5253
window_size=10,
5354
starting_point=True,
5455
id: str = "baseenv-1", # noqa: A002
5556
seed: int = 1,
56-
config: dict = {},
57+
config: dict,
5758
live: bool = False,
5859
fee: float = 0.0015,
5960
can_short: bool = False,
6061
pair: str = "",
61-
df_raw: DataFrame = DataFrame(),
62+
df_raw: DataFrame,
6263
):
6364
"""
6465
Initializes the training/eval environment.

freqtrade/freqai/RL/BaseReinforcementLearningModel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def make_env(
488488
seed: int,
489489
train_df: DataFrame,
490490
price: DataFrame,
491-
env_info: dict[str, Any] = {},
491+
env_info: dict[str, Any],
492492
) -> Callable:
493493
"""
494494
Utility function for multiprocessed env.

freqtrade/freqai/data_kitchen.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def filter_features(
214214
self,
215215
unfiltered_df: DataFrame,
216216
training_feature_list: list,
217-
label_list: list = list(),
217+
label_list: list | None = None,
218218
training_filter: bool = True,
219219
) -> tuple[DataFrame, DataFrame]:
220220
"""
@@ -244,7 +244,7 @@ def filter_features(
244244
# we don't care about total row number (total no. datapoints) in training, we only care
245245
# about removing any row with NaNs
246246
# if labels has multiple columns (user wants to train multiple modelEs), we detect here
247-
labels = unfiltered_df.filter(label_list, axis=1)
247+
labels = unfiltered_df.filter(label_list or [], axis=1)
248248
drop_index_labels = pd.isnull(labels).any(axis=1)
249249
drop_index_labels = (
250250
drop_index_labels.replace(True, 1).replace(False, 0).infer_objects(copy=False)
@@ -654,8 +654,8 @@ def get_pair_data_for_features(
654654
pair: str,
655655
tf: str,
656656
strategy: IStrategy,
657-
corr_dataframes: dict = {},
658-
base_dataframes: dict = {},
657+
corr_dataframes: dict,
658+
base_dataframes: dict,
659659
is_corr_pairs: bool = False,
660660
) -> DataFrame:
661661
"""
@@ -773,10 +773,10 @@ def populate_features(
773773
def use_strategy_to_populate_indicators( # noqa: C901
774774
self,
775775
strategy: IStrategy,
776-
corr_dataframes: dict = {},
777-
base_dataframes: dict = {},
776+
corr_dataframes: dict[str, DataFrame] | None = None,
777+
base_dataframes: dict[str, dict[str, DataFrame]] | None = None,
778778
pair: str = "",
779-
prediction_dataframe: DataFrame = pd.DataFrame(),
779+
prediction_dataframe: DataFrame | None = None,
780780
do_corr_pairs: bool = True,
781781
) -> DataFrame:
782782
"""
@@ -793,6 +793,10 @@ def use_strategy_to_populate_indicators( # noqa: C901
793793
:return:
794794
dataframe: DataFrame = dataframe containing populated indicators
795795
"""
796+
if not corr_dataframes:
797+
corr_dataframes = {}
798+
if not base_dataframes:
799+
base_dataframes = {}
796800

797801
# check if the user is using the deprecated populate_any_indicators function
798802
new_version = inspect.getsource(strategy.populate_any_indicators) == (
@@ -822,7 +826,7 @@ def use_strategy_to_populate_indicators( # noqa: C901
822826
if tf not in corr_dataframes[p]:
823827
corr_dataframes[p][tf] = pd.DataFrame()
824828

825-
if not prediction_dataframe.empty:
829+
if prediction_dataframe is not None and not prediction_dataframe.empty:
826830
dataframe = prediction_dataframe.copy()
827831
base_dataframes[self.config["timeframe"]] = dataframe.copy()
828832
else:

freqtrade/freqai/freqai_interface.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ def extract_data_and_train_model(
618618
)
619619

620620
unfiltered_dataframe = dk.use_strategy_to_populate_indicators(
621-
strategy, corr_dataframes, base_dataframes, pair
621+
strategy, corr_dataframes=corr_dataframes, base_dataframes=base_dataframes, pair=pair
622622
)
623623

624624
trained_timestamp = new_trained_timerange.stopts

freqtrade/freqai/torch/PyTorchModelTrainer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(
2525
criterion: nn.Module,
2626
device: str,
2727
data_convertor: PyTorchDataConvertor,
28-
model_meta_data: dict[str, Any] = {},
28+
model_meta_data: dict[str, Any] | None = None,
2929
window_size: int = 1,
3030
tb_logger: Any = None,
3131
**kwargs,
@@ -45,6 +45,8 @@ def __init__(
4545
:param n_epochs: The maximum number batches to use for evaluation.
4646
:param batch_size: The size of the batches to use during training.
4747
"""
48+
if model_meta_data is None:
49+
model_meta_data = {}
4850
self.model = model
4951
self.optimizer = optimizer
5052
self.criterion = criterion

pyproject.toml

-2
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,6 @@ max-complexity = 12
287287
[tool.ruff.lint.per-file-ignores]
288288
"freqtrade/freqai/**/*.py" = [
289289
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
290-
"B006", # Bugbear - mutable default argument
291-
"B008", # bugbear - Do not perform function calls in argument defaults
292290
]
293291
"tests/**/*.py" = [
294292
"S101", # allow assert in tests

tests/freqai/test_freqai_datakitchen.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,9 @@ def test_get_pair_data_for_features_with_prealoaded_data(mocker, freqai_conf):
150150
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
151151

152152
_, base_df = freqai.dd.get_base_and_corr_dataframes(timerange, "LTC/BTC", freqai.dk)
153-
df = freqai.dk.get_pair_data_for_features("LTC/BTC", "5m", strategy, base_dataframes=base_df)
153+
df = freqai.dk.get_pair_data_for_features(
154+
"LTC/BTC", "5m", strategy, {}, base_dataframes=base_df
155+
)
154156

155157
assert df is base_df["5m"]
156158
assert not df.empty
@@ -170,7 +172,9 @@ def test_get_pair_data_for_features_without_preloaded_data(mocker, freqai_conf):
170172
freqai.dd.load_all_pair_histories(timerange, freqai.dk)
171173

172174
base_df = {"5m": pd.DataFrame()}
173-
df = freqai.dk.get_pair_data_for_features("LTC/BTC", "5m", strategy, base_dataframes=base_df)
175+
df = freqai.dk.get_pair_data_for_features(
176+
"LTC/BTC", "5m", strategy, {}, base_dataframes=base_df
177+
)
174178

175179
assert df is not base_df["5m"]
176180
assert not df.empty

0 commit comments

Comments
 (0)