From ea51c02342386cb431e089e0145447a536c0b3ac Mon Sep 17 00:00:00 2001 From: Yingge He Date: Thu, 18 Jul 2024 16:27:24 -0700 Subject: [PATCH 1/3] Refactor string input checks --- src/libtorch.cc | 60 ++++++++----------------------------------------- 1 file changed, 9 insertions(+), 51 deletions(-) diff --git a/src/libtorch.cc b/src/libtorch.cc index dbea502..3aca4da 100644 --- a/src/libtorch.cc +++ b/src/libtorch.cc @@ -1937,64 +1937,22 @@ SetStringInputTensor( } #endif // TRITON_ENABLE_GPU - // Parse content and assign to 'tensor'. Each string in 'content' - // is a 4-byte length followed by the string itself with no - // null-terminator. - while (content_byte_size >= sizeof(uint32_t)) { - if (element_idx >= request_element_cnt) { - RESPOND_AND_SET_NULL_IF_ERROR( - response, - TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - std::string( - "unexpected number of string elements " + - std::to_string(element_idx + 1) + " for inference input '" + - name + "', expecting " + std::to_string(request_element_cnt)) - .c_str())); - return cuda_copy; - } - - const uint32_t len = *(reinterpret_cast(content)); - content += sizeof(uint32_t); - content_byte_size -= sizeof(uint32_t); - - if (content_byte_size < len) { - RESPOND_AND_SET_NULL_IF_ERROR( - response, - TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INVALID_ARG, - std::string( - "incomplete string data for inference input '" + - std::string(name) + "', expecting string of length " + - std::to_string(len) + " but only " + - std::to_string(content_byte_size) + " bytes available") - .c_str())); - FillStringTensor(input_list, request_element_cnt - element_idx); - return cuda_copy; - } - + auto callback = [](torch::List* input_list, const char* content, + const uint32_t len) { // Set string value input_list->push_back(std::string(content, len)); + }; + auto fn = std::bind( + callback, input_list, std::placeholders::_2, std::placeholders::_3); - content += len; - content_byte_size -= len; - element_idx++; - } - - if ((*response != nullptr) && (element_idx != request_element_cnt)) { - RESPOND_AND_SET_NULL_IF_ERROR( - response, TRITONSERVER_ErrorNew( - TRITONSERVER_ERROR_INTERNAL, - std::string( - "expected " + std::to_string(request_element_cnt) + - " strings for inference input '" + name + "', got " + - std::to_string(element_idx)) - .c_str())); + err = ValidateStringBuffer( + content, content_byte_size, request_element_cnt, name, &element_idx, fn); + if (err != nullptr) { + RESPOND_AND_SET_NULL_IF_ERROR(response, err); if (element_idx < request_element_cnt) { FillStringTensor(input_list, request_element_cnt - element_idx); } } - return cuda_copy; } From e6c7cb95550bb2506c2fbf43903de6f10abe910b Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 19 Jul 2024 15:54:36 -0700 Subject: [PATCH 2/3] Improve readability --- src/libtorch.cc | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/src/libtorch.cc b/src/libtorch.cc index 3aca4da..7ce5c58 100644 --- a/src/libtorch.cc +++ b/src/libtorch.cc @@ -1937,21 +1937,17 @@ SetStringInputTensor( } #endif // TRITON_ENABLE_GPU - auto callback = [](torch::List* input_list, const char* content, - const uint32_t len) { - // Set string value - input_list->push_back(std::string(content, len)); - }; - auto fn = std::bind( - callback, input_list, std::placeholders::_2, std::placeholders::_3); - + std::vector> str_list; err = ValidateStringBuffer( - content, content_byte_size, request_element_cnt, name, &element_idx, fn); + content, content_byte_size, request_element_cnt, name, &str_list); + // Set string values. + for (const auto& [addr, len] : str_list) { + input_list->push_back(std::string(addr, len)); + } + if (err != nullptr) { RESPOND_AND_SET_NULL_IF_ERROR(response, err); - if (element_idx < request_element_cnt) { - FillStringTensor(input_list, request_element_cnt - element_idx); - } + FillStringTensor(input_list, request_element_cnt - element_idx); } return cuda_copy; } From 182f6ad52e9fcab5deedc962bc5bf8d0118bafd2 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 26 Jul 2024 11:02:23 -0700 Subject: [PATCH 3/3] Minor fix --- src/libtorch.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/libtorch.cc b/src/libtorch.cc index 7ce5c58..c4e964c 100644 --- a/src/libtorch.cc +++ b/src/libtorch.cc @@ -1911,7 +1911,6 @@ SetStringInputTensor( cudaStream_t stream, const char* host_policy_name) { bool cuda_copy = false; - size_t element_idx = 0; // For string data type, we always need to have the data on CPU so // that we can read string length and construct the string @@ -1926,7 +1925,7 @@ SetStringInputTensor( stream, &cuda_copy); if (err != nullptr) { RESPOND_AND_SET_NULL_IF_ERROR(response, err); - FillStringTensor(input_list, request_element_cnt - element_idx); + FillStringTensor(input_list, request_element_cnt); return cuda_copy; } @@ -1945,9 +1944,10 @@ SetStringInputTensor( input_list->push_back(std::string(addr, len)); } + size_t element_cnt = str_list.size(); if (err != nullptr) { RESPOND_AND_SET_NULL_IF_ERROR(response, err); - FillStringTensor(input_list, request_element_cnt - element_idx); + FillStringTensor(input_list, request_element_cnt - element_cnt); } return cuda_copy; }