Skip to content

Commit b9cb834

Browse files
robertgshaw2-redhatLeiWang1999
authored andcommitted
[ Misc ] Remove separate bias add (vllm-project#6353)
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent 8732503 commit b9cb834

File tree

1 file changed

+3
-15
lines changed

1 file changed

+3
-15
lines changed

vllm/model_executor/layers/linear.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,7 @@ def apply(self,
109109

110110

111111
class UnquantizedLinearMethod(LinearMethodBase):
112-
"""Linear method without quantization.
113-
114-
Args:
115-
separate_bias_add: If true, add bias separately after matrix
116-
multiplication.
117-
"""
118-
119-
def __init__(self, separate_bias_add: bool = False):
120-
self.separate_bias_add = separate_bias_add
112+
"""Linear method without quantization."""
121113

122114
def create_weights(self, layer: torch.nn.Module,
123115
input_size_per_partition: int,
@@ -136,12 +128,8 @@ def apply(self,
136128
layer: torch.nn.Module,
137129
x: torch.Tensor,
138130
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
139-
weight = layer.weight
140-
if self.separate_bias_add:
141-
if bias is not None:
142-
return F.linear(x, weight) + bias
143-
return F.linear(x, weight)
144-
return F.linear(x, weight, bias)
131+
132+
return F.linear(x, layer.weight, bias)
145133

146134

147135
class LinearBase(torch.nn.Module):

0 commit comments

Comments
 (0)