Skip to content

pytorch转onnx模型问题 #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
yhwang-hub opened this issue Aug 4, 2023 · 9 comments
Open

pytorch转onnx模型问题 #8

yhwang-hub opened this issue Aug 4, 2023 · 9 comments

Comments

@yhwang-hub
Copy link

yhwang-hub commented Aug 4, 2023

在我使用您提供的pth文件和export_onnx工具转onnx模型,我也按照您的config将pre_process进行了屏蔽,但是发生如下报错:
load checkpoint from local path: models/new-bevdet-lt-d-ft-nearest.pth
[[1, 2, 128, 128], [1, 1, 128, 128], [1, 3, 128, 128], [1, 2, 128, 128], [1, 2, 128, 128], [1, 10, 128, 128]]
['reg_0', 'height_0', 'dim_0', 'rot_0', 'vel_0', 'heatmap_0']
Traceback (most recent call last):
File "tools/export/export_onnx.py", line 142, in
torch.onnx.export(
File "/opt/conda/lib/python3.8/site-packages/torch/onnx/init.py", line 316, in export
return utils.export(model, args, f, export_params, verbose, training,
File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 107, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names,
File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 724, in _export
_model_to_graph(model, args, verbose, input_names,
File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 493, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 437, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 388, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/wyh/BEVDet/mmdet3d/models/detectors/trt_model.py", line 67, in forward
x = self.img_view_transformer.depth_net(x, mlp_input)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/wyh/BEVDet/mmdet3d/models/necks/view_transformer.py", line 694, in forward
mlp_input = self.bn(mlp_input.reshape(-1, mlp_input.shape[-1]))
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 168, in forward
return F.batch_norm(
File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 2282, in batch_norm
return torch.batch_norm(
RuntimeError: running_mean should contain 24 elements not 27

@LCH1238
Copy link
Owner

LCH1238 commented Aug 6, 2023

请问你使用的是这份代码吗?这个错误看上去是网络的源码或者config不一致导致的。

@yhwang-hub
Copy link
Author

使用的是bevdet最新的代码https://github.com/HuangJunJie2017/BEVDet

@LCH1238
Copy link
Owner

LCH1238 commented Aug 7, 2023

使用的是bevdet最新的代码https://github.com/HuangJunJie2017/BEVDet

我fork的是较早版本的BEVDet官方仓库,源码部分可能有些不同。

@zouzouwei
Copy link

使用的是bevdet最新的代码https://github.com/HuangJunJie2017/BEVDet

我fork的是较早版本的BEVDet官方仓库,源码部分可能有些不同。

大佬,你这份代码的mmdet3d/datasets/pipelines
/loading.py 用的是官方2.1的代码,但是view_transformer.py get_mlp 用的是2.0的代码,我把它都换成2.0的这有影响吗

@zouzouwei
Copy link

使用https://github.com/LCH1238/BEVDet/tree/export中的export_onnx.py,权重使用https://drive.google.com/drive/folders/1jSGT0PhKOmW3fibp6fvlJ7EY6mIBVv6i路径中的bevdet-lt-d-ft-nearest.pth导出onnx,最终的img_stage_lt_d.onnx如下图所示,与作者您提供的img_stage_lt_d.onnx不一致 20230807-212252

修改tools/export/export_onnx.py文件第115-116行

        rot = torch.rand([1, 6, 4, 4], dtype=torch.float, device=f'cuda:{args.gpu_id}')
        tran = torch.rand([1, 6, 4], dtype=torch.float, device=f'cuda:{args.gpu_id}')

@yhwang-hub
Copy link
Author

yhwang-hub commented Aug 10, 2023

使用https://github.com/LCH1238/BEVDet/tree/export中的export_onnx.py,权重使用https://drive.google.com/drive/folders/1jSGT0PhKOmW3fibp6fvlJ7EY6mIBVv6i路径中的bevdet-lt-d-ft-nearest.pth导出onnx,最终的img_stage_lt_d.onnx如下图所示,与作者您提供的img_stage_lt_d.onnx不一致 20230807-212252

修改tools/export/export_onnx.py文件第115-116行

        rot = torch.rand([1, 6, 4, 4], dtype=torch.float, device=f'cuda:{args.gpu_id}')
        tran = torch.rand([1, 6, 4], dtype=torch.float, device=f'cuda:{args.gpu_id}')

只修改这个地方会报错

Traceback (most recent call last):
  File "tools/export/export_onnx.py", line 146, in <module>
    torch.onnx.export(
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/__init__.py", line 316, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 107, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 724, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 493, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 437, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 388, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mmdet3d-1.0.0rc4-py3.8-linux-x86_64.egg/mmdet3d/models/detectors/trt_model.py", line 65, in forward
    mlp_input = self.img_view_transformer.get_mlp_input(
  File "/opt/conda/lib/python3.8/site-packages/mmdet3d-1.0.0rc4-py3.8-linux-x86_64.egg/mmdet3d/models/necks/view_transformer.py", line 617, in get_mlp_input
    sensor2ego = torch.cat([rot, tran.reshape(B, N, 3, 1)],
RuntimeError: shape '[1, 6, 3, 1]' is invalid for input of size 24

@zouzouwei
Copy link

只修改这个地方会报错

Traceback (most recent call last):
  File "tools/export/export_onnx.py", line 146, in <module>
    torch.onnx.export(
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/__init__.py", line 316, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 107, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 724, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 493, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 437, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 388, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mmdet3d-1.0.0rc4-py3.8-linux-x86_64.egg/mmdet3d/models/detectors/trt_model.py", line 65, in forward
    mlp_input = self.img_view_transformer.get_mlp_input(
  File "/opt/conda/lib/python3.8/site-packages/mmdet3d-1.0.0rc4-py3.8-linux-x86_64.egg/mmdet3d/models/necks/view_transformer.py", line 617, in get_mlp_input
    sensor2ego = torch.cat([rot, tran.reshape(B, N, 3, 1)],
RuntimeError: shape '[1, 6, 3, 1]' is invalid for input of size 24

建议把mmdet3d/dataset/pipline/loading PrepareImageInputs函数,跟mmdet3d/models/necks/view_transformer.py 统一成统一版本

@cyn-liu
Copy link

cyn-liu commented Aug 10, 2023

使用https://github.com/LCH1238/BEVDet/tree/export中的export_onnx.py,权重使用https://drive.google.com/drive/folders/1jSGT0PhKOmW3fibp6fvlJ7EY6mIBVv6i路径中的bevdet-lt-d-ft-nearest.pth导出onnx,最终的img_stage_lt_d.onnx如下图所示,与作者您提供的img_stage_lt_d.onnx不一致 20230807-212252

修改tools/export/export_onnx.py文件第115-116行

        rot = torch.rand([1, 6, 4, 4], dtype=torch.float, device=f'cuda:{args.gpu_id}')
        tran = torch.rand([1, 6, 4], dtype=torch.float, device=f'cuda:{args.gpu_id}')

只修改这个地方会报错

Traceback (most recent call last):
  File "tools/export/export_onnx.py", line 146, in <module>
    torch.onnx.export(
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/__init__.py", line 316, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 107, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 724, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 493, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 437, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 388, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/mmdet3d-1.0.0rc4-py3.8-linux-x86_64.egg/mmdet3d/models/detectors/trt_model.py", line 65, in forward
    mlp_input = self.img_view_transformer.get_mlp_input(
  File "/opt/conda/lib/python3.8/site-packages/mmdet3d-1.0.0rc4-py3.8-linux-x86_64.egg/mmdet3d/models/necks/view_transformer.py", line 617, in get_mlp_input
    sensor2ego = torch.cat([rot, tran.reshape(B, N, 3, 1)],
RuntimeError: shape '[1, 6, 3, 1]' is invalid for input of size 24

@yhwang-hub 你解决这个问题了吗?我也遇到了这个问题,转出来的onnx与作者的不一样,和你转出来的结果一样。
我不理解上面的回复,修改rot 和trans的shape,和loading.py,因为我觉得执行export_onnx.py时并不涉及loading.py中的内容?
感觉这是这个代码存在bug??? @LCH1238

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants