Skip to content

Commit

Permalink
Unload all models must wait until all models complete loading or unlo…
Browse files Browse the repository at this point in the history
…ading
  • Loading branch information
kthui committed Jan 26, 2024
1 parent edd64b5 commit 90f3f53
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 7 deletions.
53 changes: 46 additions & 7 deletions src/model_repository_manager/model_repository_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -957,15 +957,44 @@ ModelRepositoryManager::PollModels(
Status
ModelRepositoryManager::UnloadAllModels()
{
Status status;
for (const auto& name_info : infos_) {
Status unload_status = model_life_cycle_->AsyncUnload(name_info.first);
if (!unload_status.IsOk()) {
status = Status(
unload_status.ErrorCode(),
"Failed to gracefully unload models: " + unload_status.Message());
std::unique_lock<std::mutex> lock(mu_);

// Get a set of all models, and make sure non of them are loading/unloading.
std::set<ModelIdentifier> all_models;
bool all_models_locked = false;
while (!all_models_locked) {
// Make a copy of the dependency graph.
std::unordered_map<std::string, std::set<ModelIdentifier>> global_map(
global_map_);
DependencyGraph dependency_graph(dependency_graph_, &global_map);
// Try to lock all models.
all_models = infos_.GetModelIdentifiers();
std::shared_ptr<std::condition_variable> retry_notify_cv;
auto conflict_model =
dependency_graph.LockNodes(all_models, &retry_notify_cv);
if (conflict_model) {
LOG_VERBOSE(2) << "Unload all models conflict '" << conflict_model->str()
<< "'";
// A model is loading/unloading. Wait for it to complete.
retry_notify_cv->wait(lock);
// There could be changes to other models as well. The dependency graph
// and models has to be reloaded.
continue;
}
all_models_locked = true;
}

// Unload all models.
for (const auto& model_id : all_models) {
Status status = model_life_cycle_->AsyncUnload(model_id);
if (!status.IsOk()) {
// The server is shutting down. There is nothing to do about the error but
// to move forward.
LOG_ERROR << "Unload all models failed on '" << model_id << "'; "
<< status.Message();
}
}

return Status::Success;
}

Expand Down Expand Up @@ -2220,6 +2249,16 @@ ModelRepositoryManager::ModelInfoMap::operator=(const ModelInfoMap& rhs)
return *this;
}

std::set<ModelIdentifier>
ModelRepositoryManager::ModelInfoMap::GetModelIdentifiers()
{
std::set<ModelIdentifier> model_ids;
for (const auto& pair : map_) {
model_ids.emplace(pair.first);
}
return model_ids;
}

void
ModelRepositoryManager::ModelInfoMap::Writeback(
const ModelInfoMap& updated_model_info,
Expand Down
3 changes: 3 additions & 0 deletions src/model_repository_manager/model_repository_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class ModelRepositoryManager {
return map_.find(key);
}

// Return all keys on the map as a set.
std::set<ModelIdentifier> GetModelIdentifiers();

// Write updated model info back to this object after model load/unload.
void Writeback(
const ModelInfoMap& updated_model_info,
Expand Down

0 comments on commit 90f3f53

Please sign in to comment.