|
1 | 1 | #pragma once
|
2 |
| - |
3 |
| -#define GGML_COMMON_DECL_C |
4 |
| -#include "ggml-common.h" |
5 |
| - |
6 | 2 | #include "ggml.h"
|
7 | 3 |
|
8 |
| -// les definitions / converstion FP8 <=> FP32 |
9 | 4 | #ifdef __cplusplus
|
10 | 5 | extern "C" {
|
11 | 6 | #endif
|
12 | 7 |
|
| 8 | +#define FP8_QK 256 |
| 9 | + |
13 | 10 | typedef struct { uint8_t bits; } ggml_e5m2_t;
|
14 | 11 | typedef struct { uint8_t bits; } ggml_e4m3_t;
|
15 | 12 | typedef struct { uint8_t bits; } ggml_e3m4_t;
|
16 | 13 |
|
17 |
| - void ggml_e5m2_to_fp32_row(const ggml_e5m2_t * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); |
18 |
| - void ggml_fp32_to_e5m2_row(const float * GGML_RESTRICT x, ggml_e5m2_t * GGML_RESTRICT y, int64_t k); |
19 |
| - void ggml_fp32_to_e5m2_row_ref(const float * GGML_RESTRICT x, ggml_e5m2_t * GGML_RESTRICT y, int64_t k); |
| 14 | + // fp8 with bloc delta => 8.125 bpw |
| 15 | + typedef struct { |
| 16 | + float d; // delta |
| 17 | + uint8_t qs[FP8_QK]; |
| 18 | + } block_e4m3_q; |
| 19 | + static_assert(sizeof(block_e4m3_q) == sizeof(float) + FP8_QK, "wrong block_e4m3_q block size/padding"); |
| 20 | + |
| 21 | + typedef struct { |
| 22 | + float d; // delta |
| 23 | + uint8_t qs[FP8_QK]; |
| 24 | + } block_e3m4_q; |
| 25 | + static_assert(sizeof(block_e3m4_q) == sizeof(float) + FP8_QK, "wrong block_e3m4_q block size/padding"); |
| 26 | + |
| 27 | + GGML_API void ggml_e5m2_to_fp32_row(const ggml_e5m2_t * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); |
| 28 | + GGML_API void ggml_fp32_to_e5m2_row(const float * GGML_RESTRICT x, ggml_e5m2_t * GGML_RESTRICT y, int64_t k); |
| 29 | + GGML_API void ggml_fp32_to_e5m2_row_ref(const float * GGML_RESTRICT x, ggml_e5m2_t * GGML_RESTRICT y, int64_t k); |
20 | 30 |
|
21 |
| - void ggml_e4m3_to_fp32_row(const ggml_e4m3_t * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); |
22 |
| - void ggml_fp32_to_e4m3_row(const float * GGML_RESTRICT x, ggml_e4m3_t * GGML_RESTRICT y, int64_t k); |
23 |
| - void ggml_fp32_to_e4m3_row_ref(const float * GGML_RESTRICT x, ggml_e4m3_t * GGML_RESTRICT y, int64_t k); |
| 31 | + GGML_API void ggml_e4m3_to_fp32_row(const ggml_e4m3_t * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); |
| 32 | + GGML_API void ggml_fp32_to_e4m3_row(const float * GGML_RESTRICT x, ggml_e4m3_t * GGML_RESTRICT y, int64_t k); |
| 33 | + GGML_API void ggml_fp32_to_e4m3_row_ref(const float * GGML_RESTRICT x, ggml_e4m3_t * GGML_RESTRICT y, int64_t k); |
24 | 34 |
|
25 |
| - void dequantize_row_e4m3_q(const block_e4m3_q * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); |
26 |
| - void quantize_row_e4m3_q(const float * GGML_RESTRICT x, block_e4m3_q * GGML_RESTRICT y, int64_t k); |
27 |
| - void quantize_row_e4m3_q_ref(const float * GGML_RESTRICT x, block_e4m3_q * GGML_RESTRICT y, int64_t k); |
| 35 | + GGML_API void dequantize_row_e4m3_q(const block_e4m3_q * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); |
| 36 | + GGML_API void quantize_row_e4m3_q(const float * GGML_RESTRICT x, block_e4m3_q * GGML_RESTRICT y, int64_t k); |
| 37 | + GGML_API void quantize_row_e4m3_q_ref(const float * GGML_RESTRICT x, block_e4m3_q * GGML_RESTRICT y, int64_t k); |
28 | 38 |
|
29 |
| - void dequantize_row_e3m4_q(const block_e3m4_q * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); |
30 |
| - void quantize_row_e3m4_q(const float * GGML_RESTRICT x, block_e3m4_q * GGML_RESTRICT y, int64_t k); |
31 |
| - void quantize_row_e3m4_q_ref(const float * GGML_RESTRICT x, block_e3m4_q * GGML_RESTRICT y, int64_t k); |
| 39 | + GGML_API void dequantize_row_e3m4_q(const block_e3m4_q * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); |
| 40 | + GGML_API void quantize_row_e3m4_q(const float * GGML_RESTRICT x, block_e3m4_q * GGML_RESTRICT y, int64_t k); |
| 41 | + GGML_API void quantize_row_e3m4_q_ref(const float * GGML_RESTRICT x, block_e3m4_q * GGML_RESTRICT y, int64_t k); |
32 | 42 |
|
33 | 43 | // TODO: the best depend on the CPU fp32 / bf16 / fp16
|
34 | 44 | #define GGML_FP8_VECT_DOT_TYPE GGML_TYPE_F32
|
35 |
| - void ggml_vec_dot_e5m2(int n, float * GGML_RESTRICT s, size_t bs, const ggml_e5m2_t * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT vy, size_t by, int nrc); |
36 |
| - void ggml_vec_dot_e4m3(int n, float * GGML_RESTRICT s, size_t bs, const ggml_e4m3_t * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT vy, size_t by, int nrc); |
37 |
| - void ggml_vec_dot_e4m3_q(int n, float * GGML_RESTRICT s, size_t bs, const block_e4m3_q * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT vy, size_t by, int nrc); |
38 |
| - void ggml_vec_dot_e3m4_q(int n, float * GGML_RESTRICT s, size_t bs, const block_e3m4_q * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT vy, size_t by, int nrc); |
| 45 | + GGML_API void ggml_vec_dot_e5m2(int n, float * GGML_RESTRICT s, size_t bs, const ggml_e5m2_t * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT vy, size_t by, int nrc); |
| 46 | + GGML_API void ggml_vec_dot_e4m3(int n, float * GGML_RESTRICT s, size_t bs, const ggml_e4m3_t * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT vy, size_t by, int nrc); |
| 47 | + GGML_API void ggml_vec_dot_e4m3_q(int n, float * GGML_RESTRICT s, size_t bs, const block_e4m3_q * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT vy, size_t by, int nrc); |
| 48 | + GGML_API void ggml_vec_dot_e3m4_q(int n, float * GGML_RESTRICT s, size_t bs, const block_e3m4_q * GGML_RESTRICT vx, size_t bx, const float * GGML_RESTRICT vy, size_t by, int nrc); |
39 | 49 |
|
40 | 50 | #ifdef __cplusplus
|
41 | 51 | }
|
|
0 commit comments