1
1
"""Module for TRL PPO trainer"""
2
2
3
+ from typing import Literal , Union
4
+
3
5
import torch
4
6
from tqdm import tqdm
5
7
from trl import (
@@ -79,6 +81,78 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
79
81
80
82
tag_names = ["axolotl" , "orpo" ]
81
83
84
+ def get_batch_loss_metrics (
85
+ self ,
86
+ model ,
87
+ batch : dict [str , Union [list , torch .LongTensor ]],
88
+ train_eval : Literal ["train" , "eval" ] = "train" ,
89
+ ):
90
+ """Compute the ORPO loss and other metrics for the given batch of inputs for train or test."""
91
+
92
+ # TODO remove once https://github.com/huggingface/trl/pull/3069 is included in a trl release
93
+
94
+ metrics = {}
95
+
96
+ forward_output = self .concatenated_forward (model , batch )
97
+ (
98
+ policy_chosen_logps ,
99
+ policy_rejected_logps ,
100
+ policy_chosen_logits ,
101
+ policy_rejected_logits ,
102
+ policy_nll_loss ,
103
+ ) = forward_output [:5 ]
104
+ if self .aux_loss_enabled :
105
+ aux_loss = forward_output [5 ]
106
+
107
+ losses , chosen_rewards , rejected_rewards , log_odds_ratio , log_odds_chosen = (
108
+ self .odds_ratio_loss (policy_chosen_logps , policy_rejected_logps )
109
+ )
110
+ # full ORPO loss
111
+ loss = policy_nll_loss - losses .mean ()
112
+
113
+ reward_accuracies = (chosen_rewards > rejected_rewards ).float ()
114
+
115
+ prefix = "eval_" if train_eval == "eval" else ""
116
+ metrics [f"{ prefix } rewards/chosen" ] = self .accelerator .gather_for_metrics (
117
+ chosen_rewards
118
+ ).mean ()
119
+ metrics [f"{ prefix } rewards/rejected" ] = self .accelerator .gather_for_metrics (
120
+ rejected_rewards
121
+ ).mean ()
122
+ metrics [f"{ prefix } rewards/accuracies" ] = self .accelerator .gather_for_metrics (
123
+ reward_accuracies
124
+ ).mean ()
125
+ metrics [f"{ prefix } rewards/margins" ] = self .accelerator .gather_for_metrics (
126
+ chosen_rewards - rejected_rewards
127
+ ).mean ()
128
+ metrics [f"{ prefix } logps/rejected" ] = (
129
+ self .accelerator .gather_for_metrics (policy_rejected_logps ).detach ().mean ()
130
+ )
131
+ metrics [f"{ prefix } logps/chosen" ] = (
132
+ self .accelerator .gather_for_metrics (policy_chosen_logps ).detach ().mean ()
133
+ )
134
+ metrics [f"{ prefix } logits/rejected" ] = self .accelerator .gather_for_metrics (
135
+ policy_rejected_logits .detach ().mean ()
136
+ ).mean ()
137
+ metrics [f"{ prefix } logits/chosen" ] = self .accelerator .gather_for_metrics (
138
+ policy_chosen_logits .detach ().mean ()
139
+ ).mean ()
140
+ metrics [f"{ prefix } nll_loss" ] = (
141
+ self .accelerator .gather_for_metrics (policy_nll_loss ).detach ().mean ()
142
+ )
143
+ metrics [f"{ prefix } log_odds_ratio" ] = (
144
+ self .accelerator .gather_for_metrics (log_odds_ratio ).detach ().mean ()
145
+ )
146
+ metrics [f"{ prefix } log_odds_chosen" ] = (
147
+ self .accelerator .gather_for_metrics (log_odds_chosen ).detach ().mean ()
148
+ )
149
+ for k , v in metrics .items ():
150
+ metrics [k ] = v .item ()
151
+ if self .aux_loss_enabled :
152
+ loss += self .aux_loss_coef * aux_loss
153
+
154
+ return loss , metrics
155
+
82
156
83
157
class AxolotlKTOTrainer (SchedulerMixin , KTOTrainer ):
84
158
"""
@@ -95,6 +169,80 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
95
169
96
170
tag_names = ["axolotl" , "cpo" ]
97
171
172
+ def get_batch_loss_metrics (
173
+ self ,
174
+ model ,
175
+ batch : dict [str , Union [list , torch .LongTensor ]],
176
+ train_eval : Literal ["train" , "eval" ] = "train" ,
177
+ ):
178
+ """Compute the CPO loss and other metrics for the given batch of inputs for train or test."""
179
+ metrics = {}
180
+
181
+ forward_output = self .concatenated_forward (model , batch )
182
+ (
183
+ policy_chosen_logps ,
184
+ policy_rejected_logps ,
185
+ policy_chosen_logits ,
186
+ policy_rejected_logits ,
187
+ policy_nll_loss ,
188
+ ) = forward_output [:5 ]
189
+ if self .aux_loss_enabled :
190
+ aux_loss = forward_output [5 ]
191
+
192
+ losses , chosen_rewards , rejected_rewards = self .cpo_loss (
193
+ policy_chosen_logps ,
194
+ policy_rejected_logps ,
195
+ )
196
+
197
+ loss = losses .mean () + self .cpo_alpha * policy_nll_loss
198
+ reward_accuracies = (chosen_rewards > rejected_rewards ).float ()
199
+
200
+ prefix = "eval_" if train_eval == "eval" else ""
201
+ metrics [f"{ prefix } rewards/chosen" ] = (
202
+ self .accelerator .gather_for_metrics (chosen_rewards ).mean ().item ()
203
+ )
204
+ metrics [f"{ prefix } rewards/rejected" ] = (
205
+ self .accelerator .gather_for_metrics (rejected_rewards ).mean ().item ()
206
+ )
207
+ metrics [f"{ prefix } rewards/accuracies" ] = (
208
+ self .accelerator .gather_for_metrics (reward_accuracies ).mean ().item ()
209
+ )
210
+ metrics [f"{ prefix } rewards/margins" ] = (
211
+ self .accelerator .gather_for_metrics (chosen_rewards - rejected_rewards )
212
+ .mean ()
213
+ .item ()
214
+ )
215
+ metrics [f"{ prefix } logps/rejected" ] = (
216
+ self .accelerator .gather_for_metrics (policy_rejected_logps )
217
+ .detach ()
218
+ .mean ()
219
+ .item ()
220
+ )
221
+ metrics [f"{ prefix } logps/chosen" ] = (
222
+ self .accelerator .gather_for_metrics (policy_chosen_logps )
223
+ .detach ()
224
+ .mean ()
225
+ .item ()
226
+ )
227
+ metrics [f"{ prefix } logits/rejected" ] = (
228
+ self .accelerator .gather_for_metrics (policy_rejected_logits .detach ().mean ())
229
+ .mean ()
230
+ .item ()
231
+ )
232
+ metrics [f"{ prefix } logits/chosen" ] = (
233
+ self .accelerator .gather_for_metrics (policy_chosen_logits .detach ().mean ())
234
+ .mean ()
235
+ .item ()
236
+ )
237
+ metrics [f"{ prefix } nll_loss" ] = (
238
+ self .accelerator .gather_for_metrics (policy_nll_loss ).detach ().mean ().item ()
239
+ )
240
+
241
+ if self .aux_loss_enabled :
242
+ loss += self .aux_loss_coef * aux_loss
243
+
244
+ return loss , metrics
245
+
98
246
99
247
class AxolotlRewardTrainer (SchedulerMixin , RewardTrainer ):
100
248
"""
0 commit comments