Skip to content

Commit c3f9d25

Browse files
authored
Vulkan: Fix float16 use on devices without float16 support + fix subgroup_size_control validation error (ggml-org#11161)
* Vulkan: Remove float16 use in shaders * Fix validation error about subgroup_size_control extension
1 parent ee7136c commit c3f9d25

9 files changed

+50
-51
lines changed

Diff for: ggml/src/ggml-vulkan/ggml-vulkan.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -2277,6 +2277,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
22772277
if (device->subgroup_size_control) {
22782278
device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize;
22792279
device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize;
2280+
device_extensions.push_back("VK_EXT_subgroup_size_control");
22802281
}
22812282

22822283
device->subgroup_size_control = device->subgroup_size_control &&
@@ -2285,7 +2286,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
22852286

22862287
if (device->subgroup_size_control) {
22872288
device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups;
2288-
device_extensions.push_back("VK_EXT_subgroup_size_control");
22892289
}
22902290

22912291
#if defined(VK_KHR_cooperative_matrix)

Diff for: ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp

+3-6
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
#version 450
22

3-
#ifdef FLOAT16
4-
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
5-
#endif
6-
#extension GL_EXT_shader_explicit_arithmetic_types : require
3+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
74

85
#include "mul_mat_vec_base.comp"
96

@@ -27,8 +24,8 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
2724

2825
#if K_PER_ITER == 8
2926
#if QUANT_R == 2
30-
const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
31-
const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4];
27+
const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
28+
const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
3229
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
3330
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
3431
#else

Diff for: ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp

+12-12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#version 450
2-
#extension GL_EXT_shader_explicit_arithmetic_types : require
2+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
33

44
#include "mul_mat_vec_base.comp"
55

@@ -40,9 +40,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
4040

4141
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
4242
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
43-
f16vec2 d = data_a[ib0 + i].d;
44-
const FLOAT_TYPE dall = d.x;
45-
const FLOAT_TYPE dmin = d.y;
43+
vec2 d = vec2(data_a[ib0 + i].d);
44+
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
45+
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
4646

4747
uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
4848
uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
@@ -63,14 +63,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
6363
uvec2 qs16 = uvec2(unpack8(qs16_u16));
6464

6565
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
66-
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
67-
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
68-
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
69-
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
70-
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
71-
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
72-
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
73-
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
66+
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
67+
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
68+
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
69+
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
70+
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
71+
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
72+
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
73+
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
7474

7575
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
7676
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);

Diff for: ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp

+9-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#version 450
2-
#extension GL_EXT_shader_explicit_arithmetic_types : require
2+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
33

44
#include "mul_mat_vec_base.comp"
55

@@ -60,14 +60,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
6060

6161
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
6262

63-
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
64-
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
65-
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
66-
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
67-
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
68-
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
69-
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
70-
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
63+
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
64+
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
65+
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
66+
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
67+
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
68+
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
69+
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
70+
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
7171

7272
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
7373
[[unroll]] for (int l = 0; l < 2; ++l) {

Diff for: ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp

+6-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#version 450
22

3-
#extension GL_EXT_shader_explicit_arithmetic_types : require
3+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
44

55
#include "mul_mat_vec_base.comp"
66

@@ -45,7 +45,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
4545

4646
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
4747
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
48-
f16vec2 d = data_a[ib0 + i].d;
48+
vec2 d = vec2(data_a[ib0 + i].d);
4949
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
5050
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
5151

@@ -96,10 +96,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
9696
const uint32_t q4_15 = qs64_hi4.w;
9797

9898
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
99-
B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4];
100-
B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8];
101-
B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4];
102-
B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8];
99+
vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]);
100+
vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]);
101+
vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]);
102+
vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]);
103103

104104
const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
105105
const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));

Diff for: ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp

+10-10
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#version 450
22

3-
#extension GL_EXT_shader_explicit_arithmetic_types : require
3+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
44

55
#include "mul_mat_vec_base.comp"
66

@@ -42,7 +42,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
4242

4343
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
4444
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
45-
f16vec2 d = data_a[ib0 + i].d;
45+
vec2 d = vec2(data_a[ib0 + i].d);
4646
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
4747
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
4848

@@ -105,14 +105,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
105105
const uint32_t q4_15 = qs64_80_hi4.w;
106106

107107
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
108-
B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2];
109-
B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8];
110-
B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16];
111-
B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24];
112-
B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2];
113-
B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8];
114-
B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16];
115-
B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24];
108+
vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]);
109+
vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]);
110+
vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
111+
vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
112+
vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]);
113+
vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]);
114+
vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
115+
vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
116116

117117
const FLOAT_TYPE sx =
118118
fma(FLOAT_TYPE(by10.x), q4_0,

Diff for: ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#version 450
22

3-
#extension GL_EXT_shader_explicit_arithmetic_types : require
3+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
44

55
#include "mul_mat_vec_base.comp"
66

@@ -77,10 +77,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
7777
uvec4 q3 = uvec4(unpack8(q3_u32));
7878

7979
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
80-
B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
81-
B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
82-
B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
83-
B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
80+
vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]);
81+
vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]);
82+
vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]);
83+
vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]);
8484

8585
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
8686
[[unroll]] for (int l = 0; l < 4; ++l) {

Diff for: ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp

-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#version 450
22

3-
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
43
#extension GL_EXT_control_flow_attributes : enable
54

65
layout (push_constant) uniform parameter

Diff for: ggml/src/ggml-vulkan/vulkan-shaders/types.comp

+4-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
#if !defined(GGML_TYPES_COMP)
33
#define GGML_TYPES_COMP
44

5-
#extension GL_EXT_shader_explicit_arithmetic_types : require
5+
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
6+
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
7+
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
8+
#extension GL_EXT_shader_16bit_storage : require
69

710
#if defined(DATA_A_F32)
811
#define QUANT_K 1

0 commit comments

Comments
 (0)