Skip to content

Commit 3392270

Browse files
jphmeJan Philipp Harries
and
Jan Philipp Harries
authored
experimental llama 2 chat support (axolotl-ai-cloud#296)
* experimental llama 2 chat support * few small fixes * llama2_chat * small fix to follow original implementation * small fixes and added fixtures/tests * fix -mixed up inference and finetuning conversations * args - small fix * small fix * small adjustment and warning * fix with pre-commit --------- Co-authored-by: Jan Philipp Harries <jpdus@users.noreply.github.com>
1 parent bb53a16 commit 3392270

File tree

4 files changed

+292
-2
lines changed

4 files changed

+292
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
"""
2+
Prompt Strategy for finetuning Llama2 chat models
3+
see also https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L213 for ma reference implementation.
4+
5+
This implementation is based on the Vicuna PR and the fastchat repo, see also:
6+
https://github.com/lm-sys/FastChat/blob/cdd7730686cb1bf9ae2b768ee171bdf7d1ff04f3/fastchat/conversation.py#L847
7+
8+
Use dataset type: "llama2_chat" in conig.yml to use this prompt style.
9+
10+
E.g. in the config.yml:
11+
```
12+
datasets:
13+
- path: llama_finetune_train.jsonl
14+
type: llama2_chat
15+
```
16+
17+
The dataset itself should look like this:
18+
```
19+
{'conversations':[{"from": "human", "value": "Who are you?"}, {"from": "gpt", "value": "I am Vicuna"},...]}
20+
```
21+
in a jsonl file. The first message should be from the human, the second from gpt.
22+
For a custom system message, the first "from" can be "system" (followed by alternating "human" and "gpt" turns).
23+
24+
Important: Don't use "special_tokens:" in your config.yml if you are not sure what you are doing!
25+
"""
26+
27+
import logging
28+
from dataclasses import dataclass, field
29+
from typing import Generator, List, Sequence
30+
31+
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
32+
from axolotl.prompters import IGNORE_TOKEN_ID
33+
34+
35+
@dataclass
36+
class Llama2ChatConversation:
37+
"""A class that manages prompt templates and keeps all conversation history.
38+
copied from https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py"""
39+
40+
name: str = "llama2"
41+
# The system prompt
42+
system: str = (
43+
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
44+
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
45+
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
46+
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
47+
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
48+
)
49+
roles: Sequence[str] = ("[INST]", "[/INST]")
50+
messages: List[List[str]] = field(default_factory=list)
51+
offset: int = 0
52+
sep = " "
53+
sep2 = " </s><s>"
54+
stop_token_ids = [2]
55+
56+
def get_prompt(self) -> str:
57+
"""Get the prompt for generation."""
58+
seps = [self.sep, self.sep2]
59+
ret = ""
60+
for i, (role, message) in enumerate(self.messages):
61+
if (i == len(self.messages) - 1) and (role == self.roles[0]):
62+
# last message is from user (due to length),
63+
# return prompt without it for training
64+
return ret
65+
if i == 0:
66+
ret += self.system + message.strip()
67+
else:
68+
ret += role + " " + message.strip() + seps[i % 2]
69+
return ret
70+
71+
def append_message(self, role: str, message: str):
72+
"""Append a new message."""
73+
self.messages.append([role, message])
74+
75+
76+
class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
77+
"""
78+
Tokenizing strategy for ShareGPT prompts.
79+
adapted from https://github.com/lm-sys/FastChat/blob/main/fastchat/train/train.py
80+
"""
81+
82+
def __init__(self, *args, **kwargs):
83+
super().__init__(*args, **kwargs)
84+
self.sequence_len = 4096
85+
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
86+
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
87+
88+
def tokenize_prompt(self, prompt):
89+
conv = next(self.prompter.build_prompt(prompt))
90+
conversation_str = conv.get_prompt()
91+
92+
# Tokenize conversations
93+
input_ids = self.tokenizer(
94+
conversation_str,
95+
return_tensors="pt",
96+
padding="max_length",
97+
max_length=self.sequence_len,
98+
truncation=True,
99+
).input_ids[0]
100+
target = input_ids.clone()
101+
102+
# Mask targets. Only compute loss on the assistant outputs.
103+
sep = conv.roles[1]
104+
105+
total_len = int(target.ne(self.tokenizer.pad_token_id).sum())
106+
107+
turns = conversation_str.split(conv.sep2)
108+
cur_len = 1
109+
target[:cur_len] = IGNORE_TOKEN_ID
110+
for turn in turns:
111+
if turn == "":
112+
break
113+
turn_len = len(self.tokenizer(turn).input_ids)
114+
115+
parts = turn.split(sep)
116+
if len(parts) != 2:
117+
break
118+
parts[0] += sep
119+
# "-1" is hardcoded for the LLaMA tokenizer to make the offset correct.
120+
instruction_len = len(self.tokenizer(parts[0]).input_ids) - 1
121+
122+
# Ignore the user instructions
123+
target[cur_len - 1 : cur_len + instruction_len] = IGNORE_TOKEN_ID
124+
cur_len += turn_len + 2 # due to length of role token
125+
126+
target[cur_len:] = IGNORE_TOKEN_ID
127+
128+
if cur_len < self.sequence_len:
129+
if cur_len != total_len:
130+
target[:] = IGNORE_TOKEN_ID
131+
logging.warning(
132+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
133+
f" (ignored)"
134+
)
135+
136+
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).tolist()
137+
input_ids = input_ids.tolist()
138+
target = target.tolist()
139+
# this is a fix for the tokenizer which tokenizes [ differently with eos tokens and
140+
# follows the original llama implementation
141+
for i in range(2, total_len - 2):
142+
if input_ids[i] == 29961:
143+
input_ids[i] = 518
144+
if target[i] == 29961:
145+
target[i] = 518
146+
return {
147+
"input_ids": input_ids,
148+
"labels": target,
149+
"attention_mask": attention_mask,
150+
}
151+
152+
153+
class Llama2ChatPrompter: # pylint: disable=too-few-public-methods
154+
"""
155+
A prompter that generates prompts for Llama2 models.
156+
"""
157+
158+
system_prompt = (
159+
"[INST] <<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. "
160+
"Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. "
161+
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
162+
"If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. "
163+
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
164+
)
165+
166+
def build_prompt(self, source) -> Generator[Llama2ChatConversation, None, None]:
167+
# see https://github.com/lm-sys/FastChat/blob/da0641e567cf93756b0978ab5a6b092e96f06240/fastchat/train/train.py#L78
168+
source = source["conversations"] # fix data structure for datasets
169+
170+
# if system prompt provided, use it
171+
if source[0]["from"] == "system":
172+
system = f"[INST] <<SYS>>\n{source[0]['value']}\n<</SYS>>\n\n"
173+
source = source[1:]
174+
else:
175+
system = self.system_prompt
176+
177+
conv = Llama2ChatConversation(system=system)
178+
179+
if len(source) < 2:
180+
# If there isn't a back and forth conversation, ignore it
181+
# also happens on the data splitting leaving empty conversations
182+
raise IndexError
183+
184+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
185+
186+
if roles[source[0]["from"]] != conv.roles[0]:
187+
# Skip the first one if it is not from human
188+
source = source[1:]
189+
190+
conv.messages = [] # pylint: disable=R0801
191+
for j, sentence in enumerate(source):
192+
role = roles[sentence["from"]]
193+
assert role == conv.roles[j % 2]
194+
if sentence["value"]:
195+
conv.append_message(role, sentence["value"])
196+
yield conv
197+
198+
199+
def load(tokenizer, cfg) -> LLama2ChatTokenizingStrategy:
200+
return LLama2ChatTokenizingStrategy(
201+
Llama2ChatPrompter(),
202+
tokenizer,
203+
cfg.train_on_inputs,
204+
cfg.sequence_len,
205+
)

src/axolotl/utils/data.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def load_prepare_datasets(
378378
[
379379
d
380380
for d in dataset
381-
if len(d["input_ids"]) < cfg.sequence_len
381+
if len(d["input_ids"]) <= cfg.sequence_len
382382
and len(d["input_ids"]) > 0
383383
and len(d["input_ids"]) == len(d["attention_mask"])
384384
and len(d["input_ids"]) == len(d["labels"])

tests/fixtures/conversation.tokenized_llama2chat.json

+1
Large diffs are not rendered by default.

tests/test_prompt_tokenizers.py

+85-1
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44
import unittest
55
from pathlib import Path
66

7-
from transformers import AutoTokenizer
7+
from transformers import AutoTokenizer, LlamaTokenizer
88

99
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
1010
from axolotl.prompt_strategies.alpaca_w_system import (
1111
InstructionWSystemPromptTokenizingStrategy,
1212
SystemDataPrompter,
1313
)
14+
from axolotl.prompt_strategies.llama2_chat import (
15+
Llama2ChatPrompter,
16+
LLama2ChatTokenizingStrategy,
17+
)
1418
from axolotl.prompt_tokenizers import (
1519
AlpacaPromptTokenizingStrategy,
1620
ShareGPTPromptTokenizingStrategy,
@@ -135,5 +139,85 @@ def test_system_alpaca(self):
135139
assert example["input_ids"][9] == 11889 # USER
136140

137141

142+
class Llama2ChatTokenizationTest(unittest.TestCase):
143+
"""
144+
Test class for prompt tokenization strategies with sys prompt from the dataset
145+
"""
146+
147+
def setUp(self) -> None:
148+
# pylint: disable=duplicate-code
149+
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
150+
# woraround because official Meta repos are not open
151+
152+
def test_llama2_chat_integration(self):
153+
with open(
154+
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
155+
) as fin:
156+
data = fin.read()
157+
conversation = json.loads(data)
158+
with open(
159+
Path(__file__).parent / "fixtures/conversation.tokenized_llama2chat.json",
160+
encoding="utf-8",
161+
) as fin:
162+
data = fin.read()
163+
tokenized_conversation = json.loads(data)
164+
prompter = Llama2ChatPrompter()
165+
strat = LLama2ChatTokenizingStrategy(
166+
prompter,
167+
self.tokenizer,
168+
False,
169+
4096,
170+
)
171+
example = strat.tokenize_prompt(conversation)
172+
for fields in ["input_ids", "attention_mask", "labels"]:
173+
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
174+
self.assertEqual(example[fields], tokenized_conversation[fields])
175+
176+
def compare_with_transformers_integration(self):
177+
# this needs transformers >= v4.31.0
178+
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
179+
from transformers.pipelines.conversational import Conversation
180+
181+
# from transformers.models.llama.tokenization_llama import DEFAULT_SYSTEM_PROMPT
182+
# broken as of 23/7/20
183+
# see https://github.com/huggingface/transformers/pull/24935
184+
# pylint: disable=C0103
185+
DEFAULT_SYSTEM_PROMPT = """\
186+
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
187+
188+
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
189+
with open(
190+
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
191+
) as fin:
192+
data = fin.read()
193+
conversation = json.loads(data)
194+
with open(
195+
Path(__file__).parent / "fixtures/conversation.tokenized_llama2chat.json",
196+
encoding="utf-8",
197+
) as fin:
198+
data = fin.read()
199+
tokenized_conversation = json.loads(data)
200+
201+
user_input = []
202+
answers = []
203+
for msg in conversation["conversations"]:
204+
if msg["from"] == "human":
205+
user_input.append(msg["value"])
206+
else:
207+
answers.append(msg["value"])
208+
hf_conf = Conversation(
209+
text=user_input[-1],
210+
past_user_inputs=[B_SYS + DEFAULT_SYSTEM_PROMPT + E_SYS + user_input[0]]
211+
+ user_input[1:-1],
212+
generated_responses=answers,
213+
)
214+
# pylint: disable=W0212
215+
hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf)
216+
217+
self.assertEqual(
218+
hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)]
219+
)
220+
221+
138222
if __name__ == "__main__":
139223
unittest.main()

0 commit comments

Comments
 (0)