|
14 | 14 | from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
15 | 15 | MergedColumnParallelLinear,
|
16 | 16 | QKVParallelLinear,
|
| 17 | + ReplicatedLinear, |
17 | 18 | RowParallelLinear)
|
18 | 19 | from vllm.model_executor.layers.quantization import QuantizationConfig
|
19 | 20 | from vllm.model_executor.layers.rotary_embedding import get_rope
|
@@ -188,17 +189,14 @@ class NomicRouter(nn.Module):
|
188 | 189 | def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int):
|
189 | 190 | super().__init__()
|
190 | 191 | self.moe_top_k = moe_top_k
|
191 |
| - |
192 |
| - self.layer = nn.Linear(hidden_size, moe_num_experts, bias=False) |
| 192 | + self.layer = ReplicatedLinear(hidden_size, moe_num_experts, bias=False) |
193 | 193 |
|
194 | 194 | def forward(
|
195 | 195 | self, x: torch.Tensor
|
196 | 196 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]:
|
197 |
| - weights = self.layer(x.view(-1, |
198 |
| - x.shape[-1])).softmax(dim=-1, |
199 |
| - dtype=torch.float32) |
| 197 | + weights = self.layer(x.view(-1, x.shape[-1]))[0].softmax( |
| 198 | + dim=-1, dtype=torch.float32) |
200 | 199 | top_weights, top_experts = torch.topk(weights, self.moe_top_k, dim=-1)
|
201 |
| - |
202 | 200 | weights = weights.to(x.dtype)
|
203 | 201 | top_weights = top_weights.to(x.dtype)
|
204 | 202 | return weights, top_weights, top_experts # type: ignore
|
@@ -293,7 +291,6 @@ def __init__(self, config: PretrainedConfig):
|
293 | 291 | def forward(self, x: torch.Tensor):
|
294 | 292 | weights, top_weights, top_experts = self.router(x)
|
295 | 293 | out = self.experts(x, weights, top_weights, top_experts)
|
296 |
| - |
297 | 294 | return out
|
298 | 295 |
|
299 | 296 |
|
@@ -517,5 +514,6 @@ def _build_model(self,
|
517 | 514 | bias=config.qkv_proj_bias,
|
518 | 515 | rotary_kwargs=rotary_kwargs)
|
519 | 516 |
|
520 |
| - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): |
521 |
| - self.model.load_weights(weights) |
| 517 | + def load_weights(self, weights: Iterable[Tuple[str, |
| 518 | + torch.Tensor]]) -> Set[str]: |
| 519 | + return self.model.load_weights(weights) |
0 commit comments