1
+ // ===- odla_torchscript.cc ------------------------------------------------===//
2
+ //
3
+ // Copyright (C) 2022 Alibaba Group Holding Limited.
4
+ //
5
+ // Licensed under the Apache License, Version 2.0 (the "License");
6
+ // you may not use this file except in compliance with the License.
7
+ // You may obtain a copy of the License at
8
+ //
9
+ // http://www.apache.org/licenses/LICENSE-2.0
10
+ //
11
+ // Unless required by applicable law or agreed to in writing, software
12
+ // distributed under the License is distributed on an "AS IS" BASIS,
13
+ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ // See the License for the specific language governing permissions and
15
+ // limitations under the License.
16
+ // =============================================================================
17
+ #include < c10/util/ArrayRef.h>
18
+ #include < torch/csrc/jit/serialization/import.h>
19
+
20
+ #include < ODLA/odla.h>
21
+
22
+ #include < sstream>
23
+
24
+ const uint32_t MAX_OUTPUT_TENSORS = 10 ;
25
+ const uint32_t MAX_INPUT_TENSORS = 20 ;
26
+
27
+ struct _odla_device {
28
+ c10::DeviceType device_t_;
29
+ };
30
+
31
+ struct _odla_value {
32
+ _odla_value (uint32_t v):id_(v) {}
33
+ uint32_t id_;
34
+ };
35
+
36
+ struct _odla_executable {
37
+ torch::jit::Module module_;
38
+ std::vector<odla_value> odla_inputs_outputs_;
39
+ odla_uint32 num_inputs_;
40
+ };
41
+
42
+ struct _odla_context {
43
+ _odla_context ();
44
+ std::vector<torch::jit::IValue> inputs_;
45
+ std::vector<odla_value_type> input_types_;
46
+ torch::jit::IValue output_;
47
+ std::vector<at::Tensor> output_tensors_;
48
+ odla_uint32 num_output_tensors_;
49
+ };
50
+
51
+ _odla_context::_odla_context () {
52
+ inputs_.resize (MAX_INPUT_TENSORS);
53
+ input_types_.resize (MAX_INPUT_TENSORS);
54
+ }
55
+
56
+ size_t static getElementCount (const odla_value_shape& dims) {
57
+ return dims.size == 0 ? 1
58
+ : std::accumulate (dims.dims , dims.dims + dims.size , 1 ,
59
+ std::multiplies<size_t >());
60
+ }
61
+
62
+ c10::IntArrayRef static toTensorDim (odla_value_shape& dims) {
63
+ return dims.size == 0 ? c10::IntArrayRef (1 ) :
64
+ c10::IntArrayRef (dims.dims , dims.size );
65
+ }
66
+
67
+ c10::ScalarType static toTensorDataType (odla_element_type dt) {
68
+ static const std::unordered_map<odla_element_type, c10::ScalarType> dt_map = {
69
+ {ODLA_FLOAT32, c10::ScalarType::Float},
70
+ {ODLA_INT32, c10::ScalarType::Int},
71
+ {ODLA_BOOL, c10::ScalarType::Bool}
72
+ };
73
+ auto it = dt_map.find (dt);
74
+ return it == dt_map.end () ? c10::ScalarType::Float : it->second ;
75
+ }
76
+
77
+ odla_element_type static toODLADataType (const c10::ScalarType& st) {
78
+ static const std::unordered_map<c10::ScalarType, odla_element_type> dt_map = {
79
+ {c10::ScalarType::Float, ODLA_FLOAT32},
80
+ {c10::ScalarType::Int, ODLA_INT32},
81
+ {c10::ScalarType::Bool, ODLA_BOOL}
82
+ };
83
+ auto it = dt_map.find (st);
84
+ return it == dt_map.end () ? ODLA_FLOAT32 : it->second ;
85
+ }
86
+
87
+ odla_value_type static toODLAValueType (const c10::ScalarType& dt, at::IntArrayRef dims) {
88
+ odla_value_type ty;
89
+ ty.element_type = toODLADataType (dt);
90
+ ty.shape .size = dims.size ();
91
+ int i = 0 ;
92
+ for (auto d : dims) {
93
+ ty.shape .dims [i++] = d;
94
+ }
95
+ return ty;
96
+ }
97
+
98
+ static std::unordered_map<odla_context, std::unique_ptr<_odla_context>> g_ctxs;
99
+ static std::unordered_map<odla_executable, std::unique_ptr<_odla_executable>>
100
+ g_executables;
101
+
102
+ static _odla_device g_device{c10::kCUDA };
103
+
104
+ odla_status odla_AllocateDevice (const odla_vendor vendor,
105
+ const odla_device_name device_name,
106
+ odla_device* device,
107
+ const char * config) {
108
+ *device = &g_device;
109
+ return ODLA_SUCCESS;
110
+ }
111
+
112
+ odla_status odla_LoadExecutable (odla_resource_location location,
113
+ odla_device device,
114
+ odla_executable* computation) {
115
+ *computation = nullptr ;
116
+ if (location.location_type != ODLA_LOCATION_MEMORY &&
117
+ location.location_type != ODLA_LOCATION_PATH) {
118
+ return ODLA_FAILURE;
119
+ }
120
+ auto comp = std::make_unique<_odla_executable>();
121
+ if (location.location_type == ODLA_LOCATION_MEMORY) {
122
+ std::istringstream s;
123
+ s.rdbuf ()->pubsetbuf (const_cast <char *>(
124
+ reinterpret_cast <const char *>(location.location )),
125
+ location.size );
126
+ comp->module_ = torch::jit::load (s, c10::Device (g_device.device_t_ ));
127
+ } else {
128
+ comp->module_ = torch::jit::load (reinterpret_cast <const char *>(
129
+ location.location ),
130
+ c10::Device (g_device.device_t_ ));
131
+ }
132
+ auto schema = comp->module_ .get_method (" forward" ).function ().getSchema ();
133
+ assert (!schema.is_vararg ());
134
+ assert (!schema.is_varret ());
135
+ auto num_inputs = comp->module_ .get_method (" forward" ).function ().num_inputs ();
136
+ comp->num_inputs_ = num_inputs - 1 ;
137
+ for (uint32_t idx = 0 ; idx < std::max (comp->num_inputs_ , MAX_OUTPUT_TENSORS); ++idx) {
138
+ auto v = std::make_unique<_odla_value>(idx);
139
+ comp->odla_inputs_outputs_ .push_back (v.get ());
140
+ }
141
+ *computation = comp.get ();
142
+ g_executables[*computation] = std::move (comp);
143
+ return ODLA_SUCCESS;
144
+ }
145
+
146
+ odla_status odla_GetArgFromExecutableByIdx (odla_executable comp,
147
+ odla_uint32 idx,
148
+ odla_value* value) {
149
+ if (idx > comp->num_inputs_ ) {
150
+ *value = nullptr ;
151
+ return ODLA_FAILURE;
152
+ }
153
+ *value = comp->odla_inputs_outputs_ [idx];
154
+ return ODLA_SUCCESS;
155
+ }
156
+
157
+ odla_status odla_GetOutputFromExecutableByIdx (const odla_executable comp,
158
+ const odla_uint32 output_idx,
159
+ odla_value* output_value) {
160
+ if (output_idx > comp->odla_inputs_outputs_ .size ()) {
161
+ *output_value = nullptr ;
162
+ return ODLA_FAILURE;
163
+ }
164
+ *output_value = comp->odla_inputs_outputs_ [output_idx];
165
+ return ODLA_SUCCESS;
166
+ }
167
+
168
+ odla_status odla_CreateContext (odla_context* context) {
169
+ *context = nullptr ;
170
+ auto ctx = std::make_unique<_odla_context>();
171
+ *context = ctx.get ();
172
+ g_ctxs[*context] = std::move (ctx);
173
+ return ODLA_SUCCESS;
174
+ }
175
+
176
+ odla_status odla_SetRuntimeValueType (odla_context context, odla_value v, odla_value_type ty) {
177
+ assert (v->id_ < MAX_INPUT_TENSORS);
178
+ context->input_types_ [v->id_ ] = std::move (ty);
179
+ return ODLA_SUCCESS;
180
+ }
181
+
182
+ odla_status odla_GetRuntimeValueType (odla_context context, odla_value value, odla_value_type* ty) {
183
+ assert (value->id_ <= context->num_output_tensors_ );
184
+ auto t = context->output_tensors_ [value->id_ ];
185
+ *ty = toODLAValueType (t.scalar_type (), t.sizes ());
186
+ return ODLA_SUCCESS;
187
+ }
188
+
189
+ odla_status odla_BindToArgument (odla_value value, const odla_void* data_ptr,
190
+ odla_context context) {
191
+ assert (value->id_ < MAX_INPUT_TENSORS);
192
+ auto ty = context->input_types_ [value->id_ ];
193
+ auto options = c10::TensorOptions ()
194
+ .dtype (toTensorDataType (ty.element_type ))
195
+ .device (c10::kCPU );
196
+ auto t = at::from_blob (const_cast <void *>(data_ptr), toTensorDim (ty.shape ), options);
197
+ if (g_device.device_t_ == c10::kCUDA ) {
198
+ t = t.to (c10::device (c10::kCUDA ));
199
+ }
200
+ context->inputs_ [value->id_ ] = c10::IValue (t);
201
+ return ODLA_SUCCESS;
202
+ }
203
+
204
+ odla_status odla_BindToOutput (odla_value value, odla_void* data_ptr,
205
+ odla_context context) {
206
+ assert (value->id_ < context->num_output_tensors_ );
207
+ auto t = context->output_tensors_ [value->id_ ];
208
+ auto ty = toODLAValueType (t.scalar_type (), t.sizes ());
209
+ void * raw_data = t.storage ().data ();
210
+ int len = at::elementSize (t.scalar_type ()) * getElementCount (ty.shape );
211
+ if (g_device.device_t_ == c10::kCPU ) {
212
+ memcpy (data_ptr, raw_data, len);
213
+ } else {
214
+ // cudaMemcpy(data_ptr, raw_data, len, cudaMemcpyDeviceToHost);
215
+ t = t.to (c10::Device (c10::kCPU ));
216
+ memcpy (data_ptr, t.storage ().data (), len);
217
+ }
218
+ return ODLA_SUCCESS;
219
+ }
220
+
221
+ odla_status odla_GetRuntimeNumOfOutputs (odla_context context,
222
+ odla_uint32 *num_output_ptr) {
223
+ *num_output_ptr = (odla_uint32)context->num_output_tensors_ ;
224
+ return ODLA_SUCCESS;
225
+ }
226
+
227
+ odla_status odla_LaunchExecutable (const odla_executable computation,
228
+ const odla_context context) {
229
+ context->inputs_ .resize (computation->num_inputs_ );
230
+ context->input_types_ .resize (computation->num_inputs_ );
231
+ context->output_ = computation->module_ .forward (context->inputs_ );
232
+
233
+ if (context->output_ .isTensor ()) {
234
+ context->output_tensors_ .push_back (context->output_ .toTensor ());
235
+ } else {
236
+ assert (context->output_ .isTuple ());
237
+ for (const auto & item : context->output_ .toTuple ()->elements ()) {
238
+ assert (item.isTensor ());
239
+ context->output_tensors_ .push_back (item.toTensor ());
240
+ }
241
+ }
242
+ context->num_output_tensors_ = context->output_tensors_ .size ();
243
+ return ODLA_SUCCESS;
244
+ }
245
+
246
+ odla_status odla_DestroyContext (odla_context context) {
247
+ auto it = g_ctxs.find (context);
248
+ if (it == g_ctxs.end ()) {
249
+ return ODLA_FAILURE;
250
+ }
251
+ g_ctxs.erase (it);
252
+ return ODLA_SUCCESS;
253
+ }
254
+
255
+ odla_status odla_DestroyExecutable (odla_executable computation) {
256
+ auto it = g_executables.find (computation);
257
+ if (it == g_executables.end ()) {
258
+ return ODLA_FAILURE;
259
+ }
260
+ g_executables.erase (it);
261
+ return ODLA_SUCCESS;
262
+ }
263
+
264
+ odla_status odla_DestroyDevice (odla_device device) { return ODLA_SUCCESS; }
0 commit comments