Skip to content

Commit ebdb9dd

Browse files
author
M. Fatih Cırıt
committed
fix(tensorrt): update tensorrt code of lidar_centerpoint
1 parent d756050 commit ebdb9dd

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

perception/lidar_centerpoint/include/lidar_centerpoint/network/tensorrt_wrapper.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ class TensorRTWrapper
6464
const std::string & onnx_path, const std::string & engine_path, const std::string & precision);
6565

6666
unique_ptr<nvinfer1::IExecutionContext> context_ = nullptr;
67+
unique_ptr<nvinfer1::ICudaEngine> engine_ = nullptr;
6768

6869
protected:
6970
virtual bool setProfile(
@@ -86,7 +87,6 @@ class TensorRTWrapper
8687

8788
unique_ptr<nvinfer1::IRuntime> runtime_ = nullptr;
8889
unique_ptr<nvinfer1::IHostMemory> plan_ = nullptr;
89-
unique_ptr<nvinfer1::ICudaEngine> engine_ = nullptr;
9090
};
9191

9292
} // namespace centerpoint

perception/lidar_centerpoint/lib/centerpoint_trt.cpp

+10-13
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ CenterPointTRT::CenterPointTRT(
3838
encoder_trt_ptr_ = std::make_unique<VoxelEncoderTRT>(config_, verbose_);
3939
encoder_trt_ptr_->init(
4040
encoder_param.onnx_path(), encoder_param.engine_path(), encoder_param.trt_precision());
41-
encoder_trt_ptr_->context_->setBindingDimensions(
42-
0,
41+
std::string name_tensor_encoder_in = encoder_trt_ptr_->engine_->getIOTensorName(0);
42+
encoder_trt_ptr_->context_->setInputShape(
43+
name_tensor_encoder_in.c_str(),
4344
nvinfer1::Dims3(
4445
config_.max_voxel_size_, config_.max_point_in_voxel_size_, config_.encoder_in_feature_size_));
4546

@@ -49,10 +50,11 @@ CenterPointTRT::CenterPointTRT(
4950
config_.head_out_dim_size_, config_.head_out_rot_size_, config_.head_out_vel_size_};
5051
head_trt_ptr_ = std::make_unique<HeadTRT>(out_channel_sizes, config_, verbose_);
5152
head_trt_ptr_->init(head_param.onnx_path(), head_param.engine_path(), head_param.trt_precision());
52-
head_trt_ptr_->context_->setBindingDimensions(
53-
0, nvinfer1::Dims4(
54-
config_.batch_size_, config_.encoder_out_feature_size_, config_.grid_size_y_,
55-
config_.grid_size_x_));
53+
std::string name_tensor_head_in = head_trt_ptr_->engine_->getIOTensorName(0);
54+
head_trt_ptr_->context_->setInputShape(
55+
name_tensor_head_in.c_str(), nvinfer1::Dims4(
56+
config_.batch_size_, config_.encoder_out_feature_size_,
57+
config_.grid_size_y_, config_.grid_size_x_));
5658

5759
initPtr();
5860

@@ -166,8 +168,7 @@ void CenterPointTRT::inference()
166168
}
167169

168170
// pillar encoder network
169-
std::vector<void *> encoder_buffers{encoder_in_features_d_.get(), pillar_features_d_.get()};
170-
encoder_trt_ptr_->context_->enqueueV2(encoder_buffers.data(), stream_, nullptr);
171+
encoder_trt_ptr_->context_->enqueueV3(stream_);
171172

172173
// scatter
173174
CHECK_CUDA_ERROR(scatterFeatures_launch(
@@ -176,11 +177,7 @@ void CenterPointTRT::inference()
176177
spatial_features_d_.get(), stream_));
177178

178179
// head network
179-
std::vector<void *> head_buffers = {spatial_features_d_.get(), head_out_heatmap_d_.get(),
180-
head_out_offset_d_.get(), head_out_z_d_.get(),
181-
head_out_dim_d_.get(), head_out_rot_d_.get(),
182-
head_out_vel_d_.get()};
183-
head_trt_ptr_->context_->enqueueV2(head_buffers.data(), stream_, nullptr);
180+
head_trt_ptr_->context_->enqueueV3(stream_);
184181
}
185182

186183
void CenterPointTRT::postProcess(std::vector<Box3D> & det_boxes3d)

0 commit comments

Comments
 (0)