Skip to content

Commit 52de40b

Browse files
committed
Add FP8 support to gguf/llama:
E5M2 & E4M3: for use with FP8 distributed model E4M3_Q & E3M4_Q: for gguf quantized model. E5M2 and A4M3 type are use like FP16 / BF16 native. E4M3_Q and E3M4_Q are define like Q8_0 with bloc size of 256 (like QK_K)
1 parent a402322 commit 52de40b

18 files changed

+555
-111
lines changed

CMakeLists.txt

+4
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ if (NOT DEFINED GGML_LLAMAFILE)
8888
set(GGML_LLAMAFILE_DEFAULT ON)
8989
endif()
9090

91+
if (NOT DEFINED GGML_OPENMP_SIMD)
92+
set(GGML_OPENMP_SIMD_DEFAULT ON)
93+
endif()
94+
9195
if (NOT DEFINED GGML_AMX)
9296
set(GGML_AMX ON)
9397
endif()

Makefile

+19-1
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ GGML_NO_OPENMP := 1
138138
DEPRECATE_WARNING := 1
139139
endif
140140

141+
ifdef LLAMA_NO_OPENMP_SIMD
142+
GGML_NO_OPENMP_SIMD := 1
143+
endif
144+
141145
ifdef LLAMA_NO_METAL
142146
GGML_NO_METAL := 1
143147
DEPRECATE_WARNING := 1
@@ -548,6 +552,13 @@ ifndef GGML_NO_OPENMP
548552
endif # GGML_MUSA
549553
endif # GGML_NO_OPENMP
550554

555+
ifndef GGML_NO_OPENMP_SIMD
556+
MK_CPPFLAGS += -DGGML_USE_OPENMP_SIMD
557+
MK_CFLAGS += -fopenmp-simd
558+
MK_CXXFLAGS += -fopenmp-simd
559+
# -openmp:experimental pour MSVC?
560+
endif # GGML_NO_OPENMP_SIMD
561+
551562
ifdef GGML_OPENBLAS
552563
MK_CPPFLAGS += -DGGML_USE_BLAS $(shell pkg-config --cflags-only-I openblas)
553564
MK_CFLAGS += $(shell pkg-config --cflags-only-other openblas)
@@ -919,7 +930,8 @@ OBJ_GGML += \
919930
ggml/src/ggml-alloc.o \
920931
ggml/src/ggml-backend.o \
921932
ggml/src/ggml-quants.o \
922-
ggml/src/ggml-aarch64.o
933+
ggml/src/ggml-aarch64.o \
934+
ggml/src/ggml-fp8.o
923935

924936
OBJ_LLAMA = \
925937
src/llama.o \
@@ -1080,6 +1092,12 @@ ggml/src/ggml-aarch64.o: \
10801092
ggml/src/ggml-common.h
10811093
$(CC) $(CFLAGS) -c $< -o $@
10821094

1095+
ggml/src/ggml-fp8.o: \
1096+
ggml/src/ggml-fp8.cpp \
1097+
ggml/src/ggml-fp8.h \
1098+
ggml/src/ggml-common.h
1099+
$(CXX) $(CXXFLAGS) -c $< -o $@
1100+
10831101
ggml/src/ggml-blas.o: \
10841102
ggml/src/ggml-blas.cpp \
10851103
ggml/include/ggml-blas.h

Package.swift

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ var sources = [
1515
"ggml/src/ggml-backend.cpp",
1616
"ggml/src/ggml-quants.c",
1717
"ggml/src/ggml-aarch64.c",
18+
"ggml/src/ggml-fp8.cpp",
1819
]
1920

2021
var resources: [Resource] = []

examples/quantize/quantize.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
5151
{ "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
5252
{ "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
5353
{ "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", },
54+
{ "E4M3_Q", LLAMA_FTYPE_MOSTLY_E4M3_Q, "12.21G, 0.0050 kld @ Mistral-Nemo", },
55+
{ "E3M4_Q", LLAMA_FTYPE_MOSTLY_E3M4_Q, "12.21G, 0.0016 kld @ Mistral-Nemo", },
5456
{ "F16", LLAMA_FTYPE_MOSTLY_F16, "14.00G, +0.0020 ppl @ Mistral-7B", },
5557
{ "BF16", LLAMA_FTYPE_MOSTLY_BF16, "14.00G, -0.0050 ppl @ Mistral-7B", },
5658
{ "F32", LLAMA_FTYPE_ALL_F32, "26.00G @ 7B", },

ggml/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ if (NOT GGML_LLAMAFILE_DEFAULT)
6161
set(GGML_LLAMAFILE_DEFAULT OFF)
6262
endif()
6363

64+
if (NOT GGML_OPENMP_SIMD_DEFAULT)
65+
set(GGML_OPENMP_SIMD_DEFAULT OFF)
66+
endif()
67+
6468
if (NOT GGML_CUDA_GRAPHS_DEFAULT)
6569
set(GGML_CUDA_GRAPHS_DEFAULT OFF)
6670
endif()
@@ -109,6 +113,7 @@ endif()
109113
option(GGML_LASX "ggml: enable lasx" ON)
110114
option(GGML_LSX "ggml: enable lsx" ON)
111115
option(GGML_SVE "ggml: enable SVE" OFF)
116+
option(GGML_OPENMP_SIMD "ggml: enable OPENMP_SIMD" ${GGML_OPENMP_SIMD_DEFAULT})
112117

113118
if (WIN32)
114119
set(GGML_WIN_VER "0x602" CACHE STRING "ggml: Windows Version")

ggml/include/ggml.h

+8
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,10 @@ extern "C" {
389389
GGML_TYPE_Q4_0_8_8 = 33,
390390
GGML_TYPE_TQ1_0 = 34,
391391
GGML_TYPE_TQ2_0 = 35,
392+
GGML_TYPE_E5M2 = 36,
393+
GGML_TYPE_E4M3 = 37,
394+
GGML_TYPE_E4M3_Q = 38,
395+
GGML_TYPE_E3M4_Q = 39,
392396
GGML_TYPE_COUNT,
393397
};
394398

@@ -433,6 +437,10 @@ extern "C" {
433437
GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
434438
GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
435439
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
440+
GGML_FTYPE_MOSTLY_E5M2 = 28, // except 1d tensors
441+
GGML_FTYPE_MOSTLY_E4M3 = 29, // except 1d tensors
442+
GGML_FTYPE_MOSTLY_E4M3_Q = 30, // except 1d tensors
443+
GGML_FTYPE_MOSTLY_E3M4_Q = 31, // except 1d tensors
436444
};
437445

438446
// available tensor operations:

ggml/src/CMakeLists.txt

+22
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ if (GGML_MUSA)
154154
endif()
155155

156156
if (GGML_OPENMP)
157+
# set(OpenMP_RUNTIME_MSVC "experimental")
157158
find_package(OpenMP)
158159
if (OpenMP_FOUND)
159160
message(STATUS "OpenMP found")
@@ -171,6 +172,18 @@ if (GGML_OPENMP)
171172
endif()
172173
endif()
173174

175+
if (GGML_OPENMP_SIMD)
176+
check_cxx_compiler_flag("-fopenmp-simd" SUPPORTS_OPENMP_SIMD)
177+
if (SUPPORTS_OPENMP_SIMD)
178+
# OpenMP_RUNTIME_MSVC=experimental / if (MSVC)
179+
message(STATUS "Using openmp_simd.")
180+
add_compile_definitions(GGML_USE_OPENMP_SIMD)
181+
set(OPENMP_SIMD_FLAGS -fopenmp-simd)
182+
else()
183+
message(WARNING "C++ compiler lacks OPENMP_SIMD support.")
184+
endif()
185+
endif()
186+
174187
if (GGML_BLAS)
175188
if (GGML_STATIC)
176189
set(BLA_STATIC ON)
@@ -1362,6 +1375,14 @@ endif()
13621375
# libraries
13631376
#
13641377

1378+
# FP8
1379+
file(GLOB GGML_HEADERS_FP8 "ggml-fp8.h")
1380+
file(GLOB GGML_SOURCES_FP8 "ggml-fp8.cpp")
1381+
1382+
if (OPENMP_SIMD_FLAGS)
1383+
set_source_files_properties(${GGML_SOURCES_FP8} PROPERTIES COMPILE_FLAGS ${OPENMP_SIMD_FLAGS})
1384+
endif()
1385+
13651386
# ggml
13661387

13671388
add_library(ggml
@@ -1389,6 +1410,7 @@ add_library(ggml
13891410
${GGML_SOURCES_AMX} ${GGML_HEADERS_AMX}
13901411
${GGML_SOURCES_CANN} ${GGML_HEADERS_CANN}
13911412
ggml-aarch64.c ggml-aarch64.h
1413+
${GGML_SOURCES_FP8} ${GGML_HEADERS_FP8}
13921414
)
13931415

13941416
if (EMSCRIPTEN)

ggml/src/ggml-common.h

+59-17
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,20 @@
66
typedef uint16_t ggml_half;
77
typedef uint32_t ggml_half2;
88

9-
#define GGML_COMMON_AGGR
9+
#define GGML_COMMON_AGGR_U
10+
#define GGML_COMMON_AGGR_S
11+
12+
#define GGML_COMMON_DECL
13+
#elif defined(GGML_COMMON_DECL_CPP)
14+
#include <cstdint>
15+
16+
typedef uint16_t ggml_half;
17+
typedef uint32_t ggml_half2;
18+
19+
// std-c++ allow anonymous unions but some compiler warn on it
20+
#define GGML_COMMON_AGGR_U data
21+
// std-c++ do not allow it.
22+
#define GGML_COMMON_AGGR_S data
1023

1124
#define GGML_COMMON_DECL
1225
#elif defined(GGML_COMMON_DECL_METAL)
@@ -15,7 +28,8 @@ typedef uint32_t ggml_half2;
1528
typedef half ggml_half;
1629
typedef half2 ggml_half2;
1730

18-
#define GGML_COMMON_AGGR
31+
#define GGML_COMMON_AGGR_U
32+
#define GGML_COMMON_AGGR_S
1933

2034
#define GGML_COMMON_DECL
2135
#elif defined(GGML_COMMON_DECL_CUDA)
@@ -29,7 +43,8 @@ typedef half2 ggml_half2;
2943
typedef half ggml_half;
3044
typedef half2 ggml_half2;
3145

32-
#define GGML_COMMON_AGGR data
46+
#define GGML_COMMON_AGGR_U
47+
#define GGML_COMMON_AGGR_S data
3348

3449
#define GGML_COMMON_DECL
3550
#elif defined(GGML_COMMON_DECL_HIP)
@@ -39,7 +54,8 @@ typedef half2 ggml_half2;
3954
typedef half ggml_half;
4055
typedef half2 ggml_half2;
4156

42-
#define GGML_COMMON_AGGR data
57+
#define GGML_COMMON_AGGR_U
58+
#define GGML_COMMON_AGGR_S data
4359

4460
#define GGML_COMMON_DECL
4561
#elif defined(GGML_COMMON_DECL_SYCL)
@@ -49,7 +65,8 @@ typedef half2 ggml_half2;
4965
typedef sycl::half ggml_half;
5066
typedef sycl::half2 ggml_half2;
5167

52-
#define GGML_COMMON_AGGR data
68+
#define GGML_COMMON_AGGR_U
69+
#define GGML_COMMON_AGGR_S data
5370

5471
#define GGML_COMMON_DECL
5572
#endif
@@ -154,9 +171,9 @@ typedef struct {
154171
struct {
155172
ggml_half d; // delta
156173
ggml_half m; // min
157-
} GGML_COMMON_AGGR;
174+
} GGML_COMMON_AGGR_S;
158175
ggml_half2 dm;
159-
};
176+
} GGML_COMMON_AGGR_U;
160177
uint8_t qs[QK4_1 / 2]; // nibbles / quants
161178
} block_q4_1;
162179
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
@@ -175,9 +192,9 @@ typedef struct {
175192
struct {
176193
ggml_half d; // delta
177194
ggml_half m; // min
178-
} GGML_COMMON_AGGR;
195+
} GGML_COMMON_AGGR_S;
179196
ggml_half2 dm;
180-
};
197+
} GGML_COMMON_AGGR_U;
181198
uint8_t qh[4]; // 5-th bit of quants
182199
uint8_t qs[QK5_1 / 2]; // nibbles / quants
183200
} block_q5_1;
@@ -196,9 +213,9 @@ typedef struct {
196213
struct {
197214
ggml_half d; // delta
198215
ggml_half s; // d * sum(qs[i])
199-
} GGML_COMMON_AGGR;
216+
} GGML_COMMON_AGGR_S;
200217
ggml_half2 ds;
201-
};
218+
} GGML_COMMON_AGGR_U;
202219
int8_t qs[QK8_1]; // quants
203220
} block_q8_1;
204221
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_half) + QK8_1, "wrong q8_1 block size/padding");
@@ -261,9 +278,9 @@ typedef struct {
261278
struct {
262279
ggml_half d; // super-block scale for quantized scales
263280
ggml_half dmin; // super-block scale for quantized mins
264-
} GGML_COMMON_AGGR;
281+
} GGML_COMMON_AGGR_S;
265282
ggml_half2 dm;
266-
};
283+
} GGML_COMMON_AGGR_U;
267284
} block_q2_K;
268285
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_half) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
269286

@@ -288,9 +305,9 @@ typedef struct {
288305
struct {
289306
ggml_half d; // super-block scale for quantized scales
290307
ggml_half dmin; // super-block scale for quantized mins
291-
} GGML_COMMON_AGGR;
308+
} GGML_COMMON_AGGR_S;
292309
ggml_half2 dm;
293-
};
310+
} GGML_COMMON_AGGR_U;
294311
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
295312
uint8_t qs[QK_K/2]; // 4--bit quants
296313
} block_q4_K;
@@ -305,9 +322,9 @@ typedef struct {
305322
struct {
306323
ggml_half d; // super-block scale for quantized scales
307324
ggml_half dmin; // super-block scale for quantized mins
308-
} GGML_COMMON_AGGR;
325+
} GGML_COMMON_AGGR_S;
309326
ggml_half2 dm;
310-
};
327+
} GGML_COMMON_AGGR_U;
311328
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
312329
uint8_t qh[QK_K/8]; // quants, high bit
313330
uint8_t qs[QK_K/2]; // quants, low 4 bits
@@ -418,6 +435,24 @@ typedef struct {
418435
} block_iq4_xs;
419436
static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding");
420437

438+
// fp8 support
439+
// - fp8 simple type
440+
typedef struct { uint8_t bits; } ggml_e5m2_t;
441+
typedef struct { uint8_t bits; } ggml_e4m3_t;
442+
443+
// - fp8 with bloc delta => 8.125 bpw
444+
typedef struct {
445+
float d; // delta
446+
uint8_t qs[QK_K];
447+
} block_e4m3_q;
448+
static_assert(sizeof(block_e4m3_q) == sizeof(float) + QK_K, "wrong block_e4m3_q block size/padding");
449+
450+
typedef struct {
451+
float d; // delta
452+
uint8_t qs[QK_K];
453+
} block_e3m4_q;
454+
static_assert(sizeof(block_e3m4_q) == sizeof(float) + QK_K, "wrong block_e3m4_q block size/padding");
455+
421456
#endif // GGML_COMMON_DECL
422457
#endif // GGML_COMMON_DECL
423458

@@ -431,6 +466,13 @@ static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_
431466
#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
432467
#define GGML_TABLE_END() };
433468

469+
#define GGML_COMMON_IMPL
470+
#elif defined(GGML_COMMON_IMPL_CPP)
471+
#include <cstdint>
472+
473+
#define GGML_TABLE_BEGIN(type, name, size) static const type name[size] = {
474+
#define GGML_TABLE_END() };
475+
434476
#define GGML_COMMON_IMPL
435477
#elif defined(GGML_COMMON_IMPL_METAL)
436478
#include <metal_stdlib>

0 commit comments

Comments
 (0)