Skip to content

Commit

Permalink
add allocator
Browse files Browse the repository at this point in the history
  • Loading branch information
grimoire committed Feb 7, 2024
1 parent 4159718 commit 91fc08b
Show file tree
Hide file tree
Showing 15 changed files with 123 additions and 448 deletions.
2 changes: 1 addition & 1 deletion .isort.cfg
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
[settings]
known_third_party = distutils,graphviz,numpy,packaging,setuptools,tensorrt,termcolor,torch,torchvision
known_third_party = distutils,graphviz,mmdeploy,numpy,packaging,setuptools,tensorrt,torch,torchvision
30 changes: 0 additions & 30 deletions test.sh

This file was deleted.

File renamed without changes.
26 changes: 26 additions & 0 deletions tests/test_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
from torch2trt_dynamic import TRTModule, module2trt
from torchvision.models import resnet18


def test_convert(tmp_path):
model = resnet18().cuda().eval()

trt_model = module2trt(
model,
args=[torch.rand(1, 3, 224, 224).cuda()],
)

model_path = tmp_path / 'tmp.pth'
torch.save(trt_model.state_dict(), model_path)
assert model_path.exists()

trt_model = TRTModule()
trt_model.load_state_dict(torch.load(model_path))

x = torch.rand(1, 3, 224, 224).cuda()
with torch.no_grad():
y = model(x)
y_trt = trt_model(x)

torch.testing.assert_close(y, y_trt)
8 changes: 4 additions & 4 deletions torch2trt_dynamic/converters/Conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ def convert_Conv2d(ctx):
if module.bias is not None:
bias = module.bias.detach().cpu().numpy()

layer = ctx.network.add_convolution(
layer = ctx.network.add_convolution_nd(
input=input_trt,
num_output_maps=module.out_channels,
kernel_shape=kernel_size,
kernel=kernel,
bias=bias)
layer.stride = stride
layer.padding = padding
layer.dilation = dilation
layer.stride_nd = stride
layer.padding_nd = padding
layer.dilation_nd = dilation

if module.groups is not None:
layer.num_groups = module.groups
Expand Down
2 changes: 1 addition & 1 deletion torch2trt_dynamic/converters/Linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def convert_Linear(ctx):
if module.bias is not None:
bias = module.bias.detach().cpu().numpy()

layer = ctx.network.add_convolution(
layer = ctx.network.add_convolution_nd(
input=layer.get_output(0),
num_output_maps=module.out_features,
kernel_shape=(1, 1),
Expand Down
8 changes: 1 addition & 7 deletions torch2trt_dynamic/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@

# supported converters will override dummy converters

from . import add # noqa: F401
from .activation import (convert_elu, convert_leaky_relu, convert_selu,
convert_softplus, convert_softsign)
from .add import (convert_add, test_add_basic, test_add_iadd,
test_add_radd_float, test_add_radd_int, test_add_torchadd)
from .addcmul import convert_addcmul, test_addcmul
from .arange import convert_arange
from .argmax import convert_argmax
Expand Down Expand Up @@ -123,11 +122,6 @@
'convert_softsign',
'convert_softplus',
]
# add
__all__ += [
'convert_add', 'test_add_basic', 'test_add_iadd', 'test_add_radd_float',
'test_add_radd_int', 'test_add_torchadd'
]
# addcmul
__all__ += ['convert_addcmul', 'test_addcmul']
# arange
Expand Down
6 changes: 3 additions & 3 deletions torch2trt_dynamic/converters/max_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ def convert_max_pool2d(ctx):
if not isinstance(padding, tuple):
padding = (padding, ) * 2

layer = ctx.network.add_pooling(
layer = ctx.network.add_pooling_nd(
input=input_trt, type=trt.PoolingType.MAX, window_size=kernel_size)

layer.stride = stride
layer.padding = padding
layer.stride_nd = stride
layer.padding_nd = padding

if ceil_mode:
layer.padding_mode = trt.PaddingMode.EXPLICIT_ROUND_UP
Expand Down
143 changes: 0 additions & 143 deletions torch2trt_dynamic/test.py

This file was deleted.

Empty file.
Loading

0 comments on commit 91fc08b

Please sign in to comment.