Skip to content

Commit 7b26d35

Browse files
Add SmolLM2 (#1848)
Co-authored-by: Andrei-Aksionov <aksionau.andrei@gmail.com> Co-authored-by: Andrei-Aksionov <58434077+Andrei-Aksionov@users.noreply.github.com>
1 parent 972dee4 commit 7b26d35

File tree

7 files changed

+157
-3
lines changed

7 files changed

+157
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ Every model is written from scratch to maximize performance and remove layers of
140140
| Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) |
141141
| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) |
142142
| QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
143+
| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) |
143144
| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) |
144145
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
145146
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |

litgpt/config.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2134,10 +2134,10 @@ def norm_class(self) -> Type:
21342134

21352135
configs.extend(qwq)
21362136

2137+
21372138
#############
21382139
# Salamandra
21392140
#############
2140-
21412141
salamandra = [
21422142
# https://huggingface.co/BSC-LT/salamandra-2b-instruct/blob/main/config.json
21432143
dict(
@@ -2189,4 +2189,78 @@ def norm_class(self) -> Type:
21892189
configs.append(copy)
21902190

21912191

2192+
###############
2193+
# SmolLM2
2194+
###############
2195+
smollm2 = [
2196+
# https://huggingface.co/HuggingFaceTB/SmolLM2-135M/blob/main/config.json
2197+
dict(
2198+
name="SmolLM2-135M{}",
2199+
hf_config=dict(org="HuggingFaceTB", name="SmolLM2-135M{}"),
2200+
block_size=8192,
2201+
vocab_size=49152,
2202+
padded_vocab_size=49152,
2203+
n_layer=30,
2204+
n_head=9,
2205+
n_embd=576,
2206+
n_query_groups=3,
2207+
rotary_percentage=1.0,
2208+
parallel_residual=False,
2209+
bias=False,
2210+
norm_class_name="RMSNorm",
2211+
mlp_class_name="LLaMAMLP",
2212+
intermediate_size=1536,
2213+
rope_base=100000,
2214+
norm_eps=1e-5,
2215+
),
2216+
# https://huggingface.co/HuggingFaceTB/SmolLM2-360M/blob/main/config.json
2217+
dict(
2218+
name="SmolLM2-360M{}",
2219+
hf_config=dict(org="HuggingFaceTB", name="SmolLM2-360M{}"),
2220+
block_size=8192,
2221+
vocab_size=49152,
2222+
padded_vocab_size=49152,
2223+
n_layer=32,
2224+
n_head=15,
2225+
n_embd=960,
2226+
n_query_groups=5,
2227+
rotary_percentage=1.0,
2228+
parallel_residual=False,
2229+
bias=False,
2230+
norm_class_name="RMSNorm",
2231+
mlp_class_name="LLaMAMLP",
2232+
intermediate_size=2560,
2233+
rope_base=100000,
2234+
norm_eps=1e-5,
2235+
),
2236+
# https://huggingface.co/HuggingFaceTB/SmolLM2-1.7B/blob/main/config.json
2237+
dict(
2238+
name="SmolLM2-1.7B{}",
2239+
hf_config=dict(org="HuggingFaceTB", name="SmolLM2-1.7B{}"),
2240+
block_size=8192,
2241+
vocab_size=49152,
2242+
padded_vocab_size=49152,
2243+
n_layer=24,
2244+
n_head=32,
2245+
n_embd=2048,
2246+
n_query_groups=32,
2247+
rotary_percentage=1.0,
2248+
parallel_residual=False,
2249+
bias=False,
2250+
norm_class_name="RMSNorm",
2251+
mlp_class_name="LLaMAMLP",
2252+
intermediate_size=8192,
2253+
rope_base=130000,
2254+
norm_eps=1e-5,
2255+
),
2256+
]
2257+
2258+
for c in smollm2:
2259+
for kind in ("", "-Instruct"):
2260+
copy = deepcopy(c)
2261+
copy["name"] = c["name"].format(kind)
2262+
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
2263+
configs.append(copy)
2264+
2265+
21922266
name_to_config = {config["name"]: config for config in configs}

litgpt/prompts.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,12 @@ def apply(self, prompt: str, **kwargs: str) -> str:
300300
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
301301

302302

303+
class SmolLM2(PromptStyle):
304+
def apply(self, prompt: str, **kwargs: str) -> str:
305+
system_message = "You are a helpful AI assistant named SmolLM, trained by Hugging Face"
306+
return f"<|im_start|>system\n{system_message}<|im_end|>\n<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
307+
308+
303309
# Maps prompt style names to PromptStyle classes
304310
prompt_styles: Dict[str, Type[PromptStyle]] = {
305311
# Dataset-specific prompt styles
@@ -326,6 +332,7 @@ def apply(self, prompt: str, **kwargs: str) -> str:
326332
"qwen2.5": Qwen2_5,
327333
"qwen2.5-math": Qwen2_5_Math,
328334
"qwq": QwQ,
335+
"smollm2": SmolLM2,
329336
"salamandra": Salamandra,
330337
}
331338

@@ -371,6 +378,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
371378
return Qwen2_5()
372379
if re.search(r"QwQ-.*", model_name):
373380
return QwQ()
381+
if re.search(r"SmolLM2.*-Instruct", model_name):
382+
return SmolLM2()
374383
if re.search(r"salamandra-.*-instruct", model_name):
375384
return Salamandra()
376385
return Default()

litgpt/scripts/download.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def find_weight_files(repo_id: str, access_token: Optional[str]) -> Tuple[List[s
131131
with gated_repo_catcher(repo_id, access_token):
132132
info = repo_info(repo_id, token=access_token)
133133
filenames = [f.rfilename for f in info.siblings]
134-
bins = list(filter_repo_objects(items=filenames, allow_patterns=["*.bin*"]))
134+
bins = list(filter_repo_objects(items=filenames, allow_patterns=["*model*.bin*"]))
135135
safetensors = list(filter_repo_objects(items=filenames, allow_patterns=["*.safetensors*"]))
136136
return bins, safetensors
137137

litgpt/tokenizer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def token_to_id(self, token: str) -> int:
8787
raise ValueError(f"token {token!r} not found in the collection.")
8888
return id_
8989

90-
def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
90+
def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
9191
if not (tokenizer_config_path := checkpoint_dir / "tokenizer_config.json").is_file():
9292
return False
9393
with open(tokenizer_config_path, encoding="utf-8") as fp:
@@ -96,6 +96,8 @@ def check_if_bos_token_used(self, checkpoint_dir: Path) -> bool:
9696
# `PreTrainedTokenizerFast`
9797
if checkpoint_dir.stem.startswith(("Meta-Llama-3", "Llama-3")):
9898
return True
99+
if checkpoint_dir.stem.startswith("SmolLM2") and checkpoint_dir.name.endswith("Instruct"):
100+
return True
99101
if "add_bos_token" in config:
100102
return config["add_bos_token"]
101103
# if `add_bos_token` isn't in the config file, but LLaMA tokenizer is used - return True.

tests/test_model.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,7 @@ def test_against_original_qwen_2_5(model_name, device, dtype):
852852
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
853853
torch.testing.assert_close(ours_y, theirs_y)
854854

855+
855856
@torch.inference_mode()
856857
@pytest.mark.parametrize("model_name", ("salamandra-2b", "salamandra-7b"))
857858
@pytest.mark.parametrize(
@@ -910,6 +911,66 @@ def test_against_original_salamandra(model_name, device, dtype):
910911
ours_y = ours_model(x)
911912
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
912913
torch.testing.assert_close(ours_y, theirs_y)
914+
915+
916+
@torch.inference_mode()
917+
@pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-360M", "SmolLM2-1.7B"))
918+
@pytest.mark.parametrize(
919+
("device", "dtype"),
920+
[
921+
(torch.device("cpu"), torch.float32),
922+
pytest.param(
923+
torch.device("cuda"),
924+
torch.float16,
925+
marks=[
926+
# the reference does softmax upscaled to fp32 during attention. additionally, the final layernorm input
927+
# is slightly different
928+
pytest.mark.xfail(raises=AssertionError, strict=False),
929+
RunIf(min_cuda_gpus=1),
930+
],
931+
),
932+
],
933+
)
934+
def test_against_original_smollm2(model_name, device, dtype):
935+
torch.set_default_dtype(dtype)
936+
937+
ours_config = Config.from_name(
938+
model_name,
939+
padded_vocab_size=10000,
940+
n_layer=2,
941+
n_head=8,
942+
n_embd=32,
943+
n_query_groups=2,
944+
intermediate_size=86,
945+
)
946+
T = 5
947+
theirs_config = LlamaConfig(
948+
vocab_size=ours_config.padded_vocab_size,
949+
hidden_size=ours_config.n_embd,
950+
num_attention_heads=ours_config.n_head,
951+
num_hidden_layers=ours_config.n_layer,
952+
intermediate_size=ours_config.intermediate_size,
953+
max_position_embeddings=T,
954+
rms_norm_eps=ours_config.norm_eps,
955+
num_key_value_heads=ours_config.n_query_groups,
956+
rope_theta=ours_config.rope_base,
957+
attention_bias=ours_config.bias,
958+
)
959+
assert ours_config.intermediate_size == theirs_config.intermediate_size
960+
961+
theirs_model = LlamaForCausalLM(theirs_config).to(device)
962+
theirs_state_dict = theirs_model.state_dict()
963+
state_dict = {}
964+
copy_weights_hf_llama(ours_config, {}, state_dict, theirs_state_dict)
965+
ours_model = GPT(ours_config).to(device)
966+
ours_model.load_state_dict(state_dict)
967+
968+
# test end to end
969+
x = torch.tensor([[9856, 23, 491, 1536, 304]], dtype=torch.int32, device=device)
970+
assert x.size(1) == T
971+
ours_y = ours_model(x)
972+
theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float
973+
torch.testing.assert_close(ours_y, theirs_y)
913974

914975

915976
@RunIf(dynamo=True)

tutorials/download_model_weights.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights.
4040
| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) |
4141
| QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
4242
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
43+
| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) |
4344
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
4445
| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) |
4546
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
@@ -122,6 +123,12 @@ google/gemma-2b-it
122123
google/gemma-7b
123124
google/gemma-7b-it
124125
h2oai/h2o-danube2-1.8b-chat
126+
HuggingFaceTB/SmolLM2-135M
127+
HuggingFaceTB/SmolLM2-135M-Instruct
128+
HuggingFaceTB/SmolLM2-360M
129+
HuggingFaceTB/SmolLM2-360M-Instruct
130+
HuggingFaceTB/SmolLM2-1.7B
131+
HuggingFaceTB/SmolLM2-1.7B-Instruct
125132
lmsys/longchat-13b-16k
126133
lmsys/longchat-7b-16k
127134
lmsys/vicuna-13b-v1.3

0 commit comments

Comments
 (0)