@@ -38,21 +38,40 @@ CenterPointTRT::CenterPointTRT(
38
38
encoder_trt_ptr_ = std::make_unique<VoxelEncoderTRT>(config_, verbose_);
39
39
encoder_trt_ptr_->init (
40
40
encoder_param.onnx_path (), encoder_param.engine_path (), encoder_param.trt_precision ());
41
+
42
+ #if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500
43
+ encoder_trt_ptr_->context_ ->setInputShape (
44
+ encoder_trt_ptr_->engine_ ->getIOTensorName (0 ),
45
+ nvinfer1::Dims3 (
46
+ config_.max_voxel_size_ , config_.max_point_in_voxel_size_ , config_.encoder_in_feature_size_ ));
47
+ #else
48
+ // Deprecated since 8.5
41
49
encoder_trt_ptr_->context_ ->setBindingDimensions (
42
50
0 ,
43
51
nvinfer1::Dims3 (
44
52
config_.max_voxel_size_ , config_.max_point_in_voxel_size_ , config_.encoder_in_feature_size_ ));
53
+ #endif
45
54
46
55
// head
47
56
std::vector<std::size_t > out_channel_sizes = {
48
57
config_.class_size_ , config_.head_out_offset_size_ , config_.head_out_z_size_ ,
49
58
config_.head_out_dim_size_ , config_.head_out_rot_size_ , config_.head_out_vel_size_ };
50
59
head_trt_ptr_ = std::make_unique<HeadTRT>(out_channel_sizes, config_, verbose_);
51
60
head_trt_ptr_->init (head_param.onnx_path (), head_param.engine_path (), head_param.trt_precision ());
61
+
62
+ #if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500
63
+ head_trt_ptr_->context_ ->setInputShape (
64
+ head_trt_ptr_->engine_ ->getIOTensorName (0 ),
65
+ nvinfer1::Dims4 (
66
+ config_.batch_size_ , config_.encoder_out_feature_size_ , config_.grid_size_y_ ,
67
+ config_.grid_size_x_ ));
68
+ #else
69
+ // Deprecated since 8.5
52
70
head_trt_ptr_->context_ ->setBindingDimensions (
53
71
0 , nvinfer1::Dims4 (
54
72
config_.batch_size_ , config_.encoder_out_feature_size_ , config_.grid_size_y_ ,
55
73
config_.grid_size_x_ ));
74
+ #endif
56
75
57
76
initPtr ();
58
77
@@ -166,8 +185,13 @@ void CenterPointTRT::inference()
166
185
}
167
186
168
187
// pillar encoder network
188
+ #if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500
189
+ encoder_trt_ptr_->context_ ->enqueueV3 (stream_);
190
+ #else
191
+ // Deprecated since 8.5
169
192
std::vector<void *> encoder_buffers{encoder_in_features_d_.get (), pillar_features_d_.get ()};
170
193
encoder_trt_ptr_->context_ ->enqueueV2 (encoder_buffers.data (), stream_, nullptr );
194
+ #endif
171
195
172
196
// scatter
173
197
CHECK_CUDA_ERROR (scatterFeatures_launch (
@@ -176,11 +200,16 @@ void CenterPointTRT::inference()
176
200
spatial_features_d_.get (), stream_));
177
201
178
202
// head network
203
+ #if (NV_TENSORRT_MAJOR * 10000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 80500
204
+ encoder_trt_ptr_->context_ ->enqueueV3 (stream_);
205
+ #else
206
+ // Deprecated since 8.5
179
207
std::vector<void *> head_buffers = {spatial_features_d_.get (), head_out_heatmap_d_.get (),
180
208
head_out_offset_d_.get (), head_out_z_d_.get (),
181
209
head_out_dim_d_.get (), head_out_rot_d_.get (),
182
210
head_out_vel_d_.get ()};
183
211
head_trt_ptr_->context_ ->enqueueV2 (head_buffers.data (), stream_, nullptr );
212
+ #endif
184
213
}
185
214
186
215
void CenterPointTRT::postProcess (std::vector<Box3D> & det_boxes3d)
0 commit comments