-
Notifications
You must be signed in to change notification settings - Fork 75
/
Copy pathodla_tensorrt.cc
1871 lines (1664 loc) · 67.6 KB
/
odla_tensorrt.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
//===- odla_tensorrt.cc ---------------------------------------------------===//
//
// Copyright (C) 2019-2021 Alibaba Group Holding Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <NvInfer.h>
#include <NvInferPlugin.h>
#include <NvInferRuntime.h>
#include <ODLA/odla.h>
#include <bits/stdint-intn.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <time.h>
#include <ctime>
#include <cassert>
#include <cmath>
#include <cstddef>
#include <fstream>
#include <iostream>
#include <memory>
#include <numeric>
#include <unordered_map>
#include <vector>
#include <mutex>
#include "plugins/initPlugin.h"
using namespace nvinfer1;
using namespace std;
#if !defined(ODLA_VERSION_NUMBER) || (ODLA_VERSION_NUMBER < 50)
#error This library requires minimum ODLA version 0.5
#endif
template <typename T>
struct TrtDestroyer {
void operator()(T* t) { t->destroy(); }
};
template <typename T>
using TrtUniquePtr = std::unique_ptr<T, TrtDestroyer<T>>;
inline bool check(cudaError_t e, int line, const char* file_name) {
if (e != cudaSuccess) {
std::cerr << "CUDA runtime API error " << cudaGetErrorName(e) << " at line "
<< line << " in file " << file_name;
return false;
}
return true;
}
inline bool check(bool result, int line, const char* file_name) {
if (!result) {
std::cerr << "Error at line " << line << " in file " << file_name;
return false;
}
return true;
}
#define CHECK(call) check(call, __LINE__, __FILE__)
namespace open_dla_tensorrt {
class Logger : public nvinfer1::ILogger {
public:
void log(ILogger::Severity severity, const char* msg) override {
int log_level;
switch (severity) {
case ILogger::Severity::kINTERNAL_ERROR:
log_level = 0;
break;
case ILogger::Severity::kERROR:
log_level = 1;
break;
case ILogger::Severity::kWARNING:
log_level = 2;
break;
case ILogger::Severity::kINFO:
log_level = 3;
break;
case ILogger::Severity::kVERBOSE:
log_level = 4;
default:
log_level = 5;
}
if (log_level <= 1) {
std::cerr << "[" << log_level << "]: " << msg << "\n";
}
}
};
} // namespace open_dla_tensorrt
static open_dla_tensorrt::Logger Logger;
struct _odla_value {
_odla_value(nvinfer1::ITensor* tensor, const odla_value_type& type,
const char* name)
: tensor(tensor), type(type), name(name) {
tensor->setName(name);
}
_odla_value(nvinfer1::ILayer* layer, const odla_value_type& type,
const char* name)
: layer(layer), tensor(layer->getOutput(0)), type(type), name(name) {
layer->setName(name);
}
_odla_value() {}
operator nvinfer1::ITensor&() { return *tensor; }
nvinfer1::ILayer* layer = nullptr;
nvinfer1::ITensor* tensor = nullptr;
nvinfer1::IConstantLayer* const_layer = nullptr;
odla_value_type type;
const char* name;
};
#ifdef MAX_WORKSPACE_SIZE
static constexpr size_t MAX_WORKSPACE_SIZE_BYTES =
(size_t)MAX_WORKSPACE_SIZE * 1024 * 1024;
#else
static constexpr size_t MAX_WORKSPACE_SIZE_BYTES = 1ul * 1024 * 1024 * 1024;
#endif
static const int MAX_INT64_CONVERTION_NUM = 65536ul;
static bool g_load_engine_mode = false;
struct _odla_computation {
nvinfer1::IBuilder* builder = nullptr;
nvinfer1::INetworkDefinition* network = nullptr;
std::unordered_map<std::string, odla_value> inputs;
std::unordered_map<std::string, odla_value> outputs;
std::vector<std::vector<float>> buffers;
std::vector<std::unique_ptr<_odla_value>> vals;
std::vector<odla_value> input_vals;
std::vector<odla_value> output_vals;
bool fp16_mode = false;
bool is_dynamic_batch = false;
int min_batch_size = 0;
int max_batch_size = 0;
int opt_batch_size = 0;
bool load_engine_mode = false;
size_t max_workspace_size = MAX_WORKSPACE_SIZE_BYTES;
_odla_computation() {
load_engine_mode = g_load_engine_mode;
if (const char* env_p = std::getenv("ODLA_TRT_MAX_WS_MB")) {
if (int mb = std::stoi(env_p); mb != 0) {
max_workspace_size = mb << 20;
}
}
if (!load_engine_mode) {
builder = nvinfer1::createInferBuilder(Logger);
#if NV_TENSORRT_MAJOR < 7
builder->setMaxWorkspaceSize(max_workspace_size);
network = builder->createNetwork();
#else
initODLAPlugin(&Logger, "");
nvinfer1::NetworkDefinitionCreationFlags flags = 0;
if (const char* env_p = std::getenv("ODLA_TRT_USE_EXPLICIT_BATCH")) {
if (*env_p != '0') {
flags = 1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
}
}
network = builder->createNetworkV2(flags);
#endif
}
}
~_odla_computation() {
if (!load_engine_mode) {
builder->destroy();
network->destroy();
}
builder = nullptr;
network = nullptr;
}
};
struct _odla_context {
odla_computation comp = nullptr;
nvinfer1::ICudaEngine* engine = nullptr;
nvinfer1::IExecutionContext* ctx = nullptr;
void* temp_input_ptr = nullptr;
void* temp_output_ptr = nullptr;
#if NV_TENSORRT_MAJOR >= 7
nvinfer1::IBuilderConfig* builder_cfg = nullptr;
nvinfer1::IOptimizationProfile* builder_profile = nullptr;
#endif
typedef struct {
void* host_ptr;
void* dev_ptr;
size_t len;
odla_value_type vt;
} OutputPtrInfo;
typedef struct {
const void* host_ptr;
void* dev_ptr;
} InputPtrInfo;
std::unordered_map<std::string, OutputPtrInfo> output_ptrs;
std::unordered_map<std::string, InputPtrInfo> input_ptrs;
int run_batch_size = 0;
// CUdeviceptr cumemalloc_address;
_odla_context(odla_computation comp) : comp(comp) {
if (!comp->load_engine_mode) {
#if NV_TENSORRT_MAJOR < 7
engine = comp->builder->buildCudaEngine(*comp->network);
#else
builder_cfg = comp->builder->createBuilderConfig();
if (comp->is_dynamic_batch) {
builder_profile = comp->builder->createOptimizationProfile();
for (auto& input : comp->inputs) {
const char* input_name = input.first.c_str();
odla_value value = input.second;
int d1 = value->type.shape.dims[1];
int d2 = value->type.shape.dims[2];
int d3 = value->type.shape.dims[3];
builder_profile->setDimensions(
input_name, OptProfileSelector::kMIN,
Dims{4, {comp->min_batch_size, d1, d2, d3}});
builder_profile->setDimensions(
input_name, OptProfileSelector::kOPT,
Dims{4, {comp->opt_batch_size, d1, d2, d3}});
builder_profile->setDimensions(
input_name, OptProfileSelector::kMAX,
Dims{4, {comp->max_batch_size, d1, d2, d3}});
}
builder_cfg->addOptimizationProfile(builder_profile);
}
builder_cfg->setMaxWorkspaceSize(comp->max_workspace_size);
if (comp->fp16_mode) {
builder_cfg->setFlag(BuilderFlag::kFP16);
builder_cfg->setFlag(BuilderFlag::kSTRICT_TYPES);
}
engine =
comp->builder->buildEngineWithConfig(*comp->network, *builder_cfg);
#endif
ctx = engine->createExecutionContext();
}
}
~_odla_context() {
ctx->destroy();
if (!comp->load_engine_mode) {
engine->destroy();
#if NV_TENSORRT_MAJOR >= 7
builder_cfg->destroy();
#endif
}
comp = nullptr;
engine = nullptr;
ctx = nullptr;
}
};
struct _odla_executable {
odla_context context = nullptr;
odla_computation computation = nullptr;
int DLACore = -1;
std::unique_ptr<_odla_value> val;
_odla_executable(odla_context context, odla_computation computation)
: context(context), computation(computation) {
val = std::make_unique<_odla_value>();
}
~_odla_executable() {}
};
static odla_element_type GetODLAType(DataType type) {
switch (type) {
case nvinfer1::DataType::kFLOAT:
return ODLA_FLOAT32;
case nvinfer1::DataType::kHALF:
return ODLA_FLOAT16;
case nvinfer1::DataType::kINT32:
return ODLA_INT32;
case nvinfer1::DataType::kINT8:
return ODLA_INT8;
case nvinfer1::DataType::kBOOL:
return ODLA_BOOL;
default:
return ODLA_FLOAT32;
}
}
static int64_t GetTotalElements(const odla_value_shape& dims) {
return dims.size == 0 ? 1
: std::accumulate(dims.dims, dims.dims + dims.size, 1,
std::multiplies<size_t>());
}
const int nvinfer1::Dims::MAX_DIMS;
static nvinfer1::Dims GetNVDims(int n, const odla_uint32* dims) {
nvinfer1::Dims ret;
assert(n <= nvinfer1::Dims::MAX_DIMS);
ret.nbDims = n;
if (n == 0) {
ret.d[0] = 0;
}
for (int i = 0; i < n; ++i) {
ret.d[i] = static_cast<int>(dims[i]);
}
return ret;
}
static nvinfer1::Dims GetNVDims(const odla_value_shape& dims) {
nvinfer1::Dims ret;
ret.nbDims = dims.size;
assert(dims.size <= std::min(nvinfer1::Dims::MAX_DIMS, ODLA_MAX_DIMENSION));
if (dims.size == 0) {
ret.d[0] = 0;
}
for (int i = 0; i < dims.size; ++i) {
ret.d[i] = dims.dims[i];
}
return ret;
}
static bool SameNVDims(const nvinfer1::Dims& d1, const nvinfer1::Dims& d2) {
if (d1.nbDims != d2.nbDims) {
return false;
}
for (int i = 0; i < d1.nbDims; ++i) {
if (d1.d[i] != d2.d[i]) {
return false;
}
}
return true;
}
static nvinfer1::Dims BroadcastDims(const odla_value_shape& dims,
size_t dim_size) {
if (dims.size >= dim_size) {
return GetNVDims(dims);
}
nvinfer1::Dims ret;
ret.nbDims = dim_size;
for (int i = 0, e = dim_size - dims.size; i != e; ++i) {
ret.d[i] = 1;
}
for (int i = dim_size - dims.size, j = 0; i != dim_size; ++i, ++j) {
ret.d[i] = dims.dims[j];
}
return ret;
}
thread_local odla_computation g_comp;
static std::vector<std::unique_ptr<_odla_computation>> g_comps;
static std::vector<int> g_workspace;
static nvinfer1::DataType GetNVDataType(odla_element_type type) {
switch (type) {
case ODLA_FLOAT32:
return nvinfer1::DataType::kFLOAT;
case ODLA_FLOAT16:
return nvinfer1::DataType::kHALF;
case ODLA_INT32:
case ODLA_INT64:
return nvinfer1::DataType::kINT32;
case ODLA_INT8:
return nvinfer1::DataType::kINT8;
case ODLA_BOOL:
return nvinfer1::DataType::kBOOL;
default:
return nvinfer1::DataType::kFLOAT;
}
}
static unsigned GetElementSize(odla_element_type type) {
switch (type) {
case ODLA_FLOAT32:
return sizeof(float);
case ODLA_FLOAT16:
return sizeof(int16_t);
case ODLA_INT32:
case ODLA_INT64:
return sizeof(int32_t);
case ODLA_INT8:
case ODLA_BOOL:
return 1;
default:
return 0;
}
}
static odla_value_type ValidateValueType(const odla_value_type& type) {
// Trt doesn't support INT64, convert value_type of ODLA_INT64 to ODLA_INT32
if (type.element_type == ODLA_INT64) {
return odla_value_type{.element_type = ODLA_INT32, .shape = type.shape};
}
return type;
}
static void* ValidateValuePtr(const odla_value_type& type, void* ptr) {
if (type.element_type == ODLA_INT64) {
int64_t* src = static_cast<int64_t*>(ptr);
auto num_elements = GetTotalElements(type.shape);
auto workspace_size = g_workspace.size();
assert(workspace_size + num_elements < MAX_INT64_CONVERTION_NUM);
int* tmp = g_workspace.data() + workspace_size;
for (int i = 0; i < num_elements; ++i) {
g_workspace.push_back(static_cast<int>(*src++));
}
return tmp;
}
return ptr;
}
template <typename T>
static odla_value CreateValue(T* t, const odla_value_type& type,
const odla_value_id id) {
const char* name = reinterpret_cast<const char*>(id);
auto v = std::make_unique<_odla_value>(t, type, name);
auto ret = v.get();
g_comp->vals.push_back(std::move(v));
return ret;
}
extern "C" {
odla_status odla_CreateComputation(odla_computation* computation) {
g_comps.push_back(std::make_unique<_odla_computation>());
g_comp = g_comps.back().get();
*computation = g_comp;
g_workspace.reserve(MAX_INT64_CONVERTION_NUM);
return ODLA_SUCCESS;
}
odla_status odla_SetActiveComputation(odla_computation computation) {
g_comp = computation;
return ODLA_SUCCESS;
}
odla_status odla_DestroyComputation(odla_computation comp) {
for (auto it = g_comps.begin(), e = g_comps.end(); it != e; ++it) {
if (it->get() == comp) {
it->reset();
g_comps.erase(it);
return ODLA_SUCCESS;
}
}
assert(0);
return ODLA_FAILURE;
}
odla_status odla_SetComputationItem(odla_computation computation,
odla_item_type type,
odla_item_value value) {
bool is_dynamic_batch = false;
switch (type) {
case ODLA_DYNAMIC_BATCH:
is_dynamic_batch = *(reinterpret_cast<bool*>(value));
if (is_dynamic_batch &&
(computation->is_dynamic_batch != is_dynamic_batch)) {
computation->is_dynamic_batch = is_dynamic_batch;
if (!computation->load_engine_mode) {
computation->network->destroy();
nvinfer1::NetworkDefinitionCreationFlags flags =
1U << static_cast<uint32_t>(
nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
computation->network = computation->builder->createNetworkV2(flags);
}
}
break;
case ODLA_MIN_BATCH_SIZE:
computation->min_batch_size = *(reinterpret_cast<int*>(value));
break;
case ODLA_MAX_BATCH_SIZE:
computation->max_batch_size = *(reinterpret_cast<int*>(value));
break;
case ODLA_OPT_BATCH_SIZE:
computation->opt_batch_size = *(reinterpret_cast<int*>(value));
break;
case ODLA_FP16_MODE:
computation->fp16_mode = *(reinterpret_cast<bool*>(value));
break;
case ODLA_LOAD_ENGINE_MODE:
g_load_engine_mode = *(reinterpret_cast<bool*>(value));
break;
case ODLA_BF16_MODE:
break;
default:
std::cerr << "Unsupported property type: " << type << std::endl;
return ODLA_FAILURE;
}
return ODLA_SUCCESS;
}
odla_status odla_SetContextItem(odla_context context, odla_item_type type,
odla_item_value value) {
switch (type) {
case ODLA_RUN_BATCH_SIZE:
context->run_batch_size = *(reinterpret_cast<int*>(value));
// odla_value_shape real_shape = value->type.shape;
// size_t bytes =
// GetTotalElements(real_shape) * GetElementSize(value->type.element_type);
// CUdeviceptr dev_ptr;
// CHECK(cuMemAlloc(&dev_ptr, bytes));
// context->cumemalloc_address = dev_ptr;
break;
default:
std::cerr << "Unsupported property type: " << type << std::endl;
return ODLA_FAILURE;
}
return ODLA_SUCCESS;
}
odla_status odla_CreateContext(odla_context* context) {
*context = new _odla_context(g_comp);
return ODLA_SUCCESS;
}
odla_status odla_DestroyContext(odla_context context) {
delete context;
return ODLA_SUCCESS;
}
odla_value odla_CreateArgument(odla_value_type type, const odla_value_id id) {
const char* name = reinterpret_cast<const char*>(id);
auto input = g_comp->network->addInput(name, GetNVDataType(type.element_type),
GetNVDims(type.shape));
odla_value v = CreateValue(input, type, id);
g_comp->inputs[name] = v; //inputs[input] = v
// odla_value_shape real_shape = v->type.shape;
// std::cerr << "odla_value_shape:" << real_shape << "\n";
// size_t bytes =
// GetTotalElements(real_shape) * GetTotalElements(v->type.element_type);
// CHECK(cudaMalloc(&dev_ptr, bytes));
// void* validated_data_ptr =
// ValidateValuePtr(value->type, const_cast<void*>(data_ptr));
// // CHECK(cudaMemcpy(dev_ptr, ))
g_comp->input_vals.push_back(v);
return v;
}
odla_status odla_GetNumOfArgsFromComputation(const odla_computation computation,
odla_uint32* num_args) {
*num_args = computation->input_vals.size();
return ODLA_SUCCESS;
}
odla_status odla_GetArgFromComputationByIdx(const odla_computation computation,
const odla_uint32 arg_idx,
odla_value* arg_value) {
*arg_value = nullptr;
if (arg_idx >= computation->input_vals.size()) {
return ODLA_FAILURE;
}
*arg_value = computation->input_vals[arg_idx];
return ODLA_SUCCESS;
}
odla_value odla_CreateConstant(odla_value_type type, const void* ptr,
const odla_value_id id) {
nvinfer1::Weights weight{
.type = GetNVDataType(type.element_type),
.values = ValidateValuePtr(type, const_cast<void*>(ptr)),
.count = GetTotalElements(type.shape)};
auto c = g_comp->network->addConstant(GetNVDims(type.shape), weight);
odla_value v = CreateValue(c->getOutput(0), ValidateValueType(type), id);
v->const_layer = c;
return v;
}
odla_status odla_SetValueAsOutput(const odla_value val) {
const char* name =
val->layer != nullptr ? val->layer->getName() : val->tensor->getName();
g_comp->outputs[name] = val;
g_comp->output_vals.push_back(val);
val->tensor->setName(name);
g_comp->network->markOutput(*val->tensor);
return ODLA_SUCCESS;
}
odla_status odla_GetNumOfOutputsFromComputation(
const odla_computation computation, odla_uint32* num_outputs) {
*num_outputs = computation->output_vals.size();
return ODLA_SUCCESS;
}
odla_status odla_GetOutputFromComputationByIdx(
const odla_computation computation, const odla_uint32 output_idx,
odla_value* output_value) {
*output_value = nullptr;
if (output_idx >= computation->output_vals.size()) {
return ODLA_FAILURE;
}
*output_value = computation->output_vals[output_idx];
return ODLA_SUCCESS;
}
//这里 每运行一个batch都会运行
odla_status odla_BindToArgument(odla_value value, const odla_void* data_ptr,
odla_context context) {
// CUdeviceptr dev_ptr;
clock_t startTime, endTime;
void* dev_ptr = nullptr;
odla_value_shape real_shape = value->type.shape;
if ((g_comp && g_comp->is_dynamic_batch) || context->run_batch_size) {
real_shape.dims[0] = context->run_batch_size;
}
size_t bytes =
GetTotalElements(real_shape) * GetElementSize(value->type.element_type);
// CHECK(cuMemAlloc(&dev_ptr, bytes));
// CHECK(cudaMalloc(&dev_ptr, bytes));
// 在这里检测一下有没有预先cudamalloc过,如果有过,将数据传到对应地址
// CUdeviceptr dev_ptr = context->cumemalloc_addres;
// std::cerr << "context->temp_input_ptr:" << context->temp_input_ptr << "\n";
if (context->temp_input_ptr == nullptr) {
CHECK(cudaMalloc(&(context->temp_input_ptr), bytes));
}
dev_ptr = context->temp_input_ptr;
void* validated_data_ptr =
ValidateValuePtr(value->type, const_cast<void*>(data_ptr));
// void* pagelocked_buffer = context->input_ptrs[value->name].host_ptr;
// startTime = clock();
// CHECK(cuMemcpyHtoD(dev_ptr, validated_data_ptr, bytes));
CHECK(cudaMemcpy(dev_ptr, validated_data_ptr, bytes, cudaMemcpyHostToDevice));
// endTime = clock();
// std::cout << "the run time is:" << (double) (endTime - startTime) /CLOCKS_PER_SEC << "s" << std::endl;
// std::ofstream outf;
// outf.open("odla_cudamemcpy_times.txt", std::ios::app);
// outf << (double) (endTime - startTime) /CLOCKS_PER_SEC << std::endl;
// outf.close();
// void* dev1_ptr;
// dev1_ptr = (void*) dev_ptr;
// CHECK(cudaMemcpy(dev_ptr, validated_data_ptr, bytes, cudaMemcpyHostToDevice));
context->input_ptrs[value->name] = {.host_ptr = data_ptr, .dev_ptr = dev_ptr};
return ODLA_SUCCESS;
}
odla_status odla_BindToArgumentById(const odla_value_id value_id,
const odla_void* data_ptr,
odla_context context) {
std::string name((const char*)value_id);
return odla_BindToArgument(context->comp->inputs[name], data_ptr, context);
}
odla_status odla_BindToOutput(odla_value value, odla_void* data_ptr,
odla_context context) {
// CUdeviceptr dst;
void* dst = nullptr;
odla_value_shape real_shape = value->type.shape;
if ((g_comp && g_comp->is_dynamic_batch) || context->run_batch_size) {
real_shape.dims[0] = context->run_batch_size;
}
size_t bytes =
GetTotalElements(real_shape) * GetElementSize(value->type.element_type);
if (context->temp_output_ptr == nullptr){
CHECK(cudaMalloc(&(context->temp_output_ptr), bytes));
}
dst = context->temp_output_ptr;
// CHECK(cudaMalloc(&dst, bytes));
context->output_ptrs[value->name] = {
.host_ptr = data_ptr, .dev_ptr = dst, .len = bytes, .vt = value->type};
return ODLA_SUCCESS;
}
odla_status odla_BindToOutputById(const odla_value_id value_id,
odla_void* data_ptr, odla_context context) {
std::string name((const char*)value_id);
assert(context->comp->outputs.count(name));
auto val = context->comp->outputs[name];
return odla_BindToOutput(val, data_ptr, context);
}
static odla_status odla_StoreEngine(odla_context context,
const odla_char* file_name) {
if (context == nullptr) {
return ODLA_FAILURE;
}
std::string engine = file_name;
std::ofstream engineFile(engine, std::ios::binary);
if (!engineFile) {
std::cerr << "Cannot open engine file: " << engine << std::endl;
return ODLA_FAILURE;
}
TrtUniquePtr<IHostMemory> serializedEngine{context->engine->serialize()};
if (serializedEngine == nullptr) {
std::cerr << "Engine serialization failed" << std::endl;
return ODLA_FAILURE;
}
engineFile.write(static_cast<char*>(serializedEngine->data()),
serializedEngine->size());
if (engineFile.fail()) {
return ODLA_FAILURE;
}
return ODLA_SUCCESS;
}
odla_status odla_StoreExecutable(const odla_char* file_name,
const odla_executable executable) {
odla_context Ctx = executable->context;
return odla_StoreEngine(Ctx, file_name);
}
static odla_status odla_LoadEngine(odla_context context,
const odla_char* file_name, int DLACore) {
std::ifstream engineFile(file_name, std::ios::binary);
if (!engineFile) {
std::cerr << "Error opening engine file: " << file_name << std::endl;
return ODLA_FAILURE;
}
engineFile.seekg(0, engineFile.end);
long int fsize = engineFile.tellg();
engineFile.seekg(0, engineFile.beg);
std::vector<char> engineData(fsize);
engineFile.read(engineData.data(), fsize);
if (!engineFile) {
std::cerr << "Error loading engine file: " << file_name << std::endl;
return ODLA_FAILURE;
}
TrtUniquePtr<IRuntime> runtime{createInferRuntime(Logger)};
if (DLACore != -1) {
runtime->setDLACore(DLACore);
}
context->engine =
runtime->deserializeCudaEngine(engineData.data(), fsize, nullptr);
context->ctx = context->engine->createExecutionContext();
return ODLA_SUCCESS;
}
odla_status odla_LoadExecutable(const odla_char* file_name,
odla_executable* executable,
odla_context* context,
odla_computation* computation) {
int DLACore = (*executable)->DLACore;
if (*computation == nullptr) {
int load_engine_mode = 1;
odla_SetComputationItem(nullptr, ODLA_LOAD_ENGINE_MODE,
(odla_item_value)&load_engine_mode);
odla_CreateComputation(computation);
bool is_dynamic_batch = true;
odla_SetComputationItem(*computation, ODLA_DYNAMIC_BATCH,
(odla_item_value)&is_dynamic_batch);
}
if (*context == nullptr) {
odla_CreateContext(context);
};
if (odla_LoadEngine(*context, file_name, DLACore) != ODLA_SUCCESS) {
return ODLA_FAILURE;
}
(*executable)->computation = *computation;
(*executable)->context = *context;
return ODLA_SUCCESS;
}
odla_status odla_GetNumOfOutputsFromExecutable(const odla_executable executable,
odla_uint32* num_outputs) {
odla_context Ctx = executable->context;
int numBindings = Ctx->engine->getNbBindings();
int numOutput = 0;
for (int i = 0; i < numBindings; ++i) {
if (!Ctx->engine->bindingIsInput(i)) {
numOutput++;
}
}
*num_outputs = numOutput;
return ODLA_SUCCESS;
}
odla_status odla_GetNumOfArgsFromExecutable(const odla_executable executable,
odla_uint32* num_args) {
odla_context Ctx = executable->context;
int numBindings = Ctx->engine->getNbBindings();
int numInput = 0;
for (int i = 0; i < numBindings; ++i) {
if (Ctx->engine->bindingIsInput(i)) {
numInput++;
}
}
*num_args = numInput;
return ODLA_SUCCESS;
}
odla_status odla_CreateExecutable(odla_executable* executable,
odla_context context,
odla_computation computation) {
*executable = new _odla_executable(context, computation);
return ODLA_SUCCESS;
}
odla_status odla_DestroyExecutable(odla_executable executable) {
delete executable;
return ODLA_SUCCESS;
}
static odla_status odla_GetValFromExecutableByIdx(
const odla_executable executable, const odla_uint32 idx,
odla_value* value) {
odla_context Ctx = executable->context;
auto val = executable->val.get();
Dims dims = Ctx->engine->getBindingDimensions(idx);
DataType nv_type = Ctx->engine->getBindingDataType(idx);
auto type = &val->type;
type->element_type = GetODLAType(nv_type);
type->shape.size = dims.nbDims;
for (int i = 0; i < dims.nbDims; ++i) {
type->shape.dims[i] = dims.d[i];
}
// val->tensor->setName(Ctx->engine->getBindingName(idx));
val->name = Ctx->engine->getBindingName(idx);
*value = val;
return ODLA_SUCCESS;
}
odla_status odla_GetArgFromExecutableByIdx(const odla_executable executable,
const odla_uint32 arg_idx,
odla_value* arg_value) {
odla_context Ctx = executable->context;
auto getIdxInEngine = [&Ctx, arg_idx] {
int numBindings = Ctx->engine->getNbBindings();
for (int i = 0, j = 0; i < numBindings; ++i) {
if (Ctx->engine->bindingIsInput(i)) {
if (j == arg_idx) {
return i;
}
j++;
}
}
};
return odla_GetValFromExecutableByIdx(executable, getIdxInEngine(),
arg_value);
}
odla_status ODLA_API_CALL odla_GetOutputFromExecutableByIdx(
const odla_executable executable, const odla_uint32 output_idx,
odla_value* output_value) {
odla_context Ctx = executable->context;
auto getIdxInEngine = [&Ctx, output_idx] {
int numBindings = Ctx->engine->getNbBindings();
for (int i = 0, j = 0; i < numBindings; ++i) {
if (!Ctx->engine->bindingIsInput(i)) {
if (j == output_idx) {
return i;
}
j++;
}
}
};
return odla_GetValFromExecutableByIdx(executable, getIdxInEngine(),
output_value);
}
odla_status odla_GetValueType(const odla_value value,
odla_value_type* value_type) {
*value_type = value->type;
return ODLA_SUCCESS;
}
odla_status odla_ExecuteComputation(odla_computation comp, odla_context context,
odla_compute_mode mode,
odla_device device) {
// clock_t startTime, endTime;
std::vector<void*> buffers;
auto add_to_buffer = [&](const std::string& name, void* ptr) {
int idx = context->engine->getBindingIndex(name.c_str());
if (idx >= 0) {
if (buffers.size() <= idx) {
buffers.resize(idx + 1);
}
buffers[idx] = ptr;
}
};
for (auto& kv : context->input_ptrs) {
// void* kv_second_devptr;
// kv_second_devptr = (void*) kv.second.dev_ptr;
add_to_buffer(kv.first, kv.second.dev_ptr); //kv.first: input, kv.second.dev_ptr: 0x7f7698600000
}
for (auto& kv : context->output_ptrs) {
// void* kv_second_devptr;
// kv_second_devptr = (void*) kv.second.dev_ptr;
// add_to_buffer(kv.first, kv_second_devptr);
add_to_buffer(kv.first, kv.second.dev_ptr);
}
if (comp->is_dynamic_batch) {
for (auto& input_ptr : context->input_ptrs) {
int idx = context->engine->getBindingIndex(input_ptr.first.c_str());
nvinfer1::Dims dims = context->ctx->getBindingDimensions(idx);
dims.d[0] = context->run_batch_size;
context->ctx->setBindingDimensions(idx, dims);
}
CHECK(context->ctx->executeV2(buffers.data()));
} else {
int batch = 1;
CHECK(context->ctx->execute(batch, buffers.data()));
}
for (auto& kv : context->output_ptrs) {
if (kv.second.vt.element_type == ODLA_INT64) {
std::vector<int> host_tmp(GetTotalElements(kv.second.vt.shape));
CHECK(cudaMemcpy(host_tmp.data(), kv.second.dev_ptr, kv.second.len,
cudaMemcpyDeviceToHost));
int64_t* ptr = static_cast<int64_t*>(kv.second.host_ptr);
for (int d : host_tmp) {
*ptr++ = static_cast<int64_t>(d);
}
} else {
CHECK(cudaMemcpy(kv.second.host_ptr, kv.second.dev_ptr, kv.second.len,
cudaMemcpyDeviceToHost));
}
}
// copy results and free temp buffers.
// for (auto& ptr : buffers) {
// CHECK(cudaFree(ptr));
// }
context->input_ptrs.clear();
context->output_ptrs.clear();
return ODLA_SUCCESS;
}
static odla_value_shape broadcastTensor(odla_computation comp,
nvinfer1::ITensor*& lhs,
nvinfer1::ITensor*& rhs,
odla_value_shape dims_lhs,
odla_value_shape dims_rhs) {
if (dims_lhs.size == dims_rhs.size) {
return dims_lhs;
}
if (dims_lhs.size > dims_rhs.size) {
auto reshape = g_comp->network->addShuffle(*rhs);
reshape->setReshapeDimensions(BroadcastDims(dims_rhs, dims_lhs.size));
rhs = reshape->getOutput(0);
return dims_lhs;
}
auto reshape = g_comp->network->addShuffle(*lhs);
reshape->setReshapeDimensions(BroadcastDims(dims_lhs, dims_rhs.size));
lhs = reshape->getOutput(0);
return dims_rhs;
}
static odla_value binary_op(nvinfer1::ElementWiseOperation op, odla_value lhs,
odla_value rhs, const odla_value_id id) {
nvinfer1::ITensor* lhs_tensor = lhs->tensor;
nvinfer1::ITensor* rhs_tensor = rhs->tensor;
const auto& dims_lhs = lhs->type.shape;
const auto& dims_rhs = rhs->type.shape;
auto out_dim =
broadcastTensor(g_comp, lhs_tensor, rhs_tensor, dims_lhs, dims_rhs);
auto sub = g_comp->network->addElementWise(*lhs_tensor, *rhs_tensor, op);
nvinfer1::Dims sub_dim = sub->getOutput(0)->getDimensions();
if (!SameNVDims(rhs_tensor->getDimensions(), lhs_tensor->getDimensions())) {
// fix the case when dims_lhs and dims_rhs both need to broadcast
// e.g., dims_lhs = (256,1), dims_rhs = (1,768), out_dims = (256,768)
out_dim.size = (odla_int32)sub_dim.nbDims;
for (int i = 0; i < sub_dim.nbDims; ++i) {
out_dim.dims[i] = (odla_int64)sub_dim.d[i];
}
}
return CreateValue(sub, {lhs->type.element_type, out_dim}, id);
}
odla_value odla_Add(odla_value lhs, odla_value rhs, const odla_value_id id) {