Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ycros moe logging #6

Open
wants to merge 6 commits into
base: exp-dynatemp-minp-latest
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 77 additions & 22 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5105,7 +5105,6 @@ static inline __device__ void swap(T & a, T & b) {

template<ggml_sort_order order>
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;

Expand All @@ -5114,31 +5113,32 @@ static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int n
const float * x_row = x + row * ncols;
int * dst_row = dst + row * ncols;

// initialize indices
if (col < ncols) {
dst_row[col] = col;
// Initialize indices
for (int i = 0; i < ncols; i++) {
dst_row[i] = i;
}
__syncthreads();
__syncthreads(); // Ensure all indices are initialized

for (int k = 2; k <= ncols; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
swap(dst_row[col], dst_row[ixj]);
}
} else {
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
swap(dst_row[col], dst_row[ixj]);
}
}
// Insertion sort
for (int i = 1; i < ncols; i++) {
int j = i;
while (j > 0) {
bool condition = order == GGML_SORT_ASC ?
x_row[dst_row[j-1]] > x_row[dst_row[j]] :
x_row[dst_row[j-1]] < x_row[dst_row[j]];

if (condition) {
// Swap
swap(dst_row[j], dst_row[j - 1]);
} else {
break;
}
__syncthreads();
j--;
}
}
}


static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
const int col = blockDim.y*blockIdx.y + threadIdx.y;
const int row = blockDim.x*blockIdx.x + threadIdx.x;
Expand Down Expand Up @@ -6474,9 +6474,6 @@ static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, con
}

static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
GGML_ASSERT((ncols & (ncols - 1)) == 0);

const dim3 block_dims(ncols, 1, 1);
const dim3 block_nums(1, nrows, 1);
if (order == GGML_SORT_ASC) {
Expand Down Expand Up @@ -7614,6 +7611,36 @@ inline void ggml_cuda_op_sum_rows(
(void) src1_dd;
}

// const size_t TOTAL_EXPERTS = 8;
static uint * expert_counter = nullptr;
static size_t expert_count;
static uint experts_per_tok = 0;

void reset_expert_counter(uint experts_per_tok, size_t expert_count) {
if (expert_counter != nullptr) {
free(expert_counter);
}
expert_counter = (uint *) calloc(expert_count, sizeof(uint));
::experts_per_tok = experts_per_tok;
::expert_count = expert_count;
}

void print_expert_counter() {
printf("\n");
uint total = 0;
for (int i = 0; i < expert_count; i++) {
printf("%u:%6u ", i, expert_counter[i]);
total += expert_counter[i];
}

printf("\n");
for (int i = 0; i < expert_count; i++) {
printf("%u:%6.2f%% ", i, 100.0f * expert_counter[i] / total);
}

printf("\n");
}

inline void ggml_cuda_op_argsort(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
Expand All @@ -7628,6 +7655,31 @@ inline void ggml_cuda_op_argsort(

argsort_f32_i32_cuda(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);

// printf ncols, nrows
// printf("\nncols: %d, nrows: %d", ncols, nrows);

if (expert_counter != nullptr) {
int* dst_host = (int *) malloc(ncols * nrows * sizeof(int));
CUDA_CHECK(cudaDeviceSynchronize());
CUDA_CHECK(cudaMemcpy(dst_host, dst_dd, ncols * nrows * sizeof(int), cudaMemcpyDeviceToHost));
// printf("%s: \n", ggml_get_name(dst));

// row major order
for (int i = 0; i < nrows; i++) {
for (int j = 0; j < ncols; j++) {
int val = dst_host[i * ncols + j];
if (j < experts_per_tok) {
GGML_ASSERT(val < expert_count);
expert_counter[val]++;
} else {
break;
}
// printf("%d ", val);
}
// printf("\n");
}
}

(void) src1;
(void) dst;
(void) src1_dd;
Expand Down Expand Up @@ -9666,12 +9718,15 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph
}
}

// printf("\n");
bool ok = ggml_cuda_compute_forward(&params, node);
// printf("\n");
if (!ok) {
fprintf(stderr, "%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);


#if 0
if (node->type == GGML_TYPE_F32) {
cudaDeviceSynchronize();
Expand Down
4 changes: 4 additions & 0 deletions ggml-cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
// pinned host buffer for use with CPU backend for faster copies between CPU and GPU
GGML_API ggml_backend_buffer_type_t ggml_backend_cuda_host_buffer_type(void);

void reset_expert_counter(uint experts_per_tok, size_t expert_count);

void print_expert_counter();

#ifdef __cplusplus
}
#endif
11 changes: 11 additions & 0 deletions ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -12524,6 +12524,17 @@ static void ggml_compute_forward_argsort(
case GGML_TYPE_F32:
{
ggml_compute_forward_argsort_f32(params, src0, dst);

// TODO: for cpu only or partial offload, the counters state would need to be
// moved out of ggml-cuda.cu and maybe into here, and then this would need to
// perform a similar update to what the CUDA code does.
// Additionally the state reset and print functions are currently only being
// called from CUBLAS ifdefs, they'd need to be moved outside of those.

// for (int i = 0; i < ggml_nelements(dst); i++) {
// printf("%d ", ggml_get_i32_1d(dst, i));
// }
// printf("\n");
} break;
default:
{
Expand Down
24 changes: 24 additions & 0 deletions gpttype_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ static std::mutex concat_output_mtx;
static std::string concat_output = "";
static std::string concat_output_reader_copy = "";

static uint32_t xx_n_expert = 0;
static uint32_t xx_n_expert_used = 0;

const int extra_context_handle_fragmentation = 80;

inline bool IsNanCheck(float f)
Expand Down Expand Up @@ -972,6 +975,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
}
}

xx_n_expert = llamamodel->hparams.n_expert;
xx_n_expert_used = llamamodel->hparams.n_expert_used;

llama_ctx_v4 = llama_new_context_with_model(llamamodel, llama_ctx_params);

if (llama_ctx_v4 == NULL)
Expand Down Expand Up @@ -1721,6 +1727,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
printf("%s\n\n", RemoveBell(outstr).c_str());
}

#ifdef GGML_USE_CUBLAS
if (xx_n_expert > 0) {
reset_expert_counter(xx_n_expert_used, xx_n_expert);
}
#endif

while (remaining_tokens > 0)
{
gpt_vocab::id id = 0;
Expand Down Expand Up @@ -1841,6 +1853,12 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o

if (!startedsampling)
{
#ifdef GGML_USE_CUBLAS
if (xx_n_expert > 0) {
print_expert_counter();
reset_expert_counter(xx_n_expert_used, xx_n_expert);
}
#endif
startedsampling = true;
params.n_batch = original_batch;
params.n_threads = original_threads;
Expand Down Expand Up @@ -2004,5 +2022,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
total_gens += 1;
snprintf(output.text, sizeof(output.text), "%s", concat_output.c_str());

#ifdef GGML_USE_CUBLAS
if (xx_n_expert > 0) {
print_expert_counter();
}
#endif

return output;
}
13 changes: 13 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4203,6 +4203,11 @@ struct llm_build_context {
}

struct ggml_cgraph * build_llama() {
#ifdef GGML_USE_CUBLAS
// printf("\nexpert_used: %u\n", n_expert_used);
// reset_expert_counter(n_expert_used, n_expert);
#endif

struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);

GGML_ASSERT(n_embd_head == hparams.n_rot);
Expand Down Expand Up @@ -6078,6 +6083,8 @@ static int llama_decode_internal(

//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);

// printf("\n");

ggml_allocr_reset(lctx.alloc);

ggml_cgraph * gf = llama_build_graph(lctx, batch);
Expand Down Expand Up @@ -6175,6 +6182,12 @@ static int llama_decode_internal(
ggml_graph_print(gf);
#endif

#ifdef GGML_USE_CUBLAS
// print_expert_counter();
#endif

// ycros_moe_debug(gf);

// plot the computation graph in dot format (for debugging purposes)
//if (n_past%100 == 0) {
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
Expand Down