@@ -38,8 +38,9 @@ CenterPointTRT::CenterPointTRT(
38
38
encoder_trt_ptr_ = std::make_unique<VoxelEncoderTRT>(config_, verbose_);
39
39
encoder_trt_ptr_->init (
40
40
encoder_param.onnx_path (), encoder_param.engine_path (), encoder_param.trt_precision ());
41
- encoder_trt_ptr_->context_ ->setBindingDimensions (
42
- 0 ,
41
+ std::string name_tensor_encoder_in = encoder_trt_ptr_->engine_ ->getIOTensorName (0 );
42
+ encoder_trt_ptr_->context_ ->setInputShape (
43
+ name_tensor_encoder_in.c_str (),
43
44
nvinfer1::Dims3 (
44
45
config_.max_voxel_size_ , config_.max_point_in_voxel_size_ , config_.encoder_in_feature_size_ ));
45
46
@@ -49,10 +50,11 @@ CenterPointTRT::CenterPointTRT(
49
50
config_.head_out_dim_size_ , config_.head_out_rot_size_ , config_.head_out_vel_size_ };
50
51
head_trt_ptr_ = std::make_unique<HeadTRT>(out_channel_sizes, config_, verbose_);
51
52
head_trt_ptr_->init (head_param.onnx_path (), head_param.engine_path (), head_param.trt_precision ());
52
- head_trt_ptr_->context_ ->setBindingDimensions (
53
- 0 , nvinfer1::Dims4 (
54
- config_.batch_size_ , config_.encoder_out_feature_size_ , config_.grid_size_y_ ,
55
- config_.grid_size_x_ ));
53
+ std::string name_tensor_head_in = head_trt_ptr_->engine_ ->getIOTensorName (0 );
54
+ head_trt_ptr_->context_ ->setInputShape (
55
+ name_tensor_head_in.c_str (), nvinfer1::Dims4 (
56
+ config_.batch_size_ , config_.encoder_out_feature_size_ ,
57
+ config_.grid_size_y_ , config_.grid_size_x_ ));
56
58
57
59
initPtr ();
58
60
@@ -166,8 +168,7 @@ void CenterPointTRT::inference()
166
168
}
167
169
168
170
// pillar encoder network
169
- std::vector<void *> encoder_buffers{encoder_in_features_d_.get (), pillar_features_d_.get ()};
170
- encoder_trt_ptr_->context_ ->enqueueV2 (encoder_buffers.data (), stream_, nullptr );
171
+ encoder_trt_ptr_->context_ ->enqueueV3 (stream_);
171
172
172
173
// scatter
173
174
CHECK_CUDA_ERROR (scatterFeatures_launch (
@@ -176,11 +177,7 @@ void CenterPointTRT::inference()
176
177
spatial_features_d_.get (), stream_));
177
178
178
179
// head network
179
- std::vector<void *> head_buffers = {spatial_features_d_.get (), head_out_heatmap_d_.get (),
180
- head_out_offset_d_.get (), head_out_z_d_.get (),
181
- head_out_dim_d_.get (), head_out_rot_d_.get (),
182
- head_out_vel_d_.get ()};
183
- head_trt_ptr_->context_ ->enqueueV2 (head_buffers.data (), stream_, nullptr );
180
+ head_trt_ptr_->context_ ->enqueueV3 (stream_);
184
181
}
185
182
186
183
void CenterPointTRT::postProcess (std::vector<Box3D> & det_boxes3d)
0 commit comments