Skip to content

Commit

Permalink
feat: add config for optional parameters in a chat message (#2260)
Browse files Browse the repository at this point in the history
* feat: add config for optional parameters in a chat message

* chore: cleanup

* chore: fix nits and add light docs

* docs: update docs/dataset-formats/conversation.qmd

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>

* feat: configurable message mappings, jinja template analyzer

* chore: handle bradley terry

* docs: update docs

* refactor: change order of mappings, improve message transform

* refactor: make chat awware of property mappings

* chore: remove .python-version

* chore: revert change

* chore: add dataset validation to tests where appropriate

* chore: add dataset validation to tests where appropriate

* chore: clean up handling of ds_cfg

* chore: recursively serialize config

* make sure to use the return value from validate_config

* DefaultDict pickle/unpickle fix

* fix super call for override

* refactor: message fields

* chore: empty commit

* tests: validate config before using

* chore: add config validation to all e2e tests

* chore: add unneeded logging

* chore: add missed config validation

* chore: pass field_messages to prompter

* test: fix borked test

* chore: remove uninteded file

* chore: add deprecation warning and update chat_datasets script

* chore: lint

* refactor: message fields

* feat: update axolotlinputconfig and test_models

- add configdict import in axolotl/utils/config/models/input/v0_4_1/__init__.py
- remove unnecessary line breaks in sftdataset, dpodataset, ktodataset, stepwisesuperviseddataset classes
- update model_dump method in axolotlinputconfig to exclude none values
- correct typo in test_models.py comment

* feat: simplify dpodataset and ktodataset classes in config models

removed several optional fields from dpodataset and ktodataset classes in axolotl/utils/config/models/input/v0_4_1. this simplifies the configuration subsets for these datasets.

* feat: improve readability and structure in dataset configuration models

this commit enhances the readability and structure of the dataset configuration models in the `axolotl/utils/config/models/input/v0_4_1` module. it removes unused `configdict` import and adds line breaks to separate class definitions for better clarity. additionally, a minor documentation fix is included to ensure a newline at the end of the `stepwise_supervised.qmd` file.

* feat: change log level from info to debug in chattemplatestrategy

* feat(prompt_strategies): refactor chattemplateprompter and chattemplatestrategy

- Make `chat_template` a required parameter in `ChatTemplatePrompter` constructor
- Add default value for `message_property_mappings` in `ChatTemplatePrompter` constructor
- Add `messages_array_name` property to `ChatTemplatePrompter`
- Change `processor` type to Optional in `ChatTemplatePrompter`
- Add TypeError check for `processor` in `ChatTemplatePrompter.build_prompt`
- Remove `_messages` property from `ChatTemplateStrategy`
- Make `prompter` a required parameter and add type hint in `ChatTemplateStrategy` constructor
- Remove `messages` getter and setter from `ChatTemplateStrategy`
- Use `prompter.messages_array_name` in `ChatTemplateStrategy.get_conversation_thread`
- Remove condition to set `messages` field in `load` function

* feat(tests/utils): ignore type check in load_model call in test_models.py

* feat: improve type handling and test structure in chat templates

- Add return type hint for `get_chat_template` function in `chat_templates.py`
- Remove unnecessary assignment of `strategy.messages` in several test cases
- Add `messages_array_name` parameter to various test configurations in `test_chat_templates.py` and `test_chat_templates_advanced.py`
- Remove redundant `strategy.messages` assignment in `test_chat_templates_advanced.py`

* feat(axolotl): enhance chat strategy with datasetconfig support

This commit introduces support for DatasetConfig in the ChatTemplateStrategy. It also refines the strategy loader to handle different types of ds_cfg inputs and improves the clarity of the code by formatting and reordering. The key changes include:

- Importing Union from typing and BaseModel from pydantic.
- Adding DatasetConfig as an optional type for ds_cfg in StrategyLoader.
- Adjusting the handling of ds_cfg in StrategyLoader to account for BaseModel instances.
- Refactoring the prompter_params and strategy_params for better readability.
- Changing the reference from prompt[self.messages] to prompt[self.prompter.messages_array_name] in the is_prompt_batched method.

* feat: update message handling in btchattemplatestrategy

* Replace `self.messages` with direct string references to "chosen_messages" and "rejected_messages"
* Append system, user, and assistant content directly to "chosen_messages" and "rejected_messages"
* Add a new attribute "messages_array_name" to the `load` function parameters
* Remove the conditional attribute assignment for "field_messages" in the `load` function

* feat: add config validation in test_kd.py

- Import `validate_config` from `axolotl.utils.config`
- Validate the configuration in `test_llama_kd` and another function in `TestKnowledgeDistillation` class

* feat: enhance config validation and capabilities handling

* Import `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
* Update `validate_config` function to create `KTODataset` and `SFTDataset` instances using `dict(ds_cfg)`
* Replace `capabilities` and `env_capabilities` with instances of `GPUCapabilities` and `EnvCapabilities` respectively in `AxolotlConfigWCapabilities` model dump

* feat: update config validation in axolotl utils

- Remove import of `EnvCapabilities` and `GPUCapabilities` from `axolotl.utils.config.models.internals`
- Update `validate_config` function to use `capabilities` and `env_capabilities` directly instead of creating new instances of `GPUCapabilities` and `EnvCapabilities`

* feat: refactor strategyloader in chat_template.py

- Extracted the creation of strategy parameters into a separate function, `_get_strategy_params(cfg, dataset_config)`
- Created a new function, `_get_strategy_cls()`, to obtain the strategy class
- Replaced `ChatTemplateStrategy` with `strategy_cls` for strategy instantiation

* trigger CI

* chore: revert dataset config changes for kto/dpo

* subject: refactor: rename 'messages_array_name' to 'field_messages'

Body:
- Renamed 'messages_array_name' to 'field_messages' in 'ChatTemplatePrompter' class and its usages in 'chat_template.py'
- Updated 'load' function in 'bradley_terry/chat_template.py' to reflect the change
- Adjusted 'get_chat_template_msg_variables' and 'get_message_vars' methods in 'jinja_template_analyzer.py' to use the new variable name
- Modified 'StrategyLoader' in 'chat_template.py' to use 'field_messages'
- Updated tests in 'test_chat_templates.py' and 'test_chat_templates_advanced.py' to use 'field_messages' instead of 'messages_array_name'

* feat: refactor prompt strategies and update config models

* Remove redundant 'return None' in `axolotl/prompt_strategies/__init__.py`
* Simplify message handling in `axolotl/prompt_strategies/bradley_terry/chat_template.py` by using a single 'messages' list instead of separate 'chosen_messages' and 'rejected_messages' lists
* Update default 'message_property_mappings' in `axolotl/prompt_strategies/bradley_terry/chat_template.py`
* Add 'field_messages' field to `axolotl/utils/config/models/input/v0_4_1/__init__.py` configuration model

* chore: remove unused input

* chore: remove redundant type ignore

* fix: remove old configs and update examples

* fix: type check

* fix: remove loading old config in ChatMessage

* fix: update faq with potential new undefinederror

* fix: add debug if property mapped is not found

* chore: improve explanation for unmapped properties

* fix: update docs with new config

* chore: add note for deprecation config and del old config from dict

---------

Co-authored-by: NanoCode012 <kevinvong@rocketmail.com>
Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: NanoCode012 <nano@axolotl.ai>
  • Loading branch information
4 people authored Feb 18, 2025
1 parent 3aac3b1 commit b194e17
Show file tree
Hide file tree
Showing 51 changed files with 1,190 additions and 230 deletions.
17 changes: 13 additions & 4 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,19 @@ datasets:

# Key containing the messages (default: "messages")
field_messages: messages
# Key for role in each message (default: "role")
message_field_role: role
# Key for content in each message (default: "content")
message_field_content: content

# Mapping of properties from the input dataset to the chat template.
# (default: message_property_mappings={'role':'role', 'content':'content'})
# If a property exists in the template but not in this mapping, the system will attempt
# to load it directly from the message using the property name as the key.
# Example: In the mapping below, 'from' is loaded from input dataset and used as 'role',
# while 'value' is loaded and used as 'content' in the chat template.
message_property_mappings:
role: from
content: value
# ...

message_property_mappings:

# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
roles:
Expand Down
15 changes: 9 additions & 6 deletions docs/dataset-formats/conversation.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ datasets:
type: chat_template

field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value

# new (if setting a new chat_template like chatml, gemma, etc)
chat_template: chatml
Expand All @@ -52,8 +53,9 @@ datasets:
type: chat_template

field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
```
We recommend checking the below examples for other usecases.
Expand Down Expand Up @@ -138,8 +140,9 @@ datasets:
type: chat_template
chat_template: tokenizer_default
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value
roles_to_train: []
train_on_eos: turn
message_field_training: train
Expand Down
24 changes: 13 additions & 11 deletions docs/dataset-formats/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ A flow chart is as follows:

4. Is your dataset in an "instruct" format, containing `{ instruction, response }`? If yes, check [Instruction Dataset](#instruction-dataset)

If you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a Github Discussion.
If you went through the flow chart and did not find one that matches, it is recommended to preprocess your dataset into one of the above or create a thread on Github Discussion.

::: {.callout-tip}
You can mix and match within each approach or across approaches to train a model on a variety of datasets.
Expand Down Expand Up @@ -289,9 +289,10 @@ If your dataset format is different, here are the keys you should check (with th
```yaml
datasets:
...
field_messages: messages
message_field_role: role
message_field_content: content
field_messages: messages # this should point to the key containing the list of conversations
message_property_mappings: # this is a mapping from keys in your dataset to keys in chat_template
role: role
content: content
```

In some `chat_templates` (e.g. [Gemma](https://huggingface.co/google/gemma-2b-it/blob/main/tokenizer_config.json#L1507)), the roles are hardcoded to `user` and `assistant`. Consequently, you may find it necessary to map the roles in your dataset to these above. We currently have some defaults that should work for common datasets, but if you get a `KeyError`, it would be necessary to add mapping for your roles. Here is an example of how it would look like:
Expand Down Expand Up @@ -348,13 +349,14 @@ datasets:
- path: A.jsonl
type: chat_template
# step 1
# step 1
chat_template: chatml
# step 2
field_messages: messages
message_field_role: role
message_field_content: content
# step 2
field_messages: messages
message_property_mappings:
role: role
content: content
roles:
assistant:
Expand All @@ -365,8 +367,8 @@ datasets:
- human
- user
# step 3
roles_to_train: ["assistant"]
# step 3
roles_to_train: ["assistant"]
train_on_eos: "turn"
special_tokens:
Expand Down
4 changes: 4 additions & 0 deletions docs/faq.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ description: Frequently asked questions
**Q: The codes is stuck on saving preprocessed datasets.**

> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**

> A: This means that the property mapping for the stated attribute does not exist when building `chat_template` prompt. For example, if `no attribute 'content'`, please check you have added the correct mapping for `content` under `message_property_mappings`.
5 changes: 3 additions & 2 deletions docs/rlhf.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,9 @@ datasets:
field_messages: "messages"
field_chosen: "chosen"
field_rejected: "rejected"
message_field_role: "role"
message_field_content: "content"
message_property_mappings:
role: role
content: content
roles:
user: ["user"]
assistant: ["assistant"]
Expand Down
5 changes: 3 additions & 2 deletions examples/deepseek-v2/qlora-fsdp-2_5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ datasets:
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value

dataset_prepared_path: last_run_prepared
val_set_size: 0.0
Expand Down
5 changes: 3 additions & 2 deletions examples/gemma2/qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ datasets:
type: chat_template
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value

val_set_size: 0.0
output_dir: ./outputs/out
Expand Down
5 changes: 3 additions & 2 deletions examples/jamba/qlora_fsdp_large.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ datasets:
type: chat_template
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value

dataset_prepared_path: last_run_prepared
val_set_size: 0.0
Expand Down
5 changes: 3 additions & 2 deletions examples/llama-3/fft-8b-liger-fsdp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ datasets:
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
message_property_mappings:
role: from
content: value

dataset_prepared_path: last_run_prepared
val_set_size: 0.02
Expand Down
5 changes: 3 additions & 2 deletions examples/llama-3/instruct-dpo-lora-8b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
system:
- system
Expand Down
5 changes: 3 additions & 2 deletions examples/llama-3/instruct-lora-8b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
user:
- user
Expand Down
10 changes: 6 additions & 4 deletions examples/llama-3/lora-1b-deduplicate-dpo.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
system:
- system
Expand All @@ -31,8 +32,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
system:
- system
Expand Down
5 changes: 3 additions & 2 deletions examples/mistral/mistral-dpo-qlora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content

dataset_prepared_path:
val_set_size: 0.05
Expand Down
5 changes: 3 additions & 2 deletions examples/phi/lora-3.5.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
field_messages: messages
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
user:
- user
Expand Down
5 changes: 3 additions & 2 deletions examples/qwen2/dpo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ datasets:
field_messages: conversation
field_chosen: chosen
field_rejected: rejected
message_field_role: role
message_field_content: content
message_property_mappings:
role: role
content: content
roles:
system:
- system
Expand Down
15 changes: 7 additions & 8 deletions scripts/chat_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,26 @@ def parse_dataset(dataset=None, split="train"):
ds_cfg["field_messages"] = field_messages

message_fields = features[field_messages][0].keys()
message_field_role = None

message_property_mappings = {"role": None, "content": None}
for key in ["from", "role"]:
if key in message_fields:
message_field_role = key
message_property_mappings["role"] = key
break
if not message_field_role:
if not message_property_mappings["role"]:
raise ValueError(
f'No role field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_role"] = message_field_role

message_field_content = None
for key in ["content", "text", "value"]:
if key in message_fields:
message_field_content = key
message_property_mappings["content"] = key
break
if not message_field_content:
if not message_property_mappings["content"]:
raise ValueError(
f'No content field found in messages: {", ".join(message_fields)}'
)
ds_cfg["message_field_content"] = message_field_content
ds_cfg["message_property_mappings"] = message_property_mappings

print(yaml.dump({"datasets": [ds_cfg]}))

Expand Down
2 changes: 1 addition & 1 deletion src/axolotl/prompt_strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
load_kwargs["ds_cfg"] = ds_cfg
if "processor" in sig.parameters:
load_kwargs["processor"] = processor

return func(tokenizer, cfg, **load_kwargs)
except ModuleNotFoundError:
return None
except Exception as exc: # pylint: disable=broad-exception-caught
LOG.error(f"Failed to load prompt strategy `{strategy}`: {str(exc)}")
raise exc
return None
36 changes: 15 additions & 21 deletions src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,12 @@ def _tokenize_single_prompt(self, prompt):

max_length = self.prompter.max_length

self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
prompt["messages"] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
prompt["messages"].append({"role": "system", "content": prompt["system"]})
prompt["messages"].append({"role": "user", "content": prompt["input"]})
prompt["messages"].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super()._tokenize_single_prompt(prompt)

if len(chosen_tokenized["input_ids"]) > max_length:
Expand All @@ -55,17 +52,12 @@ def _tokenize_single_prompt(self, prompt):
:max_length
]

self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
prompt["messages"] = []
if prompt["system"]:
prompt[self.messages].append(
{"role": "system", "content": prompt["system"]}
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
prompt["messages"].append({"role": "system", "content": prompt["system"]})
prompt["messages"].append({"role": "user", "content": prompt["input"]})
prompt["messages"].append({"role": "assistant", "content": prompt["rejected"]})
rejected_tokenized = super()._tokenize_single_prompt(prompt)

if len(rejected_tokenized["input_ids"]) > max_length:
Expand Down Expand Up @@ -99,8 +91,13 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_property_mappings": ds_cfg.get(
"message_property_mappings",
{
"role": "role",
"content": "content",
},
),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail", None
Expand All @@ -124,7 +121,4 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)

if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]

return strategy
Loading

0 comments on commit b194e17

Please sign in to comment.