-
Notifications
You must be signed in to change notification settings - Fork 47
pytorch转onnx模型问题 #8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
请问你使用的是这份代码吗?这个错误看上去是网络的源码或者config不一致导致的。 |
使用的是bevdet最新的代码https://github.com/HuangJunJie2017/BEVDet |
我fork的是较早版本的BEVDet官方仓库,源码部分可能有些不同。 |
大佬,你这份代码的mmdet3d/datasets/pipelines |
修改tools/export/export_onnx.py文件第115-116行 rot = torch.rand([1, 6, 4, 4], dtype=torch.float, device=f'cuda:{args.gpu_id}')
tran = torch.rand([1, 6, 4], dtype=torch.float, device=f'cuda:{args.gpu_id}') |
只修改这个地方会报错
|
建议把mmdet3d/dataset/pipline/loading PrepareImageInputs函数,跟mmdet3d/models/necks/view_transformer.py 统一成统一版本 |
@yhwang-hub 你解决这个问题了吗?我也遇到了这个问题,转出来的onnx与作者的不一样,和你转出来的结果一样。 |
在我使用您提供的pth文件和export_onnx工具转onnx模型,我也按照您的config将pre_process进行了屏蔽,但是发生如下报错:
load checkpoint from local path: models/new-bevdet-lt-d-ft-nearest.pth
[[1, 2, 128, 128], [1, 1, 128, 128], [1, 3, 128, 128], [1, 2, 128, 128], [1, 2, 128, 128], [1, 10, 128, 128]]
['reg_0', 'height_0', 'dim_0', 'rot_0', 'vel_0', 'heatmap_0']
Traceback (most recent call last):
File "tools/export/export_onnx.py", line 142, in
torch.onnx.export(
File "/opt/conda/lib/python3.8/site-packages/torch/onnx/init.py", line 316, in export
return utils.export(model, args, f, export_params, verbose, training,
File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 107, in export
_export(model, args, f, export_params, verbose, training, input_names, output_names,
File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 724, in _export
_model_to_graph(model, args, verbose, input_names,
File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 493, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 437, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/opt/conda/lib/python3.8/site-packages/torch/onnx/utils.py", line 388, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/opt/conda/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
outs.append(self.inner(*trace_inputs))
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/wyh/BEVDet/mmdet3d/models/detectors/trt_model.py", line 67, in forward
x = self.img_view_transformer.depth_net(x, mlp_input)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/wyh/BEVDet/mmdet3d/models/necks/view_transformer.py", line 694, in forward
mlp_input = self.bn(mlp_input.reshape(-1, mlp_input.shape[-1]))
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
result = self.forward(*input, **kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/batchnorm.py", line 168, in forward
return F.batch_norm(
File "/opt/conda/lib/python3.8/site-packages/torch/nn/functional.py", line 2282, in batch_norm
return torch.batch_norm(
RuntimeError: running_mean should contain 24 elements not 27
The text was updated successfully, but these errors were encountered: