Skip to content

Commit

Permalink
Add support for FX in all backends path
Browse files Browse the repository at this point in the history
- Update documentation
- Add new backend for only-TorchScript benchmarks
  • Loading branch information
gs-olive committed Dec 5, 2022
1 parent f1bf283 commit 360f6c4
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
4 changes: 2 additions & 2 deletions tools/perf/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ There are two sample configuration files added.

| Name | Supported Values | Description |
| ----------------- | ------------------------------------ | ------------------------------------------------------------ |
| backend | all, torch, torch_tensorrt, tensorrt, fx2trt | Supported backends for inference. |
| backend | all, torchscript, fx2trt, torch, torch_tensorrt, tensorrt | Supported backends for inference. "all" implies the last four methods in the list at left, and "torchscript" implies the last three (excludes fx path) |
| input | - | Input binding names. Expected to list shapes of each input bindings |
| model | - | Configure the model filename and name |
| model_torch | - | Name of torch model file and name (used for fx2trt) (optional) |
Expand Down Expand Up @@ -113,7 +113,7 @@ Note:

Here are the list of `CompileSpec` options that can be provided directly to compile the pytorch module

* `--backends` : Comma separated string of backends. Eg: torch, torch_tensorrt, tensorrt or fx2trt
* `--backends` : Comma separated string of backends. Eg: torch,torch_tensorrt,tensorrt,fx2trt
* `--model` : Name of the model file (Can be a torchscript module or a tensorrt engine (ending in `.plan` extension)). If the backend is `fx2trt`, the input should be a Pytorch module (instead of a torchscript module) and the options for model are (`vgg16` | `resnet50` | `efficientnet_b0`)
* `--model_torch` : Name of the PyTorch model file (optional, only necessary if fx2trt is a chosen backend)
* `--inputs` : List of input shapes & dtypes. Eg: (1, 3, 224, 224)@fp32 for Resnet or (1, 128)@int32;(1, 128)@int32 for BERT
Expand Down
34 changes: 28 additions & 6 deletions tools/perf/perf_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,13 @@ def run(
)
continue

if (model_torch is None) and (backend in ("all", "fx2trt")):
warnings.warn(
f"Requested backend {backend} without specifying a PyTorch Model, "
+ "skipping this backend"
)
continue

if backend == "all":
run_torch(model, input_tensors, params, precision, batch_size)
run_torch_tensorrt(
Expand All @@ -318,6 +325,27 @@ def run(
is_trt_engine,
batch_size,
)
run_fx2trt(model_torch, input_tensors, params, precision, batch_size)

elif backend == "torchscript":
run_torch(model, input_tensors, params, precision, batch_size)
run_torch_tensorrt(
model,
input_tensors,
params,
precision,
truncate_long_and_double,
batch_size,
)
run_tensorrt(
model,
input_tensors,
params,
precision,
truncate_long_and_double,
is_trt_engine,
batch_size,
)

elif backend == "torch":
run_torch(model, input_tensors, params, precision, batch_size)
Expand All @@ -333,12 +361,6 @@ def run(
)

elif backend == "fx2trt":
if model_torch is None:
warnings.warn(
"Requested backend fx2trt without specifying a PyTorch Model, "
+ "skipping this backend"
)
continue
run_fx2trt(model_torch, input_tensors, params, precision, batch_size)

elif backend == "tensorrt":
Expand Down

0 comments on commit 360f6c4

Please sign in to comment.