Skip to content

Commit c16c72c

Browse files
pytholicrasbt
andauthored
[fix][1760] Added fix for the missing context key issue in dolly! (#1766)
Co-authored-by: Sebastian Raschka <mail@sebastianraschka.com>
1 parent ae798ab commit c16c72c

File tree

6 files changed

+64
-7
lines changed

6 files changed

+64
-7
lines changed

litgpt/data/dolly.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,6 @@ def setup(self, stage: str = "") -> None:
7171

7272

7373
def _transform(item: dict) -> dict:
74-
item["input"] = item.pop("context")
75-
item["output"] = item.pop("response")
74+
item["input"] = item.get("context", "")
75+
item["output"] = item.get("response", "")
7676
return item

litgpt/finetune/adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E
382382
model.clear_kv_cache()
383383
model.train()
384384
output = tokenizer.decode(output)
385-
fabric.print(output)
385+
fabric.print(f"{output}\n")
386386
else:
387387
print(
388388
f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) "

litgpt/finetune/adapter_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E
378378
model.clear_kv_cache()
379379
model.train()
380380
output = tokenizer.decode(output)
381-
fabric.print(output)
381+
fabric.print(f"{output}\n")
382382
else:
383383
print(
384384
f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) "

litgpt/finetune/full.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E
354354
model.clear_kv_cache()
355355
model.train()
356356
output = tokenizer.decode(output)
357-
fabric.print(output)
357+
fabric.print(f"{output}\n")
358358
else:
359359
print(
360360
f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) "

litgpt/finetune/lora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ def generate_example(fabric: L.Fabric, model: GPT, tokenizer: Tokenizer, eval: E
410410
model.clear_kv_cache()
411411
model.train()
412412
output = tokenizer.decode(output)
413-
fabric.print(output)
413+
fabric.print(f"{output}\n")
414414
else:
415415
print(
416416
f"Length of encoded instruction ({len(encoded)}) and eval.max_new_tokens ({eval.max_new_tokens}) "

tests/data/test_dolly.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@
55

66

77
def test_dolly(mock_tokenizer, dolly_path):
8-
dolly = Dolly(val_split_fraction=0.5, download_dir=dolly_path.parent, file_name=dolly_path.name, num_workers=0)
8+
dolly = Dolly(
9+
val_split_fraction=0.5,
10+
download_dir=dolly_path.parent,
11+
file_name=dolly_path.name,
12+
num_workers=0,
13+
)
914
assert isinstance(dolly.prompt_style, AlpacaPromptStyle)
1015
dolly.connect(mock_tokenizer, batch_size=2, max_seq_length=10)
1116
dolly.prepare_data()
@@ -29,3 +34,55 @@ def test_dolly(mock_tokenizer, dolly_path):
2934

3035
# has attributes from super class `LightningDataModule`
3136
assert dolly.prepare_data_per_node
37+
38+
39+
def test_dolly_missing_keys(mock_tokenizer, dolly_path):
40+
"""
41+
Notes
42+
-----
43+
- Added only for the dolly dataset.
44+
45+
References
46+
----------
47+
- Reference issue: https://github.com/Lightning-AI/litgpt/issues/1760
48+
49+
Methodology
50+
-----------
51+
- Simulate the original behavior by popping `context` key.
52+
- Run dataloader which will apply `transform`.
53+
- Previously it would have thrown missing `context` key error because we `popped` the key.
54+
- Now we are using `get` method to not remove they key(s).
55+
"""
56+
57+
dolly = Dolly(
58+
val_split_fraction=0.5,
59+
download_dir=dolly_path.parent,
60+
file_name=dolly_path.name,
61+
num_workers=0,
62+
)
63+
dolly.connect(mock_tokenizer, batch_size=2, max_seq_length=10)
64+
dolly.prepare_data()
65+
dolly.setup()
66+
67+
# check if the dataset was created without errors
68+
assert dolly.train_dataset is not None
69+
assert dolly.test_dataset is not None
70+
71+
# Verify that the transform function handled missing keys correctly
72+
for dataset in [dolly.train_dataset, dolly.test_dataset]:
73+
for item in dataset.data:
74+
assert "context" in item
75+
assert "response" in item
76+
assert isinstance(item["context"], str)
77+
assert isinstance(item["response"], str)
78+
# Drop `context` and `response` keys
79+
# This is to simulate the behavior of original issue with `item.pop`
80+
item.pop("context")
81+
item.pop("response")
82+
83+
# Check if we can iterate through the dataloader without errors
84+
# Previous approach would through key error here since we already popped the keys
85+
train_dataloader = dolly.train_dataloader()
86+
train_batch = next(iter(train_dataloader))
87+
assert "input_ids" in train_batch
88+
assert "labels" in train_batch

0 commit comments

Comments
 (0)