From 90f3f531c4ff4e278c290a481537d1291592fc26 Mon Sep 17 00:00:00 2001 From: kthui <18255193+kthui@users.noreply.github.com> Date: Wed, 24 Jan 2024 16:55:09 -0800 Subject: [PATCH] Unload all models must wait until all models complete loading or unloading --- .../model_repository_manager.cc | 53 ++++++++++++++++--- .../model_repository_manager.h | 3 ++ 2 files changed, 49 insertions(+), 7 deletions(-) diff --git a/src/model_repository_manager/model_repository_manager.cc b/src/model_repository_manager/model_repository_manager.cc index ae982a1ad..edf2bbb55 100644 --- a/src/model_repository_manager/model_repository_manager.cc +++ b/src/model_repository_manager/model_repository_manager.cc @@ -957,15 +957,44 @@ ModelRepositoryManager::PollModels( Status ModelRepositoryManager::UnloadAllModels() { - Status status; - for (const auto& name_info : infos_) { - Status unload_status = model_life_cycle_->AsyncUnload(name_info.first); - if (!unload_status.IsOk()) { - status = Status( - unload_status.ErrorCode(), - "Failed to gracefully unload models: " + unload_status.Message()); + std::unique_lock lock(mu_); + + // Get a set of all models, and make sure non of them are loading/unloading. + std::set all_models; + bool all_models_locked = false; + while (!all_models_locked) { + // Make a copy of the dependency graph. + std::unordered_map> global_map( + global_map_); + DependencyGraph dependency_graph(dependency_graph_, &global_map); + // Try to lock all models. + all_models = infos_.GetModelIdentifiers(); + std::shared_ptr retry_notify_cv; + auto conflict_model = + dependency_graph.LockNodes(all_models, &retry_notify_cv); + if (conflict_model) { + LOG_VERBOSE(2) << "Unload all models conflict '" << conflict_model->str() + << "'"; + // A model is loading/unloading. Wait for it to complete. + retry_notify_cv->wait(lock); + // There could be changes to other models as well. The dependency graph + // and models has to be reloaded. + continue; + } + all_models_locked = true; + } + + // Unload all models. + for (const auto& model_id : all_models) { + Status status = model_life_cycle_->AsyncUnload(model_id); + if (!status.IsOk()) { + // The server is shutting down. There is nothing to do about the error but + // to move forward. + LOG_ERROR << "Unload all models failed on '" << model_id << "'; " + << status.Message(); } } + return Status::Success; } @@ -2220,6 +2249,16 @@ ModelRepositoryManager::ModelInfoMap::operator=(const ModelInfoMap& rhs) return *this; } +std::set +ModelRepositoryManager::ModelInfoMap::GetModelIdentifiers() +{ + std::set model_ids; + for (const auto& pair : map_) { + model_ids.emplace(pair.first); + } + return model_ids; +} + void ModelRepositoryManager::ModelInfoMap::Writeback( const ModelInfoMap& updated_model_info, diff --git a/src/model_repository_manager/model_repository_manager.h b/src/model_repository_manager/model_repository_manager.h index 60793500f..c9a902024 100644 --- a/src/model_repository_manager/model_repository_manager.h +++ b/src/model_repository_manager/model_repository_manager.h @@ -146,6 +146,9 @@ class ModelRepositoryManager { return map_.find(key); } + // Return all keys on the map as a set. + std::set GetModelIdentifiers(); + // Write updated model info back to this object after model load/unload. void Writeback( const ModelInfoMap& updated_model_info,