Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reinforcement Learning from Human Feedback (RLHF) examples: Direct Preference Optimization (DPO) #513

Open
danilopeixoto opened this issue Mar 1, 2024 · 8 comments · May be fixed by #1279
Labels
enhancement New feature or request

Comments

@danilopeixoto
Copy link

danilopeixoto commented Mar 1, 2024

Introduce one Reinforcement Learning from Human Feedback (RLHF) example, such as Direct Preference Optimization (DPO) method.

Paper

Direct Preference Optimization: Your Language Model is Secretly a Reward Model

Notes

Direct Preference Optimization (DPO): A Simplified Explanation by João Lages

Implementation examples

Possible MLX implementation

Policy and reference log probabilities:

def get_batched_logps(model, inputs, targets):
    logits, _ = model(inputs)
    logits = logits.astype(mx.float32)

    loss_mask = targets != 0
    per_token_logps = mx.take_along_axis(nn.log_softmax(logits), targets[..., None], axis=2).squeeze(2)

    return tuple((per_token_logps * loss_mask).sum(-1).split(2))

Loss:

def dpo_loss(model, beta, label_smoothing, reference_chosen_logps, reference_rejected_logps, inputs, targets):
    chosen_logps, rejected_logps = get_batched_logps(model, inputs, targets)

    pi_logratios = chosen_logps - rejected_logps
    reference_logratios = reference_chosen_logps - reference_rejected_logps

    logits = pi_logratios - reference_logratios
    losses = -nn.log_sigmoid(beta * logits) * (1.0 - label_smoothing) - nn.log_sigmoid(-beta * logits) * label_smoothing

    chosen_rewards = beta * (chosen_logps - reference_chosen_logps)
    rejected_rewards = beta * (rejected_logps - reference_rejected_logps)
    reward_accuracies = (chosen_rewards > rejected_rewards).astype(mx.float32)
    reward_margins = chosen_rewards - rejected_rewards

    ntoks = (inputs != 0).sum()

    return (
        losses.mean(),
        chosen_rewards.mean(),
        rejected_rewards.mean(),
        reward_accuracies.mean(),
        reward_margins.mean(),
        ntoks,
    )

Beta: The temperature parameter for the DPO loss is typically set in the range of 0.1 to 0.5. The reference model is ignored when beta equals 0.

Label smoothing: This parameter represents the conservativeness for DPO loss, assuming that preferences are noisy and can be flipped with a probability of label_smoothing.

Note label_smoothing > 0 defines the Conservative DPO loss.

@awni
Copy link
Member

awni commented Mar 1, 2024

@danilopeixoto I've been thinking about having this in MLX LM recently. Any interest in sending a PR?

It might make to do it after we have a more manageable config (#503) but that should be landed soon!

@awni awni added the enhancement New feature or request label Mar 1, 2024
@awni
Copy link
Member

awni commented Mar 1, 2024

To be more concrete, I'm envisioning you just set the loss in the config. e.g. cross_entropy or dpo

@ivanfioravanti
Copy link
Contributor

This would be an awesome addition to mlx_examples! 🔥

@N8python
Copy link
Contributor

I'm very very excited for this! Don't have the technical expertise to implement the DPO directly but would love to help in other ways (config, code cleanup) if neccessary!

@lin72h
Copy link

lin72h commented Mar 27, 2024

That makes MLX really useful for production not just a research tool!

@awni awni mentioned this issue Apr 10, 2024
@kishoretvk
Copy link

+500 waiting for this

@developerlin
Copy link

developerlin commented May 16, 2024

Wait for this, when will the DPO training be supported?

anupamme added a commit to anupamme/mlx-examples that referenced this issue Feb 12, 2025
Fixes ml-explore#513

Implement the Direct Preference Optimization (DPO) method as a Reinforcement Learning from Human Feedback (RLHF) example.

* **Add DPO Functions**: Add `get_batched_logps` and `dpo_loss` functions to `llms/mlx_lm/utils.py` for DPO implementation.
* **Update Training Logic**: Update `llms/mlx_lm/tuner/trainer.py` to include DPO-specific training logic, including a new `dpo_loss` function and condition to check for DPO loss in the training loop.
* **Add Configuration Options**: Add configuration options for DPO in `llms/mlx_lm/examples/lora_config.yaml`.
* **Update Documentation**: Update `llms/mlx_lm/README.md` to include instructions for using DPO.
* **Add Unit Tests**: Add `llms/tests/test_dpo.py` with unit tests for `get_batched_logps`, `dpo_loss`, and DPO-specific training logic.

---

For more details, open the [Copilot Workspace session](https://copilot-workspace.githubnext.com/ml-explore/mlx-examples/issues/513?shareId=XXXX-XXXX-XXXX-XXXX).
@anupamme anupamme linked a pull request Feb 12, 2025 that will close this issue
@Goekdeniz-Guelmez
Copy link
Contributor

#1233 #1210 #1209

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants