From 91fc08b86f3625a9da2eb8099cd2794932a0fb26 Mon Sep 17 00:00:00 2001 From: grimoire Date: Wed, 7 Feb 2024 19:40:48 +0800 Subject: [PATCH] add allocator --- .isort.cfg | 2 +- test.sh | 30 --- .../tests => tests}/__init__.py | 0 tests/test_convert.py | 26 +++ torch2trt_dynamic/converters/Conv2d.py | 8 +- torch2trt_dynamic/converters/Linear.py | 2 +- torch2trt_dynamic/converters/__init__.py | 8 +- torch2trt_dynamic/converters/max_pool2d.py | 6 +- torch2trt_dynamic/test.py | 143 -------------- .../tests/torchvision/__init__.py | 0 .../tests/torchvision/classification.py | 177 ------------------ .../tests/torchvision/save_load.py | 24 --- .../tests/torchvision/segmentation.py | 45 ----- torch2trt_dynamic/torch2trt_dynamic.py | 39 ++-- torch2trt_dynamic/torch_allocator.py | 61 ++++++ 15 files changed, 123 insertions(+), 448 deletions(-) delete mode 100755 test.sh rename {torch2trt_dynamic/tests => tests}/__init__.py (100%) create mode 100644 tests/test_convert.py delete mode 100644 torch2trt_dynamic/test.py delete mode 100644 torch2trt_dynamic/tests/torchvision/__init__.py delete mode 100644 torch2trt_dynamic/tests/torchvision/classification.py delete mode 100644 torch2trt_dynamic/tests/torchvision/save_load.py delete mode 100644 torch2trt_dynamic/tests/torchvision/segmentation.py create mode 100644 torch2trt_dynamic/torch_allocator.py diff --git a/.isort.cfg b/.isort.cfg index 63a2064..4c267ec 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -1,2 +1,2 @@ [settings] -known_third_party = distutils,graphviz,numpy,packaging,setuptools,tensorrt,termcolor,torch,torchvision +known_third_party = distutils,graphviz,mmdeploy,numpy,packaging,setuptools,tensorrt,torch,torchvision diff --git a/test.sh b/test.sh deleted file mode 100755 index 3bf4a4d..0000000 --- a/test.sh +++ /dev/null @@ -1,30 +0,0 @@ -#!/bin/bash - -OUTPUT_FILE=$1 - -touch $OUTPUT_FILE - -echo "| Name | Data Type | Input Shapes | torch2trt kwargs | Max Error | Throughput (PyTorch) | Throughput (TensorRT) | Latency (PyTorch) | Latency (TensorRT) |" >> $OUTPUT_FILE -echo "|------|-----------|--------------|------------------|-----------|----------------------|-----------------------|-------------------|--------------------|" >> $OUTPUT_FILE - -python3 -m torch2trt.test -o $OUTPUT_FILE --name alexnet --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name squeezenet1_0 --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name squeezenet1_1 --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet18 --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet34 --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet50 --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet101 --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name resnet152 --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name densenet121 --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name densenet169 --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name densenet201 --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name densenet161 --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg11$ --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg13$ --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg16$ --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg19$ --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg11_bn --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg13_bn --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg16_bn --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name vgg19_bn --include=torch2trt.tests.torchvision.classification -python3 -m torch2trt.test -o $OUTPUT_FILE --name mobilenet_v2 --include=torch2trt.tests.torchvision.classification diff --git a/torch2trt_dynamic/tests/__init__.py b/tests/__init__.py similarity index 100% rename from torch2trt_dynamic/tests/__init__.py rename to tests/__init__.py diff --git a/tests/test_convert.py b/tests/test_convert.py new file mode 100644 index 0000000..be05162 --- /dev/null +++ b/tests/test_convert.py @@ -0,0 +1,26 @@ +import torch +from torch2trt_dynamic import TRTModule, module2trt +from torchvision.models import resnet18 + + +def test_convert(tmp_path): + model = resnet18().cuda().eval() + + trt_model = module2trt( + model, + args=[torch.rand(1, 3, 224, 224).cuda()], + ) + + model_path = tmp_path / 'tmp.pth' + torch.save(trt_model.state_dict(), model_path) + assert model_path.exists() + + trt_model = TRTModule() + trt_model.load_state_dict(torch.load(model_path)) + + x = torch.rand(1, 3, 224, 224).cuda() + with torch.no_grad(): + y = model(x) + y_trt = trt_model(x) + + torch.testing.assert_close(y, y_trt) diff --git a/torch2trt_dynamic/converters/Conv2d.py b/torch2trt_dynamic/converters/Conv2d.py index e598ebb..9440cdb 100644 --- a/torch2trt_dynamic/converters/Conv2d.py +++ b/torch2trt_dynamic/converters/Conv2d.py @@ -34,15 +34,15 @@ def convert_Conv2d(ctx): if module.bias is not None: bias = module.bias.detach().cpu().numpy() - layer = ctx.network.add_convolution( + layer = ctx.network.add_convolution_nd( input=input_trt, num_output_maps=module.out_channels, kernel_shape=kernel_size, kernel=kernel, bias=bias) - layer.stride = stride - layer.padding = padding - layer.dilation = dilation + layer.stride_nd = stride + layer.padding_nd = padding + layer.dilation_nd = dilation if module.groups is not None: layer.num_groups = module.groups diff --git a/torch2trt_dynamic/converters/Linear.py b/torch2trt_dynamic/converters/Linear.py index 2f4fd31..634ab69 100644 --- a/torch2trt_dynamic/converters/Linear.py +++ b/torch2trt_dynamic/converters/Linear.py @@ -21,7 +21,7 @@ def convert_Linear(ctx): if module.bias is not None: bias = module.bias.detach().cpu().numpy() - layer = ctx.network.add_convolution( + layer = ctx.network.add_convolution_nd( input=layer.get_output(0), num_output_maps=module.out_features, kernel_shape=(1, 1), diff --git a/torch2trt_dynamic/converters/__init__.py b/torch2trt_dynamic/converters/__init__.py index b0b1d94..bf6b0a6 100644 --- a/torch2trt_dynamic/converters/__init__.py +++ b/torch2trt_dynamic/converters/__init__.py @@ -7,10 +7,9 @@ # supported converters will override dummy converters +from . import add # noqa: F401 from .activation import (convert_elu, convert_leaky_relu, convert_selu, convert_softplus, convert_softsign) -from .add import (convert_add, test_add_basic, test_add_iadd, - test_add_radd_float, test_add_radd_int, test_add_torchadd) from .addcmul import convert_addcmul, test_addcmul from .arange import convert_arange from .argmax import convert_argmax @@ -123,11 +122,6 @@ 'convert_softsign', 'convert_softplus', ] -# add -__all__ += [ - 'convert_add', 'test_add_basic', 'test_add_iadd', 'test_add_radd_float', - 'test_add_radd_int', 'test_add_torchadd' -] # addcmul __all__ += ['convert_addcmul', 'test_addcmul'] # arange diff --git a/torch2trt_dynamic/converters/max_pool2d.py b/torch2trt_dynamic/converters/max_pool2d.py index b861dca..ee79124 100644 --- a/torch2trt_dynamic/converters/max_pool2d.py +++ b/torch2trt_dynamic/converters/max_pool2d.py @@ -32,11 +32,11 @@ def convert_max_pool2d(ctx): if not isinstance(padding, tuple): padding = (padding, ) * 2 - layer = ctx.network.add_pooling( + layer = ctx.network.add_pooling_nd( input=input_trt, type=trt.PoolingType.MAX, window_size=kernel_size) - layer.stride = stride - layer.padding = padding + layer.stride_nd = stride + layer.padding_nd = padding if ceil_mode: layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP diff --git a/torch2trt_dynamic/test.py b/torch2trt_dynamic/test.py deleted file mode 100644 index 4e2250c..0000000 --- a/torch2trt_dynamic/test.py +++ /dev/null @@ -1,143 +0,0 @@ -import argparse -import re -import runpy -import time - -import torch -from termcolor import colored -from torch2trt_dynamic import torch2trt_dynamic - -from .module_test import MODULE_TESTS - - -def run(self): - # create module - module = self.module_fn() - module = module.to(self.device) - module = module.type(self.dtype) - module = module.eval() - - # create inputs for conversion - inputs_conversion = () - for shape in self.input_shapes: - inputs_conversion += (torch.zeros(shape).to(self.device).type( - self.dtype), ) - - # convert module - module_trt = torch2trt_dynamic(module, inputs_conversion, - **self.torch2trt_kwargs) - - # create inputs for torch/trt.. copy of inputs to handle inplace ops - inputs = () - for shape in self.input_shapes: - inputs += (torch.randn(shape).to(self.device).type(self.dtype), ) - inputs_trt = tuple([tensor.clone() for tensor in inputs]) - - # test output against original - outputs = module(*inputs) - outputs_trt = module_trt(*inputs_trt) - - if not isinstance(outputs, tuple): - outputs = (outputs, ) - - # compute max error - max_error = 0 - for i in range(len(outputs)): - max_error_i = torch.max(torch.abs(outputs[i] - outputs_trt[i])) - if max_error_i > max_error: - max_error = max_error_i - - # benchmark pytorch throughput - torch.cuda.current_stream().synchronize() - t0 = time.time() - for i in range(50): - outputs = module(*inputs) - torch.cuda.current_stream().synchronize() - t1 = time.time() - - fps = 50.0 / (t1 - t0) - - # benchmark tensorrt throughput - torch.cuda.current_stream().synchronize() - t0 = time.time() - for i in range(50): - outputs = module_trt(*inputs) - torch.cuda.current_stream().synchronize() - t1 = time.time() - - fps_trt = 50.0 / (t1 - t0) - - # benchmark pytorch latency - torch.cuda.current_stream().synchronize() - t0 = time.time() - for i in range(50): - outputs = module(*inputs) - torch.cuda.current_stream().synchronize() - t1 = time.time() - - ms = 1000.0 * (t1 - t0) / 50.0 - - # benchmark tensorrt latency - torch.cuda.current_stream().synchronize() - t0 = time.time() - for i in range(50): - outputs = module_trt(*inputs) - torch.cuda.current_stream().synchronize() - t1 = time.time() - - ms_trt = 1000.0 * (t1 - t0) / 50.0 - - return max_error, fps, fps_trt, ms, ms_trt - - -if __name__ == '__main__': - - parser = argparse.ArgumentParser() - parser.add_argument( - '--output', - '-o', - help='Test output file path', - type=str, - default='torch2trt_test.md') - parser.add_argument( - '--name', - help='Regular expression to filter modules to test by name', - type=str, - default='.*') - parser.add_argument( - '--tolerance', - help='Maximum error to print warning for entry', - type=float, - default='-1') - parser.add_argument( - '--include', - help='Addition python file to include defining additional tests', - action='append', - default=[]) - args = parser.parse_args() - - for include in args.include: - runpy.run_module(include) - - for test in MODULE_TESTS: - - # filter by module name - name = test.module_name() - if not re.search(args.name, name): - continue - - # run test - max_error, fps, fps_trt, ms, ms_trt = run(test) - - # write entry - line = '| %s | %s | %s | %s | %.2E | %.3g | %.3g | %.3g | %.3g |' % ( - name, test.dtype.__repr__().split('.')[-1], str(test.input_shapes), - str(test.torch2trt_kwargs), max_error, fps, fps_trt, ms, ms_trt) - - if args.tolerance >= 0 and max_error > args.tolerance: - print(colored(line, 'yellow')) - else: - print(line) - - with open(args.output, 'a+') as f: - f.write(line + '\n') diff --git a/torch2trt_dynamic/tests/torchvision/__init__.py b/torch2trt_dynamic/tests/torchvision/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/torch2trt_dynamic/tests/torchvision/classification.py b/torch2trt_dynamic/tests/torchvision/classification.py deleted file mode 100644 index 4e7f4c9..0000000 --- a/torch2trt_dynamic/tests/torchvision/classification.py +++ /dev/null @@ -1,177 +0,0 @@ -import torch -import torchvision -from torch2trt_dynamic.module_test import add_module_test - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def alexnet(): - return torchvision.models.alexnet(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def squeezenet1_0(): - return torchvision.models.squeezenet1_0(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def squeezenet1_1(): - return torchvision.models.squeezenet1_1(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def resnet18(): - return torchvision.models.resnet18(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def resnet34(): - return torchvision.models.resnet34(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def resnet50(): - return torchvision.models.resnet50(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def resnet101(): - return torchvision.models.resnet101(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def resnet152(): - return torchvision.models.resnet152(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def densenet121(): - return torchvision.models.densenet121(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def densenet169(): - return torchvision.models.densenet169(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def densenet201(): - return torchvision.models.densenet201(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def densenet161(): - return torchvision.models.densenet161(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def vgg11(): - return torchvision.models.vgg11(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def vgg13(): - return torchvision.models.vgg13(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def vgg16(): - return torchvision.models.vgg16(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def vgg19(): - return torchvision.models.vgg19(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def vgg11_bn(): - return torchvision.models.vgg11_bn(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def vgg13_bn(): - return torchvision.models.vgg13_bn(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def vgg16_bn(): - return torchvision.models.vgg16_bn(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def vgg19_bn(): - return torchvision.models.vgg19_bn(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def mobilenet_v2(): - return torchvision.models.mobilenet_v2(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def shufflenet_v2_x0_5(): - return torchvision.models.shufflenet_v2_x0_5(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def shufflenet_v2_x1_0(): - return torchvision.models.shufflenet_v2_x1_0(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def shufflenet_v2_x1_5(): - return torchvision.models.shufflenet_v2_x1_5(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def shufflenet_v2_x2_0(): - return torchvision.models.shufflenet_v2_x2_0(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def mnasnet0_5(): - return torchvision.models.mnasnet0_5(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def mnasnet0_75(): - return torchvision.models.mnasnet0_75(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def mnasnet1_0(): - return torchvision.models.mnasnet1_0(pretrained=False) - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def mnasnet1_3(): - return torchvision.models.mnasnet1_3(pretrained=False) diff --git a/torch2trt_dynamic/tests/torchvision/save_load.py b/torch2trt_dynamic/tests/torchvision/save_load.py deleted file mode 100644 index b6e29a2..0000000 --- a/torch2trt_dynamic/tests/torchvision/save_load.py +++ /dev/null @@ -1,24 +0,0 @@ -import torch -from torch2trt_dynamic import TRTModule, torch2trt_dynamic - -from .segmentation import deeplabv3_resnet50 - -if __name__ == '__main__': - model = deeplabv3_resnet50().cuda().eval().half() - data = torch.randn((1, 3, 224, 224)).cuda().half() - - print('Running torch2trt...') - model_trt = torch2trt_dynamic( - model, [data], fp16_mode=True, max_workspace_size=1 << 25) - - print('Saving model...') - torch.save(model_trt.state_dict(), '.test_model.pth') - - print('Loading model...') - model_trt_2 = TRTModule() - model_trt_2.load_state_dict(torch.load('.test_model.pth')) - - assert (model_trt_2.engine is not None) - - print(torch.max(torch.abs(model_trt_2(data) - model(data)))) - print(torch.max(torch.abs(model_trt_2(data) - model_trt(data)))) diff --git a/torch2trt_dynamic/tests/torchvision/segmentation.py b/torch2trt_dynamic/tests/torchvision/segmentation.py deleted file mode 100644 index d7abe55..0000000 --- a/torch2trt_dynamic/tests/torchvision/segmentation.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch -import torchvision -from torch2trt_dynamic.module_test import add_module_test - - -class ModelWrapper(torch.nn.Module): - - def __init__(self, model): - super(ModelWrapper, self).__init__() - self.model = model - - def forward(self, x): - return self.model(x)['out'] - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def deeplabv3_resnet50(): - bb = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False) - model = ModelWrapper(bb) - return model - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def deeplabv3_resnet101(): - bb = torchvision.models.segmentation.deeplabv3_resnet101(pretrained=False) - model = ModelWrapper(bb) - return model - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def fcn_resnet50(): - bb = torchvision.models.segmentation.fcn_resnet50(pretrained=False) - model = ModelWrapper(bb) - return model - - -@add_module_test( - torch.float16, torch.device('cuda'), [(1, 3, 224, 224)], fp16_mode=True) -def fcn_resnet101(): - bb = torchvision.models.segmentation.fcn_resnet101(pretrained=False) - model = ModelWrapper(bb) - return model diff --git a/torch2trt_dynamic/torch2trt_dynamic.py b/torch2trt_dynamic/torch2trt_dynamic.py index 7d5d626..d277787 100644 --- a/torch2trt_dynamic/torch2trt_dynamic.py +++ b/torch2trt_dynamic/torch2trt_dynamic.py @@ -9,6 +9,7 @@ from .calibration import (DEFAULT_CALIBRATION_ALGORITHM, DatasetCalibrator, SequenceDataset) from .shape_converter import ShapeConverter +from .torch_allocator import TorchAllocator # UTILITY FUNCTIONS @@ -478,12 +479,16 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, engine_bytes = state_dict[prefix + 'engine'] self.meta = state_dict[prefix + 'meta'] - with trt.Logger() as logger, trt.Runtime(logger) as runtime: - self.engine = runtime.deserialize_cuda_engine(engine_bytes) + logger = trt.Logger() + runtime = trt.Runtime(logger) + self.engine = runtime.deserialize_cuda_engine(engine_bytes) self.update_context() def update_context(self): self.context = self.engine.create_execution_context() + self.allocator = TorchAllocator() + if hasattr(self.context, 'temporary_allocator'): + self.context.temporary_allocator = self.allocator @property def input_names(self): @@ -516,7 +521,7 @@ def __check_range(name, shape, min_shape, max_shape): for name, tensor in inputs.items(): shape = tensor.shape - input_shapes = self.engine.get_profile_shape(0, name) + input_shapes = self.engine.get_tensor_profile_shape(name, 0) min_shape, opt_shape, max_shape = input_shapes assert len(shape) == len(opt_shape), ( f'input <{name}> dimension mismatch: ', @@ -536,17 +541,18 @@ def forward(self, *args, **kwargs): def __setup_inputs(inputs: Dict, bindings: Sequence): for input_name, tensor in inputs.items(): idx = self.engine.get_binding_index(input_name) - self.context.set_binding_shape(idx, tuple(tensor.shape)) + self.context.set_input_shape(input_name, tuple(tensor.shape)) bindings[idx] = tensor.data_ptr() def __setup_outputs(bindings: Sequence): outputs = dict() - for i, output_name in enumerate(self.output_names): + for output_name in self.output_names: idx = self.engine.get_binding_index(output_name) dtype = torch_dtype_from_trt( - self.engine.get_binding_dtype(idx)) - shape = tuple(self.context.get_binding_shape(idx)) - device = torch_device_from_trt(self.engine.get_location(idx)) + self.engine.get_tensor_dtype(output_name)) + shape = tuple(self.context.get_tensor_shape(output_name)) + device = torch_device_from_trt( + self.engine.get_tensor_location(output_name)) output = torch.empty(size=shape, dtype=dtype, device=device) outputs[output_name] = output bindings[idx] = output.data_ptr() @@ -564,11 +570,15 @@ def __get_return_value(outputs: Dict): f'{self.output_type}') inputs = self._bind_inputs(*args, **kwargs) - __setup_inputs(inputs, bindings) - outputs = __setup_outputs(bindings) + device = tuple(inputs.values())[0].device - self.context.execute_async_v2(bindings, - torch.cuda.current_stream().cuda_stream) + with torch.cuda.device(device): + __setup_inputs(inputs, bindings) + outputs = __setup_outputs(bindings) + + self.context.execute_async_v2( + bindings, + torch.cuda.current_stream().cuda_stream) return __get_return_value(outputs) def enable_profiling(self): @@ -693,7 +703,10 @@ def __make_builder_config(builder: trt.Builder, shape_ranges: Dict): builder = trt.Builder(logger) network, module_meta = build_network(builder, func, inputs, config=config) builder_config = __make_builder_config(builder, shape_ranges) - engine = builder.build_engine(network, builder_config) + host_mem = builder.build_serialized_network(network, builder_config) + + runtime = trt.Runtime(logger) + engine = runtime.deserialize_cuda_engine(host_mem) if engine is None: raise RuntimeError('Failed to build TensorRT engine') diff --git a/torch2trt_dynamic/torch_allocator.py b/torch2trt_dynamic/torch_allocator.py new file mode 100644 index 0000000..00602ae --- /dev/null +++ b/torch2trt_dynamic/torch_allocator.py @@ -0,0 +1,61 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import tensorrt as trt +import torch +from mmdeploy.utils import get_root_logger + + +class TorchAllocator(trt.IGpuAllocator): + """PyTorch Cuda Allocator Wrapper.""" + + def __init__(self, device_id: int = None) -> None: + super().__init__() + + self.device_id = device_id + self.mems = set() + self.caching_delete = torch._C._cuda_cudaCachingAllocator_raw_delete + + def __del__(self): + """destructor.""" + mems = self.mems.copy() + (self.deallocate(mem) for mem in mems) + + def allocate(self: trt.IGpuAllocator, size: int, alignment: int, + flags: int) -> int: + """allocate gpu memory. + + Args: + self (trt.IGpuAllocator): gpu allocator + size (int): memory size. + alignment (int): memory alignment. + flags (int): flags. + + Returns: + int: memory address. + """ + torch_stream = torch.cuda.current_stream(self.device_id) + logger = get_root_logger() + logger.debug(f'allocate {size} memory with TorchAllocator.') + assert alignment >= 0 + if alignment > 0: + size = size | (alignment - 1) + 1 + mem = torch.cuda.caching_allocator_alloc( + size, device=self.device_id, stream=torch_stream) + self.mems.add(mem) + return mem + + def deallocate(self: trt.IGpuAllocator, memory: int) -> bool: + """deallocate memory. + + Args: + self (trt.IGpuAllocator): gpu allocator + memory (int): memory address. + + Returns: + bool: deallocate success. + """ + if memory not in self.mems: + return False + + self.caching_delete(memory) + self.mems.discard(memory) + return True