Skip to content

Commit 9471709

Browse files
authored
Merge pull request #4 from rajammanabrolu/classifier
Adding the ability for training classifiers.
2 parents a8db3c7 + b2373f2 commit 9471709

File tree

12 files changed

+416
-24
lines changed

12 files changed

+416
-24
lines changed

compose_rl/data/preference_data.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -300,13 +300,15 @@ def __getitem__(self, idx: int) -> dict[str, Any]:
300300
idx (int): the index where we fetch the data in the StreamingDataset.
301301
"""
302302
sample = super().__getitem__(idx)
303-
text = self._read_binary_tokenized_sample(sample, 'text')
304-
label = self._read_binary_tokenized_sample(sample, 'label')
303+
text = self._read_binary_tokenized_sample(sample, 'input')
304+
label = torch.from_numpy(np.frombuffer(sample['label'], dtype=np.uint8))
305+
# This needs to be a float tensor for BCE
306+
label = label.to(torch.float32)
305307

306308
text_len = len(text)
307309

308310
return {
309311
'text': text,
310-
'label': label,
312+
'labels': label,
311313
'text_len': torch.Tensor([text_len]).to(torch.int64),
312314
}

compose_rl/metrics/reward_model_metrics.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,64 @@ def compute(self):
4040
assert isinstance(self.correct, Tensor)
4141
assert isinstance(self.total, Tensor)
4242
return self.correct / self.total
43+
44+
45+
class BinaryRewardClassificationAccuracy(Metric):
46+
"""Classification accuracy metric.
47+
48+
Computes the accuracy of a classifier by comparing predictions from logits
49+
against ground truth labels. Handles both binary and multi-class
50+
classification.
51+
"""
52+
53+
# Make torchmetrics call update only once
54+
full_state_update = False
55+
56+
def __init__(
57+
self,
58+
threshold: float = 0.5,
59+
dist_sync_on_step: bool = False,
60+
**kwargs: Any,
61+
):
62+
"""Initialize the metric.
63+
64+
Args:
65+
binary: If True, treats as binary classification with sigmoid.
66+
If False, treats as multi-class with softmax.
67+
threshold: Decision threshold for binary classification
68+
dist_sync_on_step: Synchronize metric state across processes
69+
"""
70+
super().__init__(dist_sync_on_step=dist_sync_on_step)
71+
self.threshold = threshold
72+
73+
self.add_state(
74+
'correct',
75+
default=torch.tensor(0.),
76+
dist_reduce_fx='sum',
77+
)
78+
self.add_state('total', default=torch.tensor(0.), dist_reduce_fx='sum')
79+
80+
def update(self, batch: dict, output_logits: torch.Tensor):
81+
"""Update state with predictions and targets.
82+
83+
Args:
84+
batch: Dictionary containing 'output_scores' and 'labels'
85+
output_logits: `None`
86+
"""
87+
del output_logits
88+
logits = batch['output_scores']
89+
targets = batch['labels'].squeeze(-1)
90+
assert logits.shape[0] == targets.shape[0], 'Batch sizes must match'
91+
92+
# TODO (raj): Handle multi-class classification with logging
93+
probs = torch.sigmoid(logits.squeeze())
94+
predictions = (probs > self.threshold).long()
95+
96+
self.correct += (predictions == targets).sum().detach().cpu()
97+
self.total += targets.shape[0]
98+
99+
def compute(self):
100+
"""Compute the accuracy."""
101+
assert isinstance(self.correct, Tensor)
102+
assert isinstance(self.total, Tensor)
103+
return self.correct / self.total

compose_rl/reward_learning/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from compose_rl.reward_learning.inference_model import InferenceRewardModel
2020
from compose_rl.reward_learning.model import (
21+
ComposerHFClassifierRewardModel,
2122
ComposerHFPairwiseRewardModel,
2223
ComposerMPTPairwiseRewardModel,
2324
)
@@ -32,6 +33,7 @@
3233
'RewardModel',
3334
'ComposerMPTPairwiseRewardModel',
3435
'ComposerHFPairwiseRewardModel',
36+
'ComposerHFClassifierRewardModel',
3537
'InferenceRewardModel',
3638
'BadGenerationEndReward',
3739
'IncreasingNumbersReward',

compose_rl/reward_learning/model.py

Lines changed: 73 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44
"""Reward Model Composer Implementation."""
55

66
import logging
7-
from typing import Any, Mapping, MutableMapping, Optional, Union
7+
from typing import Any, Mapping, MutableMapping, Optional
88

99
import torch
1010
from llmfoundry.models import ComposerMPTCausalLM
1111

1212
from compose_rl.reward_learning.base_reward import RewardModel, Tokenizer
1313
from compose_rl.reward_learning.hf_utils import SequenceClassifierOutput
1414
from compose_rl.reward_learning.model_methods import (
15+
ClassifierRewardEnum,
1516
PairwiseRewardEnum,
17+
classifier_forward,
18+
classifier_loss,
1619
pairwise_forward,
1720
pairwise_loss,
1821
)
@@ -62,24 +65,14 @@ def __init__(
6265
**kwargs,
6366
)
6467

65-
def forward(
66-
self,
67-
batch: MutableMapping,
68-
) -> Union[dict[str, torch.Tensor], torch.Tensor]:
68+
def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]:
6969
is_inference = batch.get('is_inference', False)
7070
if is_inference:
71-
scores = self.model(
71+
return self.model(
7272
input_ids=batch['input_ids'],
7373
attention_mask=batch['attention_mask'],
7474
return_lm_logits=self.return_lm_logits,
7575
).scores
76-
if self.min_threshold is not None and self.max_threshold is not None:
77-
scores: torch.Tensor = torch.clamp(
78-
scores,
79-
min=self.min_threshold,
80-
max=self.max_threshold,
81-
)
82-
return scores
8376
else:
8477
return pairwise_forward(
8578
model=self.model,
@@ -93,7 +86,7 @@ def eval_forward(
9386
self,
9487
batch: MutableMapping,
9588
outputs: Optional[SequenceClassifierOutput] = None,
96-
) -> Union[dict[str, torch.Tensor], torch.Tensor]:
89+
) -> dict[str, torch.Tensor]:
9790
return outputs if outputs is not None else self.forward(batch)
9891

9992
def loss(self, outputs: SequenceClassifierOutput,
@@ -105,6 +98,72 @@ def loss(self, outputs: SequenceClassifierOutput,
10598
)
10699

107100

101+
class ComposerHFClassifierRewardModel(
102+
ComposerHFSequenceClassification,
103+
RewardModel,
104+
):
105+
106+
def __init__(
107+
self,
108+
tokenizer: Tokenizer,
109+
use_train_metrics: bool = True,
110+
additional_train_metrics: Optional[list] = None,
111+
additional_eval_metrics: Optional[list] = None,
112+
loss_type: str = 'bce',
113+
return_lm_logits: bool = False,
114+
return_last: bool = True,
115+
**kwargs: Any,
116+
):
117+
self.loss_type = ClassifierRewardEnum(loss_type)
118+
self.return_lm_logits = return_lm_logits
119+
self.return_last = return_last
120+
121+
config_overrides = {
122+
'return_logits': return_lm_logits,
123+
}
124+
125+
if 'config_overrides' in kwargs:
126+
config_overrides.update(kwargs.pop('config_overrides'))
127+
128+
self.min_threshold = kwargs.pop('min_threshold', None)
129+
self.max_threshold = kwargs.pop('max_threshold', None)
130+
131+
super().__init__(
132+
tokenizer=tokenizer,
133+
use_train_metrics=use_train_metrics,
134+
additional_train_metrics=additional_train_metrics,
135+
additional_eval_metrics=additional_eval_metrics,
136+
config_overrides=config_overrides,
137+
**kwargs,
138+
)
139+
140+
def forward(self, batch: MutableMapping) -> dict[str, torch.Tensor]:
141+
ret_val = classifier_forward(
142+
model=self.model,
143+
tokenizer=self.tokenizer,
144+
batch=batch,
145+
return_last=self.return_last,
146+
return_lm_logits=self.return_lm_logits,
147+
)
148+
149+
return ret_val
150+
151+
def eval_forward(
152+
self,
153+
batch: MutableMapping,
154+
outputs: Optional[SequenceClassifierOutput] = None,
155+
) -> dict[str, torch.Tensor]:
156+
return outputs if outputs is not None else self.forward(batch)
157+
158+
def loss(self, outputs: SequenceClassifierOutput,
159+
batch: Mapping) -> dict[str, torch.Tensor]:
160+
return classifier_loss(
161+
outputs,
162+
batch,
163+
self.loss_type,
164+
)
165+
166+
108167
class ComposerMPTPairwiseRewardModel(ComposerMPTCausalLM, RewardModel):
109168
"""MPT model wrapper for Pairwise/BT reward model."""
110169

compose_rl/reward_learning/model_methods.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class PairwiseRewardEnum(Enum):
3434
BELLMAN_EURUS = 'bellman_eurus'
3535

3636

37+
class ClassifierRewardEnum(Enum):
38+
BCE = 'bce'
39+
40+
3741
def pairwise_forward(
3842
model: nn.Module,
3943
tokenizer: Tokenizer,
@@ -162,6 +166,40 @@ def pairwise_forward(
162166
return outputs
163167

164168

169+
def classifier_forward(
170+
model: nn.Module,
171+
tokenizer: Tokenizer,
172+
batch: MutableMapping,
173+
policy_model_config: Optional[PretrainedConfig] = None,
174+
use_attention_sequence_id: bool = False,
175+
return_last: bool = True,
176+
return_lm_logits: bool = False,
177+
) -> dict[str, torch.Tensor]:
178+
179+
model_output = model(
180+
batch['text'],
181+
attention_mask=batch['text_attention_mask'],
182+
return_lm_logits=return_lm_logits,
183+
)
184+
185+
output_scores = model_output.scores
186+
if return_last:
187+
# Expected Shape: (Batch Size, 1)
188+
output_scores = torch.gather(
189+
output_scores,
190+
dim=1,
191+
index=batch['text_len'].view(-1, 1) - 1,
192+
)
193+
194+
# We need to add the labels here to compute metrics
195+
outputs: dict[str, torch.Tensor] = {
196+
'output_scores': output_scores,
197+
'labels': batch['labels'],
198+
}
199+
200+
return outputs
201+
202+
165203
def pairwise_loss(
166204
outputs: SequenceClassifierOutput,
167205
batch: Mapping,
@@ -219,3 +257,34 @@ def pairwise_loss(
219257
loss_dict['total'] = losses
220258

221259
return loss_dict
260+
261+
262+
def classifier_loss(
263+
outputs: SequenceClassifierOutput,
264+
batch: Mapping,
265+
loss_type: ClassifierRewardEnum,
266+
) -> dict[str, torch.Tensor]:
267+
"""Computes Classifier loss.
268+
269+
Given precomputed values this will compute the specified classifier loss.
270+
271+
Args:
272+
outputs (SequenceClassifierOutput): Outputs from forwarding the model over the batch.
273+
batch (Mapping): Input batch of data.
274+
loss_type (str): Loss type that we should compute (e.g. bce),
275+
"""
276+
output_scores = outputs['output_scores']
277+
278+
if loss_type == ClassifierRewardEnum.BCE:
279+
loss = F.binary_cross_entropy_with_logits(
280+
output_scores,
281+
batch['labels'],
282+
)
283+
else:
284+
raise NotImplementedError(f'Loss type: {loss_type} is not supported.')
285+
286+
loss_dict = {
287+
'total': loss,
288+
}
289+
290+
return loss_dict

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ mpt_dpo_lm = "compose_rl.dpo:ComposerDPOLM"
3838
hf_dpo_lm = "compose_rl.dpo:ComposerHFDPOLM"
3939
mpt_pairwise_rm = "compose_rl.reward_learning:ComposerMPTPairwiseRewardModel"
4040
hf_pairwise_rm = "compose_rl.reward_learning:ComposerHFPairwiseRewardModel"
41+
hf_classifier_rm = "compose_rl.reward_learning:ComposerHFClassifierRewardModel"
4142
mpt_ppo_lm = "compose_rl.ppo:ComposerMosaicPolicy"
4243
hf_ppo_lm = "compose_rl.ppo:ComposerHFPolicyModel"
4344

@@ -52,6 +53,7 @@ ppo = "compose_rl.ppo:PPOCallback"
5253

5354
[project.entry-points."llmfoundry_metrics"]
5455
pairwise_rm_accuracy = "compose_rl.metrics.reward_model_metrics:PairwiseRewardClassificationAccuracy"
56+
classifier_accuracy = "compose_rl.metrics.reward_model_metrics:BinaryRewardClassificationAccuracy"
5557

5658
# iSort
5759
[tool.isort]

scripts/data/unified_tokenize_dataset.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ def __iter__(self) -> Iterator[dict[str, bytes]]:
5656
result = self._process_single_prompt_sample(sample)
5757
if result is not None:
5858
yield result
59+
elif self.dataset_type == 'classifier':
60+
yield self._process_classifier_sample(sample)
5961

6062
def _process_preference_sample(self, sample: Any):
6163
"""Process a preference sample.
@@ -104,6 +106,28 @@ def _process_single_prompt_sample(self, sample: Any):
104106

105107
return {'prompt': np.asarray(encoded_prompt).tobytes()}
106108

109+
def _process_classifier_sample(self, sample: Any):
110+
"""A dummy process a classifier sample.
111+
112+
Args:
113+
sample (Any): a sample from the dataset
114+
"""
115+
messages = [{
116+
'role': 'user',
117+
'content': f'This is a test',
118+
}]
119+
encoded_prompt = self.tokenizer.apply_chat_template(
120+
messages,
121+
tokenize=True,
122+
)
123+
124+
label = np.random.randint(0, 2, size=(1,))
125+
126+
return {
127+
'input': np.asarray(encoded_prompt).tobytes(),
128+
'label': np.asarray(label).tobytes(),
129+
}
130+
107131

108132
def main(
109133
dataset_name: str,
@@ -123,6 +147,10 @@ def main(
123147
'single_prompt': {
124148
'prompt': 'bytes',
125149
},
150+
'classifier': {
151+
'input': 'bytes',
152+
'label': 'bytes',
153+
},
126154
}[dataset_type]
127155

128156
tokenizer = AutoTokenizer.from_pretrained(
@@ -185,7 +213,7 @@ def main(
185213
parser.add_argument(
186214
'--dataset_type',
187215
type=str,
188-
choices=['preference', 'single_prompt'],
216+
choices=['preference', 'single_prompt', 'classifier'],
189217
required=True,
190218
help='Type of dataset to process',
191219
)

0 commit comments

Comments
 (0)