-
Notifications
You must be signed in to change notification settings - Fork 75
/
Copy pathodla_ops_process.h
342 lines (282 loc) · 11.8 KB
/
odla_ops_process.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
//===- odla_ops_process.h -------------------------------------------------===//
//
// Copyright (C) 2019-2020 Alibaba Group Holding Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef _ODLA_OPERATOR_OPS_PROCESS_H_
#define _ODLA_OPERATOR_OPS_PROCESS_H_
#include <ODLA/odla_common.h>
#include <ODLA/odla_value.h>
/*! \file
* \details This file defines the ODLA value process replated operators.
*/
#ifdef __cplusplus
extern "C" {
#endif
//! Interpolation methods
typedef enum {
ODLA_NEAREST, /*!< nearest-neighbor interpolation. */
ODLA_LINEAR, /*!< N-Linear interpolation. E.g. bilinear for 2D plane. */
ODLA_CUBIC, /*!< N-Cubic interpolation .E.g. bicubic for 2D plane. */
} odla_interpolation_mode;
//! Modes for coordinate transformation during resizing
typedef enum {
ODLA_ASSYMMETRIC, /*!< new_coord = orig_coord * scale */
ODLA_HALF_PIXEL, /*!< new_coord = (orig_coord + 0.5) * scale - 0.5 */
ODLA_HALF_PIXEL_TF, /*!< new_coord = orig_coord * scale - 0.5 */
ODLA_ALIGN_CORNERS, /*!< new_coord = orig_coord * (orig_dim - 1) / (new_dim -
1) */
} odla_resize_coordinate_mode;
//! Methods for filling a value
typedef enum {
ODLA_EyeLike, /*!< ones on the diagnoal and zeros elsewhere. */
ODLA_RandomNormal, /*!< normal distribution. */
ODLA_RandomUniform, /*!< random uniform distrubution. */
} odla_fill_method;
//! \brief Broadcast the value
/*!
Broadcast broadcasts the input based on \p output_shape.
\param input the input value
\param output_shape the shape after broadcasting
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Broadcast(const odla_value input, odla_value_shape output_shape,
const odla_value_id value_id);
//! \brief cast the element data type of an input
/*!
Cast casts the input element type to \p target_type.
\param input the input value
\param target_type the data type of casted value
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Cast(odla_value input, odla_element_type target_type,
const odla_value_id value_id);
//! \brief Select slices based on condition
/*!
Select slices from input along specified axis based on condition vector.
\param input the input value
\param condition the condition value (1-D)
\param axis the axis on which the slices to be selected
\param max_output_shape the maximum result shape (assume all conditions are
true)
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Compress(odla_value input, odla_value condition, odla_int32 axis,
odla_value_shape max_output_shape, const odla_value_id value_id);
//! \brief Concatenate multiple values into a single value
/*!
Concat concatenates multiple values into single one. All inputs
must have the same dimension except for the dimension size
of the \p axis to concatenate on.
\param inputs the input values
\param axis the axis on which the inputs to be concatenated
\param output_shape the result shape
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Concat(odla_values inputs, odla_int32 axis, odla_value_shape output_shape,
const odla_value_id value_id);
//! \brief Broadcast the input tensor
/*!
ExpandDims broadcast the \p input tensor into the shape of \p output_dims .
\param input the input value
\param output_dims the output shape
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_ExpandDims(odla_value input, odla_value_shape output_dims,
const odla_value_id value_id);
//! \brief Generate a value with data
/*!
Fill genrates a value with type and dimensions sepecified by \p type and
fill it with data as using the specified \p method. When filling with
normal distribution, \p p0 and \p p1 are for mean and standard
deviation, respectively. When filling with unform distribution,
\p p0 and \p p1 are for the lower and upper bounds, respectively.
For othere filling methods, \p p0 and \p p1 are ignored.
\param type the type of generated odla_value
\param method the method for filling the value
\param p0 mean for normal distribution, or lower bound for uniform
distribution
\param p1 stddev for normal distribution, or upper bound for
uniform distribution
\param seed the seed to the random generator
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Fill(odla_value_type type, odla_fill_method method, odla_float32 p0,
odla_float32 p1, odla_float32 seed, const odla_value_id value_id);
//! \brief Gather slices
/*!
Gather slices from \p input according to \p indices.
\param input the input value
\param indices the indices value
\param axis the axis on which the input is to gather
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Gather(odla_value input, odla_value indices, odla_int32 axis,
odla_value_shape output_dims, const odla_value_id value_id);
//! \brief Gather elements
/*!
Gather slices from \p input according to \p indices.
\param input the input value
\param indices the indices value
\param axis the axis on which the input is to gather
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_GatherElements(odla_value input, odla_value indices, odla_int32 axis,
odla_value_shape output_dims, const odla_value_id value_id);
//! \brief one-hot value
/*!
OneHot returns a one-hot value from \p values based on \p indices
and \p depth.
\param indices the indices of "on value"
\param depth the size of the new dimension
\param values the pair of on and off values
\param axis the axis of new dimension on
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL odla_OneHot(
odla_value indices, odla_int32 depth, odla_value values, odla_int32 axis,
odla_value_shape output_dims, const odla_value_id value_id);
//! \brief Pad the input
/*!
Pad pads the \p input with given padding amount.
\param input the input value
\param padding_front the padding amount applied to the start of input
\param padding_back the padding amount applied to the end of input
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Pad(odla_value input, const odla_uint32* padding_front,
const odla_uint32* padding_back, odla_value_shape output_dims,
const odla_value_id value_id);
//! \brief Reshape a value
/*!
Reshape reshapes the input with a new dimension specified by \p output_dims.
\param input the input value
\param output_dims the output shape
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Reshape(odla_value input, odla_value_shape output_dims,
const odla_value_id value_id);
//! \brief Resize by interpolating
/*!
Resize resizes the input using specified interploation method.
\param input the input value
\param interpolation the interpolation method
\param mode the coordinate transformation mode
\param axes_mask the mask that indicates which axes need to be resized.
The LSB corresponds to the shape dimension with stride of 1. For example,
to resize a tensor in NHWC layout, the mask would be 0b0110.
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Resize(odla_value input, odla_interpolation_mode interpolation,
odla_resize_coordinate_mode mode, odla_uint32 axes_mask,
odla_value_shape output_dims, const odla_value_id value_id);
//! \brief Get the shape of input
/*!
Shape returns the shape of \p input as a 1D odla_value. The element type of
the result value is implementation determined.
\param input the input value
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Shape(odla_value input, odla_value_shape output_dims,
const odla_value_id value_id);
//! \brief Extract a slice
/*!
Slice extracts a slice from \p input.
\param input the input value
\param start the offets at each slicing dimension
\param end the ending indices(exclusive) at each slicing dimension
\param stride the stride at each slicing dimension
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Slice(odla_value input, const odla_uint32* start, const odla_uint32* end,
const odla_uint32* stride, odla_value_shape output_dims,
const odla_value_id value_id);
//! \brief Remove dimensions of size 1
/*!
Squeeze removes dimensions of size 1 from the shape of \p input.
All single dimensions will be squeezed if num_of_axes is zero.
\param input the input value
\param num_of_axes nubmer of axes to squeeze
\param axes the axes to squeeze
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Squeeze(odla_value input, odla_size_t num_of_axes, const odla_uint32* axes,
odla_value_shape output_dims, const odla_value_id value_id);
//! \brief Transpose the input
/*!
Transpose returns a transposed value based on the \p permutation.
\param input the input value
\param permutations the axies for permutation. It should be the same size as
input_dims and output_dims
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Transpose(odla_value input, odla_value_shape permutations,
odla_value_shape output_dims, const odla_value_id value_id);
//! \brief Tile input multiples times
/*!
Replicate a given \p input value multiples times.
\param input the input value
\param repeat the dimension numbers of repeated copies along input's
dimensions.
\param output_dims the optional output shape (can be undefined)
\param value_id a unique value id (can be NULL)
\return odla_value
*/
extern ODLA_API_EXPORT odla_value ODLA_API_CALL
odla_Tile(odla_value input, const odla_uint32* repeat,
odla_value_shape output_dims, const odla_value_id value_id);
#ifdef __cplusplus
} // C extern
#endif
#endif // _ODLA_OPERATOR_OPS_PROCESS_H_