Skip to content

Commit 8ffbda0

Browse files
committed
chore: lint
1 parent 3800b89 commit 8ffbda0

File tree

4 files changed

+17
-12
lines changed

4 files changed

+17
-12
lines changed

src/axolotl/core/trainers/base.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -464,9 +464,9 @@ def get_train_dataloader(self) -> DataLoader:
464464
"pin_memory": self.args.dataloader_pin_memory,
465465
}
466466
if self.args.dataloader_prefetch_factor:
467-
dataloader_params["prefetch_factor"] = (
468-
self.args.dataloader_prefetch_factor
469-
)
467+
dataloader_params[
468+
"prefetch_factor"
469+
] = self.args.dataloader_prefetch_factor
470470

471471
sampler = self._get_train_sampler()
472472
if isinstance(sampler, BatchSampler):
@@ -511,9 +511,9 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
511511
"pin_memory": self.args.dataloader_pin_memory,
512512
}
513513
if self.args.dataloader_prefetch_factor:
514-
dataloader_params["prefetch_factor"] = (
515-
self.args.dataloader_prefetch_factor
516-
)
514+
dataloader_params[
515+
"prefetch_factor"
516+
] = self.args.dataloader_prefetch_factor
517517

518518
if isinstance(eval_sampler, BatchSampler):
519519
dataloader_params["batch_sampler"] = eval_sampler

src/axolotl/core/training_args.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ class AxolotlTrainingMixins:
3434
)
3535
sample_packing_sequentially: bool = field(
3636
default=False,
37-
metadata={"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."},
37+
metadata={
38+
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
39+
},
3840
)
3941
multipack_real_batches: bool = field(
4042
default=False,

src/axolotl/utils/samplers/multipack.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
124124

125125
# First, do sequential packing into bins
126126
all_bins = []
127-
current_bin = [0 for i in range(0)] # numba hint
127+
current_bin = [0 for i in range(0)] # numba hint
128128
remaining_capacity = c
129129

130130
for idx, size in enumerate(lengths):
@@ -190,8 +190,9 @@ def __init__(
190190
self.len_across_ranks = None
191191

192192
if self.sequential and not isinstance(sampler, SequentialSampler):
193-
LOG.warn("using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?")
194-
193+
LOG.warn(
194+
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
195+
)
195196

196197
def set_epoch(self, epoch: int):
197198
self.epoch = epoch

tests/test_packed_batch_sampler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@ class TestBatchedSamplerPacking:
3535
)
3636
@pytest.mark.parametrize("max_seq_length", [4096, 512])
3737
@pytest.mark.parametrize("sequential", [True, False])
38-
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length, sequential):
38+
def test_packing(
39+
self, batch_size, num_workers, tokenizer, max_seq_length, sequential
40+
):
3941
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
4042

4143
dataset = load_dataset(
@@ -71,7 +73,7 @@ def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length, seque
7173
batch_max_len=max_seq_length,
7274
group_size=100000,
7375
bin_size=200,
74-
sequential=sequential
76+
sequential=sequential,
7577
)
7678

7779
loader = DataLoader(

0 commit comments

Comments
 (0)