Skip to content

Commit

Permalink
Merge pull request #1529 from gs-olive/perf_docs
Browse files Browse the repository at this point in the history
feat: Add functionality to FX benchmarking + Improve documentation
  • Loading branch information
peri044 authored Dec 13, 2022
2 parents 27733ba + 360f6c4 commit 6f73a23
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 20 deletions.
13 changes: 10 additions & 3 deletions tools/perf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,10 @@ There are two sample configuration files added.

| Name | Supported Values | Description |
| ----------------- | ------------------------------------ | ------------------------------------------------------------ |
| backend | all, torch, torch_tensorrt, tensorrt | Supported backends for inference. |
| backend | all, torchscript, fx2trt, torch, torch_tensorrt, tensorrt | Supported backends for inference. "all" implies the last four methods in the list at left, and "torchscript" implies the last three (excludes fx path) |
| input | - | Input binding names. Expected to list shapes of each input bindings |
| model | - | Configure the model filename and name |
| model_torch | - | Name of torch model file and name (used for fx2trt) (optional) |
| filename | - | Model file name to load from disk. |
| name | - | Model name |
| runtime | - | Runtime configurations |
Expand All @@ -83,6 +84,7 @@ backend:
- torch
- torch_tensorrt
- tensorrt
- fx2trt
input:
input0:
- 3
Expand All @@ -92,6 +94,9 @@ input:
model:
filename: model.plan
name: vgg16
model_torch:
filename: model_torch.pt
name: vgg16
runtime:
device: 0
precision:
Expand All @@ -108,8 +113,9 @@ Note:

Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module

* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt, tensorrt or fx2trt
* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt,tensorrt,fx2trt
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `fx2trt`, the input should be a Pytorch module (instead of a torchscript module) and the options for model are (`vgg16` | `resnet50` | `efficientnet_b0`)
* `--model_torch` : Name of the PyTorch model file (optional, only necessary if fx2trt is a chosen backend)
* `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT
* `--batch_size` : Batch size
* `--precision` : Comma separated list of precisions to build TensorRT engine Eg: fp32,fp16
Expand All @@ -122,9 +128,10 @@ Eg:

```
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \
--model_torch ${MODELS_DIR}/vgg16_torch.pt \
--precision fp32,fp16 --inputs="(1, 3, 224, 224)@fp32" \
--batch_size 1 \
--backends torch,torch_tensorrt,tensorrt \
--backends torch,torch_tensorrt,tensorrt,fx2trt \
--report "vgg_perf_bs1.txt"
```

Expand Down
5 changes: 4 additions & 1 deletion tools/perf/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@

# Key models selected for benchmarking with their respective paths
BENCHMARK_MODELS = {
"vgg16": {"model": models.vgg16(pretrained=True), "path": ["script", "pytorch"]},
"vgg16": {
"model": models.vgg16(weights=models.VGG16_Weights.DEFAULT),
"path": ["script", "pytorch"],
},
"resnet50": {
"model": models.resnet50(weights=None),
"path": ["script", "pytorch"],
Expand Down
103 changes: 88 additions & 15 deletions tools/perf/perf_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,20 @@ def run(
print("int8 precision expects calibration cache file for inference")
return False

if (model is None) and (backend != "fx2trt"):
warnings.warn(
f"Requested backend {backend} without specifying a TorchScript Model, "
+ "skipping this backend"
)
continue

if (model_torch is None) and (backend in ("all", "fx2trt")):
warnings.warn(
f"Requested backend {backend} without specifying a PyTorch Model, "
+ "skipping this backend"
)
continue

if backend == "all":
run_torch(model, input_tensors, params, precision, batch_size)
run_torch_tensorrt(
Expand All @@ -311,6 +325,27 @@ def run(
is_trt_engine,
batch_size,
)
run_fx2trt(model_torch, input_tensors, params, precision, batch_size)

elif backend == "torchscript":
run_torch(model, input_tensors, params, precision, batch_size)
run_torch_tensorrt(
model,
input_tensors,
params,
precision,
truncate_long_and_double,
batch_size,
)
run_tensorrt(
model,
input_tensors,
params,
precision,
truncate_long_and_double,
is_trt_engine,
batch_size,
)

elif backend == "torch":
run_torch(model, input_tensors, params, precision, batch_size)
Expand All @@ -326,12 +361,6 @@ def run(
)

elif backend == "fx2trt":
if model_torch is None:
warnings.warn(
"Requested backend fx2trt without specifying a PyTorch Model, "
+ "skipping this backend"
)
continue
run_fx2trt(model_torch, input_tensors, params, precision, batch_size)

elif backend == "tensorrt":
Expand Down Expand Up @@ -371,9 +400,14 @@ def recordStats(backend, timings, precision, batch_size=1, compile_time_ms=None)
results.append(stats)


def load_model(params):
def load_ts_model(params):
model = None
is_trt_engine = False

# No TorchScript Model Specified
if len(params.get("model", "")) == 0:
return None, None, is_trt_engine

# Load torch model traced/scripted
model_file = params.get("model").get("filename")
try:
Expand All @@ -393,6 +427,26 @@ def load_model(params):
return model, model_name, is_trt_engine


def load_torch_model(params):
model = None

# No Torch Model Specified
if len(params.get("model_torch", "")) == 0:
return None, None

# Load torch model
model_file = params.get("model_torch").get("filename")
try:
model_name = params.get("model_torch").get("name")
except:
model_name = model_file

print("Loading Torch model: ", model_file)
model = torch.load(model_file).cuda()

return model, model_name


if __name__ == "__main__":
arg_parser = argparse.ArgumentParser(
description="Run inference on a model with random input values"
Expand All @@ -408,7 +462,9 @@ def load_model(params):
type=str,
help="Comma separated string of backends. Eg: torch,torch_tensorrt,fx2trt,tensorrt",
)
arg_parser.add_argument("--model", type=str, help="Name of torchscript model file")
arg_parser.add_argument(
"--model", type=str, default="", help="Name of torchscript model file"
)
arg_parser.add_argument(
"--model_torch",
type=str,
Expand Down Expand Up @@ -458,7 +514,16 @@ def load_model(params):
parser = ConfigParser(args.config)
# Load YAML params
params = parser.read_config()
model, model_name, is_trt_engine = load_model(params)
model, model_name, is_trt_engine = load_ts_model(params)
model_torch, model_name_torch = load_torch_model(params)

# If neither model type was provided
if (model is None) and (model_torch is None):
raise ValueError(
"No valid models specified. Please provide a torchscript model file or model name "
+ "(among the following options vgg16|resnet50|efficientnet_b0|vit) "
+ "or provide a torch model file"
)

# Default device is set to 0. Configurable using yaml config file.
torch.cuda.set_device(params.get("runtime").get("device", 0))
Expand Down Expand Up @@ -489,7 +554,10 @@ def load_model(params):

if not is_trt_engine and (precision == "fp16" or precision == "half"):
# If model is TensorRT serialized engine then model.half will report failure
model = model.half()
if model is not None:
model = model.half()
if model_torch is not None:
model_torch = model_torch.half()

backends = params.get("backend")
# Run inference
Expand All @@ -502,6 +570,7 @@ def load_model(params):
truncate_long_and_double,
batch_size,
is_trt_engine,
model_torch,
)
else:
params = vars(args)
Expand All @@ -511,23 +580,27 @@ def load_model(params):
model_name_torch = params["model_torch"]
model_torch = None

# Load TorchScript model
# Load TorchScript model, if provided
if os.path.exists(model_name):
print("Loading user provided torchscript model: ", model_name)
model = torch.jit.load(model_name).cuda().eval()
elif model_name in BENCHMARK_MODELS:
print("Loading torchscript model from BENCHMARK_MODELS for: ", model_name)
model = BENCHMARK_MODELS[model_name]["model"].eval().cuda()
else:
raise ValueError(
"Invalid model name. Please provide a torchscript model file or model name (among the following options vgg16|resnet50|efficientnet_b0|vit)"
)

# Load PyTorch Model, if provided
if len(model_name_torch) > 0 and os.path.exists(model_name_torch):
print("Loading user provided torch model: ", model_name_torch)
model_torch = torch.load(model_name_torch).eval().cuda()

# If neither model type was provided
if (model is None) and (model_torch is None):
raise ValueError(
"No valid models specified. Please provide a torchscript model file or model name "
+ "(among the following options vgg16|resnet50|efficientnet_b0|vit) "
+ "or provide a torch model file"
)

backends = parse_backends(params["backends"])
truncate_long_and_double = params["truncate"]
batch_size = params["batch_size"]
Expand Down
5 changes: 4 additions & 1 deletion tools/perf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import timm

BENCHMARK_MODELS = {
"vgg16": {"model": models.vgg16(pretrained=True), "path": ["script", "pytorch"]},
"vgg16": {
"model": models.vgg16(weights=models.VGG16_Weights.DEFAULT),
"path": ["script", "pytorch"],
},
"resnet50": {
"model": models.resnet50(weights=None),
"path": ["script", "pytorch"],
Expand Down

0 comments on commit 6f73a23

Please sign in to comment.