@@ -270,38 +270,47 @@ def __init__(
270
270
) -> None :
271
271
super ().__init__ ()
272
272
self .config = config
273
+ self .cache_config = cache_config
274
+ self .quant_config = quant_config
273
275
self .hidden_size = config .hidden_size
274
- rope_theta = getattr (config , "rope_theta" , 10000 )
275
- rope_scaling = getattr (config , "rope_scaling" , None )
276
- max_position_embeddings = getattr (config , "max_position_embeddings" ,
277
- 8192 )
276
+ self .rope_theta = getattr (config , "rope_theta" , 10000 )
277
+ self .rope_scaling = getattr (config , "rope_scaling" , None )
278
+ self .max_position_embeddings = getattr (config ,
279
+ "max_position_embeddings" , 8192 )
280
+ self ._init_attn_block ()
281
+ self ._init_ffn_block ()
282
+
283
+ def _init_attn_block (self ):
284
+ self .input_layernorm = RMSNorm (self .config .hidden_size ,
285
+ eps = self .config .rms_norm_eps )
278
286
self .self_attn = MiniCPMAttention (
279
287
hidden_size = self .hidden_size ,
280
- num_heads = config .num_attention_heads ,
281
- num_kv_heads = config .num_key_value_heads ,
282
- rope_theta = rope_theta ,
283
- rope_scaling = rope_scaling ,
284
- max_position_embeddings = max_position_embeddings ,
285
- cache_config = cache_config ,
286
- quant_config = quant_config ,
288
+ num_heads = self . config .num_attention_heads ,
289
+ num_kv_heads = self . config .num_key_value_heads ,
290
+ rope_theta = self . rope_theta ,
291
+ rope_scaling = self . rope_scaling ,
292
+ max_position_embeddings = self . max_position_embeddings ,
293
+ cache_config = self . cache_config ,
294
+ quant_config = self . quant_config ,
287
295
)
296
+
297
+ def _init_ffn_block (self ):
298
+ self .post_attention_layernorm = RMSNorm (self .config .hidden_size ,
299
+ eps = self .config .rms_norm_eps )
288
300
self .num_experts = getattr (self .config , "num_experts" , 0 )
289
301
if self .num_experts == 0 :
290
302
self .mlp = MiniCPMMLP (
291
303
hidden_size = self .hidden_size ,
292
- intermediate_size = config .intermediate_size ,
293
- hidden_act = config .hidden_act ,
294
- quant_config = quant_config ,
304
+ intermediate_size = self . config .intermediate_size ,
305
+ hidden_act = self . config .hidden_act ,
306
+ quant_config = self . quant_config ,
295
307
)
296
308
else :
297
- self .mlp = MiniCPMMoE (num_experts = config .num_experts ,
298
- top_k = config .num_experts_per_tok ,
299
- hidden_size = config .hidden_size ,
300
- intermediate_size = config .intermediate_size )
301
- self .input_layernorm = RMSNorm (config .hidden_size ,
302
- eps = config .rms_norm_eps )
303
- self .post_attention_layernorm = RMSNorm (config .hidden_size ,
304
- eps = config .rms_norm_eps )
309
+ self .mlp = MiniCPMMoE (
310
+ num_experts = self .config .num_experts ,
311
+ top_k = self .config .num_experts_per_tok ,
312
+ hidden_size = self .config .hidden_size ,
313
+ intermediate_size = self .config .intermediate_size )
305
314
306
315
def forward (
307
316
self ,
@@ -344,6 +353,8 @@ def __init__(
344
353
) -> None :
345
354
super ().__init__ ()
346
355
self .config = config
356
+ self .cache_config = cache_config
357
+ self .quant_config = quant_config
347
358
self .padding_idx = config .pad_token_id
348
359
lora_vocab = (lora_config .lora_extra_vocab_size *
349
360
(lora_config .max_loras or 1 )) if lora_config else 0
@@ -354,11 +365,15 @@ def __init__(
354
365
config .hidden_size ,
355
366
org_num_embeddings = config .vocab_size ,
356
367
)
368
+ self ._init_layers ()
369
+ self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
370
+
371
+ def _init_layers (self ):
357
372
self .layers = nn .ModuleList ([
358
- MiniCPMDecoderLayer (config , cache_config , quant_config )
359
- for _ in range (config .num_hidden_layers )
373
+ MiniCPMDecoderLayer (self .config , self .cache_config ,
374
+ self .quant_config )
375
+ for _ in range (self .config .num_hidden_layers )
360
376
])
361
- self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
362
377
363
378
def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
364
379
embedding = self .embed_tokens (input_ids )
@@ -431,13 +446,11 @@ def __init__(
431
446
432
447
self .config = config
433
448
self .lora_config = lora_config
449
+ self .cache_config = cache_config
450
+ self .quant_config = quant_config
434
451
435
452
self .num_experts = getattr (self .config , "num_experts" , 0 )
436
- self .quant_config = quant_config
437
- self .model = MiniCPMModel (config ,
438
- cache_config ,
439
- quant_config ,
440
- lora_config = lora_config )
453
+ self ._init_model ()
441
454
unpadded_vocab_size = config .vocab_size
442
455
if lora_config :
443
456
unpadded_vocab_size += lora_config .lora_extra_vocab_size
@@ -458,6 +471,12 @@ def __init__(
458
471
config .vocab_size )
459
472
self .sampler = Sampler ()
460
473
474
+ def _init_model (self ):
475
+ self .model = MiniCPMModel (config = self .config ,
476
+ cache_config = self .cache_config ,
477
+ quant_config = self .quant_config ,
478
+ lora_config = self .lora_config )
479
+
461
480
def forward (
462
481
self ,
463
482
input_ids : torch .Tensor ,
0 commit comments