Skip to content

Commit 92497e1

Browse files
JohannesGaesslerkalomaze
authored andcommitted
CUDA: mul_mat_id always on GPU for batches >= 32
1 parent 97fa427 commit 92497e1

File tree

1 file changed

+28
-7
lines changed

1 file changed

+28
-7
lines changed

ggml-cuda.cu

+28-7
Original file line numberDiff line numberDiff line change
@@ -8727,8 +8727,6 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
87278727
// TODO: mmq/mmv support
87288728
#endif
87298729

8730-
GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
8731-
87328730
const int64_t nb11 = src1->nb[1];
87338731
const int64_t nb1 = dst->nb[1];
87348732

@@ -8757,13 +8755,24 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
87578755
ggml_tensor src1_row = *src1;
87588756
ggml_tensor dst_row = *dst;
87598757

8758+
ggml_backend_type src1_original_backend = src1_row.backend;
8759+
ggml_backend_type dst_original_backend = dst_row.backend;
8760+
8761+
src1_row.backend = GGML_BACKEND_GPU;
8762+
dst_row.backend = GGML_BACKEND_GPU;
8763+
87608764
src1_row.extra = &src1_row_extra;
87618765
dst_row.extra = &dst_row_extra;
87628766

8763-
char * src1_original = (char *) src1_extra->data_device[g_main_device];
8764-
char * dst_original = (char *) dst_extra->data_device[g_main_device];
8767+
char * src1_original = src1_original_backend == GGML_BACKEND_CPU ?
8768+
(char *) src1->data : (char *) src1_extra->data_device[g_main_device];
8769+
char * dst_original = dst_original_backend == GGML_BACKEND_CPU ?
8770+
(char *) dst->data : (char *) dst_extra->data_device[g_main_device];
87658771

87668772
if (src1->ne[1] == 1) {
8773+
GGML_ASSERT(src1_original_backend == GGML_BACKEND_GPU);
8774+
GGML_ASSERT(dst_original_backend == GGML_BACKEND_GPU);
8775+
87678776
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
87688777
//int32_t row_id;
87698778
//CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
@@ -8791,6 +8800,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
87918800
src1_row_extra.data_device[g_main_device] = src1_contiguous;
87928801
dst_row_extra.data_device[g_main_device] = dst_contiguous;
87938802

8803+
const cudaMemcpyKind src1_kind = src1_original_backend == GGML_BACKEND_CPU ?
8804+
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8805+
const cudaMemcpyKind dst_kind = src1_original_backend == GGML_BACKEND_CPU ?
8806+
cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8807+
87948808
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
87958809
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
87968810

@@ -8805,7 +8819,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88058819
GGML_ASSERT(row_id >= 0 && row_id < n_as);
88068820

88078821
CUDA_CHECK(cudaMemcpyAsync(src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
8808-
nb11, cudaMemcpyDeviceToDevice, stream));
8822+
nb11, src1_kind, stream));
88098823
num_src1_rows++;
88108824
}
88118825

@@ -8837,14 +8851,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
88378851
GGML_ASSERT(row_id >= 0 && row_id < n_as);
88388852

88398853
CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
8840-
nb1, cudaMemcpyDeviceToDevice, stream));
8854+
nb1, dst_kind, stream));
88418855
num_src1_rows++;
88428856
}
88438857
}
88448858

88458859
ggml_cuda_pool_free(src1_contiguous, as_src1);
88468860
ggml_cuda_pool_free(dst_contiguous, as_dst);
88478861
}
8862+
8863+
if (dst_original_backend == GGML_BACKEND_CPU) {
8864+
CUDA_CHECK(cudaStreamSynchronize(stream));
8865+
}
8866+
8867+
src1_row.backend = src1_original_backend;
8868+
dst_row.backend = dst_original_backend;
88488869
}
88498870

88508871
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9247,7 +9268,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
92479268
|| (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT))
92489269
|| (tensor->src[1] != nullptr && tensor->src[1]->backend == GGML_BACKEND_GPU);
92499270

9250-
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) {
9271+
if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor->op != GGML_OP_MUL_MAT_ID) {
92519272
return false;
92529273
}
92539274

0 commit comments

Comments
 (0)