Skip to content

Commit 42b53d1

Browse files
CUDA: revise q8_1 data layout for mul_mat_q (ggml-org#7824)
1 parent 2decf57 commit 42b53d1

File tree

5 files changed

+281
-150
lines changed

5 files changed

+281
-150
lines changed

ggml-cuda.cu

Lines changed: 57 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1347,10 +1347,30 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
13471347
GGML_UNUSED(main_device);
13481348
}
13491349

1350+
static cudaError_t ggml_cuda_Memcpy2DPeerAsync(
1351+
void * dst, int dstDevice, size_t dpitch, void * src, int srcDevice, size_t spitch, size_t width, size_t height, cudaStream_t stream) {
1352+
1353+
#if !defined(GGML_USE_HIPBLAS)
1354+
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
1355+
cudaMemcpy3DPeerParms p = {};
1356+
p.dstDevice = dstDevice;
1357+
p.dstPtr = make_cudaPitchedPtr(dst, dpitch, dpitch, height);
1358+
p.srcDevice = srcDevice;
1359+
p.srcPtr = make_cudaPitchedPtr(src, spitch, spitch, height);
1360+
p.extent = make_cudaExtent(width, height, 1);
1361+
return cudaMemcpy3DPeerAsync(&p, stream);
1362+
#else
1363+
// HIP does not support cudaMemcpy3DPeerAsync or vmm pools
1364+
GGML_UNUSED(dstDevice);
1365+
GGML_UNUSED(srcDevice);
1366+
return cudaMemcpy2DAsync(dst, dpitch, src, spitch, width, height, cudaMemcpyDeviceToDevice, stream);
1367+
#endif // !defined(GGML_USE_HIPBLAS)
1368+
}
1369+
13501370
static void ggml_cuda_op_mul_mat(
13511371
ggml_backend_cuda_context & ctx,
13521372
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_cuda_op_mul_mat_t op,
1353-
const bool convert_src1_to_q8_1) {
1373+
quantize_cuda_t quantize_src1) {
13541374

13551375
const int64_t ne00 = src0->ne[0];
13561376
const int64_t ne01 = src0->ne[1];
@@ -1407,7 +1427,9 @@ static void ggml_cuda_op_mul_mat(
14071427
}
14081428

14091429
struct dev_data {
1410-
ggml_cuda_pool_alloc<char> src0_dd_alloc;
1430+
int cc;
1431+
1432+
ggml_cuda_pool_alloc<char> src0_dd_alloc;
14111433
ggml_cuda_pool_alloc<float> src1_ddf_alloc;
14121434
ggml_cuda_pool_alloc<char> src1_ddq_alloc;
14131435
ggml_cuda_pool_alloc<float> dst_dd_alloc;
@@ -1426,6 +1448,8 @@ static void ggml_cuda_op_mul_mat(
14261448
int used_devices = 0;
14271449

14281450
for (int id = 0; id < ggml_backend_cuda_get_device_count(); ++id) {
1451+
dev[id].cc = ggml_cuda_info().devices[id].cc;
1452+
14291453
// by default, use all rows
14301454
dev[id].row_low = 0;
14311455
dev[id].row_high = ne01;
@@ -1476,11 +1500,15 @@ static void ggml_cuda_op_mul_mat(
14761500
dev[id].src1_ddf = dev[id].src1_ddf_alloc.alloc(ctx.pool(id), ggml_nelements(src1));
14771501
}
14781502

1479-
if (convert_src1_to_q8_1) {
1480-
dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
1503+
if (quantize_src1) {
1504+
size_t src_1_ddq_size = nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs;
1505+
if (quantize_src1 == quantize_mmq_q8_1_cuda) {
1506+
src_1_ddq_size += get_mmq_x_max_host(dev[id].cc)*sizeof(block_q8_1_mmq);
1507+
}
1508+
dev[id].src1_ddq = dev[id].src1_ddq_alloc.alloc(ctx.pool(id), src_1_ddq_size);
14811509

14821510
if (src1_on_device && src1_is_contiguous) {
1483-
quantize_row_q8_1_cuda(dev[id].src1_ddf, dev[id].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
1511+
quantize_src1(dev[id].src1_ddf, dev[id].src1_ddq, ne10, ne11, ne12*ne13, src1_padded_col_size, src0->type, stream);
14841512
CUDA_CHECK(cudaGetLastError());
14851513
}
14861514
}
@@ -1526,7 +1554,12 @@ static void ggml_cuda_op_mul_mat(
15261554
const int64_t i03 = i0 / ne12;
15271555
const int64_t i02 = i0 % ne12;
15281556

1529-
const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
1557+
size_t src1_ddq_i_offset = i0*ne11 * src1_padded_col_size*q8_1_ts/q8_1_bs;
1558+
if (quantize_src1 == quantize_mmq_q8_1_cuda) {
1559+
src1_ddq_i_offset += src1_col_0 * sizeof(block_q8_1_mmq);
1560+
} else {
1561+
src1_ddq_i_offset += src1_col_0 * src1_padded_col_size*q8_1_ts/q8_1_bs;
1562+
}
15301563

15311564
// for split tensors the data begins at i0 == i0_offset_low
15321565
char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
@@ -1543,10 +1576,17 @@ static void ggml_cuda_op_mul_mat(
15431576
// copy src0, src1 to device if necessary
15441577
if (src1_is_contiguous) {
15451578
if (id != ctx.device) {
1546-
if (convert_src1_to_q8_1) {
1579+
if (quantize_src1) {
15471580
char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
1548-
CUDA_CHECK(cudaMemcpyPeerAsync(src1_ddq_i, id, src1_ddq_i_source, ctx.device,
1549-
src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
1581+
if (quantize_src1 == quantize_mmq_q8_1_cuda) {
1582+
const size_t pitch = ne11*sizeof(block_q8_1_mmq);
1583+
const size_t width = src1_ncols*sizeof(block_q8_1_mmq);
1584+
const size_t height = src1_padded_col_size/(4*QK8_1);
1585+
CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(src1_ddq_i, id, pitch, src1_ddq_i_source, ctx.device, pitch, width, height, stream));
1586+
} else {
1587+
CUDA_CHECK(cudaMemcpyPeerAsync(
1588+
src1_ddq_i, id, src1_ddq_i_source, ctx.device, src1_ncols*src1_padded_col_size*q8_1_ts/q8_1_bs, stream));
1589+
}
15501590
} else {
15511591
float * src1_ddf_i_source = (float *) src1->data;
15521592
src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
@@ -1561,8 +1601,8 @@ static void ggml_cuda_op_mul_mat(
15611601
GGML_ASSERT(false);
15621602
}
15631603

1564-
if (convert_src1_to_q8_1 && !src1_is_contiguous) {
1565-
quantize_row_q8_1_cuda(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
1604+
if (quantize_src1 && !src1_is_contiguous) {
1605+
quantize_src1(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, 1, src1_padded_col_size, src0->type, stream);
15661606
CUDA_CHECK(cudaGetLastError());
15671607
}
15681608

@@ -1587,22 +1627,8 @@ static void ggml_cuda_op_mul_mat(
15871627
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
15881628
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
15891629
dhf_dst_i += src1_col_0*ne0 + dev[id].row_low;
1590-
#if !defined(GGML_USE_HIPBLAS)
1591-
// cudaMemcpy2DAsync may fail with copies between vmm pools of different devices
1592-
cudaMemcpy3DPeerParms p = {};
1593-
p.dstDevice = ctx.device;
1594-
p.dstPtr = make_cudaPitchedPtr(dhf_dst_i, ne0*sizeof(float), row_diff, src1_ncols);
1595-
p.srcDevice = id;
1596-
p.srcPtr = make_cudaPitchedPtr(dst_dd_i, row_diff*sizeof(float), row_diff, src1_ncols);
1597-
p.extent = make_cudaExtent(row_diff*sizeof(float), src1_ncols, 1);
1598-
CUDA_CHECK(cudaMemcpy3DPeerAsync(&p, stream));
1599-
#else
1600-
// HIP does not support cudaMemcpy3DPeerAsync or vmm pools
1601-
CUDA_CHECK(cudaMemcpy2DAsync(dhf_dst_i, ne0*sizeof(float),
1602-
dst_dd_i, row_diff*sizeof(float),
1603-
row_diff*sizeof(float), src1_ncols,
1604-
cudaMemcpyDeviceToDevice, stream));
1605-
#endif
1630+
CUDA_CHECK(ggml_cuda_Memcpy2DPeerAsync(
1631+
dhf_dst_i, ctx.device, ne0*sizeof(float), dst_dd_i, id, row_diff*sizeof(float), row_diff*sizeof(float), src1_ncols, stream));
16061632
} else {
16071633
float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
16081634
GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
@@ -1941,13 +1967,13 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
19411967
// KQ + KQV multi-batch
19421968
ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst);
19431969
} else if (use_dequantize_mul_mat_vec) {
1944-
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
1970+
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr);
19451971
} else if (use_mul_mat_vec_q) {
1946-
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, true);
1972+
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda);
19471973
} else if (use_mul_mat_q) {
1948-
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
1974+
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda);
19491975
} else {
1950-
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
1976+
ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr);
19511977
}
19521978
}
19531979

ggml-cuda/mmq.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ void ggml_cuda_op_mul_mat_q(
1111
const int64_t nb01 = src0->nb[1];
1212

1313
const int64_t ne10 = src1->ne[0];
14+
const int64_t ne11 = src1->ne[1];
1415
GGML_ASSERT(ne10 % QK8_1 == 0);
1516

1617
const int64_t ne0 = dst->ne[0];
@@ -25,7 +26,7 @@ void ggml_cuda_op_mul_mat_q(
2526
// nrows_dst == nrows of the matrix that the kernel writes into
2627
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
2728

28-
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, nrows_dst};
29+
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst};
2930

3031
switch (src0->type) {
3132
case GGML_TYPE_Q4_0:

0 commit comments

Comments
 (0)