@@ -44,7 +44,7 @@ def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None:
44
44
self .model = model_cls (self .config .model , * args , ** kwargs )
45
45
self .fc = nn .Linear (config .model .hidden_size * 2 ,
46
46
config .model .hidden_size ,
47
- bias = False )
47
+ bias = getattr ( self . config , "bias" , False ) )
48
48
49
49
self .orig_vocab_size = config .vocab_size
50
50
self .truncated_vocab_size = config .truncated_vocab_size
@@ -136,10 +136,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
136
136
if self .config .truncated_vocab_size < self .config .vocab_size :
137
137
self .token_map = nn .Parameter (loaded_weight ,
138
138
requires_grad = False )
139
- elif name .startswith ("fc." ):
139
+ elif name .startswith ("fc.weight " ):
140
140
weight_loader = getattr (self .fc .weight , "weight_loader" ,
141
141
default_weight_loader )
142
142
weight_loader (self .fc .weight , loaded_weight )
143
+ elif name .startswith ("fc.bias" ):
144
+ if self .fc .bias is not None :
145
+ weight_loader = getattr (self .fc .bias , "weight_loader" ,
146
+ default_weight_loader )
147
+ weight_loader (self .fc .bias , loaded_weight )
148
+ else :
149
+ raise ValueError ("Found bias in the loaded weights "
150
+ "but the model config doesn't have bias" )
143
151
elif name .startswith ("model.lm_head." ) or name .startswith (
144
152
"model.model." ):
145
153
model_weights [name .split ("model." , 1 )[- 1 ]] = loaded_weight
0 commit comments