Skip to content

Commit a0a1213

Browse files
Weiming Zhaoweimingzha0
Weiming Zhao
authored andcommitted
[ODLA/TRT] Fix bug and reuse device memory
The scratchpad for int64->int32 is fixed. For static staic and batch, we can reuse device memory.
1 parent 54e80f3 commit a0a1213

File tree

1 file changed

+49
-31
lines changed

1 file changed

+49
-31
lines changed

ODLA/platforms/tensorrt/odla_tensorrt.cc

+49-31
Original file line numberDiff line numberDiff line change
@@ -216,15 +216,15 @@ struct _odla_context {
216216
#endif
217217

218218
typedef struct {
219-
void* host_ptr;
220-
void* dev_ptr;
221-
size_t len;
219+
void* host_ptr = nullptr;
220+
void* dev_ptr = nullptr;
221+
size_t len = 0;
222222
odla_value_type vt;
223223
} OutputPtrInfo;
224224

225225
typedef struct {
226-
const void* host_ptr;
227-
void* dev_ptr;
226+
const void* host_ptr = nullptr;
227+
void* dev_ptr = nullptr;
228228
} InputPtrInfo;
229229
std::unordered_map<std::string, OutputPtrInfo> output_ptrs;
230230
std::unordered_map<std::string, InputPtrInfo> input_ptrs;
@@ -387,7 +387,7 @@ static nvinfer1::Dims SqueezeNVDims(const nvinfer1::Dims dims, int index) {
387387

388388
thread_local odla_computation g_comp;
389389
static std::vector<std::unique_ptr<_odla_computation>> g_comps;
390-
static std::vector<int> g_workspace;
390+
static std::vector<std::unique_ptr<int[]>> g_workspace;
391391

392392
static nvinfer1::DataType GetNVDataType(odla_element_type type) {
393393
switch (type) {
@@ -433,19 +433,20 @@ static odla_value_type ValidateValueType(const odla_value_type& type) {
433433
return type;
434434
}
435435

436-
static void* ValidateValuePtr(const odla_value_type& type, void* ptr) {
436+
static std::unique_ptr<int[]> ConvertData(const odla_value_type& type,
437+
const void* ptr) {
437438
if (type.element_type == ODLA_INT64) {
438-
int64_t* src = static_cast<int64_t*>(ptr);
439+
const int64_t* src = static_cast<const int64_t*>(ptr);
439440
auto num_elements = GetTotalElements(type.shape);
440-
auto workspace_size = g_workspace.size();
441-
assert(workspace_size + num_elements < MAX_INT64_CONVERTION_NUM);
442-
int* tmp = g_workspace.data() + workspace_size;
441+
auto buf = std::make_unique<int[]>(num_elements);
442+
int* tmp = buf.get();
443443
for (int i = 0; i < num_elements; ++i) {
444-
g_workspace.push_back(static_cast<int>(*src++));
444+
assert(*src < MAX_INT64_CONVERTION_NUM);
445+
tmp[i] = (static_cast<int>(*src++));
445446
}
446-
return tmp;
447+
return buf;
447448
}
448-
return ptr;
449+
return nullptr;
449450
}
450451

451452
template <typename T>
@@ -496,7 +497,6 @@ odla_status odla_CreateComputation(odla_computation* computation) {
496497
g_comps.push_back(std::make_unique<_odla_computation>());
497498
g_comp = g_comps.back().get();
498499
*computation = g_comp;
499-
g_workspace.reserve(MAX_INT64_CONVERTION_NUM);
500500
return ODLA_SUCCESS;
501501
}
502502

@@ -622,10 +622,15 @@ odla_status odla_GetArgFromComputationByIdx(const odla_computation computation,
622622

623623
odla_value odla_CreateConstant(odla_value_type type, const void* ptr,
624624
const odla_value_id id) {
625-
nvinfer1::Weights weight{
626-
.type = GetNVDataType(type.element_type),
627-
.values = ValidateValuePtr(type, const_cast<void*>(ptr)),
628-
.count = GetTotalElements(type.shape)};
625+
void* host_ptr = const_cast<void*>(ptr);
626+
auto buf = ConvertData(type, ptr);
627+
if (buf != nullptr) {
628+
host_ptr = buf.get();
629+
g_workspace.push_back(std::move(buf));
630+
}
631+
nvinfer1::Weights weight{.type = GetNVDataType(type.element_type),
632+
.values = host_ptr,
633+
.count = GetTotalElements(type.shape)};
629634
auto c = g_comp->network->addConstant(GetNVDims(type.shape), weight);
630635
odla_value v = CreateValue(c->getOutput(0), ValidateValueType(type), id);
631636
v->const_layer = c;
@@ -661,20 +666,27 @@ odla_status odla_GetOutputFromComputationByIdx(
661666

662667
odla_status odla_BindToArgument(odla_value value, const odla_void* data_ptr,
663668
odla_context context) {
664-
void* dev_ptr = nullptr;
665669
odla_value_shape real_shape = value->type.shape;
670+
bool dynamic_input_size = false;
666671
if ((g_comp && g_comp->is_dynamic_batch) || context->run_batch_size) {
667672
real_shape.dims[0] = context->run_batch_size;
673+
dynamic_input_size = true;
668674
}
669675
size_t bytes =
670676
GetTotalElements(real_shape) * GetElementSize(value->type.element_type);
671-
CHECK(cudaMalloc(&dev_ptr, bytes));
672-
void* validated_data_ptr =
673-
ValidateValuePtr(value->type, const_cast<void*>(data_ptr));
674-
CHECK(cudaMemcpy(dev_ptr, validated_data_ptr, bytes, cudaMemcpyHostToDevice));
675-
677+
void* validated_data_ptr = const_cast<void*>(data_ptr);
678+
auto buf = ConvertData(value->type, data_ptr);
679+
if (buf != nullptr) {
680+
validated_data_ptr = buf.get();
681+
}
682+
void* dev_ptr = context->input_ptrs[value->name].dev_ptr;
683+
if (dev_ptr == nullptr) {
684+
CHECK(cudaMalloc(&dev_ptr, bytes));
685+
}
676686
context->input_ptrs[value->name] = {.host_ptr = data_ptr, .dev_ptr = dev_ptr};
677687

688+
CHECK(cudaMemcpy(dev_ptr, validated_data_ptr, bytes, cudaMemcpyHostToDevice));
689+
678690
return ODLA_SUCCESS;
679691
}
680692

@@ -688,15 +700,17 @@ odla_status odla_BindToArgumentById(const odla_value_id value_id,
688700

689701
odla_status odla_BindToOutput(odla_value value, odla_void* data_ptr,
690702
odla_context context) {
691-
void* dst = nullptr;
692703
odla_value_shape real_shape = value->type.shape;
693704
if ((g_comp && g_comp->is_dynamic_batch) || context->run_batch_size) {
694705
real_shape.dims[0] = context->run_batch_size;
695706
}
696707
size_t bytes =
697708
GetTotalElements(real_shape) * GetElementSize(value->type.element_type);
698-
699-
CHECK(cudaMalloc(&dst, bytes));
709+
// TODO: convert to int64 for int64 outputs?
710+
void* dst = context->output_ptrs[value->name].dev_ptr;
711+
if (dst == nullptr) {
712+
CHECK(cudaMalloc(&dst, bytes));
713+
}
700714

701715
context->output_ptrs[value->name] = {
702716
.host_ptr = data_ptr, .dev_ptr = dst, .len = bytes, .vt = value->type};
@@ -959,7 +973,9 @@ odla_status odla_ExecuteComputation(odla_computation comp, odla_context context,
959973
cudaMemcpyDeviceToHost));
960974
}
961975
}
962-
976+
if (!comp->is_dynamic_batch) {
977+
return ODLA_SUCCESS;
978+
}
963979
// copy results and free temp buffers.
964980
for (auto& ptr : buffers) {
965981
CHECK(cudaFree(ptr));
@@ -1948,7 +1964,8 @@ odla_values odla_LSTM(odla_value input, odla_rnn_weight_format weight_format,
19481964
return g_comp->network->addConstant(GetNVDims(dim), weight)->getOutput(0);
19491965
};
19501966
nvinfer1::ITensor* init_hidden_t = getInitTensor(initial_h);
1951-
// LOG_VERBOSE("init_hidden dim:" + gen_str(init_hidden_t->getDimensions()));
1967+
// LOG_VERBOSE("init_hidden dim:" +
1968+
// gen_str(init_hidden_t->getDimensions()));
19521969
nvinfer1::ITensor* init_cell_t = getInitTensor(initial_c);
19531970
rnn_layer->setHiddenState(*init_hidden_t);
19541971
rnn_layer->setCellState(*init_cell_t);
@@ -2090,7 +2107,8 @@ odla_values odla_LSTM(odla_value input, odla_rnn_weight_format weight_format,
20902107
return g_comp->network->addConstant(GetNVDims(dim), weight)->getOutput(0);
20912108
};
20922109
nvinfer1::ITensor* init_hidden_t = getInitTensor(initial_h);
2093-
// LOG_VERBOSE("init_hidden dim:" + gen_str(init_hidden_t->getDimensions()));
2110+
// LOG_VERBOSE("init_hidden dim:" +
2111+
// gen_str(init_hidden_t->getDimensions()));
20942112
nvinfer1::ITensor* init_cell_t = getInitTensor(initial_c);
20952113
// LOG("init_cell dim:" << init_cell_t->getDimensions());
20962114

0 commit comments

Comments
 (0)