diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7268e24ad4bed..d8d8b91dd95c5 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -5105,7 +5105,6 @@ static inline __device__ void swap(T & a, T & b) { template static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) { - // bitonic sort int col = threadIdx.x; int row = blockIdx.y; @@ -5114,31 +5113,32 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n const float * x_row = x + row * ncols; int * dst_row = dst + row * ncols; - // initialize indices - if (col < ncols) { - dst_row[col] = col; + // Initialize indices + for (int i = 0; i < ncols; i++) { + dst_row[i] = i; } - __syncthreads(); + __syncthreads(); // Ensure all indices are initialized - for (int k = 2; k <= ncols; k *= 2) { - for (int j = k / 2; j > 0; j /= 2) { - int ixj = col ^ j; - if (ixj > col) { - if ((col & k) == 0) { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { - swap(dst_row[col], dst_row[ixj]); - } - } else { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { - swap(dst_row[col], dst_row[ixj]); - } - } + // Insertion sort + for (int i = 1; i < ncols; i++) { + int j = i; + while (j > 0) { + bool condition = order == GGML_SORT_ASC ? + x_row[dst_row[j-1]] > x_row[dst_row[j]] : + x_row[dst_row[j-1]] < x_row[dst_row[j]]; + + if (condition) { + // Swap + swap(dst_row[j], dst_row[j - 1]); + } else { + break; } - __syncthreads(); + j--; } } } + static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { const int col = blockDim.y*blockIdx.y + threadIdx.y; const int row = blockDim.x*blockIdx.x + threadIdx.x; @@ -6474,9 +6474,6 @@ static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, con } static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { - // bitonic sort requires ncols to be power of 2 - GGML_ASSERT((ncols & (ncols - 1)) == 0); - const dim3 block_dims(ncols, 1, 1); const dim3 block_nums(1, nrows, 1); if (order == GGML_SORT_ASC) { @@ -7614,6 +7611,36 @@ inline void ggml_cuda_op_sum_rows( (void) src1_dd; } +// const size_t TOTAL_EXPERTS = 8; +static uint * expert_counter = nullptr; +static size_t expert_count; +static uint experts_per_tok = 0; + +void reset_expert_counter(uint experts_per_tok, size_t expert_count) { + if (expert_counter != nullptr) { + free(expert_counter); + } + expert_counter = (uint *) calloc(expert_count, sizeof(uint)); + ::experts_per_tok = experts_per_tok; + ::expert_count = expert_count; +} + +void print_expert_counter() { + printf("\n"); + uint total = 0; + for (int i = 0; i < expert_count; i++) { + printf("%u:%6u ", i, expert_counter[i]); + total += expert_counter[i]; + } + + printf("\n"); + for (int i = 0; i < expert_count; i++) { + printf("%u:%6.2f%% ", i, 100.0f * expert_counter[i] / total); + } + + printf("\n"); +} + inline void ggml_cuda_op_argsort( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -7628,6 +7655,31 @@ inline void ggml_cuda_op_argsort( argsort_f32_i32_cuda(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream); + // printf ncols, nrows + // printf("\nncols: %d, nrows: %d", ncols, nrows); + + if (expert_counter != nullptr) { + int* dst_host = (int *) malloc(ncols * nrows * sizeof(int)); + CUDA_CHECK(cudaDeviceSynchronize()); + CUDA_CHECK(cudaMemcpy(dst_host, dst_dd, ncols * nrows * sizeof(int), cudaMemcpyDeviceToHost)); + // printf("%s: \n", ggml_get_name(dst)); + + // row major order + for (int i = 0; i < nrows; i++) { + for (int j = 0; j < ncols; j++) { + int val = dst_host[i * ncols + j]; + if (j < experts_per_tok) { + GGML_ASSERT(val < expert_count); + expert_counter[val]++; + } else { + break; + } + // printf("%d ", val); + } + // printf("\n"); + } + } + (void) src1; (void) dst; (void) src1_dd; @@ -9666,12 +9718,15 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph } } + // printf("\n"); bool ok = ggml_cuda_compute_forward(¶ms, node); + // printf("\n"); if (!ok) { fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } GGML_ASSERT(ok); + #if 0 if (node->type == GGML_TYPE_F32) { cudaDeviceSynchronize(); diff --git a/ggml-cuda.h b/ggml-cuda.h index cdb0c0c41618a..8ff238506c1c5 100644 --- a/ggml-cuda.h +++ b/ggml-cuda.h @@ -59,6 +59,10 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device); // pinned host buffer for use with CPU backend for faster copies between CPU and GPU GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void); +void reset_expert_counter(uint experts_per_tok, size_t expert_count); + +void print_expert_counter(); + #ifdef __cplusplus } #endif diff --git a/ggml.c b/ggml.c index a380114faf6ff..34045d95e7000 100644 --- a/ggml.c +++ b/ggml.c @@ -12524,6 +12524,17 @@ static void ggml_compute_forward_argsort( case GGML_TYPE_F32: { ggml_compute_forward_argsort_f32(params, src0, dst); + + // TODO: for cpu only or partial offload, the counters state would need to be + // moved out of ggml-cuda.cu and maybe into here, and then this would need to + // perform a similar update to what the CUDA code does. + // Additionally the state reset and print functions are currently only being + // called from CUBLAS ifdefs, they'd need to be moved outside of those. + + // for (int i = 0; i < ggml_nelements(dst); i++) { + // printf("%d ", ggml_get_i32_1d(dst, i)); + // } + // printf("\n"); } break; default: { diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index 0e8b4b259bbab..1496ce0bbbec1 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -99,6 +99,9 @@ static std::mutex concat_output_mtx; static std::string concat_output = ""; static std::string concat_output_reader_copy = ""; +static uint32_t xx_n_expert = 0; +static uint32_t xx_n_expert_used = 0; + const int extra_context_handle_fragmentation = 80; inline bool IsNanCheck(float f) @@ -972,6 +975,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in } } + xx_n_expert = llamamodel->hparams.n_expert; + xx_n_expert_used = llamamodel->hparams.n_expert_used; + llama_ctx_v4 = llama_new_context_with_model(llamamodel, llama_ctx_params); if (llama_ctx_v4 == NULL) @@ -1721,6 +1727,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o printf("%s\n\n", RemoveBell(outstr).c_str()); } +#ifdef GGML_USE_CUBLAS + if (xx_n_expert > 0) { + reset_expert_counter(xx_n_expert_used, xx_n_expert); + } +#endif + while (remaining_tokens > 0) { gpt_vocab::id id = 0; @@ -1841,6 +1853,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o if (!startedsampling) { + #ifdef GGML_USE_CUBLAS + if (xx_n_expert > 0) { + print_expert_counter(); + reset_expert_counter(xx_n_expert_used, xx_n_expert); + } + #endif startedsampling = true; params.n_batch = original_batch; params.n_threads = original_threads; @@ -2004,5 +2022,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o total_gens += 1; snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str()); +#ifdef GGML_USE_CUBLAS + if (xx_n_expert > 0) { + print_expert_counter(); + } +#endif + return output; } diff --git a/llama.cpp b/llama.cpp index 466c0a65e543e..365f9249235c2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -4203,6 +4203,11 @@ struct llm_build_context { } struct ggml_cgraph * build_llama() { + #ifdef GGML_USE_CUBLAS + // printf("\nexpert_used: %u\n", n_expert_used); + // reset_expert_counter(n_expert_used, n_expert); + #endif + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); GGML_ASSERT(n_embd_head == hparams.n_rot); @@ -6078,6 +6083,8 @@ static int llama_decode_internal( //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head); + // printf("\n"); + ggml_allocr_reset(lctx.alloc); ggml_cgraph * gf = llama_build_graph(lctx, batch); @@ -6175,6 +6182,12 @@ static int llama_decode_internal( ggml_graph_print(gf); #endif +#ifdef GGML_USE_CUBLAS + // print_expert_counter(); +#endif + + // ycros_moe_debug(gf); + // plot the computation graph in dot format (for debugging purposes) //if (n_past%100 == 0) { // ggml_graph_dump_dot(gf, NULL, "llama.dot");