Skip to content

Commit b58e7b1

Browse files
authoredFeb 14, 2025
zero out the weight in bert init run (tinygrad#9076)
`DEFAULT_FLOAT=HALF BENCHMARK=10 BS=66 EVAL_BS=6 GPUS=6 MODEL=bert python3 examples/mlperf/model_train.py` no longer oom. I think the buffer of random init weights caused the oom.
1 parent 82ad0d2 commit b58e7b1

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed
 

‎examples/mlperf/helpers.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def get_mlperf_bert_config():
208208
"vocab_size": 30522
209209
}
210210

211-
def get_mlperf_bert_model(checkpoint_path:Optional[str]=None):
211+
def get_mlperf_bert_model():
212212
from extra.models import bert
213213
from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
214214

@@ -220,8 +220,7 @@ def get_mlperf_bert_model(checkpoint_path:Optional[str]=None):
220220
config = get_mlperf_bert_config()
221221
if getenv("DISABLE_DROPOUT", 0):
222222
config["hidden_dropout_prob"] = config["attention_probs_dropout_prob"] = 0.0
223-
model = BertForPretraining(**config)
224-
return model.load_from_pretrained(checkpoint_path) if checkpoint_path else model
223+
return BertForPretraining(**config)
225224

226225
def get_data_bert(GPUS:list[str], it):
227226
data: dict[str, Tensor] = next(it)

‎examples/mlperf/model_train.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -683,8 +683,14 @@ def train_bert():
683683

684684
# ** init model **
685685

686-
model = get_mlperf_bert_model(init_ckpt if RUNMLPERF else None)
687-
686+
model = get_mlperf_bert_model()
687+
if RUNMLPERF:
688+
model.load_from_pretrained(init_ckpt)
689+
else:
690+
# for init, zero out all weights
691+
for p in get_parameters(model):
692+
p = p.assign(Tensor.zeros_like(p).contiguous()).realize()
693+
688694
parameters = get_parameters(model)
689695
for p in parameters:
690696
p.to_(GPUS)

0 commit comments

Comments
 (0)