Skip to content

Commit

Permalink
feat: Add BF16 tensor support via dlpack (#371)
Browse files Browse the repository at this point in the history
  • Loading branch information
rmccorm4 authored Jul 30, 2024
1 parent c8b188f commit 2b12abe
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 7 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,10 @@ input0 = pb_utils.Tensor.from_dlpack("INPUT0", pytorch_tensor)
This method only supports contiguous Tensors that are in C-order. If the tensor
is not C-order contiguous an exception will be raised.

For python models with input or output tensors of type BFloat16 (BF16), the
`as_numpy()` method is not supported, and the `from_dlpack` and `to_dlpack`
methods must be used instead.

## `pb_utils.Tensor.is_cpu() -> bool`

This function can be used to check whether a tensor is placed in CPU or not.
Expand Down
17 changes: 16 additions & 1 deletion src/pb_stub_utils.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -168,6 +168,8 @@ triton_to_pybind_dtype(TRITONSERVER_DataType data_type)
dtype_numpy = py::dtype(py::format_descriptor<uint8_t>::format());
break;
case TRITONSERVER_TYPE_BF16:
// NOTE: Currently skipping this call via `if (BF16)` check, but may
// want to better handle this or set some default/invalid dtype.
throw PythonBackendException("TYPE_BF16 not currently supported.");
case TRITONSERVER_TYPE_INVALID:
throw PythonBackendException("Dtype is invalid.");
Expand Down Expand Up @@ -240,6 +242,10 @@ triton_to_dlpack_type(TRITONSERVER_DataType triton_dtype)
case TRITONSERVER_TYPE_BYTES:
throw PythonBackendException(
"TYPE_BYTES tensors cannot be converted to DLPack.");
case TRITONSERVER_TYPE_BF16:
dl_code = DLDataTypeCode::kDLBfloat;
dt_size = 16;
break;

default:
throw PythonBackendException(
Expand Down Expand Up @@ -301,6 +307,15 @@ dlpack_to_triton_type(const DLDataType& data_type)
}
}

if (data_type.code == DLDataTypeCode::kDLBfloat) {
if (data_type.bits != 16) {
throw PythonBackendException(
"Expected BF16 tensor to have 16 bits, but had: " +
std::to_string(data_type.bits));
}
return TRITONSERVER_TYPE_BF16;
}

return TRITONSERVER_TYPE_INVALID;
}
}}} // namespace triton::backend::python
24 changes: 18 additions & 6 deletions src/pb_tensor.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2021-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -152,7 +152,10 @@ PbTensor::PbTensor(
#ifdef TRITON_PB_STUB
if (memory_type_ == TRITONSERVER_MEMORY_CPU ||
memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) {
if (dtype != TRITONSERVER_TYPE_BYTES) {
if (dtype == TRITONSERVER_TYPE_BF16) {
// No native numpy representation for BF16. DLPack should be used instead.
numpy_array_ = py::none();
} else if (dtype != TRITONSERVER_TYPE_BYTES) {
py::object numpy_array =
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
Expand Down Expand Up @@ -512,12 +515,18 @@ PbTensor::Name() const
const py::array*
PbTensor::AsNumpy() const
{
if (IsCPU()) {
return &numpy_array_;
} else {
if (!IsCPU()) {
throw PythonBackendException(
"Tensor is stored in GPU and cannot be converted to NumPy.");
}

if (dtype_ == TRITONSERVER_TYPE_BF16) {
throw PythonBackendException(
"Tensor dtype is BF16 and cannot be converted to NumPy. Use "
"to_dlpack() and from_dlpack() instead.");
}

return &numpy_array_;
}
#endif // TRITON_PB_STUB

Expand Down Expand Up @@ -643,7 +652,10 @@ PbTensor::PbTensor(
#ifdef TRITON_PB_STUB
if (memory_type_ == TRITONSERVER_MEMORY_CPU ||
memory_type_ == TRITONSERVER_MEMORY_CPU_PINNED) {
if (dtype_ != TRITONSERVER_TYPE_BYTES) {
if (dtype_ == TRITONSERVER_TYPE_BF16) {
// No native numpy representation for BF16. DLPack should be used instead.
numpy_array_ = py::none();
} else if (dtype_ != TRITONSERVER_TYPE_BYTES) {
py::object numpy_array =
py::array(triton_to_pybind_dtype(dtype_), dims_, (void*)memory_ptr_);
numpy_array_ = numpy_array.attr("view")(triton_to_numpy_type(dtype_));
Expand Down

0 comments on commit 2b12abe

Please sign in to comment.