Skip to content

Commit d8c5025

Browse files
author
M. Fatih Cırıt
committed
fix(tensorrt): update tensorrt code of lidar_centerpoint
Signed-off-by: M. Fatih Cırıt <mfc@leodrive.ai>
1 parent 9376697 commit d8c5025

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

perception/lidar_centerpoint/include/lidar_centerpoint/network/tensorrt_wrapper.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class TensorRTWrapper
6464
const std::string & onnx_path, const std::string & engine_path, const std::string & precision);
6565

6666
unique_ptr<nvinfer1::IExecutionContext> context_ = nullptr;
67+
unique_ptr<nvinfer1::ICudaEngine> engine_ = nullptr;
6768

6869
protected:
6970
virtual bool setProfile(
@@ -86,7 +87,6 @@ class TensorRTWrapper
8687

8788
unique_ptr<nvinfer1::IRuntime> runtime_ = nullptr;
8889
unique_ptr<nvinfer1::IHostMemory> plan_ = nullptr;
89-
unique_ptr<nvinfer1::ICudaEngine> engine_ = nullptr;
9090
};
9191

9292
} // namespace centerpoint

perception/lidar_centerpoint/lib/centerpoint_trt.cpp

+29
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,40 @@ CenterPointTRT::CenterPointTRT(
3838
encoder_trt_ptr_ = std::make_unique<VoxelEncoderTRT>(config_, verbose_);
3939
encoder_trt_ptr_->init(
4040
encoder_param.onnx_path(), encoder_param.engine_path(), encoder_param.trt_precision());
41+
42+
#if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500
43+
encoder_trt_ptr_->context_->setInputShape(
44+
encoder_trt_ptr_->engine_->getIOTensorName(0),
45+
nvinfer1::Dims3(
46+
config_.max_voxel_size_, config_.max_point_in_voxel_size_, config_.encoder_in_feature_size_));
47+
#else
48+
// Deprecated since 8.5
4149
encoder_trt_ptr_->context_->setBindingDimensions(
4250
0,
4351
nvinfer1::Dims3(
4452
config_.max_voxel_size_, config_.max_point_in_voxel_size_, config_.encoder_in_feature_size_));
53+
#endif
4554

4655
// head
4756
std::vector<std::size_t> out_channel_sizes = {
4857
config_.class_size_, config_.head_out_offset_size_, config_.head_out_z_size_,
4958
config_.head_out_dim_size_, config_.head_out_rot_size_, config_.head_out_vel_size_};
5059
head_trt_ptr_ = std::make_unique<HeadTRT>(out_channel_sizes, config_, verbose_);
5160
head_trt_ptr_->init(head_param.onnx_path(), head_param.engine_path(), head_param.trt_precision());
61+
62+
#if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500
63+
head_trt_ptr_->context_->setInputShape(
64+
head_trt_ptr_->engine_->getIOTensorName(0),
65+
nvinfer1::Dims4(
66+
config_.batch_size_, config_.encoder_out_feature_size_, config_.grid_size_y_,
67+
config_.grid_size_x_));
68+
#else
69+
// Deprecated since 8.5
5270
head_trt_ptr_->context_->setBindingDimensions(
5371
0, nvinfer1::Dims4(
5472
config_.batch_size_, config_.encoder_out_feature_size_, config_.grid_size_y_,
5573
config_.grid_size_x_));
74+
#endif
5675

5776
initPtr();
5877

@@ -166,8 +185,13 @@ void CenterPointTRT::inference()
166185
}
167186

168187
// pillar encoder network
188+
#if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500
189+
encoder_trt_ptr_->context_->enqueueV3(stream_);
190+
#else
191+
// Deprecated since 8.5
169192
std::vector<void *> encoder_buffers{encoder_in_features_d_.get(), pillar_features_d_.get()};
170193
encoder_trt_ptr_->context_->enqueueV2(encoder_buffers.data(), stream_, nullptr);
194+
#endif
171195

172196
// scatter
173197
CHECK_CUDA_ERROR(scatterFeatures_launch(
@@ -176,11 +200,16 @@ void CenterPointTRT::inference()
176200
spatial_features_d_.get(), stream_));
177201

178202
// head network
203+
#if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500
204+
encoder_trt_ptr_->context_->enqueueV3(stream_);
205+
#else
206+
// Deprecated since 8.5
179207
std::vector<void *> head_buffers = {spatial_features_d_.get(), head_out_heatmap_d_.get(),
180208
head_out_offset_d_.get(), head_out_z_d_.get(),
181209
head_out_dim_d_.get(), head_out_rot_d_.get(),
182210
head_out_vel_d_.get()};
183211
head_trt_ptr_->context_->enqueueV2(head_buffers.data(), stream_, nullptr);
212+
#endif
184213
}
185214

186215
void CenterPointTRT::postProcess(std::vector<Box3D> & det_boxes3d)

0 commit comments

Comments
 (0)