@@ -8727,8 +8727,6 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8727
8727
// TODO: mmq/mmv support
8728
8728
#endif
8729
8729
8730
- GGML_ASSERT (dst->backend == GGML_BACKEND_GPU);
8731
-
8732
8730
const int64_t nb11 = src1->nb [1 ];
8733
8731
const int64_t nb1 = dst->nb [1 ];
8734
8732
@@ -8757,13 +8755,24 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8757
8755
ggml_tensor src1_row = *src1;
8758
8756
ggml_tensor dst_row = *dst;
8759
8757
8758
+ ggml_backend_type src1_original_backend = src1_row.backend ;
8759
+ ggml_backend_type dst_original_backend = dst_row.backend ;
8760
+
8761
+ src1_row.backend = GGML_BACKEND_GPU;
8762
+ dst_row.backend = GGML_BACKEND_GPU;
8763
+
8760
8764
src1_row.extra = &src1_row_extra;
8761
8765
dst_row.extra = &dst_row_extra;
8762
8766
8763
- char * src1_original = (char *) src1_extra->data_device [g_main_device];
8764
- char * dst_original = (char *) dst_extra->data_device [g_main_device];
8767
+ char * src1_original = src1_original_backend == GGML_BACKEND_CPU ?
8768
+ (char *) src1->data : (char *) src1_extra->data_device [g_main_device];
8769
+ char * dst_original = dst_original_backend == GGML_BACKEND_CPU ?
8770
+ (char *) dst->data : (char *) dst_extra->data_device [g_main_device];
8765
8771
8766
8772
if (src1->ne [1 ] == 1 ) {
8773
+ GGML_ASSERT (src1_original_backend == GGML_BACKEND_GPU);
8774
+ GGML_ASSERT (dst_original_backend == GGML_BACKEND_GPU);
8775
+
8767
8776
for (int64_t i01 = 0 ; i01 < ids->ne [1 ]; i01++) {
8768
8777
// int32_t row_id;
8769
8778
// CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
@@ -8791,6 +8800,11 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8791
8800
src1_row_extra.data_device [g_main_device] = src1_contiguous;
8792
8801
dst_row_extra.data_device [g_main_device] = dst_contiguous;
8793
8802
8803
+ const cudaMemcpyKind src1_kind = src1_original_backend == GGML_BACKEND_CPU ?
8804
+ cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8805
+ const cudaMemcpyKind dst_kind = src1_original_backend == GGML_BACKEND_CPU ?
8806
+ cudaMemcpyHostToDevice : cudaMemcpyDeviceToDevice;
8807
+
8794
8808
for (int32_t row_id = 0 ; row_id < n_as; ++row_id) {
8795
8809
const struct ggml_tensor * src0_row = dst->src [row_id + 2 ];
8796
8810
@@ -8805,7 +8819,7 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8805
8819
GGML_ASSERT (row_id >= 0 && row_id < n_as);
8806
8820
8807
8821
CUDA_CHECK (cudaMemcpyAsync (src1_contiguous + num_src1_rows*nb11, src1_original + i01*nb11,
8808
- nb11, cudaMemcpyDeviceToDevice , stream));
8822
+ nb11, src1_kind , stream));
8809
8823
num_src1_rows++;
8810
8824
}
8811
8825
@@ -8837,14 +8851,21 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
8837
8851
GGML_ASSERT (row_id >= 0 && row_id < n_as);
8838
8852
8839
8853
CUDA_CHECK (cudaMemcpyAsync (dst_original + i01*nb1, dst_contiguous + num_src1_rows*nb1,
8840
- nb1, cudaMemcpyDeviceToDevice , stream));
8854
+ nb1, dst_kind , stream));
8841
8855
num_src1_rows++;
8842
8856
}
8843
8857
}
8844
8858
8845
8859
ggml_cuda_pool_free (src1_contiguous, as_src1);
8846
8860
ggml_cuda_pool_free (dst_contiguous, as_dst);
8847
8861
}
8862
+
8863
+ if (dst_original_backend == GGML_BACKEND_CPU) {
8864
+ CUDA_CHECK (cudaStreamSynchronize (stream));
8865
+ }
8866
+
8867
+ src1_row.backend = src1_original_backend;
8868
+ dst_row.backend = dst_original_backend;
8848
8869
}
8849
8870
8850
8871
static void ggml_cuda_scale (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -9247,7 +9268,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
9247
9268
|| (tensor->src [0 ] != nullptr && (tensor->src [0 ]->backend == GGML_BACKEND_GPU || tensor->src [0 ]->backend == GGML_BACKEND_GPU_SPLIT))
9248
9269
|| (tensor->src [1 ] != nullptr && tensor->src [1 ]->backend == GGML_BACKEND_GPU);
9249
9270
9250
- if (!any_on_device && tensor->op != GGML_OP_MUL_MAT) {
9271
+ if (!any_on_device && tensor->op != GGML_OP_MUL_MAT && tensor-> op != GGML_OP_MUL_MAT_ID ) {
9251
9272
return false ;
9252
9273
}
9253
9274
0 commit comments