Skip to content

Commit 71b8890

Browse files
authored
use validation dataloader inside retinanet eval (tinygrad#9747)
1 parent 5f7c796 commit 71b8890

File tree

2 files changed

+38
-32
lines changed

2 files changed

+38
-32
lines changed

examples/mlperf/model_eval.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -81,61 +81,61 @@ def eval_unet3d():
8181

8282
def eval_retinanet():
8383
# RetinaNet with ResNeXt50_32X4D
84+
from examples.mlperf.dataloader import batch_load_retinanet
85+
from extra.datasets.openimages import normalize, download_dataset, BASEDIR
8486
from extra.models.resnet import ResNeXt50_32X4D
8587
from extra.models.retinanet import RetinaNet
86-
mdl = RetinaNet(ResNeXt50_32X4D())
87-
mdl.load_from_pretrained()
88-
89-
input_mean = Tensor([0.485, 0.456, 0.406]).reshape(1, -1, 1, 1)
90-
input_std = Tensor([0.229, 0.224, 0.225]).reshape(1, -1, 1, 1)
91-
def input_fixup(x):
92-
x = x.permute([0,3,1,2]) / 255.0
93-
x -= input_mean
94-
x /= input_std
95-
return x
96-
97-
from extra.datasets.openimages import download_dataset, iterate, BASEDIR
9888
from pycocotools.coco import COCO
9989
from pycocotools.cocoeval import COCOeval
10090
from contextlib import redirect_stdout
91+
tlog("imports")
92+
93+
mdl = RetinaNet(ResNeXt50_32X4D())
94+
mdl.load_from_pretrained()
95+
tlog("loaded models")
96+
10197
coco = COCO(download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), 'validation'))
10298
coco_eval = COCOeval(coco, iouType="bbox")
10399
coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng)
100+
tlog("loaded dataset")
104101

105-
from tinygrad.engine.jit import TinyJit
106-
mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
107-
108-
n, bs = 0, 8
102+
iterator = batch_load_retinanet(coco, True, Path(base_dir), getenv("BS", 8), shuffle=False)
103+
def data_get():
104+
x, img_ids, img_sizes, cookie = next(iterator)
105+
return x.to(Device.DEFAULT).realize(), img_ids, img_sizes, cookie
106+
n = 0
107+
proc = data_get()
108+
tlog("loaded initial data")
109109
st = time.perf_counter()
110-
for x, targets in iterate(coco, base_dir, bs):
111-
dat = Tensor(x.astype(np.float32))
112-
mt = time.perf_counter()
113-
if dat.shape[0] == bs:
114-
outs = mdlrun(dat).numpy()
115-
else:
116-
mdlrun._jit_cache = []
117-
outs = mdl(input_fixup(dat)).numpy()
118-
et = time.perf_counter()
119-
predictions = mdl.postprocess_detections(outs, input_size=dat.shape[1:3], orig_image_sizes=[t["image_size"] for t in targets])
120-
ext = time.perf_counter()
121-
n += len(targets)
122-
print(f"[{n}/{len(coco.imgs)}] == {(mt-st)*1000:.2f} ms loading data, {(et-mt)*1000:.2f} ms to run model, {(ext-et)*1000:.2f} ms for postprocessing")
123-
img_ids = [t["image_id"] for t in targets]
124-
coco_results = [{"image_id": targets[i]["image_id"], "category_id": label, "bbox": box.tolist(), "score": score}
110+
while proc is not None:
111+
GlobalCounters.reset()
112+
proc = (mdl(normalize(proc[0])), proc[1], proc[2], proc[3])
113+
run = time.perf_counter()
114+
# load the next data here
115+
try: next_proc = data_get()
116+
except StopIteration: next_proc = None
117+
nd = time.perf_counter()
118+
predictions, img_ids = mdl.postprocess_detections(proc[0].numpy(), orig_image_sizes=proc[2]), proc[1]
119+
coco_results = [{"image_id": img_ids[i], "category_id": label, "bbox": box.tolist(), "score": score}
125120
for i, prediction in enumerate(predictions) for box, score, label in zip(*prediction.values())]
126121
with redirect_stdout(None):
127122
coco_eval.cocoDt = coco.loadRes(coco_results)
128123
coco_eval.params.imgIds = img_ids
129124
coco_eval.evaluate()
130125
evaluated_imgs.extend(img_ids)
131126
coco_evalimgs.append(np.array(coco_eval.evalImgs).reshape(ncats, narea, len(img_ids)))
132-
st = time.perf_counter()
127+
n += len(proc[0])
128+
et = time.perf_counter()
129+
tlog(f"****** {(run-st)*1000:7.2f} ms to enqueue, {(et-run)*1000:7.2f} ms to realize ({(nd-run)*1000:7.2f} ms fetching). {(len(proc))/(et-st):8.2f} examples/sec. {GlobalCounters.global_ops*1e-12/(et-st):5.2f} TFLOPS")
130+
st = et
131+
proc, next_proc = next_proc, None
133132

134133
coco_eval.params.imgIds = evaluated_imgs
135134
coco_eval._paramsEval.imgIds = evaluated_imgs
136135
coco_eval.evalImgs = list(np.concatenate(coco_evalimgs, -1).flatten())
137136
coco_eval.accumulate()
138137
coco_eval.summarize()
138+
tlog("done")
139139

140140
def eval_rnnt():
141141
# RNN-T

extra/datasets/openimages.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,12 @@ def resize(img:Image, tgt:dict[str, np.ndarray|tuple]|None=None, size:tuple[int,
205205

206206
return img, img_size
207207

208+
def normalize(img:Tensor, device:list[str]|None = None):
209+
mean = Tensor([0.485, 0.456, 0.406], device=device, dtype=dtypes.float32).reshape(1, -1, 1, 1)
210+
std = Tensor([0.229, 0.224, 0.225], device=device, dtype=dtypes.float32).reshape(1, -1, 1, 1)
211+
img = ((img.permute([0, 3, 1, 2]) / 255.0) - mean) / std
212+
return img.cast(dtypes.default_float)
213+
208214
if __name__ == "__main__":
209215
download_dataset(base_dir:=getenv("BASE_DIR", BASEDIR), "train")
210216
download_dataset(base_dir, "validation")

0 commit comments

Comments
 (0)