|
40 | 40 | #error This library requires minimum ODLA version 0.5
|
41 | 41 | #endif
|
42 | 42 |
|
43 |
| -enum class alg_unary_eltwise { |
44 |
| - isnan, |
45 |
| - isinf, |
46 |
| - isinf_pos, |
47 |
| - isinf_neg, |
48 |
| - abs, |
49 |
| - acos, |
50 |
| - asin, |
51 |
| - atan, |
52 |
| - ceil, |
53 |
| - cos, |
54 |
| - cosh, |
55 |
| - sin, |
56 |
| - sinh, |
57 |
| - log, |
58 |
| - tan, |
59 |
| - tanh, |
60 |
| - sqrt, |
61 |
| - neg, |
62 |
| - acosh, |
63 |
| - asinh, |
64 |
| - atanh, |
65 |
| - reciprocal, |
66 |
| - sign, |
67 |
| -}; |
68 |
| - |
69 | 43 | struct _odla_context {
|
70 | 44 | odla_computation comp;
|
71 | 45 | std::unique_ptr<dnnl::stream> stream;
|
@@ -528,23 +502,6 @@ odla_value odla_GatherElements(odla_value data, const odla_value indices,
|
528 | 502 | return CreateValue(ret_mem, output_dims, id);
|
529 | 503 | }
|
530 | 504 |
|
531 |
| -static odla_value unary_eltwise_op( |
532 |
| - dnnl::algorithm algo, odla_value input, odla_float32 alpha, |
533 |
| - odla_float32 beta, const odla_value_id id, |
534 |
| - dnnl::primitive_attr attr = dnnl::primitive_attr()) { |
535 |
| - auto eltwise_d = |
536 |
| - dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, algo, |
537 |
| - input->mem.get_desc(), alpha, beta); |
538 |
| - auto pd = dnnl::eltwise_forward::primitive_desc(eltwise_d, attr, g_comp->eng); |
539 |
| - |
540 |
| - dnnl::primitive prim = dnnl::eltwise_forward(pd); |
541 |
| - auto ret_mem = dnnl::memory(input->mem.get_desc(), g_comp->eng); |
542 |
| - odla_value v = CreateValue(ret_mem, input->shape, id); |
543 |
| - add_op(prim, {{DNNL_ARG_SRC, input->mem}, {DNNL_ARG_DST, ret_mem}}); |
544 |
| - InterpretIfNeeded(); |
545 |
| - return v; |
546 |
| -} |
547 |
| - |
548 | 505 | static odla_value binary_eltwise_s32(dnnl::algorithm alg, dnnl::memory lhs_mem,
|
549 | 506 | dnnl::memory rhs_mem,
|
550 | 507 | odla_value_shape shape,
|
@@ -590,16 +547,6 @@ static odla_value binary_eltwise(dnnl::algorithm algo, odla_value lhs,
|
590 | 547 | return v;
|
591 | 548 | }
|
592 | 549 |
|
593 |
| -odla_value odla_Abs(odla_value input, const odla_value_id value_id) { |
594 |
| - return unary_eltwise_op(dnnl::algorithm::eltwise_abs, input, 0.f, 0.f, |
595 |
| - value_id); |
596 |
| -} |
597 |
| - |
598 |
| -odla_value odla_Tanh(odla_value input, const odla_value_id value_id) { |
599 |
| - return unary_eltwise_op(dnnl::algorithm::eltwise_tanh, input, 0.f, 0.f, |
600 |
| - value_id); |
601 |
| -} |
602 |
| - |
603 | 550 | odla_value odla_Add(odla_value lhs, odla_value rhs, const odla_value_id id) {
|
604 | 551 | return binary_eltwise(dnnl::algorithm::binary_add, lhs, rhs, id);
|
605 | 552 | }
|
@@ -903,227 +850,6 @@ odla_value odla_Shift(odla_value input, odla_value shift_amount,
|
903 | 850 | return v;
|
904 | 851 | }
|
905 | 852 |
|
906 |
| -template <typename T> |
907 |
| -static void unary_eltwise_T(alg_unary_eltwise alg, void* dst, const void* input, |
908 |
| - int n) { |
909 |
| - const T* input_t = static_cast<const T*>(input); |
910 |
| - Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>> in(input_t, n); |
911 |
| - T* dst_t = static_cast<T*>(dst); |
912 |
| - Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>> out(dst_t, n); |
913 |
| - switch (alg) { |
914 |
| - case alg_unary_eltwise::abs: |
915 |
| - out = in.abs(); |
916 |
| - break; |
917 |
| - case alg_unary_eltwise::neg: |
918 |
| - out = -in; |
919 |
| - break; |
920 |
| - case alg_unary_eltwise::sign: |
921 |
| - out = (0 < in).select(1, in); |
922 |
| - out = (0 > out).select(-1, out); |
923 |
| - break; |
924 |
| - case alg_unary_eltwise::ceil: |
925 |
| - out = in.ceil(); |
926 |
| - break; |
927 |
| - case alg_unary_eltwise::log: |
928 |
| - out = in.log(); |
929 |
| - break; |
930 |
| - case alg_unary_eltwise::sqrt: |
931 |
| - out = in.sqrt(); |
932 |
| - break; |
933 |
| - case alg_unary_eltwise::reciprocal: |
934 |
| - out = in.pow(-1); |
935 |
| - break; |
936 |
| - case alg_unary_eltwise::sin: |
937 |
| - out = in.sin(); |
938 |
| - break; |
939 |
| - case alg_unary_eltwise::cos: |
940 |
| - out = in.cos(); |
941 |
| - break; |
942 |
| - case alg_unary_eltwise::tan: |
943 |
| - out = in.tan(); |
944 |
| - break; |
945 |
| - case alg_unary_eltwise::acos: |
946 |
| - out = in.acos(); |
947 |
| - break; |
948 |
| - case alg_unary_eltwise::asin: |
949 |
| - out = in.asin(); |
950 |
| - break; |
951 |
| - case alg_unary_eltwise::asinh: |
952 |
| - out = in.asinh(); |
953 |
| - break; |
954 |
| - case alg_unary_eltwise::atan: |
955 |
| - out = in.atan(); |
956 |
| - break; |
957 |
| - case alg_unary_eltwise::atanh: |
958 |
| - out = in.atanh(); |
959 |
| - break; |
960 |
| - case alg_unary_eltwise::sinh: |
961 |
| - out = in.sinh(); |
962 |
| - break; |
963 |
| - case alg_unary_eltwise::tanh: |
964 |
| - out = in.tanh(); |
965 |
| - break; |
966 |
| - case alg_unary_eltwise::cosh: |
967 |
| - out = in.cosh(); |
968 |
| - break; |
969 |
| - case alg_unary_eltwise::acosh: |
970 |
| - out = in.acosh(); |
971 |
| - break; |
972 |
| - default: |
973 |
| - assert(0); |
974 |
| - } |
975 |
| -} |
976 |
| - |
977 |
| -template <typename T> |
978 |
| -static void unary_eltwise_bool(alg_unary_eltwise alg, void* dst, |
979 |
| - const void* input, int n) { |
980 |
| - const T* input_t = static_cast<const T*>(input); |
981 |
| - Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>> in(input_t, n); |
982 |
| - bool* dst_t = static_cast<bool*>(dst); |
983 |
| - Eigen::Map<Eigen::Array<bool, Eigen::Dynamic, 1>> out(dst_t, n); |
984 |
| - switch (alg) { |
985 |
| - case alg_unary_eltwise::isnan: |
986 |
| - out = in.isNaN(); |
987 |
| - break; |
988 |
| - case alg_unary_eltwise::isinf: |
989 |
| - out = in.isInf(); |
990 |
| - break; |
991 |
| - case alg_unary_eltwise::isinf_neg: |
992 |
| - out = in.isInf() && (in < 0); |
993 |
| - break; |
994 |
| - case alg_unary_eltwise::isinf_pos: |
995 |
| - out = in.isInf() && (in > 0); |
996 |
| - break; |
997 |
| - default: |
998 |
| - assert(0); |
999 |
| - } |
1000 |
| -} |
1001 |
| - |
1002 |
| -static odla_value odla_unary_eltwise(alg_unary_eltwise alg, odla_value input, |
1003 |
| - const odla_value_id value_id) { |
1004 |
| - // Extract type and size |
1005 |
| - auto elem_type = input->elem_type; |
1006 |
| - bool ret_bool = |
1007 |
| - (alg == alg_unary_eltwise::isnan || alg == alg_unary_eltwise::isinf || |
1008 |
| - alg == alg_unary_eltwise::isinf_neg || |
1009 |
| - alg == alg_unary_eltwise::isinf_pos); |
1010 |
| - if (ret_bool) { |
1011 |
| - elem_type = ODLA_BOOL; |
1012 |
| - } |
1013 |
| - int n = GetTotalElements(input->shape); |
1014 |
| - // Prepare destination memory |
1015 |
| - dnnl::memory dst_mem; |
1016 |
| - dnnl::memory::desc dst_md = getMemoryDesc({elem_type, input->shape}); |
1017 |
| - dst_mem = dnnl::memory(dst_md, g_comp->eng); |
1018 |
| - auto v = CreateValue(dst_mem, input->shape, value_id); |
1019 |
| - v->elem_type = elem_type; |
1020 |
| - // Create lambda operation |
1021 |
| - auto op = [alg, ret_bool, input, dst_mem, n] { |
1022 |
| - void* dst = dst_mem.get_data_handle(); |
1023 |
| - const void* data = input->mem.get_data_handle(); |
1024 |
| - if (input->elem_type == ODLA_FLOAT32) { |
1025 |
| - ret_bool ? unary_eltwise_bool<float>(alg, dst, data, n) |
1026 |
| - : unary_eltwise_T<float>(alg, dst, data, n); |
1027 |
| - } else if (input->elem_type == ODLA_FLOAT64) { |
1028 |
| - ret_bool ? unary_eltwise_bool<double>(alg, dst, data, n) |
1029 |
| - : unary_eltwise_T<double>(alg, dst, data, n); |
1030 |
| - } else if (input->elem_type == ODLA_UINT8) { |
1031 |
| - ret_bool ? unary_eltwise_bool<uint8_t>(alg, dst, data, n) |
1032 |
| - : unary_eltwise_T<uint8_t>(alg, dst, data, n); |
1033 |
| - } else if (input->elem_type == ODLA_UINT16) { |
1034 |
| - ret_bool ? unary_eltwise_bool<uint16_t>(alg, dst, data, n) |
1035 |
| - : unary_eltwise_T<uint16_t>(alg, dst, data, n); |
1036 |
| - } else if (input->elem_type == ODLA_UINT32) { |
1037 |
| - ret_bool ? unary_eltwise_bool<uint32_t>(alg, dst, data, n) |
1038 |
| - : unary_eltwise_T<uint32_t>(alg, dst, data, n); |
1039 |
| - } else if (input->elem_type == ODLA_UINT64) { |
1040 |
| - ret_bool ? unary_eltwise_bool<uint64_t>(alg, dst, data, n) |
1041 |
| - : unary_eltwise_T<uint64_t>(alg, dst, data, n); |
1042 |
| - } else { |
1043 |
| - assert(0); |
1044 |
| - } |
1045 |
| - }; |
1046 |
| - // Postprocess |
1047 |
| - add_op(op); |
1048 |
| - InterpretIfNeeded(); |
1049 |
| - return v; |
1050 |
| -} |
1051 |
| - |
1052 |
| -odla_value odla_IsNaN(odla_value input, const odla_value_id value_id) { |
1053 |
| - return odla_unary_eltwise(alg_unary_eltwise::isnan, input, value_id); |
1054 |
| -} |
1055 |
| - |
1056 |
| -odla_value odla_IsInf(odla_value input, odla_bool detect_pos, |
1057 |
| - odla_bool detect_neg, const odla_value_id value_id) { |
1058 |
| - if (detect_pos != 0 && detect_neg != 0) { |
1059 |
| - return odla_unary_eltwise(alg_unary_eltwise::isinf, input, value_id); |
1060 |
| - } |
1061 |
| - if (detect_pos != 0) { |
1062 |
| - return odla_unary_eltwise(alg_unary_eltwise::isinf_pos, input, value_id); |
1063 |
| - } |
1064 |
| - return odla_unary_eltwise(alg_unary_eltwise::isinf_neg, input, value_id); |
1065 |
| -} |
1066 |
| - |
1067 |
| -odla_value odla_Cos(odla_value input, const odla_value_id value_id) { |
1068 |
| - return odla_unary_eltwise(alg_unary_eltwise::cos, input, value_id); |
1069 |
| -} |
1070 |
| - |
1071 |
| -odla_value odla_Sin(odla_value input, const odla_value_id value_id) { |
1072 |
| - return odla_unary_eltwise(alg_unary_eltwise::sin, input, value_id); |
1073 |
| -} |
1074 |
| - |
1075 |
| -odla_value odla_Tan(odla_value input, const odla_value_id value_id) { |
1076 |
| - return odla_unary_eltwise(alg_unary_eltwise::tan, input, value_id); |
1077 |
| -} |
1078 |
| - |
1079 |
| -odla_value odla_ACos(odla_value input, const odla_value_id value_id) { |
1080 |
| - return odla_unary_eltwise(alg_unary_eltwise::acos, input, value_id); |
1081 |
| -} |
1082 |
| - |
1083 |
| -odla_value odla_ACosh(odla_value input, const odla_value_id value_id) { |
1084 |
| - return odla_unary_eltwise(alg_unary_eltwise::acosh, input, value_id); |
1085 |
| -} |
1086 |
| - |
1087 |
| -odla_value odla_ASin(odla_value input, const odla_value_id value_id) { |
1088 |
| - return odla_unary_eltwise(alg_unary_eltwise::asin, input, value_id); |
1089 |
| -} |
1090 |
| - |
1091 |
| -odla_value odla_ASinh(odla_value input, const odla_value_id value_id) { |
1092 |
| - return odla_unary_eltwise(alg_unary_eltwise::asinh, input, value_id); |
1093 |
| -} |
1094 |
| - |
1095 |
| -odla_value odla_ATan(odla_value input, const odla_value_id value_id) { |
1096 |
| - return odla_unary_eltwise(alg_unary_eltwise::atan, input, value_id); |
1097 |
| -} |
1098 |
| - |
1099 |
| -odla_value odla_ATanh(odla_value input, const odla_value_id value_id) { |
1100 |
| - return odla_unary_eltwise(alg_unary_eltwise::atanh, input, value_id); |
1101 |
| -} |
1102 |
| - |
1103 |
| -odla_value odla_Sinh(odla_value input, const odla_value_id value_id) { |
1104 |
| - return odla_unary_eltwise(alg_unary_eltwise::sinh, input, value_id); |
1105 |
| -} |
1106 |
| - |
1107 |
| -odla_value odla_Cosh(odla_value input, const odla_value_id value_id) { |
1108 |
| - return odla_unary_eltwise(alg_unary_eltwise::cosh, input, value_id); |
1109 |
| -} |
1110 |
| - |
1111 |
| -odla_value odla_Ceil(odla_value input, const odla_value_id value_id) { |
1112 |
| - return odla_unary_eltwise(alg_unary_eltwise::ceil, input, value_id); |
1113 |
| -} |
1114 |
| - |
1115 |
| -odla_value odla_Neg(odla_value input, const odla_value_id value_id) { |
1116 |
| - return odla_unary_eltwise(alg_unary_eltwise::neg, input, value_id); |
1117 |
| -} |
1118 |
| - |
1119 |
| -odla_value odla_Reciprocal(odla_value input, const odla_value_id value_id) { |
1120 |
| - return odla_unary_eltwise(alg_unary_eltwise::reciprocal, input, value_id); |
1121 |
| -} |
1122 |
| - |
1123 |
| -odla_value odla_Sign(odla_value input, const odla_value_id value_id) { |
1124 |
| - return odla_unary_eltwise(alg_unary_eltwise::sign, input, value_id); |
1125 |
| -} |
1126 |
| - |
1127 | 853 | odla_value odla_Conv(odla_value input, odla_memory_layout input_layout,
|
1128 | 854 | odla_uint32 group, odla_value kernel,
|
1129 | 855 | odla_memory_layout kernel_layout,
|
|
0 commit comments