Skip to content

Commit

Permalink
refactor: Refactor string input checks (#104)
Browse files Browse the repository at this point in the history
Refactor string input tensor checks
  • Loading branch information
yinggeh authored Jul 31, 2024
1 parent 515466c commit 80296d0
Showing 1 changed file with 14 additions and 63 deletions.
77 changes: 14 additions & 63 deletions src/tensorflow.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -567,7 +567,6 @@ SetStringInputTensor(
cudaStream_t stream, const char* host_policy_name)
{
bool cuda_copy = false;
size_t element_idx = 0;

// For string data type, we always need to have the data on CPU so
// that we can read string length and construct the string
Expand All @@ -582,8 +581,7 @@ SetStringInputTensor(
&contiguous_buffer, stream, &cuda_copy);
if (err != nullptr) {
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
FillStringTensor(
tensor, tensor_offset + element_idx, request_element_cnt - element_idx);
FillStringTensor(tensor, tensor_offset, request_element_cnt);
free(contiguous_buffer);
return cuda_copy;
}
Expand All @@ -595,68 +593,21 @@ SetStringInputTensor(
}
#endif // TRITON_ENABLE_GPU

// Parse content and assign to 'tensor'. Each string in 'content'
// is a 4-byte length followed by the string itself with no
// null-terminator.
while (content_byte_size >= sizeof(uint32_t)) {
if (element_idx >= request_element_cnt) {
RESPOND_AND_SET_NULL_IF_ERROR(
response,
TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"unexpected number of string elements " +
std::to_string(element_idx + 1) + " for inference input '" +
name + "', expecting " + std::to_string(request_element_cnt))
.c_str()));
FillStringTensor(
tensor, tensor_offset + element_idx,
request_element_cnt - element_idx);
free(contiguous_buffer);
return cuda_copy;
}

const uint32_t len = *(reinterpret_cast<const uint32_t*>(content));
content += sizeof(uint32_t);
content_byte_size -= sizeof(uint32_t);

if (content_byte_size < len) {
RESPOND_AND_SET_NULL_IF_ERROR(
response,
TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
std::string(
"incomplete string data for inference input '" +
std::string(name) + "', expecting string of length " +
std::to_string(len) + " but only " +
std::to_string(content_byte_size) + " bytes available")
.c_str()));
FillStringTensor(
tensor, tensor_offset + element_idx,
request_element_cnt - element_idx);
free(contiguous_buffer);
return cuda_copy;
}
std::vector<std::pair<const char*, const uint32_t>> str_list;
err = ValidateStringBuffer(
content, content_byte_size, request_element_cnt, name, &str_list);
// Set string values.
for (size_t element_idx = 0; element_idx < str_list.size(); ++element_idx) {
const auto& [addr, len] = str_list[element_idx];
TRITONTF_TensorSetString(tensor, tensor_offset + element_idx, addr, len);
}

TRITONTF_TensorSetString(tensor, tensor_offset + element_idx, content, len);
content += len;
content_byte_size -= len;
element_idx++;
}

if ((*response != nullptr) && (element_idx != request_element_cnt)) {
RESPOND_AND_SET_NULL_IF_ERROR(
response, TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
std::string(
"expected " + std::to_string(request_element_cnt) +
" strings for inference input '" + name + "', got " +
std::to_string(element_idx))
.c_str()));
size_t element_cnt = str_list.size();
if (err != nullptr) {
RESPOND_AND_SET_NULL_IF_ERROR(response, err);
FillStringTensor(
tensor, tensor_offset + element_idx, request_element_cnt - element_idx);
tensor, tensor_offset + element_cnt, request_element_cnt - element_cnt);
}

free(contiguous_buffer);
return cuda_copy;
}
Expand Down

0 comments on commit 80296d0

Please sign in to comment.