Skip to content

Commit 8875f4c

Browse files
authored
Automatic vocabulary size padding (Lightning-AI#223)
1 parent ef49654 commit 8875f4c

File tree

3 files changed

+30
-8
lines changed

3 files changed

+30
-8
lines changed

lit_llama/model.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,29 @@
55
# mypy: ignore-errors
66
import math
77
from dataclasses import dataclass
8+
from typing import Optional
89

910
import torch
1011
import torch.nn as nn
1112
from torch.nn import functional as F
1213
from typing_extensions import Self
1314

15+
from lit_llama.utils import find_multiple
16+
17+
1418
@dataclass
1519
class LLaMAConfig:
1620
block_size: int = 2048
1721
vocab_size: int = 32000
22+
padded_vocab_size: Optional[int] = None
1823
n_layer: int = 32
1924
n_head: int = 32
2025
n_embd: int = 4096
2126

27+
def __post_init__(self):
28+
if self.padded_vocab_size is None:
29+
self.padded_vocab_size = find_multiple(self.vocab_size, 64)
30+
2231
@classmethod
2332
def from_name(cls, name: str) -> Self:
2433
return cls(**llama_configs[name])
@@ -35,14 +44,13 @@ def from_name(cls, name: str) -> Self:
3544
class LLaMA(nn.Module):
3645
def __init__(self, config: LLaMAConfig) -> None:
3746
super().__init__()
38-
assert config.vocab_size is not None
39-
assert config.block_size is not None
47+
assert config.padded_vocab_size is not None
4048
self.config = config
4149

42-
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
50+
self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False)
4351
self.transformer = nn.ModuleDict(
4452
dict(
45-
wte=nn.Embedding(config.vocab_size, config.n_embd),
53+
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
4654
h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
4755
ln_f=RMSNorm(config.n_embd),
4856
)
@@ -103,7 +111,7 @@ def __init__(self, config: LLaMAConfig) -> None:
103111
self.n_head = config.n_head
104112
self.n_embd = config.n_embd
105113
self.block_size = config.block_size
106-
self.rope_cache = None
114+
self.rope_cache: Optional[torch.Tensor] = None
107115

108116
def forward(self, x: torch.Tensor) -> torch.Tensor:
109117
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
@@ -150,9 +158,7 @@ def __init__(self, config: LLaMAConfig) -> None:
150158
super().__init__()
151159
hidden_dim = 4 * config.n_embd
152160
n_hidden = int(2 * hidden_dim / 3)
153-
N = 256
154-
# ensure n_hidden is multiple of N
155-
n_hidden = ((n_hidden - 1) // N) * N + N
161+
n_hidden = find_multiple(n_hidden, 256)
156162

157163
self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
158164
self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
@@ -210,6 +216,7 @@ def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torc
210216
cache = cache.half()
211217
return cache
212218

219+
213220
def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
214221
x = x.transpose(1, 2)
215222

lit_llama/utils.py

+6
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,12 @@ def llama_model_lookup(checkpoint: dict) -> str:
3131
return llama_model_sizes[embedding_size]
3232

3333

34+
def find_multiple(n: int, k: int) -> int:
35+
if n % k == 0:
36+
return n
37+
return n + k - (n % k)
38+
39+
3440
def save_model_checkpoint(fabric, model, file_path):
3541
"""Handles boilerplate logic for retrieving and saving the state_dict.
3642

tests/test_utils.py

+9
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,12 @@ def test_lazy_load_subclass(lit_llama):
4545
actual = sd_lazy[k]
4646
expected = sd[k]
4747
torch.testing.assert_close(actual._load_tensor(), expected)
48+
49+
50+
def test_find_multiple(lit_llama):
51+
from lit_llama.utils import find_multiple
52+
53+
assert find_multiple(17, 5) == 20
54+
assert find_multiple(30, 7) == 35
55+
assert find_multiple(10, 2) == 10
56+
assert find_multiple(5, 10) == 10

0 commit comments

Comments
 (0)