Skip to content

Commit 0b89395

Browse files
committed
Add initial support for NemotronForCausalLM.
1 parent a977c11 commit 0b89395

File tree

5 files changed

+605
-6
lines changed

5 files changed

+605
-6
lines changed

convert_hf_to_gguf.py

+129
Original file line numberDiff line numberDiff line change
@@ -3395,6 +3395,135 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
33953395
name = name.removeprefix("transformer.")
33963396
return [(self.map_tensor_name(name), data_torch)]
33973397

3398+
3399+
@Model.register("NemotronForCausalLM")
3400+
class Nemotron4Model(Model):
3401+
model_arch = gguf.MODEL_ARCH.NEMOTRON4
3402+
3403+
def set_vocab(self):
3404+
# to avoid TypeError: Descriptors cannot be created directly
3405+
# exception when importing sentencepiece_model_pb2
3406+
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
3407+
from sentencepiece import SentencePieceProcessor
3408+
from sentencepiece import sentencepiece_model_pb2 as model
3409+
3410+
tokenizer_path = self.dir_model / 'tokenizer.model'
3411+
3412+
if not tokenizer_path.is_file():
3413+
raise FileNotFoundError(f"File not found: {tokenizer_path}")
3414+
3415+
sentencepiece_model = model.ModelProto() # pyright: ignore[reportAttributeAccessIssue]
3416+
sentencepiece_model.ParseFromString(open(tokenizer_path, "rb").read())
3417+
3418+
assert sentencepiece_model.trainer_spec.model_type == 2 # BPE
3419+
3420+
add_prefix = sentencepiece_model.normalizer_spec.add_dummy_prefix
3421+
remove_whitespaces = sentencepiece_model.normalizer_spec.remove_extra_whitespaces
3422+
3423+
tokenizer = SentencePieceProcessor()
3424+
tokenizer.LoadFromFile(str(tokenizer_path))
3425+
3426+
vocab_size = self.hparams.get('vocab_size', tokenizer.vocab_size())
3427+
3428+
tokens: list[bytes] = [f"[PAD{i}]".encode("utf-8") for i in range(vocab_size)]
3429+
scores: list[float] = [-10000.0] * vocab_size
3430+
toktypes: list[int] = [SentencePieceTokenTypes.UNKNOWN] * vocab_size
3431+
3432+
for token_id in range(tokenizer.vocab_size()):
3433+
piece = tokenizer.IdToPiece(token_id)
3434+
text = piece.encode("utf-8")
3435+
score = tokenizer.GetScore(token_id)
3436+
3437+
toktype = SentencePieceTokenTypes.NORMAL
3438+
if tokenizer.IsUnknown(token_id):
3439+
toktype = SentencePieceTokenTypes.UNKNOWN
3440+
elif tokenizer.IsControl(token_id):
3441+
toktype = SentencePieceTokenTypes.CONTROL
3442+
elif tokenizer.IsUnused(token_id):
3443+
toktype = SentencePieceTokenTypes.UNUSED
3444+
elif tokenizer.IsByte(token_id):
3445+
toktype = SentencePieceTokenTypes.BYTE
3446+
3447+
tokens[token_id] = text
3448+
scores[token_id] = score
3449+
toktypes[token_id] = toktype
3450+
3451+
added_tokens_file = self.dir_model / 'added_tokens.json'
3452+
if added_tokens_file.is_file():
3453+
with open(added_tokens_file, "r", encoding="utf-8") as f:
3454+
added_tokens_json = json.load(f)
3455+
for key in added_tokens_json:
3456+
token_id = added_tokens_json[key]
3457+
if (token_id >= vocab_size):
3458+
logger.warning(f'ignore token {token_id}: id is out of range, max={vocab_size - 1}')
3459+
continue
3460+
3461+
tokens[token_id] = key.encode("utf-8")
3462+
scores[token_id] = -1000.0
3463+
toktypes[token_id] = SentencePieceTokenTypes.USER_DEFINED
3464+
3465+
if vocab_size > len(tokens):
3466+
pad_count = vocab_size - len(tokens)
3467+
logger.debug(f"Padding vocab with {pad_count} token(s) - [PAD1] through [PAD{pad_count}]")
3468+
for i in range(1, pad_count + 1):
3469+
tokens.append(bytes(f"[PAD{i}]", encoding="utf-8"))
3470+
scores.append(-1000.0)
3471+
toktypes.append(SentencePieceTokenTypes.UNUSED)
3472+
3473+
self.gguf_writer.add_tokenizer_model("nemotron")
3474+
self.gguf_writer.add_tokenizer_pre("default")
3475+
self.gguf_writer.add_token_list(tokens)
3476+
self.gguf_writer.add_token_scores(scores)
3477+
self.gguf_writer.add_token_types(toktypes)
3478+
self.gguf_writer.add_add_space_prefix(add_prefix)
3479+
self.gguf_writer.add_remove_extra_whitespaces(remove_whitespaces)
3480+
3481+
special_vocab = gguf.SpecialVocab(
3482+
self.dir_model, n_vocab=len(tokens),
3483+
special_token_types = ('bos', 'eos', 'eot')
3484+
)
3485+
special_vocab._set_special_token("eot", 5) # <extra_id_1>
3486+
special_vocab.add_to_gguf(self.gguf_writer)
3487+
3488+
self.gguf_writer.add_add_bos_token(True)
3489+
self.gguf_writer.add_add_eos_token(False)
3490+
3491+
def set_gguf_parameters(self):
3492+
self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
3493+
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
3494+
self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
3495+
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
3496+
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
3497+
self.gguf_writer.add_rope_dimension_count(
3498+
int(self.hparams["partial_rotary_factor"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
3499+
)
3500+
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
3501+
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
3502+
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
3503+
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
3504+
self.gguf_writer.add_file_type(self.ftype)
3505+
3506+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3507+
del bid # unused
3508+
3509+
if name.endswith(".layer_norm.weight") or name == "final_layernorm.weight":
3510+
logger.info(f"Adding 1.0 to {name} tensor data, see NeMo zero_centered_gamma documentation")
3511+
data_torch = data_torch + 1.0
3512+
if name.endswith(".linear_qkv.weight"):
3513+
n_head = self.find_hparam(["num_attention_heads"])
3514+
n_head_kv = self.find_hparam(["num_key_value_heads"])
3515+
head_dim = self.hparams["hidden_size"] // n_head
3516+
3517+
qkv = data_torch.view(n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head)
3518+
q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
3519+
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
3520+
v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
3521+
data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
3522+
3523+
3524+
return [(self.map_tensor_name(name), data_torch)]
3525+
3526+
33983527
###### CONVERSION LOGIC ######
33993528

34003529

gguf-py/gguf/constants.py

+13
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ class MODEL_ARCH(IntEnum):
166166
BITNET = auto()
167167
T5 = auto()
168168
JAIS = auto()
169+
NEMOTRON4 = auto()
169170

170171

171172
class MODEL_TENSOR(IntEnum):
@@ -293,6 +294,7 @@ class MODEL_TENSOR(IntEnum):
293294
MODEL_ARCH.BITNET: "bitnet",
294295
MODEL_ARCH.T5: "t5",
295296
MODEL_ARCH.JAIS: "jais",
297+
MODEL_ARCH.NEMOTRON4: "nemotron4",
296298
}
297299

298300
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -996,6 +998,17 @@ class MODEL_TENSOR(IntEnum):
996998
MODEL_TENSOR.FFN_GATE,
997999
MODEL_TENSOR.FFN_UP,
9981000
],
1001+
MODEL_ARCH.NEMOTRON4: [
1002+
MODEL_TENSOR.TOKEN_EMBD,
1003+
MODEL_TENSOR.OUTPUT_NORM,
1004+
MODEL_TENSOR.OUTPUT,
1005+
MODEL_TENSOR.ATTN_NORM,
1006+
MODEL_TENSOR.ATTN_QKV,
1007+
MODEL_TENSOR.ATTN_OUT,
1008+
MODEL_TENSOR.FFN_NORM,
1009+
MODEL_TENSOR.FFN_DOWN,
1010+
MODEL_TENSOR.FFN_UP,
1011+
],
9991012
# TODO
10001013
}
10011014

gguf-py/gguf/tensor_mapping.py

+7
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class TensorNameMap:
7575
"transformer.rms_norm", # Grok
7676
"encoder.final_layernorm", # chatglm
7777
"transformer.norm", # openelm
78+
"final_layernorm", # nemotron4
7879
),
7980

8081
# Rope frequencies
@@ -107,6 +108,7 @@ class TensorNameMap:
107108
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
108109
"encoder.layers.{bid}.input_layernorm", # chatglm
109110
"transformer.layers.{bid}.attn_norm", # openelm
111+
"model.layers.{bid}.self_attention.linear_qkv.layer_norm" # nemotron4
110112
),
111113

112114
# Attention norm 2
@@ -131,6 +133,7 @@ class TensorNameMap:
131133
"model.layers.{bid}.self_attn.qkv_proj", # phi3
132134
"encoder.layers.{bid}.self_attention.query_key_value", # chatglm
133135
"transformer.layers.{bid}.attn.qkv_proj", # openelm
136+
"model.layers.{bid}.self_attention.linear_qkv" # nemotron4
134137
),
135138

136139
# Attention query
@@ -190,6 +193,7 @@ class TensorNameMap:
190193
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
191194
"encoder.layers.{bid}.self_attention.dense", # chatglm
192195
"transformer.layers.{bid}.attn.out_proj", # openelm
196+
"model.layers.{bid}.self_attention.linear_proj", # nemotron4
193197
),
194198

195199
# Attention output norm
@@ -227,6 +231,7 @@ class TensorNameMap:
227231
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
228232
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
229233
"transformer.layers.{bid}.ffn_norm", # openelm
234+
"model.layers.{bid}.mlp.linear_fc1.layer_norm", # nemotron4
230235
),
231236

232237
# Post feed-forward norm
@@ -277,6 +282,7 @@ class TensorNameMap:
277282
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
278283
"model.layers.{bid}.residual_mlp.w3", # arctic
279284
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
285+
"model.layers.{bid}.mlp.linear_fc1", # nemotron4
280286
),
281287

282288
MODEL_TENSOR.FFN_UP_EXP: (
@@ -347,6 +353,7 @@ class TensorNameMap:
347353
"model.layers.{bid}.residual_mlp.w2", # arctic
348354
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
349355
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
356+
"model.layers.{bid}.mlp.linear_fc2", # nemotron4
350357
),
351358

352359
MODEL_TENSOR.FFN_DOWN_EXP: (

include/llama.h

+1
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ extern "C" {
6868
LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE
6969
LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece
7070
LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram
71+
LLAMA_VOCAB_TYPE_NTN = 5, // Nemotron tokenizer based on SentencePiece BPE
7172
};
7273

7374
// pre-tokenization types

0 commit comments

Comments
 (0)