Skip to content

Commit 9eb45eb

Browse files
authored
add a flag to skip bert train (tinygrad#9349)
1 parent 14c88ab commit 9eb45eb

File tree

1 file changed

+41
-39
lines changed

1 file changed

+41
-39
lines changed

examples/mlperf/model_train.py

+41-39
Original file line numberDiff line numberDiff line change
@@ -771,48 +771,50 @@ def repeat_fake(bs):
771771
if MLLOGGER:
772772
MLLOGGER.start(key=mllog_constants.EPOCH_START, value=i*BS, metadata={"epoch_num": i*BS})
773773

774+
# TODO: put copy into jit
774775
while train_data is not None and i < train_steps and not achieved:
775-
Tensor.training = True
776-
BEAM.value = TRAIN_BEAM
777-
st = time.perf_counter()
778-
GlobalCounters.reset()
779-
loss, global_norm = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
780-
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
781-
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS)
776+
if getenv("TRAIN", 1):
777+
Tensor.training = True
778+
BEAM.value = TRAIN_BEAM
779+
st = time.perf_counter()
780+
GlobalCounters.reset()
781+
loss, global_norm = train_step_bert(model, optimizer_group, scheduler_group, loss_scaler,
782+
train_data["input_ids"], train_data["segment_ids"], train_data["input_mask"], train_data["masked_lm_positions"], \
783+
train_data["masked_lm_ids"], train_data["masked_lm_weights"], train_data["next_sentence_labels"], GPUS)
782784

783-
pt = time.perf_counter()
785+
pt = time.perf_counter()
784786

785-
try:
786-
next_data = next(train_it)
787-
except StopIteration:
788-
next_data = None
789-
790-
dt = time.perf_counter()
791-
792-
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
793-
loss = loss.item()
794-
795-
cl = time.perf_counter()
796-
if BENCHMARK: step_times.append(cl - st)
797-
798-
tqdm.write(
799-
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
800-
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, "
801-
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
802-
if WANDB:
803-
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st,
804-
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
805-
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS})
806-
807-
train_data, next_data = next_data, None
808-
i += 1
809-
810-
if i == BENCHMARK:
811-
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
812-
estimated_total_minutes = int(median_step_time * train_steps / 60)
813-
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
814-
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
815-
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
787+
try:
788+
next_data = next(train_it)
789+
except StopIteration:
790+
next_data = None
791+
792+
dt = time.perf_counter()
793+
794+
device_str = loss.device if isinstance(loss.device, str) else f"{loss.device[0]} * {len(loss.device)}"
795+
loss = loss.item()
796+
797+
cl = time.perf_counter()
798+
if BENCHMARK: step_times.append(cl - st)
799+
800+
tqdm.write(
801+
f"{i:5} {((cl - st)) * 1000.0:7.2f} ms run, {(pt - st) * 1000.0:7.2f} ms python, {(dt - pt) * 1000.0:6.2f} ms fetch data, "
802+
f"{(cl - dt) * 1000.0:7.2f} ms {device_str}, {loss:5.2f} loss, {optimizer_wd.lr.numpy()[0]:.6f} LR, "
803+
f"{GlobalCounters.mem_used / 1e9:.2f} GB used, {GlobalCounters.global_ops * 1e-9 / (cl - st):9.2f} GFLOPS")
804+
if WANDB:
805+
wandb.log({"lr": optimizer_wd.lr.numpy(), "train/loss": loss, "train/global_norm": global_norm.item(), "train/step_time": cl - st,
806+
"train/python_time": pt - st, "train/data_time": dt - pt, "train/cl_time": cl - dt,
807+
"train/GFLOPS": GlobalCounters.global_ops * 1e-9 / (cl - st), "epoch": (i+1)*BS})
808+
809+
train_data, next_data = next_data, None
810+
i += 1
811+
812+
if i == BENCHMARK:
813+
median_step_time = sorted(step_times)[(BENCHMARK + 1) // 2] # in seconds
814+
estimated_total_minutes = int(median_step_time * train_steps / 60)
815+
print(f"Estimated training time: {estimated_total_minutes // 60}h{estimated_total_minutes % 60}m")
816+
print(f"epoch global_ops: {train_steps * GlobalCounters.global_ops:_}, "
817+
f"epoch global_mem: {train_steps * GlobalCounters.global_mem:_}")
816818

817819
# ** eval loop **
818820
if i % eval_step_freq == 0 or (BENCHMARK and i == BENCHMARK) or i == train_steps:

0 commit comments

Comments
 (0)