Skip to content

Commit 224c1ab

Browse files
author
Viviane Potocnik
committed
[tests] Add FA-2 tests to GitLab CI
1 parent ffb2dcc commit 224c1ab

File tree

3 files changed

+12
-9
lines changed

3 files changed

+12
-9
lines changed

.gitlab-ci.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ snitch-cluster-vsim:
106106
# Run additional, more extensive tests
107107
- cd sw/apps/blas/gemm/test
108108
- ./run.py runs.yaml --cfg $PWD/cfg/* --simulator vsim -j
109+
- cd ../../../dnn/flashattention_2/test
110+
# FP8 FA-2 tests are failing with precision mismatch
111+
# due to operand ordering
112+
- ./run.py runs.yaml --cfg $PWD/cfg/fp32* --simulator vsim -j
113+
- ./run.py runs.yaml --cfg $PWD/cfg/fp16* --simulator vsim -j
109114

110115
# Banshee
111116
snitch-cluster-banshee:

sw/blas/gemm/src/gemm_fp16.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ void gemm_fp16_naive(uint32_t M, uint32_t N, uint32_t K, void* A_p,
1919
for (uint32_t m = 0; m < M; m++) {
2020
for (uint32_t n = 0; n < N; n++) {
2121
__fp16 c;
22-
if (beta != 0) {
22+
if (beta != 0) {
2323
c = C[m * ldC + n] * beta;
2424
} else {
2525
c = 0.0;

sw/blas/gemm/src/gemm_fp8.h

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,19 @@ void gemm_fp8_naive(uint32_t M, uint32_t N, uint32_t K, void* A_p, uint32_t ldA,
1818
for (uint32_t m = 0; m < M; m++) {
1919
for (uint32_t n = 0; n < N; n++) {
2020
char c;
21-
if (BETA != 0){
21+
if (BETA != 0) {
2222
c = C[m * ldC + n];
2323
// FIXME: get the correct beta value
24-
asm volatile (
24+
asm volatile(
2525
// "fmv.b.x ft0, %[beta]\n"
2626
"fcvt.b.s ft0, %[beta]\n"
2727
"fmv.b.x ft1, %[c]\n"
2828
"fmul.b ft2, ft0, ft1\n"
2929
"fmv.x.b %[c], ft2\n"
30-
: [c] "+r"(c)
31-
: [beta] "f"(1.0f)
32-
: "ft0", "ft1", "ft2"
33-
);
34-
}
35-
else{
30+
: [ c ] "+r"(c)
31+
: [ beta ] "f"(1.0f)
32+
: "ft0", "ft1", "ft2");
33+
} else {
3634
c = 0.0;
3735
}
3836
for (uint32_t k = 0; k < K; k++) {

0 commit comments

Comments
 (0)