Skip to content

Commit 54e80f3

Browse files
tianbohweimingzha0
authored andcommitted
Refactored unary operator, and provid support for logic NOT operator.
1 parent 1f424d3 commit 54e80f3

File tree

9 files changed

+326
-281
lines changed

9 files changed

+326
-281
lines changed

ODLA/platforms/dnnl/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ set(ODLA_DNNL_SRC
3131
odla_dnnl_cast.cc
3232
odla_dnnl_loss.cc
3333
odla_dnnl_rnn.cc
34+
odla_dnnl_unary.cc
3435
odla_dnnl_binary.cc
3536
odla_dnnl_statistics.cc
3637
)

ODLA/platforms/dnnl/odla_dnnl.cc

-274
Original file line numberDiff line numberDiff line change
@@ -40,32 +40,6 @@
4040
#error This library requires minimum ODLA version 0.5
4141
#endif
4242

43-
enum class alg_unary_eltwise {
44-
isnan,
45-
isinf,
46-
isinf_pos,
47-
isinf_neg,
48-
abs,
49-
acos,
50-
asin,
51-
atan,
52-
ceil,
53-
cos,
54-
cosh,
55-
sin,
56-
sinh,
57-
log,
58-
tan,
59-
tanh,
60-
sqrt,
61-
neg,
62-
acosh,
63-
asinh,
64-
atanh,
65-
reciprocal,
66-
sign,
67-
};
68-
6943
struct _odla_context {
7044
odla_computation comp;
7145
std::unique_ptr<dnnl::stream> stream;
@@ -528,23 +502,6 @@ odla_value odla_GatherElements(odla_value data, const odla_value indices,
528502
return CreateValue(ret_mem, output_dims, id);
529503
}
530504

531-
static odla_value unary_eltwise_op(
532-
dnnl::algorithm algo, odla_value input, odla_float32 alpha,
533-
odla_float32 beta, const odla_value_id id,
534-
dnnl::primitive_attr attr = dnnl::primitive_attr()) {
535-
auto eltwise_d =
536-
dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo,
537-
input->mem.get_desc(), alpha, beta);
538-
auto pd = dnnl::eltwise_forward::primitive_desc(eltwise_d, attr, g_comp->eng);
539-
540-
dnnl::primitive prim = dnnl::eltwise_forward(pd);
541-
auto ret_mem = dnnl::memory(input->mem.get_desc(), g_comp->eng);
542-
odla_value v = CreateValue(ret_mem, input->shape, id);
543-
add_op(prim, {{DNNL_ARG_SRC, input->mem}, {DNNL_ARG_DST, ret_mem}});
544-
InterpretIfNeeded();
545-
return v;
546-
}
547-
548505
static odla_value binary_eltwise_s32(dnnl::algorithm alg, dnnl::memory lhs_mem,
549506
dnnl::memory rhs_mem,
550507
odla_value_shape shape,
@@ -590,16 +547,6 @@ static odla_value binary_eltwise(dnnl::algorithm algo, odla_value lhs,
590547
return v;
591548
}
592549

593-
odla_value odla_Abs(odla_value input, const odla_value_id value_id) {
594-
return unary_eltwise_op(dnnl::algorithm::eltwise_abs, input, 0.f, 0.f,
595-
value_id);
596-
}
597-
598-
odla_value odla_Tanh(odla_value input, const odla_value_id value_id) {
599-
return unary_eltwise_op(dnnl::algorithm::eltwise_tanh, input, 0.f, 0.f,
600-
value_id);
601-
}
602-
603550
odla_value odla_Add(odla_value lhs, odla_value rhs, const odla_value_id id) {
604551
return binary_eltwise(dnnl::algorithm::binary_add, lhs, rhs, id);
605552
}
@@ -903,227 +850,6 @@ odla_value odla_Shift(odla_value input, odla_value shift_amount,
903850
return v;
904851
}
905852

906-
template <typename T>
907-
static void unary_eltwise_T(alg_unary_eltwise alg, void* dst, const void* input,
908-
int n) {
909-
const T* input_t = static_cast<const T*>(input);
910-
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>> in(input_t, n);
911-
T* dst_t = static_cast<T*>(dst);
912-
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> out(dst_t, n);
913-
switch (alg) {
914-
case alg_unary_eltwise::abs:
915-
out = in.abs();
916-
break;
917-
case alg_unary_eltwise::neg:
918-
out = -in;
919-
break;
920-
case alg_unary_eltwise::sign:
921-
out = (0 < in).select(1, in);
922-
out = (0 > out).select(-1, out);
923-
break;
924-
case alg_unary_eltwise::ceil:
925-
out = in.ceil();
926-
break;
927-
case alg_unary_eltwise::log:
928-
out = in.log();
929-
break;
930-
case alg_unary_eltwise::sqrt:
931-
out = in.sqrt();
932-
break;
933-
case alg_unary_eltwise::reciprocal:
934-
out = in.pow(-1);
935-
break;
936-
case alg_unary_eltwise::sin:
937-
out = in.sin();
938-
break;
939-
case alg_unary_eltwise::cos:
940-
out = in.cos();
941-
break;
942-
case alg_unary_eltwise::tan:
943-
out = in.tan();
944-
break;
945-
case alg_unary_eltwise::acos:
946-
out = in.acos();
947-
break;
948-
case alg_unary_eltwise::asin:
949-
out = in.asin();
950-
break;
951-
case alg_unary_eltwise::asinh:
952-
out = in.asinh();
953-
break;
954-
case alg_unary_eltwise::atan:
955-
out = in.atan();
956-
break;
957-
case alg_unary_eltwise::atanh:
958-
out = in.atanh();
959-
break;
960-
case alg_unary_eltwise::sinh:
961-
out = in.sinh();
962-
break;
963-
case alg_unary_eltwise::tanh:
964-
out = in.tanh();
965-
break;
966-
case alg_unary_eltwise::cosh:
967-
out = in.cosh();
968-
break;
969-
case alg_unary_eltwise::acosh:
970-
out = in.acosh();
971-
break;
972-
default:
973-
assert(0);
974-
}
975-
}
976-
977-
template <typename T>
978-
static void unary_eltwise_bool(alg_unary_eltwise alg, void* dst,
979-
const void* input, int n) {
980-
const T* input_t = static_cast<const T*>(input);
981-
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>> in(input_t, n);
982-
bool* dst_t = static_cast<bool*>(dst);
983-
Eigen::Map<Eigen::Array<bool, Eigen::Dynamic, 1>> out(dst_t, n);
984-
switch (alg) {
985-
case alg_unary_eltwise::isnan:
986-
out = in.isNaN();
987-
break;
988-
case alg_unary_eltwise::isinf:
989-
out = in.isInf();
990-
break;
991-
case alg_unary_eltwise::isinf_neg:
992-
out = in.isInf() && (in < 0);
993-
break;
994-
case alg_unary_eltwise::isinf_pos:
995-
out = in.isInf() && (in > 0);
996-
break;
997-
default:
998-
assert(0);
999-
}
1000-
}
1001-
1002-
static odla_value odla_unary_eltwise(alg_unary_eltwise alg, odla_value input,
1003-
const odla_value_id value_id) {
1004-
// Extract type and size
1005-
auto elem_type = input->elem_type;
1006-
bool ret_bool =
1007-
(alg == alg_unary_eltwise::isnan || alg == alg_unary_eltwise::isinf ||
1008-
alg == alg_unary_eltwise::isinf_neg ||
1009-
alg == alg_unary_eltwise::isinf_pos);
1010-
if (ret_bool) {
1011-
elem_type = ODLA_BOOL;
1012-
}
1013-
int n = GetTotalElements(input->shape);
1014-
// Prepare destination memory
1015-
dnnl::memory dst_mem;
1016-
dnnl::memory::desc dst_md = getMemoryDesc({elem_type, input->shape});
1017-
dst_mem = dnnl::memory(dst_md, g_comp->eng);
1018-
auto v = CreateValue(dst_mem, input->shape, value_id);
1019-
v->elem_type = elem_type;
1020-
// Create lambda operation
1021-
auto op = [alg, ret_bool, input, dst_mem, n] {
1022-
void* dst = dst_mem.get_data_handle();
1023-
const void* data = input->mem.get_data_handle();
1024-
if (input->elem_type == ODLA_FLOAT32) {
1025-
ret_bool ? unary_eltwise_bool<float>(alg, dst, data, n)
1026-
: unary_eltwise_T<float>(alg, dst, data, n);
1027-
} else if (input->elem_type == ODLA_FLOAT64) {
1028-
ret_bool ? unary_eltwise_bool<double>(alg, dst, data, n)
1029-
: unary_eltwise_T<double>(alg, dst, data, n);
1030-
} else if (input->elem_type == ODLA_UINT8) {
1031-
ret_bool ? unary_eltwise_bool<uint8_t>(alg, dst, data, n)
1032-
: unary_eltwise_T<uint8_t>(alg, dst, data, n);
1033-
} else if (input->elem_type == ODLA_UINT16) {
1034-
ret_bool ? unary_eltwise_bool<uint16_t>(alg, dst, data, n)
1035-
: unary_eltwise_T<uint16_t>(alg, dst, data, n);
1036-
} else if (input->elem_type == ODLA_UINT32) {
1037-
ret_bool ? unary_eltwise_bool<uint32_t>(alg, dst, data, n)
1038-
: unary_eltwise_T<uint32_t>(alg, dst, data, n);
1039-
} else if (input->elem_type == ODLA_UINT64) {
1040-
ret_bool ? unary_eltwise_bool<uint64_t>(alg, dst, data, n)
1041-
: unary_eltwise_T<uint64_t>(alg, dst, data, n);
1042-
} else {
1043-
assert(0);
1044-
}
1045-
};
1046-
// Postprocess
1047-
add_op(op);
1048-
InterpretIfNeeded();
1049-
return v;
1050-
}
1051-
1052-
odla_value odla_IsNaN(odla_value input, const odla_value_id value_id) {
1053-
return odla_unary_eltwise(alg_unary_eltwise::isnan, input, value_id);
1054-
}
1055-
1056-
odla_value odla_IsInf(odla_value input, odla_bool detect_pos,
1057-
odla_bool detect_neg, const odla_value_id value_id) {
1058-
if (detect_pos != 0 && detect_neg != 0) {
1059-
return odla_unary_eltwise(alg_unary_eltwise::isinf, input, value_id);
1060-
}
1061-
if (detect_pos != 0) {
1062-
return odla_unary_eltwise(alg_unary_eltwise::isinf_pos, input, value_id);
1063-
}
1064-
return odla_unary_eltwise(alg_unary_eltwise::isinf_neg, input, value_id);
1065-
}
1066-
1067-
odla_value odla_Cos(odla_value input, const odla_value_id value_id) {
1068-
return odla_unary_eltwise(alg_unary_eltwise::cos, input, value_id);
1069-
}
1070-
1071-
odla_value odla_Sin(odla_value input, const odla_value_id value_id) {
1072-
return odla_unary_eltwise(alg_unary_eltwise::sin, input, value_id);
1073-
}
1074-
1075-
odla_value odla_Tan(odla_value input, const odla_value_id value_id) {
1076-
return odla_unary_eltwise(alg_unary_eltwise::tan, input, value_id);
1077-
}
1078-
1079-
odla_value odla_ACos(odla_value input, const odla_value_id value_id) {
1080-
return odla_unary_eltwise(alg_unary_eltwise::acos, input, value_id);
1081-
}
1082-
1083-
odla_value odla_ACosh(odla_value input, const odla_value_id value_id) {
1084-
return odla_unary_eltwise(alg_unary_eltwise::acosh, input, value_id);
1085-
}
1086-
1087-
odla_value odla_ASin(odla_value input, const odla_value_id value_id) {
1088-
return odla_unary_eltwise(alg_unary_eltwise::asin, input, value_id);
1089-
}
1090-
1091-
odla_value odla_ASinh(odla_value input, const odla_value_id value_id) {
1092-
return odla_unary_eltwise(alg_unary_eltwise::asinh, input, value_id);
1093-
}
1094-
1095-
odla_value odla_ATan(odla_value input, const odla_value_id value_id) {
1096-
return odla_unary_eltwise(alg_unary_eltwise::atan, input, value_id);
1097-
}
1098-
1099-
odla_value odla_ATanh(odla_value input, const odla_value_id value_id) {
1100-
return odla_unary_eltwise(alg_unary_eltwise::atanh, input, value_id);
1101-
}
1102-
1103-
odla_value odla_Sinh(odla_value input, const odla_value_id value_id) {
1104-
return odla_unary_eltwise(alg_unary_eltwise::sinh, input, value_id);
1105-
}
1106-
1107-
odla_value odla_Cosh(odla_value input, const odla_value_id value_id) {
1108-
return odla_unary_eltwise(alg_unary_eltwise::cosh, input, value_id);
1109-
}
1110-
1111-
odla_value odla_Ceil(odla_value input, const odla_value_id value_id) {
1112-
return odla_unary_eltwise(alg_unary_eltwise::ceil, input, value_id);
1113-
}
1114-
1115-
odla_value odla_Neg(odla_value input, const odla_value_id value_id) {
1116-
return odla_unary_eltwise(alg_unary_eltwise::neg, input, value_id);
1117-
}
1118-
1119-
odla_value odla_Reciprocal(odla_value input, const odla_value_id value_id) {
1120-
return odla_unary_eltwise(alg_unary_eltwise::reciprocal, input, value_id);
1121-
}
1122-
1123-
odla_value odla_Sign(odla_value input, const odla_value_id value_id) {
1124-
return odla_unary_eltwise(alg_unary_eltwise::sign, input, value_id);
1125-
}
1126-
1127853
odla_value odla_Conv(odla_value input, odla_memory_layout input_layout,
1128854
odla_uint32 group, odla_value kernel,
1129855
odla_memory_layout kernel_layout,

ODLA/platforms/dnnl/odla_dnnl.h

+17
Original file line numberDiff line numberDiff line change
@@ -375,4 +375,21 @@ static inline std::pair<dnnl::memory, dnnl::memory> broadcast_operands(
375375
};
376376
}
377377

378+
static inline odla_value unary_eltwise_op(
379+
dnnl::algorithm algo, odla_value input, odla_float32 alpha,
380+
odla_float32 beta, const odla_value_id id,
381+
dnnl::primitive_attr attr = dnnl::primitive_attr()) {
382+
auto eltwise_d =
383+
dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo,
384+
input->mem.get_desc(), alpha, beta);
385+
auto pd = dnnl::eltwise_forward::primitive_desc(eltwise_d, attr, g_comp->eng);
386+
387+
dnnl::primitive prim = dnnl::eltwise_forward(pd);
388+
auto ret_mem = dnnl::memory(input->mem.get_desc(), g_comp->eng);
389+
odla_value v = CreateValue(ret_mem, input->shape, id);
390+
add_op(prim, {{DNNL_ARG_SRC, input->mem}, {DNNL_ARG_DST, ret_mem}});
391+
InterpretIfNeeded();
392+
return v;
393+
}
394+
378395
#endif // ODLA_DNNL_H_

ODLA/platforms/dnnl/odla_dnnl_binary.cc

-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
enum class alg_binary_eltwise {
2323
logic_or,
2424
logic_and,
25-
logic_not,
2625
logic_xor,
2726
cmp_equal,
2827
cmp_less,
@@ -114,7 +113,6 @@ static void binary_eltwise_T(alg_binary_eltwise alg, void* dst,
114113
bool binary_ret_bool(alg_binary_eltwise alg) {
115114
return (alg == alg_binary_eltwise::logic_or) ||
116115
(alg == alg_binary_eltwise::logic_and) ||
117-
(alg == alg_binary_eltwise::logic_not) ||
118116
(alg == alg_binary_eltwise::logic_xor) ||
119117
(alg == alg_binary_eltwise::cmp_equal) ||
120118
(alg == alg_binary_eltwise::cmp_less) ||

0 commit comments

Comments
 (0)