diff --git a/perception/autoware_tensorrt_common/include/autoware/tensorrt_common/tensorrt_common.hpp b/perception/autoware_tensorrt_common/include/autoware/tensorrt_common/tensorrt_common.hpp index 9355732962e23..d35efbfd0ab4a 100644 --- a/perception/autoware_tensorrt_common/include/autoware/tensorrt_common/tensorrt_common.hpp +++ b/perception/autoware_tensorrt_common/include/autoware/tensorrt_common/tensorrt_common.hpp @@ -46,6 +46,10 @@ using ProfileDimsPtr = std::unique_ptr>; using TensorsVec = std::vector>; using TensorsMap = std::unordered_map>; +constexpr int TRT_MAJOR_IDX = 24; +constexpr int TRT_MINOR_IDX = 25; +constexpr int TRT_PATCH_IDX = 26; + /** * @class TrtCommon * @brief TensorRT common library. @@ -317,6 +321,13 @@ class TrtCommon // NOLINT */ bool buildEngineFromOnnx(); + /** + * @brief Validate TensorRT engine. + * + * @return Whether TensorRT version used for building engine is compatible. + */ + bool validateEngine(); + /** * @brief Load TensorRT engine. * diff --git a/perception/autoware_tensorrt_common/src/tensorrt_common.cpp b/perception/autoware_tensorrt_common/src/tensorrt_common.cpp index ba422277416ab..1cd9c58667992 100644 --- a/perception/autoware_tensorrt_common/src/tensorrt_common.cpp +++ b/perception/autoware_tensorrt_common/src/tensorrt_common.cpp @@ -115,7 +115,15 @@ bool TrtCommon::setup(ProfileDimsPtr profile_dims, NetworkIOPtr network_io) // Load engine file if it exists if (fs::exists(trt_config_->engine_path)) { logger_->log(nvinfer1::ILogger::Severity::kINFO, "Loading engine"); - if (!loadEngine()) { + if (!validateEngine()) { + logger_->log( + nvinfer1::ILogger::Severity::kWARNING, + "Engine validation failed for loaded engine from file. Rebuilding engine"); + // Rebuild engine if version mismatch occurred + if (!build_engine_with_log()) { + return false; + } + } else if (!loadEngine()) { return false; } logger_->log(nvinfer1::ILogger::Severity::kINFO, "Network validation"); @@ -543,6 +551,33 @@ bool TrtCommon::buildEngineFromOnnx() return true; } +bool TrtCommon::validateEngine() +{ +#if (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSOR_PATCH >= 8600 + std::ifstream engine_file(trt_config_->engine_path); + std::stringstream engine_buffer; + engine_buffer << engine_file.rdbuf(); + std::string engine_str = engine_buffer.str(); + + auto const blob = reinterpret_cast(engine_str.data()); + logger_->log( + nvinfer1::ILogger::Severity::kINFO, "Plan was created with TensorRT %d.%d.%d", + static_cast(blob[TRT_MAJOR_IDX]), static_cast(blob[TRT_MINOR_IDX]), + static_cast(blob[TRT_PATCH_IDX])); + auto plan_ver = static_cast(blob[TRT_MAJOR_IDX]) * 1000 + + static_cast(blob[TRT_MINOR_IDX]) * 100 + + static_cast(blob[TRT_PATCH_IDX]); + if (plan_ver != (NV_TENSORRT_MAJOR * 1000) + (NV_TENSORRT_MINOR * 100) + NV_TENSORRT_PATCH) { + logger_->log( + nvinfer1::ILogger::Severity::kWARNING, + "Plan was created with a different version of TensorRT! Current version: %d.%d.%d", + NV_TENSORRT_MAJOR, NV_TENSORRT_MINOR, NV_TENSORRT_PATCH); + return false; + } +#endif + return true; +} + bool TrtCommon::loadEngine() { std::ifstream engine_file(trt_config_->engine_path);