Skip to content

Commit d65ca6f

Browse files
ko3n1gsuiyoubi
andauthored
Fix Exporter for baichuan and chatglm (#13095) (#13126)
Signed-off-by: Ao Tang <aot@nvidia.com> Co-authored-by: Ao Tang <aot@nvidia.com>
1 parent c2bf5f0 commit d65ca6f

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

nemo/collections/llm/gpt/model/baichuan.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ class HFBaichuan2Exporter(io.ModelConnector[Baichuan2Model, "AutoModelForCausalL
224224
BaichuanForCausalLM format, including weight mapping and configuration translation.
225225
"""
226226

227-
def init(self, dtype=torch.bfloat16, model_name="baichuan-inc/Baichuan2-7B-Base") -> "AutoModelForCausalLM":
227+
def init(self, dtype=torch.bfloat16, model_name=None) -> "AutoModelForCausalLM":
228228
"""
229229
Initialize a HF BaichuanForCausalLM instance.
230230
@@ -237,6 +237,8 @@ def init(self, dtype=torch.bfloat16, model_name="baichuan-inc/Baichuan2-7B-Base"
237237
from transformers import AutoModelForCausalLM
238238
from transformers.modeling_utils import no_init_weights
239239

240+
if model_name is None:
241+
model_name = "baichuan-inc/Baichuan2-7B-Base"
240242
with no_init_weights(True):
241243
# Since Baichuan2 is not importable from transformers, we can only initialize the HF model
242244
# from a known checkpoint. The model_name will need to be passed in.

nemo/collections/llm/gpt/model/chatglm.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,12 @@ class HFChatGLMExporter(io.ModelConnector[ChatGLMModel, "AutoModelForCausalLM"])
218218
ChatGLMForCausalLM format, including weight mapping and configuration translation.
219219
"""
220220

221-
def init(self, dtype=torch.bfloat16, model_name="THUDM/chatglm3-6b") -> "AutoModelForCausalLM":
221+
def init(self, dtype=torch.bfloat16, model_name=None) -> "AutoModelForCausalLM":
222222
from transformers import AutoModelForCausalLM
223223
from transformers.modeling_utils import no_init_weights
224224

225+
if model_name is None:
226+
model_name = "THUDM/chatglm3-6b"
225227
with no_init_weights(True):
226228
# Since ChatGLM is not importable from transformers, we can only initialize the HF model
227229
# from a known checkpoint. The model_name will need to be passed in.

0 commit comments

Comments
 (0)