Skip to content

Commit a4e430e

Browse files
authored
add override of upstream fix for multi-gpu orpo (axolotl-ai-cloud#2440)
* add override of upstream fix * override batch loss metrics for CPO/Simpo as well
1 parent 6cdcb8d commit a4e430e

File tree

1 file changed

+148
-0
lines changed
  • src/axolotl/core/trainers

1 file changed

+148
-0
lines changed

src/axolotl/core/trainers/trl.py

+148
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Module for TRL PPO trainer"""
22

3+
from typing import Literal, Union
4+
35
import torch
46
from tqdm import tqdm
57
from trl import (
@@ -79,6 +81,78 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
7981

8082
tag_names = ["axolotl", "orpo"]
8183

84+
def get_batch_loss_metrics(
85+
self,
86+
model,
87+
batch: dict[str, Union[list, torch.LongTensor]],
88+
train_eval: Literal["train", "eval"] = "train",
89+
):
90+
"""Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
91+
92+
# TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
93+
94+
metrics = {}
95+
96+
forward_output = self.concatenated_forward(model, batch)
97+
(
98+
policy_chosen_logps,
99+
policy_rejected_logps,
100+
policy_chosen_logits,
101+
policy_rejected_logits,
102+
policy_nll_loss,
103+
) = forward_output[:5]
104+
if self.aux_loss_enabled:
105+
aux_loss = forward_output[5]
106+
107+
losses, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen = (
108+
self.odds_ratio_loss(policy_chosen_logps, policy_rejected_logps)
109+
)
110+
# full ORPO loss
111+
loss = policy_nll_loss - losses.mean()
112+
113+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
114+
115+
prefix = "eval_" if train_eval == "eval" else ""
116+
metrics[f"{prefix}rewards/chosen"] = self.accelerator.gather_for_metrics(
117+
chosen_rewards
118+
).mean()
119+
metrics[f"{prefix}rewards/rejected"] = self.accelerator.gather_for_metrics(
120+
rejected_rewards
121+
).mean()
122+
metrics[f"{prefix}rewards/accuracies"] = self.accelerator.gather_for_metrics(
123+
reward_accuracies
124+
).mean()
125+
metrics[f"{prefix}rewards/margins"] = self.accelerator.gather_for_metrics(
126+
chosen_rewards - rejected_rewards
127+
).mean()
128+
metrics[f"{prefix}logps/rejected"] = (
129+
self.accelerator.gather_for_metrics(policy_rejected_logps).detach().mean()
130+
)
131+
metrics[f"{prefix}logps/chosen"] = (
132+
self.accelerator.gather_for_metrics(policy_chosen_logps).detach().mean()
133+
)
134+
metrics[f"{prefix}logits/rejected"] = self.accelerator.gather_for_metrics(
135+
policy_rejected_logits.detach().mean()
136+
).mean()
137+
metrics[f"{prefix}logits/chosen"] = self.accelerator.gather_for_metrics(
138+
policy_chosen_logits.detach().mean()
139+
).mean()
140+
metrics[f"{prefix}nll_loss"] = (
141+
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean()
142+
)
143+
metrics[f"{prefix}log_odds_ratio"] = (
144+
self.accelerator.gather_for_metrics(log_odds_ratio).detach().mean()
145+
)
146+
metrics[f"{prefix}log_odds_chosen"] = (
147+
self.accelerator.gather_for_metrics(log_odds_chosen).detach().mean()
148+
)
149+
for k, v in metrics.items():
150+
metrics[k] = v.item()
151+
if self.aux_loss_enabled:
152+
loss += self.aux_loss_coef * aux_loss
153+
154+
return loss, metrics
155+
82156

83157
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
84158
"""
@@ -95,6 +169,80 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
95169

96170
tag_names = ["axolotl", "cpo"]
97171

172+
def get_batch_loss_metrics(
173+
self,
174+
model,
175+
batch: dict[str, Union[list, torch.LongTensor]],
176+
train_eval: Literal["train", "eval"] = "train",
177+
):
178+
"""Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
179+
metrics = {}
180+
181+
forward_output = self.concatenated_forward(model, batch)
182+
(
183+
policy_chosen_logps,
184+
policy_rejected_logps,
185+
policy_chosen_logits,
186+
policy_rejected_logits,
187+
policy_nll_loss,
188+
) = forward_output[:5]
189+
if self.aux_loss_enabled:
190+
aux_loss = forward_output[5]
191+
192+
losses, chosen_rewards, rejected_rewards = self.cpo_loss(
193+
policy_chosen_logps,
194+
policy_rejected_logps,
195+
)
196+
197+
loss = losses.mean() + self.cpo_alpha * policy_nll_loss
198+
reward_accuracies = (chosen_rewards > rejected_rewards).float()
199+
200+
prefix = "eval_" if train_eval == "eval" else ""
201+
metrics[f"{prefix}rewards/chosen"] = (
202+
self.accelerator.gather_for_metrics(chosen_rewards).mean().item()
203+
)
204+
metrics[f"{prefix}rewards/rejected"] = (
205+
self.accelerator.gather_for_metrics(rejected_rewards).mean().item()
206+
)
207+
metrics[f"{prefix}rewards/accuracies"] = (
208+
self.accelerator.gather_for_metrics(reward_accuracies).mean().item()
209+
)
210+
metrics[f"{prefix}rewards/margins"] = (
211+
self.accelerator.gather_for_metrics(chosen_rewards - rejected_rewards)
212+
.mean()
213+
.item()
214+
)
215+
metrics[f"{prefix}logps/rejected"] = (
216+
self.accelerator.gather_for_metrics(policy_rejected_logps)
217+
.detach()
218+
.mean()
219+
.item()
220+
)
221+
metrics[f"{prefix}logps/chosen"] = (
222+
self.accelerator.gather_for_metrics(policy_chosen_logps)
223+
.detach()
224+
.mean()
225+
.item()
226+
)
227+
metrics[f"{prefix}logits/rejected"] = (
228+
self.accelerator.gather_for_metrics(policy_rejected_logits.detach().mean())
229+
.mean()
230+
.item()
231+
)
232+
metrics[f"{prefix}logits/chosen"] = (
233+
self.accelerator.gather_for_metrics(policy_chosen_logits.detach().mean())
234+
.mean()
235+
.item()
236+
)
237+
metrics[f"{prefix}nll_loss"] = (
238+
self.accelerator.gather_for_metrics(policy_nll_loss).detach().mean().item()
239+
)
240+
241+
if self.aux_loss_enabled:
242+
loss += self.aux_loss_coef * aux_loss
243+
244+
return loss, metrics
245+
98246

99247
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
100248
"""

0 commit comments

Comments
 (0)