Skip to content

Commit 482744b

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Complex Support: optimized bmm
Reviewed By: swolchok Differential Revision: D73052497
1 parent e42dafc commit 482744b

File tree

4 files changed

+137
-8
lines changed

4 files changed

+137
-8
lines changed

kernels/optimized/blas/CPUBlas.cpp

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
// clang-format off
1818
extern "C" void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, const double *a, int *lda, const double *b, int *ldb, double *beta, double *c, int *ldc);
1919
extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, const float *a, int *lda, const float *b, int *ldb, float *beta, float *c, int *ldc);
20+
extern "C" void cgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc);
21+
extern "C" void zgemm_(char *transa, char *transb, int *m, int *n, int *k, void *alpha, const void *a, int *lda, const void *b, int *ldb, void *beta, void *c, int *ldc);
2022
// clang-format on
2123
#endif // ET_BUILD_FOR_APPLE
2224
#endif // ET_BUILD_WITH_BLAS
@@ -26,6 +28,7 @@ namespace cpublas {
2628

2729
using executorch::aten::BFloat16;
2830
using executorch::aten::Half;
31+
using executorch::aten::complex;
2932

3033
#ifdef ET_BUILD_WITH_BLAS
3134
#ifdef ET_BUILD_FOR_APPLE
@@ -197,5 +200,100 @@ void gemm(
197200
}
198201
// clang-format on
199202

203+
// clang-format off
204+
void gemm(
205+
TransposeType transa, TransposeType transb,
206+
int64_t m, int64_t n, int64_t k,
207+
const complex<double> alpha,
208+
const complex<double> *a, int64_t lda,
209+
const complex<double> *b, int64_t ldb,
210+
const complex<double> beta,
211+
complex<double> *c, int64_t ldc) {
212+
normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
213+
#if defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE)
214+
complex<double> alpha_ = alpha, beta_ = beta;
215+
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
216+
char transa_ = to_blas(transa), transb_ = to_blas(transb);
217+
zgemm_(
218+
&transa_, &transb_,
219+
&m_, &n_, &k_,
220+
&alpha_,
221+
a, &lda_,
222+
b, &ldb_,
223+
&beta_,
224+
c, &ldc_);
225+
#else
226+
using acc_type = utils::compute_dtype<complex<double>>;
227+
gemm_impl(
228+
transa, transb,
229+
m, n, k,
230+
static_cast<const acc_type>(alpha),
231+
a, lda,
232+
b, ldb,
233+
static_cast<const acc_type>(beta),
234+
c, ldc);
235+
#endif
236+
}
237+
// clang-format on
238+
239+
// clang-format off
240+
void gemm(
241+
TransposeType transa, TransposeType transb,
242+
int64_t m, int64_t n, int64_t k,
243+
const complex<float> alpha,
244+
const complex<float> *a, int64_t lda,
245+
const complex<float> *b, int64_t ldb,
246+
const complex<float> beta,
247+
complex<float> *c, int64_t ldc) {
248+
normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
249+
#if defined(ET_BUILD_WITH_BLAS) && !defined(ET_BUILD_FOR_APPLE)
250+
complex<float> alpha_ = alpha, beta_ = beta;
251+
int m_ = m, n_ = n, k_ = k, lda_ = lda, ldb_ = ldb, ldc_ = ldc;
252+
char transa_ = to_blas(transa), transb_ = to_blas(transb);
253+
cgemm_(
254+
&transa_, &transb_,
255+
&m_, &n_, &k_,
256+
&alpha_,
257+
a, &lda_,
258+
b, &ldb_,
259+
&beta_,
260+
c, &ldc_);
261+
#else
262+
using acc_type = utils::compute_dtype<complex<float>>;
263+
gemm_impl(
264+
transa, transb,
265+
m, n, k,
266+
static_cast<const acc_type>(alpha),
267+
a, lda,
268+
b, ldb,
269+
static_cast<const acc_type>(beta),
270+
c, ldc);
271+
#endif
272+
}
273+
// clang-format on
274+
275+
// clang-format off
276+
void gemm(
277+
TransposeType transa, TransposeType transb,
278+
int64_t m, int64_t n, int64_t k,
279+
const complex<Half> alpha,
280+
const complex<Half> *a, int64_t lda,
281+
const complex<Half> *b, int64_t ldb,
282+
const complex<Half> beta,
283+
complex<Half> *c, int64_t ldc) {
284+
normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc);
285+
286+
using acc_type = utils::compute_dtype<complex<Half>>;
287+
gemm_impl(
288+
transa, transb,
289+
m, n, k,
290+
static_cast<const acc_type>(alpha),
291+
a, lda,
292+
b, ldb,
293+
static_cast<const acc_type>(beta),
294+
c, ldc);
295+
}
296+
// clang-format on
297+
200298
} // namespace cpublas
201299
} // namespace executorch

kernels/optimized/blas/CPUBlas.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,33 @@ void gemm(
111111
const executorch::aten::BFloat16 *b, int64_t ldb,
112112
const executorch::aten::BFloat16 beta,
113113
executorch::aten::BFloat16 *c, int64_t ldc);
114+
115+
void gemm(
116+
TransposeType transa, TransposeType transb,
117+
int64_t m, int64_t n, int64_t k,
118+
const executorch::aten::complex<double> alpha,
119+
const executorch::aten::complex<double> *a, int64_t lda,
120+
const executorch::aten::complex<double> *b, int64_t ldb,
121+
const executorch::aten::complex<double> beta,
122+
executorch::aten::complex<double> *c, int64_t ldc);
123+
124+
void gemm(
125+
TransposeType transa, TransposeType transb,
126+
int64_t m, int64_t n, int64_t k,
127+
const executorch::aten::complex<float> alpha,
128+
const executorch::aten::complex<float> *a, int64_t lda,
129+
const executorch::aten::complex<float> *b, int64_t ldb,
130+
const executorch::aten::complex<float> beta,
131+
executorch::aten::complex<float> *c, int64_t ldc);
132+
133+
void gemm(
134+
TransposeType transa, TransposeType transb,
135+
int64_t m, int64_t n, int64_t k,
136+
const executorch::aten::complex<executorch::aten::Half> alpha,
137+
const executorch::aten::complex<executorch::aten::Half> *a, int64_t lda,
138+
const executorch::aten::complex<executorch::aten::Half> *b, int64_t ldb,
139+
const executorch::aten::complex<executorch::aten::Half> beta,
140+
executorch::aten::complex<executorch::aten::Half> *c, int64_t ldc);
114141
// clang-format on
115142

116143
// clang-format off

kernels/optimized/cpu/op_bmm.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ Tensor& opt_bmm_out(
155155

156156
if (executorch::runtime::isComplexType(self_type)) {
157157
ET_SWITCH_COMPLEXH_TYPES(self_type, ctx, name, CTYPE, [&]() {
158-
internal::bmm_out_impl<CTYPE>(self, mat2, out);
158+
bmm_kernel<CTYPE>(self, mat2, out);
159159
});
160160
} else {
161161
ET_SWITCH_REALH_TYPES(self_type, ctx, name, CTYPE, [&]() {

kernels/optimized/test/libblas_test.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,17 @@
1313

1414
#include <vector>
1515

16-
#define TEST_FORALL_SUPPORTED_CTYPES(_, N) \
17-
_<double, N>(); \
18-
_<float, N>(); \
19-
_<int64_t, N>(); \
20-
_<uint8_t, N>(); \
21-
_<int32_t, N>(); \
22-
_<executorch::aten::BFloat16, N>();
16+
#define TEST_FORALL_SUPPORTED_CTYPES(_, N) \
17+
_<double, N>(); \
18+
_<float, N>(); \
19+
_<int64_t, N>(); \
20+
_<uint8_t, N>(); \
21+
_<int32_t, N>(); \
22+
_<executorch::aten::Half, N>(); \
23+
_<executorch::aten::BFloat16, N>(); \
24+
_<executorch::aten::complex<double>, N>(); \
25+
_<executorch::aten::complex<float>, N>(); \
26+
_<executorch::aten::complex<executorch::aten::Half>, N>();
2327

2428
namespace {
2529

0 commit comments

Comments
 (0)