Skip to content

Commit 2d20229

Browse files
committed
odla torchscript
1 parent bfeab7e commit 2d20229

File tree

3 files changed

+334
-0
lines changed

3 files changed

+334
-0
lines changed

ODLA/platforms/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,8 @@ option(ODLA_BUILD_TF_Wrapper "Build ODLA Tensorflow Wrapper" OFF)
8686
if (ODLA_BUILD_TF_Wrapper)
8787
add_subdirectory(tensorflow)
8888
endif()
89+
90+
option(ODLA_BUILD_TORCH_Wrapper "Build ODLA TORCHSCRIPT Wrapper" OFF)
91+
if (ODLA_BUILD_TORCH_Wrapper)
92+
add_subdirectory(torch)
93+
endif()

ODLA/platforms/torch/CMakeLists.txt

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# ==============================================================================
2+
# Copyright (C) 2022 Alibaba Group Holding Limited.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License
15+
# ==============================================================================
16+
add_odla_library(odla_torch SHARED odla_torchscript.cc)
17+
18+
if (CMAKE_PREFIX_PATH)
19+
find_package(Torch REQUIRED)
20+
endif()
21+
if (TORCH_INSTALL_PREFIX)
22+
set(TORCH_PATH ${TORCH_INSTALL_PREFIX})
23+
else()
24+
execute_process(
25+
COMMAND python3 -c "import torch; print(torch.__path__[0])"
26+
OUTPUT_VARIABLE TORCH_PATH
27+
OUTPUT_STRIP_TRAILING_WHITESPACE
28+
RESULT_VARIABLE retcode)
29+
endif()
30+
if (NOT TORCH_PATH)
31+
message(FATAL_ERROR "torch install is not found.")
32+
else()
33+
message(STATUS "torch install path: ${TORCH_PATH}")
34+
endif()
35+
36+
set(PYTORCH_INC_DIR
37+
${TORCH_PATH}/include
38+
${TORCH_PATH}/torch/csrc/api/include
39+
${TORCH_PATH}/include/TH
40+
${TORCH_PATH}/include/THC
41+
)
42+
message(STATUS "torch include dirs: ${PYTORCH_INC_DIR}")
43+
target_include_directories(odla_torch PRIVATE ${PYTORCH_INC_DIR})
44+
45+
set(PYTORCH_LIBS
46+
-lc10
47+
-ltorch
48+
-ltorch_cpu
49+
-ltorch_python)
50+
51+
if (NOT TORCH_CXX_FLAGS)
52+
execute_process(
53+
COMMAND python3 -c "import torch; print(torch._C._GLIBCXX_USE_CXX11_ABI)"
54+
OUTPUT_VARIABLE TORCH_CXX11_ABI
55+
OUTPUT_STRIP_TRAILING_WHITESPACE
56+
RESULT_VARIABLE retcode)
57+
message(STATUS "Torch CXX flags:-D_GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI}")
58+
target_compile_definitions(odla_torch PRIVATE _GLIBCXX_USE_CXX11_ABI=${TORCH_CXX11_ABI})
59+
else()
60+
message(STATUS "Torch CXX flags:${TORCH_CXX_FLAGS}")
61+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
62+
endif()
63+
target_link_options(odla_torch PUBLIC -Wl,-rpath,${TORCH_PATH}/lib -L${TORCH_PATH}/lib)
64+
target_link_libraries(odla_torch PUBLIC ${PYTORCH_LIBS})
65+
target_link_libraries(odla_torch PUBLIC ODLA)
+264
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
//===- odla_torchscript.cc ------------------------------------------------===//
2+
//
3+
// Copyright (C) 2022 Alibaba Group Holding Limited.
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
// =============================================================================
17+
#include <c10/util/ArrayRef.h>
18+
#include <torch/csrc/jit/serialization/import.h>
19+
20+
#include <ODLA/odla.h>
21+
22+
#include <sstream>
23+
24+
const uint32_t MAX_OUTPUT_TENSORS = 10;
25+
const uint32_t MAX_INPUT_TENSORS = 20;
26+
27+
struct _odla_device {
28+
c10::DeviceType device_t_;
29+
};
30+
31+
struct _odla_value {
32+
_odla_value(uint32_t v):id_(v) {}
33+
uint32_t id_;
34+
};
35+
36+
struct _odla_executable {
37+
torch::jit::Module module_;
38+
std::vector<odla_value> odla_inputs_outputs_;
39+
odla_uint32 num_inputs_;
40+
};
41+
42+
struct _odla_context {
43+
_odla_context();
44+
std::vector<torch::jit::IValue> inputs_;
45+
std::vector<odla_value_type> input_types_;
46+
torch::jit::IValue output_;
47+
std::vector<at::Tensor> output_tensors_;
48+
odla_uint32 num_output_tensors_;
49+
};
50+
51+
_odla_context::_odla_context() {
52+
inputs_.resize(MAX_INPUT_TENSORS);
53+
input_types_.resize(MAX_INPUT_TENSORS);
54+
}
55+
56+
size_t static getElementCount(const odla_value_shape& dims) {
57+
return dims.size == 0 ? 1
58+
: std::accumulate(dims.dims, dims.dims + dims.size, 1,
59+
std::multiplies<size_t>());
60+
}
61+
62+
c10::IntArrayRef static toTensorDim(odla_value_shape& dims) {
63+
return dims.size == 0 ? c10::IntArrayRef(1) :
64+
c10::IntArrayRef(dims.dims, dims.size);
65+
}
66+
67+
c10::ScalarType static toTensorDataType(odla_element_type dt) {
68+
static const std::unordered_map<odla_element_type, c10::ScalarType> dt_map = {
69+
{ODLA_FLOAT32, c10::ScalarType::Float},
70+
{ODLA_INT32, c10::ScalarType::Int},
71+
{ODLA_BOOL, c10::ScalarType::Bool}
72+
};
73+
auto it = dt_map.find(dt);
74+
return it == dt_map.end() ? c10::ScalarType::Float : it->second;
75+
}
76+
77+
odla_element_type static toODLADataType(const c10::ScalarType& st) {
78+
static const std::unordered_map<c10::ScalarType, odla_element_type> dt_map = {
79+
{c10::ScalarType::Float, ODLA_FLOAT32},
80+
{c10::ScalarType::Int, ODLA_INT32},
81+
{c10::ScalarType::Bool, ODLA_BOOL}
82+
};
83+
auto it = dt_map.find(st);
84+
return it == dt_map.end() ? ODLA_FLOAT32 : it->second;
85+
}
86+
87+
odla_value_type static toODLAValueType(const c10::ScalarType& dt, at::IntArrayRef dims) {
88+
odla_value_type ty;
89+
ty.element_type = toODLADataType(dt);
90+
ty.shape.size = dims.size();
91+
int i = 0;
92+
for (auto d : dims) {
93+
ty.shape.dims[i++] = d;
94+
}
95+
return ty;
96+
}
97+
98+
static std::unordered_map<odla_context, std::unique_ptr<_odla_context>> g_ctxs;
99+
static std::unordered_map<odla_executable, std::unique_ptr<_odla_executable>>
100+
g_executables;
101+
102+
static _odla_device g_device{c10::kCUDA};
103+
104+
odla_status odla_AllocateDevice(const odla_vendor vendor,
105+
const odla_device_name device_name,
106+
odla_device* device,
107+
const char* config) {
108+
*device = &g_device;
109+
return ODLA_SUCCESS;
110+
}
111+
112+
odla_status odla_LoadExecutable(odla_resource_location location,
113+
odla_device device,
114+
odla_executable* computation) {
115+
*computation = nullptr;
116+
if (location.location_type != ODLA_LOCATION_MEMORY &&
117+
location.location_type != ODLA_LOCATION_PATH) {
118+
return ODLA_FAILURE;
119+
}
120+
auto comp = std::make_unique<_odla_executable>();
121+
if (location.location_type == ODLA_LOCATION_MEMORY) {
122+
std::istringstream s;
123+
s.rdbuf()->pubsetbuf(const_cast<char*>(
124+
reinterpret_cast<const char*>(location.location)),
125+
location.size);
126+
comp->module_ = torch::jit::load(s, c10::Device(g_device.device_t_));
127+
} else {
128+
comp->module_ = torch::jit::load(reinterpret_cast<const char*>(
129+
location.location),
130+
c10::Device(g_device.device_t_));
131+
}
132+
auto schema = comp->module_.get_method("forward").function().getSchema();
133+
assert(!schema.is_vararg());
134+
assert(!schema.is_varret());
135+
auto num_inputs = comp->module_.get_method("forward").function().num_inputs();
136+
comp->num_inputs_ = num_inputs - 1;
137+
for (uint32_t idx = 0; idx < std::max(comp->num_inputs_, MAX_OUTPUT_TENSORS); ++idx) {
138+
auto v = std::make_unique<_odla_value>(idx);
139+
comp->odla_inputs_outputs_.push_back(v.get());
140+
}
141+
*computation = comp.get();
142+
g_executables[*computation] = std::move(comp);
143+
return ODLA_SUCCESS;
144+
}
145+
146+
odla_status odla_GetArgFromExecutableByIdx(odla_executable comp,
147+
odla_uint32 idx,
148+
odla_value* value) {
149+
if (idx > comp->num_inputs_) {
150+
*value = nullptr;
151+
return ODLA_FAILURE;
152+
}
153+
*value = comp->odla_inputs_outputs_[idx];
154+
return ODLA_SUCCESS;
155+
}
156+
157+
odla_status odla_GetOutputFromExecutableByIdx(const odla_executable comp,
158+
const odla_uint32 output_idx,
159+
odla_value* output_value) {
160+
if (output_idx > comp->odla_inputs_outputs_.size()) {
161+
*output_value = nullptr;
162+
return ODLA_FAILURE;
163+
}
164+
*output_value = comp->odla_inputs_outputs_[output_idx];
165+
return ODLA_SUCCESS;
166+
}
167+
168+
odla_status odla_CreateContext(odla_context* context) {
169+
*context = nullptr;
170+
auto ctx = std::make_unique<_odla_context>();
171+
*context = ctx.get();
172+
g_ctxs[*context] = std::move(ctx);
173+
return ODLA_SUCCESS;
174+
}
175+
176+
odla_status odla_SetRuntimeValueType(odla_context context, odla_value v, odla_value_type ty) {
177+
assert(v->id_ < MAX_INPUT_TENSORS);
178+
context->input_types_[v->id_] = std::move(ty);
179+
return ODLA_SUCCESS;
180+
}
181+
182+
odla_status odla_GetRuntimeValueType(odla_context context, odla_value value, odla_value_type* ty) {
183+
assert(value->id_ <= context->num_output_tensors_);
184+
auto t = context->output_tensors_[value->id_];
185+
*ty = toODLAValueType(t.scalar_type(), t.sizes());
186+
return ODLA_SUCCESS;
187+
}
188+
189+
odla_status odla_BindToArgument(odla_value value, const odla_void* data_ptr,
190+
odla_context context) {
191+
assert(value->id_ < MAX_INPUT_TENSORS);
192+
auto ty = context->input_types_[value->id_];
193+
auto options = c10::TensorOptions()
194+
.dtype(toTensorDataType(ty.element_type))
195+
.device(c10::kCPU);
196+
auto t = at::from_blob(const_cast<void*>(data_ptr), toTensorDim(ty.shape), options);
197+
if (g_device.device_t_ == c10::kCUDA) {
198+
t = t.to(c10::device(c10::kCUDA));
199+
}
200+
context->inputs_[value->id_] = c10::IValue(t);
201+
return ODLA_SUCCESS;
202+
}
203+
204+
odla_status odla_BindToOutput(odla_value value, odla_void* data_ptr,
205+
odla_context context) {
206+
assert(value->id_ < context->num_output_tensors_);
207+
auto t = context->output_tensors_[value->id_];
208+
auto ty = toODLAValueType(t.scalar_type(), t.sizes());
209+
void* raw_data = t.storage().data();
210+
int len = at::elementSize(t.scalar_type()) * getElementCount(ty.shape);
211+
if (g_device.device_t_ == c10::kCPU) {
212+
memcpy(data_ptr, raw_data, len);
213+
} else {
214+
// cudaMemcpy(data_ptr, raw_data, len, cudaMemcpyDeviceToHost);
215+
t = t.to(c10::Device(c10::kCPU));
216+
memcpy(data_ptr, t.storage().data(), len);
217+
}
218+
return ODLA_SUCCESS;
219+
}
220+
221+
odla_status odla_GetRuntimeNumOfOutputs(odla_context context,
222+
odla_uint32 *num_output_ptr) {
223+
*num_output_ptr = (odla_uint32)context->num_output_tensors_;
224+
return ODLA_SUCCESS;
225+
}
226+
227+
odla_status odla_LaunchExecutable(const odla_executable computation,
228+
const odla_context context) {
229+
context->inputs_.resize(computation->num_inputs_);
230+
context->input_types_.resize(computation->num_inputs_);
231+
context->output_ = computation->module_.forward(context->inputs_);
232+
233+
if (context->output_.isTensor()) {
234+
context->output_tensors_.push_back(context->output_.toTensor());
235+
} else {
236+
assert(context->output_.isTuple());
237+
for (const auto& item : context->output_.toTuple()->elements()) {
238+
assert (item.isTensor());
239+
context->output_tensors_.push_back(item.toTensor());
240+
}
241+
}
242+
context->num_output_tensors_ = context->output_tensors_.size();
243+
return ODLA_SUCCESS;
244+
}
245+
246+
odla_status odla_DestroyContext(odla_context context) {
247+
auto it = g_ctxs.find(context);
248+
if (it == g_ctxs.end()) {
249+
return ODLA_FAILURE;
250+
}
251+
g_ctxs.erase(it);
252+
return ODLA_SUCCESS;
253+
}
254+
255+
odla_status odla_DestroyExecutable(odla_executable computation) {
256+
auto it = g_executables.find(computation);
257+
if (it == g_executables.end()) {
258+
return ODLA_FAILURE;
259+
}
260+
g_executables.erase(it);
261+
return ODLA_SUCCESS;
262+
}
263+
264+
odla_status odla_DestroyDevice(odla_device device) { return ODLA_SUCCESS; }

0 commit comments

Comments
 (0)