@@ -216,15 +216,15 @@ struct _odla_context {
216
216
#endif
217
217
218
218
typedef struct {
219
- void * host_ptr;
220
- void * dev_ptr;
221
- size_t len;
219
+ void * host_ptr = nullptr ;
220
+ void * dev_ptr = nullptr ;
221
+ size_t len = 0 ;
222
222
odla_value_type vt;
223
223
} OutputPtrInfo;
224
224
225
225
typedef struct {
226
- const void * host_ptr;
227
- void * dev_ptr;
226
+ const void * host_ptr = nullptr ;
227
+ void * dev_ptr = nullptr ;
228
228
} InputPtrInfo;
229
229
std::unordered_map<std::string, OutputPtrInfo> output_ptrs;
230
230
std::unordered_map<std::string, InputPtrInfo> input_ptrs;
@@ -387,7 +387,7 @@ static nvinfer1::Dims SqueezeNVDims(const nvinfer1::Dims dims, int index) {
387
387
388
388
thread_local odla_computation g_comp;
389
389
static std::vector<std::unique_ptr<_odla_computation>> g_comps;
390
- static std::vector<int > g_workspace;
390
+ static std::vector<std::unique_ptr< int []> > g_workspace;
391
391
392
392
static nvinfer1::DataType GetNVDataType (odla_element_type type) {
393
393
switch (type) {
@@ -433,19 +433,20 @@ static odla_value_type ValidateValueType(const odla_value_type& type) {
433
433
return type;
434
434
}
435
435
436
- static void * ValidateValuePtr (const odla_value_type& type, void * ptr) {
436
+ static std::unique_ptr<int []> ConvertData (const odla_value_type& type,
437
+ const void * ptr) {
437
438
if (type.element_type == ODLA_INT64) {
438
- int64_t * src = static_cast <int64_t *>(ptr);
439
+ const int64_t * src = static_cast <const int64_t *>(ptr);
439
440
auto num_elements = GetTotalElements (type.shape );
440
- auto workspace_size = g_workspace.size ();
441
- assert (workspace_size + num_elements < MAX_INT64_CONVERTION_NUM);
442
- int * tmp = g_workspace.data () + workspace_size;
441
+ auto buf = std::make_unique<int []>(num_elements);
442
+ int * tmp = buf.get ();
443
443
for (int i = 0 ; i < num_elements; ++i) {
444
- g_workspace.push_back (static_cast <int >(*src++));
444
+ assert (*src < MAX_INT64_CONVERTION_NUM);
445
+ tmp[i] = (static_cast <int >(*src++));
445
446
}
446
- return tmp ;
447
+ return buf ;
447
448
}
448
- return ptr ;
449
+ return nullptr ;
449
450
}
450
451
451
452
template <typename T>
@@ -496,7 +497,6 @@ odla_status odla_CreateComputation(odla_computation* computation) {
496
497
g_comps.push_back (std::make_unique<_odla_computation>());
497
498
g_comp = g_comps.back ().get ();
498
499
*computation = g_comp;
499
- g_workspace.reserve (MAX_INT64_CONVERTION_NUM);
500
500
return ODLA_SUCCESS;
501
501
}
502
502
@@ -622,10 +622,15 @@ odla_status odla_GetArgFromComputationByIdx(const odla_computation computation,
622
622
623
623
odla_value odla_CreateConstant (odla_value_type type, const void * ptr,
624
624
const odla_value_id id) {
625
- nvinfer1::Weights weight{
626
- .type = GetNVDataType (type.element_type ),
627
- .values = ValidateValuePtr (type, const_cast <void *>(ptr)),
628
- .count = GetTotalElements (type.shape )};
625
+ void * host_ptr = const_cast <void *>(ptr);
626
+ auto buf = ConvertData (type, ptr);
627
+ if (buf != nullptr ) {
628
+ host_ptr = buf.get ();
629
+ g_workspace.push_back (std::move (buf));
630
+ }
631
+ nvinfer1::Weights weight{.type = GetNVDataType (type.element_type ),
632
+ .values = host_ptr,
633
+ .count = GetTotalElements (type.shape )};
629
634
auto c = g_comp->network ->addConstant (GetNVDims (type.shape ), weight);
630
635
odla_value v = CreateValue (c->getOutput (0 ), ValidateValueType (type), id);
631
636
v->const_layer = c;
@@ -661,20 +666,27 @@ odla_status odla_GetOutputFromComputationByIdx(
661
666
662
667
odla_status odla_BindToArgument (odla_value value, const odla_void* data_ptr,
663
668
odla_context context) {
664
- void * dev_ptr = nullptr ;
665
669
odla_value_shape real_shape = value->type .shape ;
670
+ bool dynamic_input_size = false ;
666
671
if ((g_comp && g_comp->is_dynamic_batch ) || context->run_batch_size ) {
667
672
real_shape.dims [0 ] = context->run_batch_size ;
673
+ dynamic_input_size = true ;
668
674
}
669
675
size_t bytes =
670
676
GetTotalElements (real_shape) * GetElementSize (value->type .element_type );
671
- CHECK (cudaMalloc (&dev_ptr, bytes));
672
- void * validated_data_ptr =
673
- ValidateValuePtr (value->type , const_cast <void *>(data_ptr));
674
- CHECK (cudaMemcpy (dev_ptr, validated_data_ptr, bytes, cudaMemcpyHostToDevice));
675
-
677
+ void * validated_data_ptr = const_cast <void *>(data_ptr);
678
+ auto buf = ConvertData (value->type , data_ptr);
679
+ if (buf != nullptr ) {
680
+ validated_data_ptr = buf.get ();
681
+ }
682
+ void * dev_ptr = context->input_ptrs [value->name ].dev_ptr ;
683
+ if (dev_ptr == nullptr ) {
684
+ CHECK (cudaMalloc (&dev_ptr, bytes));
685
+ }
676
686
context->input_ptrs [value->name ] = {.host_ptr = data_ptr, .dev_ptr = dev_ptr};
677
687
688
+ CHECK (cudaMemcpy (dev_ptr, validated_data_ptr, bytes, cudaMemcpyHostToDevice));
689
+
678
690
return ODLA_SUCCESS;
679
691
}
680
692
@@ -688,15 +700,17 @@ odla_status odla_BindToArgumentById(const odla_value_id value_id,
688
700
689
701
odla_status odla_BindToOutput (odla_value value, odla_void* data_ptr,
690
702
odla_context context) {
691
- void * dst = nullptr ;
692
703
odla_value_shape real_shape = value->type .shape ;
693
704
if ((g_comp && g_comp->is_dynamic_batch ) || context->run_batch_size ) {
694
705
real_shape.dims [0 ] = context->run_batch_size ;
695
706
}
696
707
size_t bytes =
697
708
GetTotalElements (real_shape) * GetElementSize (value->type .element_type );
698
-
699
- CHECK (cudaMalloc (&dst, bytes));
709
+ // TODO: convert to int64 for int64 outputs?
710
+ void * dst = context->output_ptrs [value->name ].dev_ptr ;
711
+ if (dst == nullptr ) {
712
+ CHECK (cudaMalloc (&dst, bytes));
713
+ }
700
714
701
715
context->output_ptrs [value->name ] = {
702
716
.host_ptr = data_ptr, .dev_ptr = dst, .len = bytes, .vt = value->type };
@@ -959,7 +973,9 @@ odla_status odla_ExecuteComputation(odla_computation comp, odla_context context,
959
973
cudaMemcpyDeviceToHost));
960
974
}
961
975
}
962
-
976
+ if (!comp->is_dynamic_batch ) {
977
+ return ODLA_SUCCESS;
978
+ }
963
979
// copy results and free temp buffers.
964
980
for (auto & ptr : buffers) {
965
981
CHECK (cudaFree (ptr));
@@ -1948,7 +1964,8 @@ odla_values odla_LSTM(odla_value input, odla_rnn_weight_format weight_format,
1948
1964
return g_comp->network ->addConstant (GetNVDims (dim), weight)->getOutput (0 );
1949
1965
};
1950
1966
nvinfer1::ITensor* init_hidden_t = getInitTensor (initial_h);
1951
- // LOG_VERBOSE("init_hidden dim:" + gen_str(init_hidden_t->getDimensions()));
1967
+ // LOG_VERBOSE("init_hidden dim:" +
1968
+ // gen_str(init_hidden_t->getDimensions()));
1952
1969
nvinfer1::ITensor* init_cell_t = getInitTensor (initial_c);
1953
1970
rnn_layer->setHiddenState (*init_hidden_t );
1954
1971
rnn_layer->setCellState (*init_cell_t );
@@ -2090,7 +2107,8 @@ odla_values odla_LSTM(odla_value input, odla_rnn_weight_format weight_format,
2090
2107
return g_comp->network ->addConstant (GetNVDims (dim), weight)->getOutput (0 );
2091
2108
};
2092
2109
nvinfer1::ITensor* init_hidden_t = getInitTensor (initial_h);
2093
- // LOG_VERBOSE("init_hidden dim:" + gen_str(init_hidden_t->getDimensions()));
2110
+ // LOG_VERBOSE("init_hidden dim:" +
2111
+ // gen_str(init_hidden_t->getDimensions()));
2094
2112
nvinfer1::ITensor* init_cell_t = getInitTensor (initial_c);
2095
2113
// LOG("init_cell dim:" << init_cell_t->getDimensions());
2096
2114
0 commit comments