5
5
# mypy: ignore-errors
6
6
import math
7
7
from dataclasses import dataclass
8
+ from typing import Optional
8
9
9
10
import torch
10
11
import torch .nn as nn
11
12
from torch .nn import functional as F
12
13
from typing_extensions import Self
13
14
15
+ from lit_llama .utils import find_multiple
16
+
17
+
14
18
@dataclass
15
19
class LLaMAConfig :
16
20
block_size : int = 2048
17
21
vocab_size : int = 32000
22
+ padded_vocab_size : Optional [int ] = None
18
23
n_layer : int = 32
19
24
n_head : int = 32
20
25
n_embd : int = 4096
21
26
27
+ def __post_init__ (self ):
28
+ if self .padded_vocab_size is None :
29
+ self .padded_vocab_size = find_multiple (self .vocab_size , 64 )
30
+
22
31
@classmethod
23
32
def from_name (cls , name : str ) -> Self :
24
33
return cls (** llama_configs [name ])
@@ -35,14 +44,13 @@ def from_name(cls, name: str) -> Self:
35
44
class LLaMA (nn .Module ):
36
45
def __init__ (self , config : LLaMAConfig ) -> None :
37
46
super ().__init__ ()
38
- assert config .vocab_size is not None
39
- assert config .block_size is not None
47
+ assert config .padded_vocab_size is not None
40
48
self .config = config
41
49
42
- self .lm_head = nn .Linear (config .n_embd , config .vocab_size , bias = False )
50
+ self .lm_head = nn .Linear (config .n_embd , config .padded_vocab_size , bias = False )
43
51
self .transformer = nn .ModuleDict (
44
52
dict (
45
- wte = nn .Embedding (config .vocab_size , config .n_embd ),
53
+ wte = nn .Embedding (config .padded_vocab_size , config .n_embd ),
46
54
h = nn .ModuleList ([Block (config ) for _ in range (config .n_layer )]),
47
55
ln_f = RMSNorm (config .n_embd ),
48
56
)
@@ -103,7 +111,7 @@ def __init__(self, config: LLaMAConfig) -> None:
103
111
self .n_head = config .n_head
104
112
self .n_embd = config .n_embd
105
113
self .block_size = config .block_size
106
- self .rope_cache = None
114
+ self .rope_cache : Optional [ torch . Tensor ] = None
107
115
108
116
def forward (self , x : torch .Tensor ) -> torch .Tensor :
109
117
B , T , C = x .size () # batch size, sequence length, embedding dimensionality (n_embd)
@@ -150,9 +158,7 @@ def __init__(self, config: LLaMAConfig) -> None:
150
158
super ().__init__ ()
151
159
hidden_dim = 4 * config .n_embd
152
160
n_hidden = int (2 * hidden_dim / 3 )
153
- N = 256
154
- # ensure n_hidden is multiple of N
155
- n_hidden = ((n_hidden - 1 ) // N ) * N + N
161
+ n_hidden = find_multiple (n_hidden , 256 )
156
162
157
163
self .c_fc1 = nn .Linear (config .n_embd , n_hidden , bias = False )
158
164
self .c_fc2 = nn .Linear (config .n_embd , n_hidden , bias = False )
@@ -210,6 +216,7 @@ def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torc
210
216
cache = cache .half ()
211
217
return cache
212
218
219
+
213
220
def apply_rope (x : torch .Tensor , rope_cache : torch .Tensor ) -> torch .Tensor :
214
221
x = x .transpose (1 , 2 )
215
222
0 commit comments