Skip to content

Commit 5b95dd3

Browse files
committed
fix
1 parent 2cae829 commit 5b95dd3

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

vllm/model_executor/models/nomic.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
1515
MergedColumnParallelLinear,
1616
QKVParallelLinear,
17+
ReplicatedLinear,
1718
RowParallelLinear)
1819
from vllm.model_executor.layers.quantization import QuantizationConfig
1920
from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -188,17 +189,14 @@ class NomicRouter(nn.Module):
188189
def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int):
189190
super().__init__()
190191
self.moe_top_k = moe_top_k
191-
192-
self.layer = nn.Linear(hidden_size, moe_num_experts, bias=False)
192+
self.layer = ReplicatedLinear(hidden_size, moe_num_experts, bias=False)
193193

194194
def forward(
195195
self, x: torch.Tensor
196196
) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
197-
weights = self.layer(x.view(-1,
198-
x.shape[-1])).softmax(dim=-1,
199-
dtype=torch.float32)
197+
weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax(
198+
dim=-1, dtype=torch.float32)
200199
top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
201-
202200
weights = weights.to(x.dtype)
203201
top_weights = top_weights.to(x.dtype)
204202
return weights, top_weights, top_experts # type: ignore
@@ -293,7 +291,6 @@ def __init__(self, config: PretrainedConfig):
293291
def forward(self, x: torch.Tensor):
294292
weights, top_weights, top_experts = self.router(x)
295293
out = self.experts(x, weights, top_weights, top_experts)
296-
297294
return out
298295

299296

@@ -517,5 +514,6 @@ def _build_model(self,
517514
bias=config.qkv_proj_bias,
518515
rotary_kwargs=rotary_kwargs)
519516

520-
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
521-
self.model.load_weights(weights)
517+
def load_weights(self, weights: Iterable[Tuple[str,
518+
torch.Tensor]]) -> Set[str]:
519+
return self.model.load_weights(weights)

0 commit comments

Comments
 (0)