Skip to content

Commit 7562c04

Browse files
committed
Use sigmoid for single-label classification
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
1 parent 66e63e8 commit 7562c04

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

vllm/model_executor/layers/pooler.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,16 @@ def forward(self, pooled_data: Union[list[torch.Tensor], torch.Tensor],
242242

243243
if self.softmax:
244244
if isinstance(pooled_data, list):
245-
pooled_data = [F.softmax(data, dim=-1) for data in pooled_data]
245+
pooled_data = [
246+
F.softmax(data, dim=-1)
247+
if data.shape[-1] >= 2 else F.sigmoid(data)
248+
for data in pooled_data
249+
]
246250
else:
247-
pooled_data = F.softmax(pooled_data, dim=-1)
251+
if pooled_data.shape[-1] >= 2:
252+
pooled_data = F.softmax(pooled_data, dim=-1)
253+
else:
254+
pooled_data = F.sigmoid(pooled_data)
248255

249256
return pooled_data
250257

0 commit comments

Comments
 (0)