Skip to content

Commit ea8a389

Browse files
sohamparikhsumitd2
authored andcommitted
[Bugfix] load fc bias from config for eagle (vllm-project#8790)
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
1 parent 01fc989 commit ea8a389

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

vllm/model_executor/models/eagle.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None:
4444
self.model = model_cls(self.config.model, *args, **kwargs)
4545
self.fc = nn.Linear(config.model.hidden_size * 2,
4646
config.model.hidden_size,
47-
bias=False)
47+
bias=getattr(self.config, "bias", False))
4848

4949
self.orig_vocab_size = config.vocab_size
5050
self.truncated_vocab_size = config.truncated_vocab_size
@@ -136,10 +136,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
136136
if self.config.truncated_vocab_size < self.config.vocab_size:
137137
self.token_map = nn.Parameter(loaded_weight,
138138
requires_grad=False)
139-
elif name.startswith("fc."):
139+
elif name.startswith("fc.weight"):
140140
weight_loader = getattr(self.fc.weight, "weight_loader",
141141
default_weight_loader)
142142
weight_loader(self.fc.weight, loaded_weight)
143+
elif name.startswith("fc.bias"):
144+
if self.fc.bias is not None:
145+
weight_loader = getattr(self.fc.bias, "weight_loader",
146+
default_weight_loader)
147+
weight_loader(self.fc.bias, loaded_weight)
148+
else:
149+
raise ValueError("Found bias in the loaded weights "
150+
"but the model config doesn't have bias")
143151
elif name.startswith("model.lm_head.") or name.startswith(
144152
"model.model."):
145153
model_weights[name.split("model.", 1)[-1]] = loaded_weight

0 commit comments

Comments
 (0)