diff --git a/perception/autoware_tensorrt_common/src/tensorrt_common.cpp b/perception/autoware_tensorrt_common/src/tensorrt_common.cpp index ba422277416ab..c3297e9738553 100644 --- a/perception/autoware_tensorrt_common/src/tensorrt_common.cpp +++ b/perception/autoware_tensorrt_common/src/tensorrt_common.cpp @@ -36,6 +36,45 @@ namespace autoware namespace tensorrt_common { +class TrtErrorRecorder : public nvinfer1::IErrorRecorder +{ +public: + struct Error + { + nvinfer1::ErrorCode code; + std::string desc; + }; + const std::vector & getErrors() const noexcept { return errors_; } + +private: + RefCount incRefCount() noexcept override { return ++ref_count_; } + + RefCount decRefCount() noexcept override { return --ref_count_; } + + int32_t getNbErrors() const noexcept override { return static_cast(errors_.size()); } + + nvinfer1::ErrorCode getErrorCode(int32_t errorIdx) const noexcept override + { + return errors_[errorIdx].code; + } + + ErrorDesc getErrorDesc(int32_t errorIdx) const noexcept override + { + return errors_[errorIdx].desc.c_str(); + } + + bool hasOverflowed() const noexcept override { return false; } + bool reportError(nvinfer1::ErrorCode val, ErrorDesc desc) noexcept override + { + errors_.push_back({val, std::string(desc)}); + return false; + } + void clear() noexcept override { errors_.clear(); } + + std::vector errors_; + std::atomic ref_count_{0}; +}; + TrtCommon::TrtCommon( const TrtCommonConfig & trt_config, const std::shared_ptr & profiler, const std::vector & plugin_paths) @@ -543,6 +582,19 @@ bool TrtCommon::buildEngineFromOnnx() return true; } +auto setup_error_recorder(TrtUniquePtr & runtime) +{ + auto errorRecorder = std::make_unique(); + runtime->setErrorRecorder(errorRecorder.get()); + return std::make_pair(std::move(errorRecorder), runtime->getErrorRecorder()); +} + +void restore_default_recorder( + TrtUniquePtr & runtime, nvinfer1::IErrorRecorder * defRecorder) +{ + runtime->setErrorRecorder(defRecorder); +} + bool TrtCommon::loadEngine() { std::ifstream engine_file(trt_config_->engine_path); @@ -550,12 +602,34 @@ bool TrtCommon::loadEngine() engine_buffer << engine_file.rdbuf(); std::string engine_str = engine_buffer.str(); + auto [errorRecorder, defRecorder] = setup_error_recorder(runtime_); + engine_ = TrtUniquePtr(runtime_->deserializeCudaEngine( reinterpret_cast( // NOLINT engine_str.data()), engine_str.size())); + + restore_default_recorder(runtime_, defRecorder); + if (!engine_) { - logger_->log(nvinfer1::ILogger::Severity::kERROR, "Fail to create engine"); + for (const auto & error : errorRecorder->getErrors()) { + auto code = error.code; + auto desc = error.desc; + logger_->log( + nvinfer1::ILogger::Severity::kERROR, "Error code: %d, Description: %s", code, desc.c_str()); + if ( + code == nvinfer1::ErrorCode::kUNSPECIFIED_ERROR && + desc.find("Serialization assertion stdVersionRead == kSERIALIZATION_VERSION failed") != + std::string::npos) { + fs::remove(trt_config_->engine_path); + logger_->log( + nvinfer1::ILogger::Severity::kERROR, + "Engine file %s removed due to version mismatch. This file will be regenerated, next " + "time the node starts.", + trt_config_->engine_path.c_str()); + } + } + logger_->log(nvinfer1::ILogger::Severity::kERROR, "Failed to create engine"); return false; }