diff --git a/perception/lidar_centerpoint/include/lidar_centerpoint/network/tensorrt_wrapper.hpp b/perception/lidar_centerpoint/include/lidar_centerpoint/network/tensorrt_wrapper.hpp index 2cd05400bc9af..3ef8df7b864ca 100644 --- a/perception/lidar_centerpoint/include/lidar_centerpoint/network/tensorrt_wrapper.hpp +++ b/perception/lidar_centerpoint/include/lidar_centerpoint/network/tensorrt_wrapper.hpp @@ -64,6 +64,7 @@ class TensorRTWrapper const std::string & onnx_path, const std::string & engine_path, const std::string & precision); unique_ptr context_ = nullptr; + unique_ptr engine_ = nullptr; protected: virtual bool setProfile( @@ -86,7 +87,6 @@ class TensorRTWrapper unique_ptr runtime_ = nullptr; unique_ptr plan_ = nullptr; - unique_ptr engine_ = nullptr; }; } // namespace centerpoint diff --git a/perception/lidar_centerpoint/lib/centerpoint_trt.cpp b/perception/lidar_centerpoint/lib/centerpoint_trt.cpp index a489be9e86538..c1b0903f01279 100644 --- a/perception/lidar_centerpoint/lib/centerpoint_trt.cpp +++ b/perception/lidar_centerpoint/lib/centerpoint_trt.cpp @@ -38,10 +38,19 @@ CenterPointTRT::CenterPointTRT( encoder_trt_ptr_ = std::make_unique(config_, verbose_); encoder_trt_ptr_->init( encoder_param.onnx_path(), encoder_param.engine_path(), encoder_param.trt_precision()); + +#if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500 + encoder_trt_ptr_->context_->setInputShape( + encoder_trt_ptr_->engine_->getIOTensorName(0), + nvinfer1::Dims3( + config_.max_voxel_size_, config_.max_point_in_voxel_size_, config_.encoder_in_feature_size_)); +#else + // Deprecated since 8.5 encoder_trt_ptr_->context_->setBindingDimensions( 0, nvinfer1::Dims3( config_.max_voxel_size_, config_.max_point_in_voxel_size_, config_.encoder_in_feature_size_)); +#endif // head std::vector out_channel_sizes = { @@ -49,10 +58,20 @@ CenterPointTRT::CenterPointTRT( config_.head_out_dim_size_, config_.head_out_rot_size_, config_.head_out_vel_size_}; head_trt_ptr_ = std::make_unique(out_channel_sizes, config_, verbose_); head_trt_ptr_->init(head_param.onnx_path(), head_param.engine_path(), head_param.trt_precision()); + +#if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500 + head_trt_ptr_->context_->setInputShape( + head_trt_ptr_->engine_->getIOTensorName(0), + nvinfer1::Dims4( + config_.batch_size_, config_.encoder_out_feature_size_, config_.grid_size_y_, + config_.grid_size_x_)); +#else + // Deprecated since 8.5 head_trt_ptr_->context_->setBindingDimensions( 0, nvinfer1::Dims4( config_.batch_size_, config_.encoder_out_feature_size_, config_.grid_size_y_, config_.grid_size_x_)); +#endif initPtr(); @@ -166,8 +185,13 @@ void CenterPointTRT::inference() } // pillar encoder network +#if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500 + encoder_trt_ptr_->context_->enqueueV3(stream_); +#else + // Deprecated since 8.5 std::vector encoder_buffers{encoder_in_features_d_.get(), pillar_features_d_.get()}; encoder_trt_ptr_->context_->enqueueV2(encoder_buffers.data(), stream_, nullptr); +#endif // scatter CHECK_CUDA_ERROR(scatterFeatures_launch( @@ -176,11 +200,16 @@ void CenterPointTRT::inference() spatial_features_d_.get(), stream_)); // head network +#if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500 + encoder_trt_ptr_->context_->enqueueV3(stream_); +#else + // Deprecated since 8.5 std::vector head_buffers = {spatial_features_d_.get(), head_out_heatmap_d_.get(), head_out_offset_d_.get(), head_out_z_d_.get(), head_out_dim_d_.get(), head_out_rot_d_.get(), head_out_vel_d_.get()}; head_trt_ptr_->context_->enqueueV2(head_buffers.data(), stream_, nullptr); +#endif } void CenterPointTRT::postProcess(std::vector & det_boxes3d)