From d45c2b1d46fe04978278202271535c94e6781e65 Mon Sep 17 00:00:00 2001 From: imoc Date: Fri, 10 Jan 2025 12:04:45 +0800 Subject: [PATCH 01/17] Change malloc to calloc --- exllamav2/exllamav2_ext/cpp/sampling.cpp | 10 +++++----- exllamav2/exllamav2_ext/cuda/cache.cu | 10 +++++----- exllamav2/exllamav2_ext/cuda/q_matrix.cu | 13 +++++++++++-- exllamav2/exllamav2_ext/cuda/util.cu | 2 +- exllamav2/exllamav2_ext/ext_stloader.cpp | 2 +- 5 files changed, 23 insertions(+), 14 deletions(-) diff --git a/exllamav2/exllamav2_ext/cpp/sampling.cpp b/exllamav2/exllamav2_ext/cpp/sampling.cpp index d4b88197..2fc1d99b 100644 --- a/exllamav2/exllamav2_ext/cpp/sampling.cpp +++ b/exllamav2/exllamav2_ext/cpp/sampling.cpp @@ -38,7 +38,7 @@ void apply_rep_penalty_cpu // { // if (g_rep_mask) free(g_rep_mask); // g_vocab_size = vocab_size; -// g_rep_mask = (bool*) malloc(g_vocab_size * sizeof(bool)); +// g_rep_mask = (bool*) calloc(1, g_vocab_size * sizeof(bool)); // } // memset(g_rep_mask, 0, g_vocab_size * sizeof(bool)); bool* g_rep_mask = (bool*) calloc(vocab_size, sizeof(bool)); @@ -655,7 +655,7 @@ int tfs_cpu int nc = sort_descending(num_candidates, temp_probs, temp_indices, num_candidates); - float* derivative = (float*) malloc(nc * sizeof(float)); + float* derivative = (float*) calloc(1, nc * sizeof(float)); float dsum = 0.0f; for (int i = 0; i < nc - 2; i++) { @@ -759,9 +759,9 @@ int typical_cpu int r_candidates = pre_sort_descending(num_candidates, temp_probs, temp_indices); - float* temp = (float*) malloc(r_candidates * sizeof(float)); - int* entropy_dev_order = (int*) malloc(r_candidates * sizeof(int)); - int* temp_indices_2 = (int*) malloc(r_candidates * sizeof(int)); + float* temp = (float*) calloc(1, r_candidates * sizeof(float)); + int* entropy_dev_order = (int*) calloc(1, r_candidates * sizeof(int)); + int* temp_indices_2 = (int*) calloc(1, r_candidates * sizeof(int)); float neg_entropy = 0.0f; for (int i = 0; i < r_candidates; i++) diff --git a/exllamav2/exllamav2_ext/cuda/cache.cu b/exllamav2/exllamav2_ext/cuda/cache.cu index 53ec1cb2..f1f81a12 100644 --- a/exllamav2/exllamav2_ext/cuda/cache.cu +++ b/exllamav2/exllamav2_ext/cuda/cache.cu @@ -165,8 +165,8 @@ __global__ void fp16_to_q_kv_paged_kernel int page = block_table[pages_per_seq * y + x]; int seqlen = cache_seqlens[y]; - int vx_a = page_size * x; - int px_a = seqlen - vx_a; + int vx_a = (int64_t)page_size * x; + int px_a = (int64_t)seqlen - vx_a; int px_b = px_a + q_len; if (dim % BLOCKSIZE_Q) @@ -174,7 +174,7 @@ __global__ void fp16_to_q_kv_paged_kernel while ((px_a * dim) % BLOCKSIZE_Q) px_a--; while ((px_b * dim) % BLOCKSIZE_Q) px_b++; } - + px_a = max(px_a, 0); px_b = min(px_b, page_size); @@ -346,7 +346,7 @@ __global__ void q_to_fp16_kv_paged_kernel int seqlen = cache_seqlens[y]; if (!seqlen) return; - int vx_a = page_size * x; + int vx_a = (int64_t)page_size * x; int vx_b = min(vx_a + page_size, seqlen); if (dim < BLOCKSIZE_Q) @@ -491,4 +491,4 @@ void array_q_to_fp16_kv_cuda v_in, v_scales, v_out, dim, offset, stride ); -} +} \ No newline at end of file diff --git a/exllamav2/exllamav2_ext/cuda/q_matrix.cu b/exllamav2/exllamav2_ext/cuda/q_matrix.cu index 40350e87..bac12e7e 100644 --- a/exllamav2/exllamav2_ext/cuda/q_matrix.cu +++ b/exllamav2/exllamav2_ext/cuda/q_matrix.cu @@ -603,9 +603,18 @@ bool QMatrix::make_sequential(const uint32_t* cpu_g_idx, cudaStream_t stream) return false; } + // Zero out the allocated memory + size_t mem_size = (height / 8) * width * sizeof(uint32_t); + err = cudaMemset(cuda_new_qweight, 0, mem_size); + if (err != cudaSuccess) {;;; + printf("CUDA memset failed: %s\n", cudaGetErrorString(err)); + cudaFree(cuda_new_qweight); // Free the allocated memory in case of error + return err; + } + uint32_t* cpu_g_idx_map = (uint32_t*) calloc(groups, sizeof(uint32_t)); - uint32_t* cpu_x_map = (uint32_t*) malloc(height * sizeof(uint32_t)); - uint32_t* cpu_x_map_inv = (uint32_t*) malloc(height * sizeof(uint32_t)); + uint32_t* cpu_x_map = (uint32_t*) calloc(1, height * sizeof(uint32_t)); + uint32_t* cpu_x_map_inv = (uint32_t*) calloc(1, height * sizeof(uint32_t)); // Group histogram diff --git a/exllamav2/exllamav2_ext/cuda/util.cu b/exllamav2/exllamav2_ext/cuda/util.cu index 4f385791..8f7834ae 100644 --- a/exllamav2/exllamav2_ext/cuda/util.cu +++ b/exllamav2/exllamav2_ext/cuda/util.cu @@ -2,7 +2,7 @@ void print_global_mem(const half* ptr, int rows, int columns, int stride) { - half* temp = (half*) malloc(rows * columns * sizeof(half)); + half* temp = (half*) calloc(1, rows * columns * sizeof(half)); cudaDeviceSynchronize(); cudaMemcpyAsync(temp, ptr, rows * columns * sizeof(half), cudaMemcpyDeviceToHost); diff --git a/exllamav2/exllamav2_ext/ext_stloader.cpp b/exllamav2/exllamav2_ext/ext_stloader.cpp index 2b0b4c1e..0a4ce540 100644 --- a/exllamav2/exllamav2_ext/ext_stloader.cpp +++ b/exllamav2/exllamav2_ext/ext_stloader.cpp @@ -31,7 +31,7 @@ void stloader_read } else { - load_buffer = (uint8_t*) malloc(size); + load_buffer = (uint8_t*) calloc(1, size); TORCH_CHECK(load_buffer, "Can't allocate buffer for tensor"); cuda_buffer = (uint8_t*) target.data_ptr(); cudaSetDevice(device.value().index()); From a129ee9afbb6346833ec36f1ccb1675e27f12ac2 Mon Sep 17 00:00:00 2001 From: imoc Date: Fri, 10 Jan 2025 12:05:24 +0800 Subject: [PATCH 02/17] reverse VRAM scratching --- exllamav2/model.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/exllamav2/model.py b/exllamav2/model.py index 80ccf758..d00bd198 100644 --- a/exllamav2/model.py +++ b/exllamav2/model.py @@ -199,40 +199,37 @@ def set_device_map( reserve_bytes_attn = [0 for a in allocation] fixed_bytes = [0 for a in allocation] - current_idx = 0 - for idx, module in enumerate(self.modules): + # Start from the last device index + current_idx = len(allocation_bytes) - 1 + for idx, module in reversed(list(enumerate(self.modules))): # Special case for token embeddings on CPU - if isinstance(module, ExLlamaV2Embedding) and embed_cpu: - module.set_device_idx(-1) continue # Special case for attention - attn_bytes_current = 0 - if isinstance(module, ExLlamaV2Attention): attn_bytes_current = module.temp_attn_size() - - # Advance current_idx until module fits in allocation + if isinstance(module, ExLlamaV2Attention): + attn_bytes_current = module.temp_attn_size() + # Move current_idx backward until module fits in allocation footprint = module.weight_footprint() # Footprint, in bytes scratch = module.scratch_space() # Scratch space required by module while True: - assert current_idx < len(allocation_bytes), "Insufficient space in device allocation" + assert current_idx >= 0, "Insufficient space in device allocation" dev_scratch = max(scratch, reserve_bytes[current_idx]) dev_scratch_attn = max(attn_bytes_current, reserve_bytes_attn[current_idx]) - if footprint + dev_scratch + dev_scratch_attn <= allocation_bytes[current_idx]: break - current_idx += 1 + if footprint + dev_scratch + dev_scratch_attn <= allocation_bytes[current_idx]: + break + current_idx -= 1 # Size for fixed tensors - scratch_fixed = module.scratch_space_fixed() fixed_bytes[current_idx] = max(scratch_fixed, fixed_bytes[current_idx]) # Subtract module size from allocation - reserve_bytes[current_idx] = dev_scratch reserve_bytes_attn[current_idx] = dev_scratch_attn allocation_bytes[current_idx] -= footprint From 782d26c21a33a37e303a0b471427396f91c7a20f Mon Sep 17 00:00:00 2001 From: imoc Date: Tue, 7 Jan 2025 10:16:13 +0800 Subject: [PATCH 03/17] improvement v1 --- exllamav2/exllamav2_ext/ext_quant.cpp | 124 ++++++++++++++++++++++++-- 1 file changed, 119 insertions(+), 5 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index 38e7b439..9b354cdd 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -173,6 +173,7 @@ std::tuple>, std::vector, float, ui float norm ) { + // --- Original Simulated Annealing --- int num_slots = slots.size(); std::random_device rd; @@ -181,7 +182,7 @@ std::tuple>, std::vector, float, ui std::vector solution_idx(num_slots); uint64_t current_cost = 0; - float current_sum = 0.0; + float current_max_exp_error = 0; // Track max exp error float temp = initial_temp; int iterations_outer = static_cast(std::log(min_temp / temp) / std::log(cooling_factor)); @@ -190,7 +191,7 @@ std::tuple>, std::vector, float, ui { solution[i] = slots[i][0]; current_cost += std::get<0>(slots[i][0]); - current_sum += powf(std::get<1>(slots[i][0]), norm); + current_max_exp_error = std::max(current_max_exp_error, std::get<1>(slots[i][0])); // Initialize max exp error } for (int j = 0; j < iterations_outer; ++j) @@ -202,7 +203,19 @@ std::tuple>, std::vector, float, ui auto new_option = slots[i][n]; auto old_option = solution[i]; uint64_t delta_cost = std::get<0>(new_option) - std::get<0>(old_option); - float delta_e = powf(std::get<1>(new_option), norm) - powf(std::get<1>(old_option), norm); + float delta_e = std::get<1>(new_option) - std::get<1>(old_option); // Change to exp error difference + + // Calculate new max exp error + float new_max_exp_error = current_max_exp_error; + if(std::get<1>(old_option) == current_max_exp_error) { // If old option has max exp error, recalculate max + new_max_exp_error = std::get<1>(new_option); + for(int slot_idx = 0; slot_idx < num_slots; slot_idx++) { + if(slot_idx == i) continue; + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(solution[slot_idx])); + } + } else { // If old option does not have max exp error, only compare with new option + new_max_exp_error = std::max(current_max_exp_error, std::get<1>(new_option)); + } if (current_cost + delta_cost <= max_cost || (delta_cost < 0 && current_cost > max_cost)) { @@ -211,18 +224,119 @@ std::tuple>, std::vector, float, ui { solution[i] = new_option; solution_idx[i] = n; - current_sum += delta_e; current_cost += delta_cost; + current_max_exp_error = new_max_exp_error; } } } temp *= cooling_factor; } + // --- Post-processing: Bit Redistribution --- + const float bpw_threshold = 4.0f; + const int redistribution_iterations = 10; // Tune this parameter + const float bpw_penalty_scale = 0.01f; // Tune this parameter + + auto calculate_bpw = [&](const std::tuple& option) { + return 8.0f * std::get<0>(option) / 1024.0f; // Assuming cost is related to BPW + }; + + auto calculate_bpw_stats = [&](const std::vector>& sol) { + std::vector current_bpws(num_slots); + for (int i = 0; i < num_slots; ++i) { + current_bpws[i] = calculate_bpw(sol[i]); + } + float bpw_mean = std::accumulate(current_bpws.begin(), current_bpws.end(), 0.0f) / num_slots; + float bpw_sq_sum = std::inner_product(current_bpws.begin(), current_bpws.end(), current_bpws.begin(), 0.0f); + float bpw_variance = bpw_sq_sum / num_slots - bpw_mean * bpw_mean; + return std::make_pair(bpw_mean, std::sqrt(std::max(0.0f, bpw_variance))); + }; + + for (int r = 0; r < redistribution_iterations; ++r) { + std::vector low_bpw_indices; + std::vector high_bpw_indices; + + for (int i = 0; i < num_slots; ++i) { + float bpw = calculate_bpw(solution[i]); + if (bpw < bpw_threshold) { + low_bpw_indices.push_back(i); + } else { + high_bpw_indices.push_back(i); + } + } + + bool improved = false; + for (int low_idx : low_bpw_indices) { + if (high_bpw_indices.empty()) break; + + int high_idx = high_bpw_indices[std::uniform_int_distribution<>(0, high_bpw_indices.size() - 1)(gen)]; + + // Find a higher BPW option for the low-BPW slot + int best_low_new_idx = -1; + float best_low_new_error = 1e10f; + for (int n = 0; n < slots[low_idx].size(); ++n) { + if (calculate_bpw(slots[low_idx][n]) > calculate_bpw(solution[low_idx])) { + if (std::get<1>(slots[low_idx][n]) < best_low_new_error) { + best_low_new_error = std::get<1>(slots[low_idx][n]); + best_low_new_idx = n; + } + } + } + + // Find a lower BPW option for the high-BPW slot + int best_high_new_idx = -1; + float best_high_new_error = 1e10f; + for (int n = 0; n < slots[high_idx].size(); ++n) { + if (calculate_bpw(slots[high_idx][n]) < calculate_bpw(solution[high_idx])) { + if (std::get<1>(slots[high_idx][n]) < best_high_new_error) { + best_high_new_error = std::get<1>(slots[high_idx][n]); + best_high_new_idx = n; + } + } + } + + if (best_low_new_idx != -1 && best_high_new_idx != -1) { + auto new_low_option = slots[low_idx][best_low_new_idx]; + auto new_high_option = slots[high_idx][best_high_new_idx]; + + uint64_t new_cost = current_cost - std::get<0>(solution[low_idx]) - std::get<0>(solution[high_idx]) + std::get<0>(new_low_option) + std::get<0>(new_high_option); + + if (new_cost <= max_cost) { + // Calculate new max exp error + float new_max_exp_error = std::get<1>(new_low_option); + for(int i = 0; i < num_slots; i++) { + if(i == low_idx) continue; + if(i == high_idx) { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_high_option)); + } else { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(solution[i])); + } + } + + // Optional: Add a penalty for BPW imbalance + auto [current_bpw_mean, current_bpw_stddev] = calculate_bpw_stats(solution); + auto [new_bpw_mean, new_bpw_stddev] = calculate_bpw_stats({new_low_option, new_high_option}); + float bpw_penalty = bpw_penalty_scale * (new_bpw_stddev - current_bpw_stddev); + + if (new_max_exp_error + bpw_penalty < current_max_exp_error) { + // Accept the changes + solution[low_idx] = new_low_option; + solution_idx[low_idx] = best_low_new_idx; + solution[high_idx] = new_high_option; + solution_idx[high_idx] = best_high_new_idx; + current_cost = new_cost; + current_max_exp_error = new_max_exp_error; + improved = true; + } + } + } + } + } + float max_err = 0.0f; for (int i = 0; i < num_slots; ++i) max_err = std::max(max_err, std::get<1>(solution[i])); - return { solution, solution_idx, current_sum, current_cost, max_err }; + return { solution, solution_idx, current_max_exp_error, current_cost, max_err }; } From 056c60e49aa3decd24aab959b29a71b9bb45440a Mon Sep 17 00:00:00 2001 From: imoc Date: Tue, 7 Jan 2025 11:14:37 +0800 Subject: [PATCH 04/17] improvement v2 6.5bpw OG: -- sum(log(err)): -852.326775 -- max(err): 0.003952 calibration perplexity (quant): 8.0247 5.43bpw: v1: -- sum(log(err)): -833.710759 -- max(err): 0.005545 calibration perplexity (quant): 8.0294 v2: -- sum(log(err)): -865.617083 -- max(err): 0.006786 calibration perplexity (quant): DNF --- exllamav2/exllamav2_ext/ext_quant.cpp | 142 +++++++++++++++++++++----- 1 file changed, 114 insertions(+), 28 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index 9b354cdd..2437cdac 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -173,6 +173,12 @@ std::tuple>, std::vector, float, ui float norm ) { + // --- Internal Parameters --- + const int redistribution_iterations = 25; + const float bpw_penalty_scale = 0.01f; + const float min_bpw_limit = 2.0f; // Minimum allowed BPW + const int opportunistic_iterations = 1000; // Iterations for opportunistic optimization + // --- Original Simulated Annealing --- int num_slots = slots.size(); @@ -182,7 +188,7 @@ std::tuple>, std::vector, float, ui std::vector solution_idx(num_slots); uint64_t current_cost = 0; - float current_max_exp_error = 0; // Track max exp error + float current_max_exp_error = 0; float temp = initial_temp; int iterations_outer = static_cast(std::log(min_temp / temp) / std::log(cooling_factor)); @@ -191,29 +197,28 @@ std::tuple>, std::vector, float, ui { solution[i] = slots[i][0]; current_cost += std::get<0>(slots[i][0]); - current_max_exp_error = std::max(current_max_exp_error, std::get<1>(slots[i][0])); // Initialize max exp error + current_max_exp_error = std::max(current_max_exp_error, std::get<1>(slots[i][0])); } for (int j = 0; j < iterations_outer; ++j) { for (int k = 0; k < iterations; ++k) { - int i = std::uniform_int_distribution<>(0, num_slots - 1)(gen); // target slot - int n = std::uniform_int_distribution<>(0, slots[i].size() - 1)(gen); // target option + int i = std::uniform_int_distribution<>(0, num_slots - 1)(gen); + int n = std::uniform_int_distribution<>(0, slots[i].size() - 1)(gen); auto new_option = slots[i][n]; auto old_option = solution[i]; uint64_t delta_cost = std::get<0>(new_option) - std::get<0>(old_option); - float delta_e = std::get<1>(new_option) - std::get<1>(old_option); // Change to exp error difference - - // Calculate new max exp error + float delta_e = std::get<1>(new_option) - std::get<1>(old_option); + float new_max_exp_error = current_max_exp_error; - if(std::get<1>(old_option) == current_max_exp_error) { // If old option has max exp error, recalculate max + if (std::get<1>(old_option) == current_max_exp_error) { new_max_exp_error = std::get<1>(new_option); - for(int slot_idx = 0; slot_idx < num_slots; slot_idx++) { - if(slot_idx == i) continue; + for (int slot_idx = 0; slot_idx < num_slots; slot_idx++) { + if (slot_idx == i) continue; new_max_exp_error = std::max(new_max_exp_error, std::get<1>(solution[slot_idx])); } - } else { // If old option does not have max exp error, only compare with new option + } else { new_max_exp_error = std::max(current_max_exp_error, std::get<1>(new_option)); } @@ -233,12 +238,8 @@ std::tuple>, std::vector, float, ui } // --- Post-processing: Bit Redistribution --- - const float bpw_threshold = 4.0f; - const int redistribution_iterations = 10; // Tune this parameter - const float bpw_penalty_scale = 0.01f; // Tune this parameter - auto calculate_bpw = [&](const std::tuple& option) { - return 8.0f * std::get<0>(option) / 1024.0f; // Assuming cost is related to BPW + return 8.0f * std::get<0>(option) / 1024.0f; }; auto calculate_bpw_stats = [&](const std::vector>& sol) { @@ -253,6 +254,10 @@ std::tuple>, std::vector, float, ui }; for (int r = 0; r < redistribution_iterations; ++r) { + // Calculate BPW statistics and dynamic bpw_threshold + auto [bpw_mean, bpw_stddev] = calculate_bpw_stats(solution); + float bpw_threshold = std::max(min_bpw_limit, bpw_mean - 0.5f * bpw_stddev); + std::vector low_bpw_indices; std::vector high_bpw_indices; @@ -268,10 +273,9 @@ std::tuple>, std::vector, float, ui bool improved = false; for (int low_idx : low_bpw_indices) { if (high_bpw_indices.empty()) break; - + int high_idx = high_bpw_indices[std::uniform_int_distribution<>(0, high_bpw_indices.size() - 1)(gen)]; - // Find a higher BPW option for the low-BPW slot int best_low_new_idx = -1; float best_low_new_error = 1e10f; for (int n = 0; n < slots[low_idx].size(); ++n) { @@ -283,7 +287,6 @@ std::tuple>, std::vector, float, ui } } - // Find a lower BPW option for the high-BPW slot int best_high_new_idx = -1; float best_high_new_error = 1e10f; for (int n = 0; n < slots[high_idx].size(); ++n) { @@ -300,26 +303,23 @@ std::tuple>, std::vector, float, ui auto new_high_option = slots[high_idx][best_high_new_idx]; uint64_t new_cost = current_cost - std::get<0>(solution[low_idx]) - std::get<0>(solution[high_idx]) + std::get<0>(new_low_option) + std::get<0>(new_high_option); - + if (new_cost <= max_cost) { - // Calculate new max exp error float new_max_exp_error = std::get<1>(new_low_option); - for(int i = 0; i < num_slots; i++) { - if(i == low_idx) continue; - if(i == high_idx) { + for (int i = 0; i < num_slots; i++) { + if (i == low_idx) continue; + if (i == high_idx) { new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_high_option)); } else { new_max_exp_error = std::max(new_max_exp_error, std::get<1>(solution[i])); } } - // Optional: Add a penalty for BPW imbalance auto [current_bpw_mean, current_bpw_stddev] = calculate_bpw_stats(solution); auto [new_bpw_mean, new_bpw_stddev] = calculate_bpw_stats({new_low_option, new_high_option}); float bpw_penalty = bpw_penalty_scale * (new_bpw_stddev - current_bpw_stddev); if (new_max_exp_error + bpw_penalty < current_max_exp_error) { - // Accept the changes solution[low_idx] = new_low_option; solution_idx[low_idx] = best_low_new_idx; solution[high_idx] = new_high_option; @@ -333,10 +333,96 @@ std::tuple>, std::vector, float, ui } } + // --- Opportunistic Optimization --- + for (int i = 0; i < opportunistic_iterations; ++i) { + int slot1 = std::uniform_int_distribution<>(0, num_slots - 1)(gen); + int slot2 = std::uniform_int_distribution<>(0, num_slots - 1)(gen); + if (slot1 == slot2) continue; + + int option1 = solution_idx[slot1]; + int option2 = solution_idx[slot2]; + + // Try to increase BPW of slot1 and decrease BPW of slot2 + int best_option1 = -1; + float best_option1_error = 1e10f; + for(int new_option1 = 0; new_option1 < slots[slot1].size(); new_option1++) { + if(calculate_bpw(slots[slot1][new_option1]) > calculate_bpw(solution[slot1])) { + if (std::get<1>(slots[slot1][new_option1]) < best_option1_error) { + best_option1_error = std::get<1>(slots[slot1][new_option1]); + best_option1 = new_option1; + } + } + } + int best_option2 = -1; + float best_option2_error = 1e10f; + for(int new_option2 = 0; new_option2 < slots[slot2].size(); new_option2++) { + if(calculate_bpw(slots[slot2][new_option2]) < calculate_bpw(solution[slot2])) { + if (std::get<1>(slots[slot2][new_option2]) < best_option2_error) { + best_option2_error = std::get<1>(slots[slot2][new_option2]); + best_option2 = new_option2; + } + } + } + + if (best_option1 != -1 && best_option2 != -1) { + auto new_option1 = slots[slot1][best_option1]; + auto new_option2 = slots[slot2][best_option2]; + + if(calculate_bpw(new_option1) < min_bpw_limit || calculate_bpw(new_option2) < min_bpw_limit) continue; + + uint64_t new_cost = current_cost - std::get<0>(solution[slot1]) - std::get<0>(solution[slot2]) + std::get<0>(new_option1) + std::get<0>(new_option2); + + if (new_cost <= max_cost) { + // Calculate new max exp error + float new_max_exp_error = std::get<1>(new_option1); + for(int j = 0; j < num_slots; j++) { + if(j == slot1) continue; + if(j == slot2) { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_option2)); + } else { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(solution[j])); + } + } + + // Calculate sum of log errors + float new_sum_log_err = 0; + for (int j = 0; j < num_slots; ++j) { + if (j == slot1) { + new_sum_log_err += log(std::get<1>(new_option1)); + } else if (j == slot2) { + new_sum_log_err += log(std::get<1>(new_option2)); + } else { + new_sum_log_err += log(std::get<1>(solution[j])); + } + } + + // Calculate current sum of log errors + float current_sum_log_err = 0; + for (int j = 0; j < num_slots; ++j) { + current_sum_log_err += log(std::get<1>(solution[j])); + } + + // Accept change if it reduces sum of log errors without increasing max error too much + if (new_sum_log_err < current_sum_log_err && new_max_exp_error < current_max_exp_error * 1.05f) { + solution[slot1] = new_option1; + solution_idx[slot1] = best_option1; + solution[slot2] = new_option2; + solution_idx[slot2] = best_option2; + current_cost = new_cost; + current_max_exp_error = new_max_exp_error; + } + } + } + } + + // Calculate final max error and sum of log errors float max_err = 0.0f; - for (int i = 0; i < num_slots; ++i) + float sum_log_err = 0.0; + for (int i = 0; i < num_slots; ++i) { max_err = std::max(max_err, std::get<1>(solution[i])); + sum_log_err += log(std::get<1>(solution[i])); + } - return { solution, solution_idx, current_max_exp_error, current_cost, max_err }; + return { solution, solution_idx, sum_log_err, current_cost, max_err }; } From 272c31d3e4d682a7d817286a94f4c2c5971b4a38 Mon Sep 17 00:00:00 2001 From: imoc Date: Tue, 7 Jan 2025 11:33:47 +0800 Subject: [PATCH 05/17] improvement v3-1 -- sum(log(err)): -866.199803 -- max(err): 0.005706 --- exllamav2/exllamav2_ext/ext_quant.cpp | 83 +++++++++++++++++++++------ 1 file changed, 65 insertions(+), 18 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index 2437cdac..157661a4 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -174,10 +174,11 @@ std::tuple>, std::vector, float, ui ) { // --- Internal Parameters --- - const int redistribution_iterations = 25; + const int redistribution_iterations = 25; const float bpw_penalty_scale = 0.01f; - const float min_bpw_limit = 2.0f; // Minimum allowed BPW - const int opportunistic_iterations = 1000; // Iterations for opportunistic optimization + const float min_bpw_limit = 2.0f; + const int opportunistic_iterations = 5000; + const float bpw_transfer_step = 0.0625f; // Amount of BPW to transfer in each step // --- Original Simulated Annealing --- int num_slots = slots.size(); @@ -334,29 +335,50 @@ std::tuple>, std::vector, float, ui } // --- Opportunistic Optimization --- + // Track the best solution found during opportunistic optimization + std::vector> best_solution_opportunistic = solution; + std::vector best_solution_idx_opportunistic = solution_idx; + float best_sum_log_err_opportunistic = 1e18f; + uint64_t best_cost_opportunistic = current_cost; + for (int i = 0; i < opportunistic_iterations; ++i) { - int slot1 = std::uniform_int_distribution<>(0, num_slots - 1)(gen); + auto [bpw_mean, bpw_stddev] = calculate_bpw_stats(solution); + float bpw_threshold = std::max(min_bpw_limit, bpw_mean - 0.5f * bpw_stddev); + + int slot1 = -1; + // Find a slot with BPW above the threshold + std::vector high_bpw_indices; + for(int j = 0; j < num_slots; j++) { + if(calculate_bpw(solution[j]) > bpw_threshold) { + high_bpw_indices.push_back(j); + } + } + if(high_bpw_indices.empty()) continue; + slot1 = high_bpw_indices[std::uniform_int_distribution<>(0, high_bpw_indices.size() - 1)(gen)]; + int slot2 = std::uniform_int_distribution<>(0, num_slots - 1)(gen); if (slot1 == slot2) continue; int option1 = solution_idx[slot1]; int option2 = solution_idx[slot2]; - // Try to increase BPW of slot1 and decrease BPW of slot2 + // Find a lower BPW option for slot1 int best_option1 = -1; float best_option1_error = 1e10f; - for(int new_option1 = 0; new_option1 < slots[slot1].size(); new_option1++) { - if(calculate_bpw(slots[slot1][new_option1]) > calculate_bpw(solution[slot1])) { + for (int new_option1 = 0; new_option1 < slots[slot1].size(); new_option1++) { + if (calculate_bpw(slots[slot1][new_option1]) < calculate_bpw(solution[slot1])) { if (std::get<1>(slots[slot1][new_option1]) < best_option1_error) { best_option1_error = std::get<1>(slots[slot1][new_option1]); best_option1 = new_option1; } } } + + // Find a higher BPW option for slot2 int best_option2 = -1; float best_option2_error = 1e10f; - for(int new_option2 = 0; new_option2 < slots[slot2].size(); new_option2++) { - if(calculate_bpw(slots[slot2][new_option2]) < calculate_bpw(solution[slot2])) { + for (int new_option2 = 0; new_option2 < slots[slot2].size(); new_option2++) { + if (calculate_bpw(slots[slot2][new_option2]) > calculate_bpw(solution[slot2])) { if (std::get<1>(slots[slot2][new_option2]) < best_option2_error) { best_option2_error = std::get<1>(slots[slot2][new_option2]); best_option2 = new_option2; @@ -367,18 +389,18 @@ std::tuple>, std::vector, float, ui if (best_option1 != -1 && best_option2 != -1) { auto new_option1 = slots[slot1][best_option1]; auto new_option2 = slots[slot2][best_option2]; - - if(calculate_bpw(new_option1) < min_bpw_limit || calculate_bpw(new_option2) < min_bpw_limit) continue; + + if (calculate_bpw(new_option2) < min_bpw_limit) continue; uint64_t new_cost = current_cost - std::get<0>(solution[slot1]) - std::get<0>(solution[slot2]) + std::get<0>(new_option1) + std::get<0>(new_option2); if (new_cost <= max_cost) { // Calculate new max exp error - float new_max_exp_error = std::get<1>(new_option1); - for(int j = 0; j < num_slots; j++) { - if(j == slot1) continue; - if(j == slot2) { - new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_option2)); + float new_max_exp_error = std::get<1>(new_option2); + for (int j = 0; j < num_slots; j++) { + if (j == slot2) continue; + if (j == slot1) { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_option1)); } else { new_max_exp_error = std::max(new_max_exp_error, std::get<1>(solution[j])); } @@ -402,19 +424,44 @@ std::tuple>, std::vector, float, ui current_sum_log_err += log(std::get<1>(solution[j])); } - // Accept change if it reduces sum of log errors without increasing max error too much - if (new_sum_log_err < current_sum_log_err && new_max_exp_error < current_max_exp_error * 1.05f) { + // Accept change if it reduces sum of log errors without increasing max error + if (new_sum_log_err < current_sum_log_err && new_max_exp_error <= current_max_exp_error) + { solution[slot1] = new_option1; solution_idx[slot1] = best_option1; solution[slot2] = new_option2; solution_idx[slot2] = best_option2; current_cost = new_cost; current_max_exp_error = new_max_exp_error; + current_sum_log_err = new_sum_log_err; + + // Update best solution found during opportunistic optimization + if (current_sum_log_err < best_sum_log_err_opportunistic) { + best_sum_log_err_opportunistic = current_sum_log_err; + best_cost_opportunistic = current_cost; + best_solution_opportunistic = solution; + best_solution_idx_opportunistic = solution_idx; + } } } } } + // Use the best solution found during opportunistic optimization + if (best_sum_log_err_opportunistic < 1e18f) { + solution = best_solution_opportunistic; + solution_idx = best_solution_idx_opportunistic; + current_cost = best_cost_opportunistic; + } + + // --- Final Cost Check and Rollback (if necessary) --- + if (current_cost > max_cost) { + // Revert to the solution before opportunistic optimization + solution = best_solution_opportunistic; + solution_idx = best_solution_idx_opportunistic; + current_cost = best_cost_opportunistic; + } + // Calculate final max error and sum of log errors float max_err = 0.0f; float sum_log_err = 0.0; From 117a60a352157265c9ba1226468c83d7329807b5 Mon Sep 17 00:00:00 2001 From: imoc Date: Tue, 7 Jan 2025 15:39:12 +0800 Subject: [PATCH 06/17] improvement v3-2 -- sum(log(err)): -840.236110 -- max(err): 0.005603 --- exllamav2/exllamav2_ext/ext_quant.cpp | 331 ++++++++++++++++++-------- 1 file changed, 227 insertions(+), 104 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index 157661a4..a9f80edd 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -177,8 +177,10 @@ std::tuple>, std::vector, float, ui const int redistribution_iterations = 25; const float bpw_penalty_scale = 0.01f; const float min_bpw_limit = 2.0f; - const int opportunistic_iterations = 5000; - const float bpw_transfer_step = 0.0625f; // Amount of BPW to transfer in each step + const int opportunistic_iterations_stage1 = 5000; + const int opportunistic_iterations_stage2 = 10000; + const float initial_opportunistic_temp = 0.01f; + const float min_exp_error_threshold = 0.001f; // --- Original Simulated Annealing --- int num_slots = slots.size(); @@ -334,132 +336,253 @@ std::tuple>, std::vector, float, ui } } - // --- Opportunistic Optimization --- - // Track the best solution found during opportunistic optimization - std::vector> best_solution_opportunistic = solution; - std::vector best_solution_idx_opportunistic = solution_idx; - float best_sum_log_err_opportunistic = 1e18f; - uint64_t best_cost_opportunistic = current_cost; - - for (int i = 0; i < opportunistic_iterations; ++i) { - auto [bpw_mean, bpw_stddev] = calculate_bpw_stats(solution); - float bpw_threshold = std::max(min_bpw_limit, bpw_mean - 0.5f * bpw_stddev); + // --- Opportunistic Optimization (Stage 1: Focus on Sum of Log Errors) --- + float current_sum_log_err = 0; + for (int i = 0; i < num_slots; ++i) { + current_sum_log_err += log(std::get<1>(solution[i])); + } - int slot1 = -1; - // Find a slot with BPW above the threshold - std::vector high_bpw_indices; - for(int j = 0; j < num_slots; j++) { - if(calculate_bpw(solution[j]) > bpw_threshold) { - high_bpw_indices.push_back(j); - } + float best_sum_log_err = current_sum_log_err; + std::vector> best_solution = solution; + std::vector best_solution_idx = solution_idx; + float best_max_exp_error = current_max_exp_error; + + float local_temp = initial_opportunistic_temp; + for (int i = 0; i < opportunistic_iterations_stage1; ++i) { + // Select a neighborhood of slots + int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); + int neighborhood_size = std::min(5, num_slots); + int start_slot = std::max(0, center_slot - neighborhood_size / 2); + int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2); + + // Calculate average BPW in the neighborhood + float neighborhood_bpw_sum = 0; + for (int j = start_slot; j <= end_slot; ++j) { + neighborhood_bpw_sum += calculate_bpw(solution[j]); } - if(high_bpw_indices.empty()) continue; - slot1 = high_bpw_indices[std::uniform_int_distribution<>(0, high_bpw_indices.size() - 1)(gen)]; - - int slot2 = std::uniform_int_distribution<>(0, num_slots - 1)(gen); - if (slot1 == slot2) continue; - - int option1 = solution_idx[slot1]; - int option2 = solution_idx[slot2]; - - // Find a lower BPW option for slot1 - int best_option1 = -1; - float best_option1_error = 1e10f; - for (int new_option1 = 0; new_option1 < slots[slot1].size(); new_option1++) { - if (calculate_bpw(slots[slot1][new_option1]) < calculate_bpw(solution[slot1])) { - if (std::get<1>(slots[slot1][new_option1]) < best_option1_error) { - best_option1_error = std::get<1>(slots[slot1][new_option1]); - best_option1 = new_option1; + float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1); + + // Adjust BPWs within the neighborhood, weighted by error + std::vector> new_solution = solution; + std::vector new_solution_idx = solution_idx; + float new_sum_log_err = current_sum_log_err; + uint64_t new_cost = current_cost; + + for (int j = start_slot; j <= end_slot; ++j) { + float current_bpw = calculate_bpw(solution[j]); + float target_bpw = neighborhood_bpw_avg; + float error = std::get<1>(solution[j]); + float adjustment = 0.125f; + + // Error-weighted adjustment + float error_weight = std::max(0.0f, error - min_exp_error_threshold); + + if (current_bpw < target_bpw) { + // Search for a higher BPW option + for (int n = 0; n < slots[j].size(); ++n) { + auto new_option = slots[j][n]; + if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) + { + if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) + { + new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); + new_solution[j] = new_option; + new_solution_idx[j] = n; + break; + } + } + } + } else if (current_bpw > target_bpw) { + // Search for a lower BPW option + for (int n = 0; n < slots[j].size(); ++n) { + auto new_option = slots[j][n]; + if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) + { + if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) + { + new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); + new_solution[j] = new_option; + new_solution_idx[j] = n; + break; + } + } } } } - // Find a higher BPW option for slot2 - int best_option2 = -1; - float best_option2_error = 1e10f; - for (int new_option2 = 0; new_option2 < slots[slot2].size(); new_option2++) { - if (calculate_bpw(slots[slot2][new_option2]) > calculate_bpw(solution[slot2])) { - if (std::get<1>(slots[slot2][new_option2]) < best_option2_error) { - best_option2_error = std::get<1>(slots[slot2][new_option2]); - best_option2 = new_option2; + // Calculate new max exp error + float new_max_exp_error = 0; + for (int j = 0; j < num_slots; ++j) { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); + } + + // Acceptance criterion with a small probability of accepting worse solutions + if (new_cost <= max_cost) { + float delta_sum_log_err = new_sum_log_err - current_sum_log_err; + if (delta_sum_log_err < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err / local_temp)) { + solution = new_solution; + solution_idx = new_solution_idx; + current_sum_log_err = new_sum_log_err; + current_cost = new_cost; + current_max_exp_error = new_max_exp_error; + + if (current_sum_log_err < best_sum_log_err) { + best_sum_log_err = current_sum_log_err; + best_solution = solution; + best_solution_idx = solution_idx; + best_max_exp_error = current_max_exp_error; } } } - if (best_option1 != -1 && best_option2 != -1) { - auto new_option1 = slots[slot1][best_option1]; - auto new_option2 = slots[slot2][best_option2]; - - if (calculate_bpw(new_option2) < min_bpw_limit) continue; - - uint64_t new_cost = current_cost - std::get<0>(solution[slot1]) - std::get<0>(solution[slot2]) + std::get<0>(new_option1) + std::get<0>(new_option2); + local_temp *= 0.95f; + } - if (new_cost <= max_cost) { - // Calculate new max exp error - float new_max_exp_error = std::get<1>(new_option2); - for (int j = 0; j < num_slots; j++) { - if (j == slot2) continue; - if (j == slot1) { - new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_option1)); - } else { - new_max_exp_error = std::max(new_max_exp_error, std::get<1>(solution[j])); + // --- Opportunistic Optimization (Stage 2: Focus on Max Error and Min BPW) --- + local_temp = initial_opportunistic_temp * 0.1f; // Lower temperature for Stage 2 + for (int i = 0; i < opportunistic_iterations_stage2; ++i) { + // Select a neighborhood of slots + int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); + int neighborhood_size = std::min(5, num_slots); + int start_slot = std::max(0, center_slot - neighborhood_size / 2); + int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2); + + // Calculate average BPW in the neighborhood + float neighborhood_bpw_sum = 0; + for (int j = start_slot; j <= end_slot; ++j) { + neighborhood_bpw_sum += calculate_bpw(solution[j]); + } + float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1); + + // Adjust BPWs within the neighborhood, weighted by error + std::vector> new_solution = solution; + std::vector new_solution_idx = solution_idx; + float new_sum_log_err = current_sum_log_err; + uint64_t new_cost = current_cost; + + for (int j = start_slot; j <= end_slot; ++j) { + float current_bpw = calculate_bpw(solution[j]); + float target_bpw = neighborhood_bpw_avg; + float error = std::get<1>(solution[j]); + float adjustment = 0.125f; + + // Focus on increasing BPW if below min_bpw_limit + if (current_bpw < min_bpw_limit) { + // Search for a higher BPW option + for (int n = 0; n < slots[j].size(); ++n) { + auto new_option = slots[j][n]; + if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) + { + if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) + { + new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); + new_solution[j] = new_option; + new_solution_idx[j] = n; + break; + } } } - - // Calculate sum of log errors - float new_sum_log_err = 0; - for (int j = 0; j < num_slots; ++j) { - if (j == slot1) { - new_sum_log_err += log(std::get<1>(new_option1)); - } else if (j == slot2) { - new_sum_log_err += log(std::get<1>(new_option2)); - } else { - new_sum_log_err += log(std::get<1>(solution[j])); + } else { + // Error-weighted adjustment (less aggressive if error is already low) + float error_weight = std::max(0.0f, error - min_exp_error_threshold); + + if (current_bpw < target_bpw) { + // Search for a higher BPW option + for (int n = 0; n < slots[j].size(); ++n) { + auto new_option = slots[j][n]; + if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) + { + if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) + { + new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); + new_solution[j] = new_option; + new_solution_idx[j] = n; + break; + } + } + } + } else if (current_bpw > target_bpw) { + // Search for a lower BPW option + for (int n = 0; n < slots[j].size(); ++n) { + auto new_option = slots[j][n]; + if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) + { + if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) + { + new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); + new_solution[j] = new_option; + new_solution_idx[j] = n; + break; + } + } } } + } + } - // Calculate current sum of log errors - float current_sum_log_err = 0; - for (int j = 0; j < num_slots; ++j) { - current_sum_log_err += log(std::get<1>(solution[j])); - } + // Calculate new max exp error + float new_max_exp_error = 0; + for (int j = 0; j < num_slots; ++j) { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); + } - // Accept change if it reduces sum of log errors without increasing max error - if (new_sum_log_err < current_sum_log_err && new_max_exp_error <= current_max_exp_error) - { - solution[slot1] = new_option1; - solution_idx[slot1] = best_option1; - solution[slot2] = new_option2; - solution_idx[slot2] = best_option2; - current_cost = new_cost; - current_max_exp_error = new_max_exp_error; - current_sum_log_err = new_sum_log_err; - - // Update best solution found during opportunistic optimization - if (current_sum_log_err < best_sum_log_err_opportunistic) { - best_sum_log_err_opportunistic = current_sum_log_err; - best_cost_opportunistic = current_cost; - best_solution_opportunistic = solution; - best_solution_idx_opportunistic = solution_idx; - } + // Acceptance criterion (more emphasis on max error and min BPW) + auto [new_bpw_mean, new_bpw_stddev] = calculate_bpw_stats(new_solution); + if (new_cost <= max_cost && new_bpw_mean >= min_bpw_limit) { + float delta_sum_log_err = new_sum_log_err - current_sum_log_err; + float delta_max_exp_error = new_max_exp_error - current_max_exp_error; + + // Prioritize reducing max error and increasing min BPW + if ((delta_max_exp_error < 0 || (delta_max_exp_error == 0 && delta_sum_log_err < 0)) || + std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-(delta_sum_log_err + delta_max_exp_error * 100.0f) / local_temp)) { + solution = new_solution; + solution_idx = new_solution_idx; + current_sum_log_err = new_sum_log_err; + current_cost = new_cost; + current_max_exp_error = new_max_exp_error; + + if (current_sum_log_err < best_sum_log_err) { + best_sum_log_err = current_sum_log_err; + best_solution = solution; + best_solution_idx = solution_idx; + best_max_exp_error = current_max_exp_error; } } } - } - // Use the best solution found during opportunistic optimization - if (best_sum_log_err_opportunistic < 1e18f) { - solution = best_solution_opportunistic; - solution_idx = best_solution_idx_opportunistic; - current_cost = best_cost_opportunistic; + local_temp *= 0.95f; } - // --- Final Cost Check and Rollback (if necessary) --- + // --- Final Cost Correction (if needed) --- if (current_cost > max_cost) { - // Revert to the solution before opportunistic optimization - solution = best_solution_opportunistic; - solution_idx = best_solution_idx_opportunistic; - current_cost = best_cost_opportunistic; + std::vector> error_indices(num_slots); + for (int i = 0; i < num_slots; ++i) { + error_indices[i] = {std::get<1>(solution[i]), i}; + } + std::sort(error_indices.begin(), error_indices.end()); + + for (const auto& pair : error_indices) { + int i = pair.second; + for (int n = slots[i].size() - 1; n >= 0; --n) { + if (calculate_bpw(slots[i][n]) < calculate_bpw(solution[i])) + { + if (current_cost - std::get<0>(solution[i]) + std::get<0>(slots[i][n]) <= max_cost) + { + uint64_t delta_cost = std::get<0>(slots[i][n]) - std::get<0>(solution[i]); + current_cost += delta_cost; + solution[i] = slots[i][n]; + solution_idx[i] = n; + break; + } + } + } + if (current_cost <= max_cost) break; + } } // Calculate final max error and sum of log errors From 8483153827595e4d4d5c46ca71163b3f06713428 Mon Sep 17 00:00:00 2001 From: imoc Date: Tue, 7 Jan 2025 16:28:10 +0800 Subject: [PATCH 07/17] improvement v3-3 -- sum(log(err)): -839.939039 -- max(err): 0.005954 +1: try to avoid <4 bpw layer(rng) --- exllamav2/exllamav2_ext/ext_quant.cpp | 322 +++++++++++++------------- 1 file changed, 156 insertions(+), 166 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index a9f80edd..ba83813b 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -177,10 +177,9 @@ std::tuple>, std::vector, float, ui const int redistribution_iterations = 25; const float bpw_penalty_scale = 0.01f; const float min_bpw_limit = 2.0f; - const int opportunistic_iterations_stage1 = 5000; - const int opportunistic_iterations_stage2 = 10000; - const float initial_opportunistic_temp = 0.01f; - const float min_exp_error_threshold = 0.001f; + const int opportunistic_iterations = 10000; + const float opportunistic_temp = 0.01f; + const float error_threshold = 0.005f; // Threshold to switch to Stage 2 optimization // --- Original Simulated Annealing --- int num_slots = slots.size(); @@ -336,7 +335,7 @@ std::tuple>, std::vector, float, ui } } - // --- Opportunistic Optimization (Stage 1: Focus on Sum of Log Errors) --- + // --- Multi-Stage Opportunistic Optimization --- float current_sum_log_err = 0; for (int i = 0; i < num_slots; ++i) { current_sum_log_err += log(std::get<1>(solution[i])); @@ -345,160 +344,147 @@ std::tuple>, std::vector, float, ui float best_sum_log_err = current_sum_log_err; std::vector> best_solution = solution; std::vector best_solution_idx = solution_idx; - float best_max_exp_error = current_max_exp_error; - - float local_temp = initial_opportunistic_temp; - for (int i = 0; i < opportunistic_iterations_stage1; ++i) { - // Select a neighborhood of slots - int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); - int neighborhood_size = std::min(5, num_slots); - int start_slot = std::max(0, center_slot - neighborhood_size / 2); - int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2); - - // Calculate average BPW in the neighborhood - float neighborhood_bpw_sum = 0; - for (int j = start_slot; j <= end_slot; ++j) { - neighborhood_bpw_sum += calculate_bpw(solution[j]); - } - float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1); - - // Adjust BPWs within the neighborhood, weighted by error - std::vector> new_solution = solution; - std::vector new_solution_idx = solution_idx; - float new_sum_log_err = current_sum_log_err; - uint64_t new_cost = current_cost; - - for (int j = start_slot; j <= end_slot; ++j) { - float current_bpw = calculate_bpw(solution[j]); - float target_bpw = neighborhood_bpw_avg; - float error = std::get<1>(solution[j]); - float adjustment = 0.125f; - - // Error-weighted adjustment - float error_weight = std::max(0.0f, error - min_exp_error_threshold); - - if (current_bpw < target_bpw) { - // Search for a higher BPW option - for (int n = 0; n < slots[j].size(); ++n) { - auto new_option = slots[j][n]; - if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) - { - if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) - { - new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); - new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); - new_solution[j] = new_option; - new_solution_idx[j] = n; - break; + + float local_temp = opportunistic_temp; + for (int i = 0; i < opportunistic_iterations; ++i) { + // STAGE 1: Focus on minimizing sum(log(err)) while keeping max(err) reasonable + if (current_max_exp_error > error_threshold) { + // Select a neighborhood of slots + int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); + int neighborhood_size = std::min(7, num_slots); // Larger neighborhood size + int start_slot = std::max(0, center_slot - neighborhood_size / 2); + int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2); + + // Calculate average BPW in the neighborhood + float neighborhood_bpw_sum = 0; + for (int j = start_slot; j <= end_slot; ++j) { + neighborhood_bpw_sum += calculate_bpw(solution[j]); + } + float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1); + + // Adjust BPWs within the neighborhood + std::vector> new_solution = solution; + std::vector new_solution_idx = solution_idx; + float new_sum_log_err = current_sum_log_err; + uint64_t new_cost = current_cost; + + for (int j = start_slot; j <= end_slot; ++j) { + float current_bpw = calculate_bpw(solution[j]); + float target_bpw = neighborhood_bpw_avg; + float error = std::get<1>(solution[j]); + + // Dynamic BPW adjustment based on error and BPW difference + float adjustment = std::min(0.5f * std::abs(current_bpw - target_bpw) * error, 0.25f); + + // Adjust BPW towards the target, weighted by error + if (current_bpw < target_bpw) { + // Search for a higher BPW option + for (int n = 0; n < slots[j].size(); ++n) { + auto new_option = slots[j][n]; + if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) + { + if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) + { + new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::max(0.001f, std::get<1>(solution[j]))) + log(std::max(0.001f, std::get<1>(new_option))); + new_solution[j] = new_option; + new_solution_idx[j] = n; + break; + } } } - } - } else if (current_bpw > target_bpw) { - // Search for a lower BPW option - for (int n = 0; n < slots[j].size(); ++n) { - auto new_option = slots[j][n]; - if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) - { - if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) - { - new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); - new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); - new_solution[j] = new_option; - new_solution_idx[j] = n; - break; + } else if (current_bpw > target_bpw) { + // Search for a lower BPW option + for (int n = 0; n < slots[j].size(); ++n) { + auto new_option = slots[j][n]; + if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) + { + if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) + { + new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::max(0.001f, std::get<1>(solution[j]))) + log(std::max(0.001f, std::get<1>(new_option))); + new_solution[j] = new_option; + new_solution_idx[j] = n; + break; + } } } } } - } - // Calculate new max exp error - float new_max_exp_error = 0; - for (int j = 0; j < num_slots; ++j) { - new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); - } - - // Acceptance criterion with a small probability of accepting worse solutions - if (new_cost <= max_cost) { - float delta_sum_log_err = new_sum_log_err - current_sum_log_err; - if (delta_sum_log_err < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err / local_temp)) { - solution = new_solution; - solution_idx = new_solution_idx; - current_sum_log_err = new_sum_log_err; - current_cost = new_cost; - current_max_exp_error = new_max_exp_error; - - if (current_sum_log_err < best_sum_log_err) { - best_sum_log_err = current_sum_log_err; - best_solution = solution; - best_solution_idx = solution_idx; - best_max_exp_error = current_max_exp_error; - } + // Calculate new max exp error + float new_max_exp_error = 0; + for (int j = 0; j < num_slots; ++j) { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); } - } - - local_temp *= 0.95f; - } - // --- Opportunistic Optimization (Stage 2: Focus on Max Error and Min BPW) --- - local_temp = initial_opportunistic_temp * 0.1f; // Lower temperature for Stage 2 - for (int i = 0; i < opportunistic_iterations_stage2; ++i) { - // Select a neighborhood of slots - int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); - int neighborhood_size = std::min(5, num_slots); - int start_slot = std::max(0, center_slot - neighborhood_size / 2); - int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2); - - // Calculate average BPW in the neighborhood - float neighborhood_bpw_sum = 0; - for (int j = start_slot; j <= end_slot; ++j) { - neighborhood_bpw_sum += calculate_bpw(solution[j]); - } - float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1); - - // Adjust BPWs within the neighborhood, weighted by error - std::vector> new_solution = solution; - std::vector new_solution_idx = solution_idx; - float new_sum_log_err = current_sum_log_err; - uint64_t new_cost = current_cost; - - for (int j = start_slot; j <= end_slot; ++j) { - float current_bpw = calculate_bpw(solution[j]); - float target_bpw = neighborhood_bpw_avg; - float error = std::get<1>(solution[j]); - float adjustment = 0.125f; - - // Focus on increasing BPW if below min_bpw_limit - if (current_bpw < min_bpw_limit) { - // Search for a higher BPW option - for (int n = 0; n < slots[j].size(); ++n) { - auto new_option = slots[j][n]; - if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) - { - if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) - { - new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); - new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); - new_solution[j] = new_option; - new_solution_idx[j] = n; + // Acceptance criterion with a small probability of accepting worse solutions + if (new_cost <= max_cost) { + float delta_sum_log_err = new_sum_log_err - current_sum_log_err; + if (delta_sum_log_err < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err / local_temp)) { + bool bpw_within_limit = true; + for (int j = 0; j < num_slots; ++j) { + if (calculate_bpw(new_solution[j]) < min_bpw_limit) { + bpw_within_limit = false; break; } } + if (bpw_within_limit) + { + solution = new_solution; + solution_idx = new_solution_idx; + current_sum_log_err = new_sum_log_err; + current_cost = new_cost; + current_max_exp_error = new_max_exp_error; + + if (current_sum_log_err < best_sum_log_err) { + best_sum_log_err = current_sum_log_err; + best_solution = solution; + best_solution_idx = solution_idx; + } + } } - } else { - // Error-weighted adjustment (less aggressive if error is already low) - float error_weight = std::max(0.0f, error - min_exp_error_threshold); + } + } else { + // STAGE 2: Focus on maintaining max(err) and min(bpw) while trying to improve sum(log(err)) + // Select a neighborhood of slots + int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); + int neighborhood_size = std::min(5, num_slots); // Smaller neighborhood size + int start_slot = std::max(0, center_slot - neighborhood_size / 2); + int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2); + + // Calculate average BPW in the neighborhood + float neighborhood_bpw_sum = 0; + for (int j = start_slot; j <= end_slot; ++j) { + neighborhood_bpw_sum += calculate_bpw(solution[j]); + } + float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1); + // Adjust BPWs within the neighborhood + std::vector> new_solution = solution; + std::vector new_solution_idx = solution_idx; + float new_sum_log_err = current_sum_log_err; + uint64_t new_cost = current_cost; + + for (int j = start_slot; j <= end_slot; ++j) { + float current_bpw = calculate_bpw(solution[j]); + float target_bpw = neighborhood_bpw_avg; + float error = std::get<1>(solution[j]); + + // More cautious BPW adjustment + float adjustment = std::min(0.25f * std::abs(current_bpw - target_bpw) * error, 0.125f); + + // Adjust BPW towards the target, weighted by error if (current_bpw < target_bpw) { // Search for a higher BPW option for (int n = 0; n < slots[j].size(); ++n) { auto new_option = slots[j][n]; if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) - { + { if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) { new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); - new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); + new_sum_log_err = new_sum_log_err - log(std::max(0.001f, std::get<1>(solution[j]))) + log(std::max(0.001f, std::get<1>(new_option))); new_solution[j] = new_option; new_solution_idx[j] = n; break; @@ -510,11 +496,11 @@ std::tuple>, std::vector, float, ui for (int n = 0; n < slots[j].size(); ++n) { auto new_option = slots[j][n]; if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) - { + { if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) { new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); - new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); + new_sum_log_err = new_sum_log_err - log(std::max(0.001f, std::get<1>(solution[j]))) + log(std::max(0.001f, std::get<1>(new_option))); new_solution[j] = new_option; new_solution_idx[j] = n; break; @@ -523,34 +509,38 @@ std::tuple>, std::vector, float, ui } } } - } - // Calculate new max exp error - float new_max_exp_error = 0; - for (int j = 0; j < num_slots; ++j) { - new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); - } + // Calculate new max exp error + float new_max_exp_error = 0; + for (int j = 0; j < num_slots; ++j) { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); + } - // Acceptance criterion (more emphasis on max error and min BPW) - auto [new_bpw_mean, new_bpw_stddev] = calculate_bpw_stats(new_solution); - if (new_cost <= max_cost && new_bpw_mean >= min_bpw_limit) { - float delta_sum_log_err = new_sum_log_err - current_sum_log_err; - float delta_max_exp_error = new_max_exp_error - current_max_exp_error; - - // Prioritize reducing max error and increasing min BPW - if ((delta_max_exp_error < 0 || (delta_max_exp_error == 0 && delta_sum_log_err < 0)) || - std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-(delta_sum_log_err + delta_max_exp_error * 100.0f) / local_temp)) { - solution = new_solution; - solution_idx = new_solution_idx; - current_sum_log_err = new_sum_log_err; - current_cost = new_cost; - current_max_exp_error = new_max_exp_error; - - if (current_sum_log_err < best_sum_log_err) { - best_sum_log_err = current_sum_log_err; - best_solution = solution; - best_solution_idx = solution_idx; - best_max_exp_error = current_max_exp_error; + // Acceptance criterion with a focus on maintaining max_err and min_bpw + if (new_cost <= max_cost && new_max_exp_error <= current_max_exp_error * 1.05f) { + bool bpw_within_limit = true; + for (int j = 0; j < num_slots; ++j) { + if (calculate_bpw(new_solution[j]) < min_bpw_limit) { + bpw_within_limit = false; + break; + } + } + if (bpw_within_limit) + { + float delta_sum_log_err = new_sum_log_err - current_sum_log_err; + if (delta_sum_log_err < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err / (local_temp * 0.1f))) { + solution = new_solution; + solution_idx = new_solution_idx; + current_sum_log_err = new_sum_log_err; + current_cost = new_cost; + current_max_exp_error = new_max_exp_error; + + if (current_sum_log_err < best_sum_log_err) { + best_sum_log_err = current_sum_log_err; + best_solution = solution; + best_solution_idx = solution_idx; + } + } } } } @@ -564,7 +554,7 @@ std::tuple>, std::vector, float, ui for (int i = 0; i < num_slots; ++i) { error_indices[i] = {std::get<1>(solution[i]), i}; } - std::sort(error_indices.begin(), error_indices.end()); + std::sort(error_indices.begin(), error_indices.end()); // Sort by error (ascending) for (const auto& pair : error_indices) { int i = pair.second; From 21a4d9c405366fe1efbefb63708339f8f1692b33 Mon Sep 17 00:00:00 2001 From: imoc Date: Tue, 7 Jan 2025 21:19:52 +0800 Subject: [PATCH 08/17] improvement v3-4 -- sum(log(err)): -840.398832 -- max(err): 0.006020 +1: try to avoid <4 bpw layer(rng) --- exllamav2/exllamav2_ext/ext_quant.cpp | 380 +++++++++++++------------- 1 file changed, 195 insertions(+), 185 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index ba83813b..ef1de72a 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -175,11 +175,12 @@ std::tuple>, std::vector, float, ui { // --- Internal Parameters --- const int redistribution_iterations = 25; - const float bpw_penalty_scale = 0.01f; + const float bpw_penalty_scale = 0.005f; const float min_bpw_limit = 2.0f; - const int opportunistic_iterations = 10000; - const float opportunistic_temp = 0.01f; - const float error_threshold = 0.005f; // Threshold to switch to Stage 2 optimization + const int stage1_iterations = 5000; + const int stage2_iterations = 10000; + const float bpw_transfer_step = 0.0625f; + const float low_error_threshold = 0.001f; // Threshold for applying dampening factor // --- Original Simulated Annealing --- int num_slots = slots.size(); @@ -336,6 +337,8 @@ std::tuple>, std::vector, float, ui } // --- Multi-Stage Opportunistic Optimization --- + + // Stage 1: Focus on minimizing sum(log(err)) with a tolerance for max_err float current_sum_log_err = 0; for (int i = 0; i < num_slots; ++i) { current_sum_log_err += log(std::get<1>(solution[i])); @@ -344,206 +347,215 @@ std::tuple>, std::vector, float, ui float best_sum_log_err = current_sum_log_err; std::vector> best_solution = solution; std::vector best_solution_idx = solution_idx; - - float local_temp = opportunistic_temp; - for (int i = 0; i < opportunistic_iterations; ++i) { - // STAGE 1: Focus on minimizing sum(log(err)) while keeping max(err) reasonable - if (current_max_exp_error > error_threshold) { - // Select a neighborhood of slots - int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); - int neighborhood_size = std::min(7, num_slots); // Larger neighborhood size - int start_slot = std::max(0, center_slot - neighborhood_size / 2); - int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2); - - // Calculate average BPW in the neighborhood - float neighborhood_bpw_sum = 0; - for (int j = start_slot; j <= end_slot; ++j) { - neighborhood_bpw_sum += calculate_bpw(solution[j]); - } - float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1); - - // Adjust BPWs within the neighborhood - std::vector> new_solution = solution; - std::vector new_solution_idx = solution_idx; - float new_sum_log_err = current_sum_log_err; - uint64_t new_cost = current_cost; - - for (int j = start_slot; j <= end_slot; ++j) { - float current_bpw = calculate_bpw(solution[j]); - float target_bpw = neighborhood_bpw_avg; - float error = std::get<1>(solution[j]); - - // Dynamic BPW adjustment based on error and BPW difference - float adjustment = std::min(0.5f * std::abs(current_bpw - target_bpw) * error, 0.25f); - - // Adjust BPW towards the target, weighted by error - if (current_bpw < target_bpw) { - // Search for a higher BPW option - for (int n = 0; n < slots[j].size(); ++n) { - auto new_option = slots[j][n]; - if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) - { - if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) - { - new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); - new_sum_log_err = new_sum_log_err - log(std::max(0.001f, std::get<1>(solution[j]))) + log(std::max(0.001f, std::get<1>(new_option))); - new_solution[j] = new_option; - new_solution_idx[j] = n; - break; - } + float stage1_max_exp_error = current_max_exp_error; + + float local_temp = initial_temp; + for (int i = 0; i < stage1_iterations; ++i) { + + // Select a neighborhood of slots + int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); + int neighborhood_size = std::min(5, num_slots); // Example neighborhood size + int start_slot = std::max(0, center_slot - neighborhood_size / 2); + int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2); + + // Calculate average BPW in the neighborhood + float neighborhood_bpw_sum = 0; + for (int j = start_slot; j <= end_slot; ++j) { + neighborhood_bpw_sum += calculate_bpw(solution[j]); + } + float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1); + + // Adjust BPWs within the neighborhood + std::vector> new_solution = solution; + std::vector new_solution_idx = solution_idx; + float new_sum_log_err = current_sum_log_err; + uint64_t new_cost = current_cost; + + for (int j = start_slot; j <= end_slot; ++j) { + float current_bpw = calculate_bpw(solution[j]); + float target_bpw = neighborhood_bpw_avg; + float error = std::get<1>(solution[j]); + float adjustment = bpw_transfer_step; + + // Adjust BPW towards the target, weighted by error + if (current_bpw < target_bpw) { + // Search for a higher BPW option + for (int n = 0; n < slots[j].size(); ++n) { + auto new_option = slots[j][n]; + if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) + { + if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) + { + + new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); + new_solution[j] = new_option; + new_solution_idx[j] = n; + break; } } - } else if (current_bpw > target_bpw) { - // Search for a lower BPW option - for (int n = 0; n < slots[j].size(); ++n) { - auto new_option = slots[j][n]; - if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) - { - if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) - { - new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); - new_sum_log_err = new_sum_log_err - log(std::max(0.001f, std::get<1>(solution[j]))) + log(std::max(0.001f, std::get<1>(new_option))); - new_solution[j] = new_option; - new_solution_idx[j] = n; - break; - } + } + } else if (current_bpw > target_bpw) { + // Search for a lower BPW option + for (int n = 0; n < slots[j].size(); ++n) { + auto new_option = slots[j][n]; + if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) + { + if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) + { + new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); + new_solution[j] = new_option; + new_solution_idx[j] = n; + break; } } } } + } + + // Calculate new max exp error + float new_max_exp_error = 0; + for (int j = 0; j < num_slots; ++j) { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); + } + + // Dampen the influence of very low errors + float new_sum_log_err_dampened = 0; + for (int j = 0; j < num_slots; ++j) { + float error = std::get<1>(new_solution[j]); + float dampening_factor = (error < low_error_threshold) ? (error / low_error_threshold) : 1.0f; + new_sum_log_err_dampened += log(error) * dampening_factor; + } - // Calculate new max exp error - float new_max_exp_error = 0; - for (int j = 0; j < num_slots; ++j) { - new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); + // Acceptance criterion with tolerance for max_err increase + if (new_cost <= max_cost && calculate_bpw(new_solution[i]) >= min_bpw_limit) + { + float delta_sum_log_err = new_sum_log_err_dampened - current_sum_log_err; + if (delta_sum_log_err < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err / local_temp)) { + solution = new_solution; + solution_idx = new_solution_idx; + current_sum_log_err = new_sum_log_err_dampened; // Use dampened sum for acceptance + current_cost = new_cost; + current_max_exp_error = new_max_exp_error; + + if (current_sum_log_err < best_sum_log_err) { + best_sum_log_err = current_sum_log_err; + best_solution = solution; + best_solution_idx = solution_idx; + stage1_max_exp_error = current_max_exp_error; // Update stage1_max_exp_error + } } + } + + local_temp *= 0.95f; + } - // Acceptance criterion with a small probability of accepting worse solutions - if (new_cost <= max_cost) { - float delta_sum_log_err = new_sum_log_err - current_sum_log_err; - if (delta_sum_log_err < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err / local_temp)) { - bool bpw_within_limit = true; - for (int j = 0; j < num_slots; ++j) { - if (calculate_bpw(new_solution[j]) < min_bpw_limit) { - bpw_within_limit = false; + // Stage 2: Focus on reducing max_err while still considering sum(log(err)) + float max_err_penalty_scale = 1.0f; // Initial penalty scale for max_err + solution = best_solution; + solution_idx = best_solution_idx; + current_sum_log_err = best_sum_log_err; + current_max_exp_error = stage1_max_exp_error; + + for (int i = 0; i < stage2_iterations; ++i) { + // Select a neighborhood of slots + int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); + int neighborhood_size = std::min(5, num_slots); + int start_slot = std::max(0, center_slot - neighborhood_size / 2); + int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2); + + // Calculate average BPW in the neighborhood + float neighborhood_bpw_sum = 0; + for (int j = start_slot; j <= end_slot; ++j) { + neighborhood_bpw_sum += calculate_bpw(solution[j]); + } + float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1); + + // Adjust BPWs within the neighborhood + std::vector> new_solution = solution; + std::vector new_solution_idx = solution_idx; + float new_sum_log_err = current_sum_log_err; + uint64_t new_cost = current_cost; + + for (int j = start_slot; j <= end_slot; ++j) { + float current_bpw = calculate_bpw(solution[j]); + float target_bpw = neighborhood_bpw_avg; + float error = std::get<1>(solution[j]); + float adjustment = bpw_transfer_step; + + // Adjust BPW towards the target, weighted by error + if (current_bpw < target_bpw) { + for (int n = 0; n < slots[j].size(); ++n) { + auto new_option = slots[j][n]; + if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) { + if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) { + new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); + new_solution[j] = new_option; + new_solution_idx[j] = n; break; } } - if (bpw_within_limit) - { - solution = new_solution; - solution_idx = new_solution_idx; - current_sum_log_err = new_sum_log_err; - current_cost = new_cost; - current_max_exp_error = new_max_exp_error; - - if (current_sum_log_err < best_sum_log_err) { - best_sum_log_err = current_sum_log_err; - best_solution = solution; - best_solution_idx = solution_idx; - } - } } - } - } else { - // STAGE 2: Focus on maintaining max(err) and min(bpw) while trying to improve sum(log(err)) - // Select a neighborhood of slots - int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); - int neighborhood_size = std::min(5, num_slots); // Smaller neighborhood size - int start_slot = std::max(0, center_slot - neighborhood_size / 2); - int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2); - - // Calculate average BPW in the neighborhood - float neighborhood_bpw_sum = 0; - for (int j = start_slot; j <= end_slot; ++j) { - neighborhood_bpw_sum += calculate_bpw(solution[j]); - } - float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1); - - // Adjust BPWs within the neighborhood - std::vector> new_solution = solution; - std::vector new_solution_idx = solution_idx; - float new_sum_log_err = current_sum_log_err; - uint64_t new_cost = current_cost; - - for (int j = start_slot; j <= end_slot; ++j) { - float current_bpw = calculate_bpw(solution[j]); - float target_bpw = neighborhood_bpw_avg; - float error = std::get<1>(solution[j]); - - // More cautious BPW adjustment - float adjustment = std::min(0.25f * std::abs(current_bpw - target_bpw) * error, 0.125f); - - // Adjust BPW towards the target, weighted by error - if (current_bpw < target_bpw) { - // Search for a higher BPW option - for (int n = 0; n < slots[j].size(); ++n) { - auto new_option = slots[j][n]; - if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) - { - if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) - { - new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); - new_sum_log_err = new_sum_log_err - log(std::max(0.001f, std::get<1>(solution[j]))) + log(std::max(0.001f, std::get<1>(new_option))); - new_solution[j] = new_option; - new_solution_idx[j] = n; - break; - } - } - } - } else if (current_bpw > target_bpw) { - // Search for a lower BPW option - for (int n = 0; n < slots[j].size(); ++n) { - auto new_option = slots[j][n]; - if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) - { - if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) - { - new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); - new_sum_log_err = new_sum_log_err - log(std::max(0.001f, std::get<1>(solution[j]))) + log(std::max(0.001f, std::get<1>(new_option))); - new_solution[j] = new_option; - new_solution_idx[j] = n; - break; - } + } else if (current_bpw > target_bpw) { + for (int n = 0; n < slots[j].size(); ++n) { + auto new_option = slots[j][n]; + if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) { + if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) { + new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); + new_solution[j] = new_option; + new_solution_idx[j] = n; + break; } } } } + } - // Calculate new max exp error - float new_max_exp_error = 0; - for (int j = 0; j < num_slots; ++j) { - new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); - } + // Calculate new max exp error + float new_max_exp_error = 0; + for (int j = 0; j < num_slots; ++j) { + new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); + } - // Acceptance criterion with a focus on maintaining max_err and min_bpw - if (new_cost <= max_cost && new_max_exp_error <= current_max_exp_error * 1.05f) { - bool bpw_within_limit = true; - for (int j = 0; j < num_slots; ++j) { - if (calculate_bpw(new_solution[j]) < min_bpw_limit) { - bpw_within_limit = false; - break; - } - } - if (bpw_within_limit) - { - float delta_sum_log_err = new_sum_log_err - current_sum_log_err; - if (delta_sum_log_err < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err / (local_temp * 0.1f))) { - solution = new_solution; - solution_idx = new_solution_idx; - current_sum_log_err = new_sum_log_err; - current_cost = new_cost; - current_max_exp_error = new_max_exp_error; + // Dynamic penalty for max_err + float max_err_penalty = max_err_penalty_scale * new_max_exp_error; - if (current_sum_log_err < best_sum_log_err) { - best_sum_log_err = current_sum_log_err; - best_solution = solution; - best_solution_idx = solution_idx; - } - } + // Dampen the influence of very low errors + float new_sum_log_err_dampened = 0; + for (int j = 0; j < num_slots; ++j) { + float error = std::get<1>(new_solution[j]); + float dampening_factor = (error < low_error_threshold) ? (error / low_error_threshold) : 1.0f; + new_sum_log_err_dampened += log(error) * dampening_factor; + } + + // Acceptance criterion considering both sum(log(err)) and max_err + if (new_cost <= max_cost && calculate_bpw(new_solution[i]) >= min_bpw_limit) + { + float delta_err = (new_sum_log_err_dampened + max_err_penalty) - (current_sum_log_err + max_err_penalty_scale * current_max_exp_error); + if (delta_err < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_err / local_temp)) { + solution = new_solution; + solution_idx = new_solution_idx; + current_sum_log_err = new_sum_log_err_dampened; + current_cost = new_cost; + current_max_exp_error = new_max_exp_error; + + if (current_sum_log_err < best_sum_log_err) { + best_sum_log_err = current_sum_log_err; + best_solution = solution; + best_solution_idx = solution_idx; } } } + + // Update max_err_penalty_scale based on current_max_exp_error + if (current_max_exp_error < low_error_threshold) { + max_err_penalty_scale = std::min(1000.0f, max_err_penalty_scale * 1.1f); // Increase penalty if max_err is low + } else { + max_err_penalty_scale = std::max(0.01f, max_err_penalty_scale * 0.9f); // Decrease penalty if max_err is high + } local_temp *= 0.95f; } @@ -559,10 +571,8 @@ std::tuple>, std::vector, float, ui for (const auto& pair : error_indices) { int i = pair.second; for (int n = slots[i].size() - 1; n >= 0; --n) { - if (calculate_bpw(slots[i][n]) < calculate_bpw(solution[i])) - { - if (current_cost - std::get<0>(solution[i]) + std::get<0>(slots[i][n]) <= max_cost) - { + if (calculate_bpw(slots[i][n]) < calculate_bpw(solution[i])) { + if (current_cost - std::get<0>(solution[i]) + std::get<0>(slots[i][n]) <= max_cost) { uint64_t delta_cost = std::get<0>(slots[i][n]) - std::get<0>(solution[i]); current_cost += delta_cost; solution[i] = slots[i][n]; From 8918b24bf6a98fbf8c1fde725800672677df7ca6 Mon Sep 17 00:00:00 2001 From: imoc Date: Tue, 7 Jan 2025 22:35:11 +0800 Subject: [PATCH 09/17] improvement v3-5 -- sum(log(err)): -840.932717 -- max(err): 0.005954 +1: try to avoid <4 bpw layer(rng) --- exllamav2/exllamav2_ext/ext_quant.cpp | 229 ++++++++------------------ 1 file changed, 71 insertions(+), 158 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index ef1de72a..53964f80 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -175,12 +175,11 @@ std::tuple>, std::vector, float, ui { // --- Internal Parameters --- const int redistribution_iterations = 25; - const float bpw_penalty_scale = 0.005f; + const float bpw_penalty_scale = 0.01f; const float min_bpw_limit = 2.0f; - const int stage1_iterations = 5000; - const int stage2_iterations = 10000; - const float bpw_transfer_step = 0.0625f; - const float low_error_threshold = 0.001f; // Threshold for applying dampening factor + const int opportunistic_iterations = 10000; + const float initial_opportunistic_temp = 0.01f; + const float low_error_threshold = 0.0009f; // --- Original Simulated Annealing --- int num_slots = slots.size(); @@ -277,25 +276,35 @@ std::tuple>, std::vector, float, ui for (int low_idx : low_bpw_indices) { if (high_bpw_indices.empty()) break; - int high_idx = high_bpw_indices[std::uniform_int_distribution<>(0, high_bpw_indices.size() - 1)(gen)]; + // Error-weighted selection of high_idx + std::vector high_bpw_errors; + for (int high_idx : high_bpw_indices) { + high_bpw_errors.push_back(std::get<1>(solution[high_idx])); + } + std::discrete_distribution high_idx_dist(high_bpw_errors.begin(), high_bpw_errors.end()); + int high_idx = high_bpw_indices[high_idx_dist(gen)]; + // Find a higher BPW option for the low-BPW slot, with bias towards lower error int best_low_new_idx = -1; float best_low_new_error = 1e10f; for (int n = 0; n < slots[low_idx].size(); ++n) { if (calculate_bpw(slots[low_idx][n]) > calculate_bpw(solution[low_idx])) { - if (std::get<1>(slots[low_idx][n]) < best_low_new_error) { - best_low_new_error = std::get<1>(slots[low_idx][n]); + float error_factor = 1.0f + std::get<1>(slots[low_idx][n]); + if (error_factor * std::get<1>(slots[low_idx][n]) < best_low_new_error) { + best_low_new_error = error_factor * std::get<1>(slots[low_idx][n]); best_low_new_idx = n; } } } + // Find a lower BPW option for the high-BPW slot, with bias towards lower error int best_high_new_idx = -1; float best_high_new_error = 1e10f; for (int n = 0; n < slots[high_idx].size(); ++n) { if (calculate_bpw(slots[high_idx][n]) < calculate_bpw(solution[high_idx])) { - if (std::get<1>(slots[high_idx][n]) < best_high_new_error) { - best_high_new_error = std::get<1>(slots[high_idx][n]); + float error_factor = 1.0f + std::get<1>(slots[high_idx][n]); + if (error_factor * std::get<1>(slots[high_idx][n]) < best_high_new_error) { + best_high_new_error = error_factor * std::get<1>(slots[high_idx][n]); best_high_new_idx = n; } } @@ -336,9 +345,7 @@ std::tuple>, std::vector, float, ui } } - // --- Multi-Stage Opportunistic Optimization --- - - // Stage 1: Focus on minimizing sum(log(err)) with a tolerance for max_err + // --- Opportunistic Optimization with Simulated Annealing --- float current_sum_log_err = 0; for (int i = 0; i < num_slots; ++i) { current_sum_log_err += log(std::get<1>(solution[i])); @@ -347,11 +354,9 @@ std::tuple>, std::vector, float, ui float best_sum_log_err = current_sum_log_err; std::vector> best_solution = solution; std::vector best_solution_idx = solution_idx; - float stage1_max_exp_error = current_max_exp_error; - float local_temp = initial_temp; - for (int i = 0; i < stage1_iterations; ++i) { - + float local_temp = initial_opportunistic_temp; + for (int i = 0; i < opportunistic_iterations; ++i) { // Select a neighborhood of slots int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); int neighborhood_size = std::min(5, num_slots); // Example neighborhood size @@ -374,122 +379,25 @@ std::tuple>, std::vector, float, ui for (int j = start_slot; j <= end_slot; ++j) { float current_bpw = calculate_bpw(solution[j]); float target_bpw = neighborhood_bpw_avg; - float error = std::get<1>(solution[j]); - float adjustment = bpw_transfer_step; - - // Adjust BPW towards the target, weighted by error - if (current_bpw < target_bpw) { - // Search for a higher BPW option - for (int n = 0; n < slots[j].size(); ++n) { - auto new_option = slots[j][n]; - if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) - { - if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) - { - - new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); - new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); - new_solution[j] = new_option; - new_solution_idx[j] = n; - break; - } - } - } - } else if (current_bpw > target_bpw) { - // Search for a lower BPW option - for (int n = 0; n < slots[j].size(); ++n) { - auto new_option = slots[j][n]; - if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) - { - if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) - { - new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); - new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); - new_solution[j] = new_option; - new_solution_idx[j] = n; - break; - } - } - } - } - } - - // Calculate new max exp error - float new_max_exp_error = 0; - for (int j = 0; j < num_slots; ++j) { - new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); - } - // Dampen the influence of very low errors - float new_sum_log_err_dampened = 0; - for (int j = 0; j < num_slots; ++j) { - float error = std::get<1>(new_solution[j]); - float dampening_factor = (error < low_error_threshold) ? (error / low_error_threshold) : 1.0f; - new_sum_log_err_dampened += log(error) * dampening_factor; - } - - // Acceptance criterion with tolerance for max_err increase - if (new_cost <= max_cost && calculate_bpw(new_solution[i]) >= min_bpw_limit) - { - float delta_sum_log_err = new_sum_log_err_dampened - current_sum_log_err; - if (delta_sum_log_err < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err / local_temp)) { - solution = new_solution; - solution_idx = new_solution_idx; - current_sum_log_err = new_sum_log_err_dampened; // Use dampened sum for acceptance - current_cost = new_cost; - current_max_exp_error = new_max_exp_error; - - if (current_sum_log_err < best_sum_log_err) { - best_sum_log_err = current_sum_log_err; - best_solution = solution; - best_solution_idx = solution_idx; - stage1_max_exp_error = current_max_exp_error; // Update stage1_max_exp_error - } + // Error-weighted adjustment + float avg_error = 0; + for (int k = start_slot; k <= end_slot; ++k) { + avg_error += std::get<1>(solution[k]); } - } - - local_temp *= 0.95f; - } - - // Stage 2: Focus on reducing max_err while still considering sum(log(err)) - float max_err_penalty_scale = 1.0f; // Initial penalty scale for max_err - solution = best_solution; - solution_idx = best_solution_idx; - current_sum_log_err = best_sum_log_err; - current_max_exp_error = stage1_max_exp_error; - - for (int i = 0; i < stage2_iterations; ++i) { - // Select a neighborhood of slots - int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); - int neighborhood_size = std::min(5, num_slots); - int start_slot = std::max(0, center_slot - neighborhood_size / 2); - int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2); - - // Calculate average BPW in the neighborhood - float neighborhood_bpw_sum = 0; - for (int j = start_slot; j <= end_slot; ++j) { - neighborhood_bpw_sum += calculate_bpw(solution[j]); - } - float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1); - - // Adjust BPWs within the neighborhood - std::vector> new_solution = solution; - std::vector new_solution_idx = solution_idx; - float new_sum_log_err = current_sum_log_err; - uint64_t new_cost = current_cost; + avg_error /= (end_slot - start_slot + 1); + float error_ratio = std::get<1>(solution[j]) / avg_error; - for (int j = start_slot; j <= end_slot; ++j) { - float current_bpw = calculate_bpw(solution[j]); - float target_bpw = neighborhood_bpw_avg; - float error = std::get<1>(solution[j]); - float adjustment = bpw_transfer_step; + float adjustment = 0.125f; // Adjust BPW towards the target, weighted by error if (current_bpw < target_bpw) { + // Search for a higher BPW option for (int n = 0; n < slots[j].size(); ++n) { auto new_option = slots[j][n]; if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) { - if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) { + if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) + { new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); new_solution[j] = new_option; @@ -499,10 +407,12 @@ std::tuple>, std::vector, float, ui } } } else if (current_bpw > target_bpw) { + // Search for a lower BPW option for (int n = 0; n < slots[j].size(); ++n) { auto new_option = slots[j][n]; if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) { - if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) { + if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) + { new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); new_solution[j] = new_option; @@ -520,59 +430,62 @@ std::tuple>, std::vector, float, ui new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); } - // Dynamic penalty for max_err - float max_err_penalty = max_err_penalty_scale * new_max_exp_error; + // Acceptance criterion with a small probability of accepting worse solutions + bool accept = false; + float delta_sum_log_err = new_sum_log_err - current_sum_log_err; - // Dampen the influence of very low errors - float new_sum_log_err_dampened = 0; - for (int j = 0; j < num_slots; ++j) { - float error = std::get<1>(new_solution[j]); - float dampening_factor = (error < low_error_threshold) ? (error / low_error_threshold) : 1.0f; - new_sum_log_err_dampened += log(error) * dampening_factor; + // Dampen penalty for low errors + float error_factor = 1.0f; + if (current_max_exp_error < low_error_threshold) { + error_factor = 0.1f; // Reduce the weight of sum_log_err } - // Acceptance criterion considering both sum(log(err)) and max_err - if (new_cost <= max_cost && calculate_bpw(new_solution[i]) >= min_bpw_limit) - { - float delta_err = (new_sum_log_err_dampened + max_err_penalty) - (current_sum_log_err + max_err_penalty_scale * current_max_exp_error); - if (delta_err < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_err / local_temp)) { - solution = new_solution; - solution_idx = new_solution_idx; - current_sum_log_err = new_sum_log_err_dampened; - current_cost = new_cost; - current_max_exp_error = new_max_exp_error; - - if (current_sum_log_err < best_sum_log_err) { - best_sum_log_err = current_sum_log_err; - best_solution = solution; - best_solution_idx = solution_idx; - } + if (new_cost <= max_cost && calculate_bpw(new_solution[i]) >= min_bpw_limit) { + if (delta_sum_log_err * error_factor < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err * error_factor / local_temp)) { + accept = true; + } + // Give high priority to the solution that has high minimum BPW + if (calculate_bpw(new_solution[i]) < min_bpw_limit) { + accept = false; } } - - // Update max_err_penalty_scale based on current_max_exp_error - if (current_max_exp_error < low_error_threshold) { - max_err_penalty_scale = std::min(1000.0f, max_err_penalty_scale * 1.1f); // Increase penalty if max_err is low - } else { - max_err_penalty_scale = std::max(0.01f, max_err_penalty_scale * 0.9f); // Decrease penalty if max_err is high + + if (accept) { + solution = new_solution; + solution_idx = new_solution_idx; + current_sum_log_err = new_sum_log_err; + current_cost = new_cost; + current_max_exp_error = new_max_exp_error; + + if (current_sum_log_err < best_sum_log_err) { + best_sum_log_err = current_sum_log_err; + best_solution = solution; + best_solution_idx = solution_idx; + } } local_temp *= 0.95f; } - // --- Final Cost Correction (if needed) --- + // Use the best solution found during opportunistic optimization + solution = best_solution; + solution_idx = best_solution_idx; + current_sum_log_err = best_sum_log_err; + + // --- Final Cost Check and Rollback (if necessary) --- if (current_cost > max_cost) { std::vector> error_indices(num_slots); for (int i = 0; i < num_slots; ++i) { error_indices[i] = {std::get<1>(solution[i]), i}; } - std::sort(error_indices.begin(), error_indices.end()); // Sort by error (ascending) + std::sort(error_indices.begin(), error_indices.end()); for (const auto& pair : error_indices) { int i = pair.second; for (int n = slots[i].size() - 1; n >= 0; --n) { if (calculate_bpw(slots[i][n]) < calculate_bpw(solution[i])) { - if (current_cost - std::get<0>(solution[i]) + std::get<0>(slots[i][n]) <= max_cost) { + if (current_cost - std::get<0>(solution[i]) + std::get<0>(slots[i][n]) <= max_cost) + { uint64_t delta_cost = std::get<0>(slots[i][n]) - std::get<0>(solution[i]); current_cost += delta_cost; solution[i] = slots[i][n]; From 228262169dd33de9f8622d12ec4b4ea54cc8260c Mon Sep 17 00:00:00 2001 From: imoc Date: Wed, 8 Jan 2025 22:17:21 +0800 Subject: [PATCH 10/17] improvement v3-5-1 72B: -- sum(log(err)): -884.396164 -- max(err): 0.017426 vs OGB: -- sum(log(err)): -842.360744 -- max(err): 0.018692 --- exllamav2/exllamav2_ext/ext_quant.cpp | 123 ++++++++++++++++++++------ 1 file changed, 95 insertions(+), 28 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index 53964f80..670fbd9f 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -175,12 +175,36 @@ std::tuple>, std::vector, float, ui { // --- Internal Parameters --- const int redistribution_iterations = 25; - const float bpw_penalty_scale = 0.01f; - const float min_bpw_limit = 2.0f; + const float bpw_penalty_scale = 0.05f; // Increased penalty + const float min_bpw_base = 2.8f; // Absolute minimum BPW const int opportunistic_iterations = 10000; const float initial_opportunistic_temp = 0.01f; const float low_error_threshold = 0.0009f; + // --- Dynamic Minimum BPW --- + auto calculate_dynamic_min_bpw = [&](float target_bpw, float temp_ratio) { + float scaled_min_bpw = min_bpw_base + 0.5f * (target_bpw - min_bpw_base); + return min_bpw_base + temp_ratio * (scaled_min_bpw - min_bpw_base); + }; + + // --- Calculate BPW --- + auto calculate_bpw = [&](const std::tuple& option) { + return 8.0f * std::get<0>(option) / 1024.0f; + }; + + // --- Calculate BPW stats --- + auto calculate_bpw_stats = [&](const std::vector>& sol) { + int num_slots = sol.size(); + std::vector current_bpws(num_slots); + for (int i = 0; i < num_slots; ++i) { + current_bpws[i] = calculate_bpw(sol[i]); + } + float bpw_mean = std::accumulate(current_bpws.begin(), current_bpws.end(), 0.0f) / num_slots; + float bpw_sq_sum = std::inner_product(current_bpws.begin(), current_bpws.end(), current_bpws.begin(), 0.0f); + float bpw_variance = bpw_sq_sum / num_slots - bpw_mean * bpw_mean; + return std::make_pair(bpw_mean, std::sqrt(std::max(0.0f, bpw_variance))); + }; + // --- Original Simulated Annealing --- int num_slots = slots.size(); @@ -194,6 +218,7 @@ std::tuple>, std::vector, float, ui float temp = initial_temp; int iterations_outer = static_cast(std::log(min_temp / temp) / std::log(cooling_factor)); + float target_bpw = max_cost * 8.0f / 1024.0f / num_slots; for (int i = 0; i < num_slots; ++i) { @@ -204,6 +229,9 @@ std::tuple>, std::vector, float, ui for (int j = 0; j < iterations_outer; ++j) { + float temp_ratio = temp / initial_temp; + float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); + for (int k = 0; k < iterations; ++k) { int i = std::uniform_int_distribution<>(0, num_slots - 1)(gen); @@ -224,10 +252,18 @@ std::tuple>, std::vector, float, ui new_max_exp_error = std::max(current_max_exp_error, std::get<1>(new_option)); } + // BPW Penalty (Dynamic and Temperature-Dependent) + float bpw_new = calculate_bpw(new_option); + float bpw_penalty = 0.0f; + + if (bpw_new < min_bpw_limit) { + bpw_penalty = (min_bpw_limit - bpw_new) * bpw_penalty_scale * (1 + temp_ratio); // Stronger penalty at higher temp + } + if (current_cost + delta_cost <= max_cost || (delta_cost < 0 && current_cost > max_cost)) { - if (delta_e < 0 || - std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_e / temp)) + if (delta_e + bpw_penalty < 0 || + std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-(delta_e + bpw_penalty) / temp)) { solution[i] = new_option; solution_idx[i] = n; @@ -240,22 +276,11 @@ std::tuple>, std::vector, float, ui } // --- Post-processing: Bit Redistribution --- - auto calculate_bpw = [&](const std::tuple& option) { - return 8.0f * std::get<0>(option) / 1024.0f; - }; - - auto calculate_bpw_stats = [&](const std::vector>& sol) { - std::vector current_bpws(num_slots); - for (int i = 0; i < num_slots; ++i) { - current_bpws[i] = calculate_bpw(sol[i]); - } - float bpw_mean = std::accumulate(current_bpws.begin(), current_bpws.end(), 0.0f) / num_slots; - float bpw_sq_sum = std::inner_product(current_bpws.begin(), current_bpws.end(), current_bpws.begin(), 0.0f); - float bpw_variance = bpw_sq_sum / num_slots - bpw_mean * bpw_mean; - return std::make_pair(bpw_mean, std::sqrt(std::max(0.0f, bpw_variance))); - }; for (int r = 0; r < redistribution_iterations; ++r) { + float temp_ratio = temp / initial_temp; + float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); + // Calculate BPW statistics and dynamic bpw_threshold auto [bpw_mean, bpw_stddev] = calculate_bpw_stats(solution); float bpw_threshold = std::max(min_bpw_limit, bpw_mean - 0.5f * bpw_stddev); @@ -357,6 +382,9 @@ std::tuple>, std::vector, float, ui float local_temp = initial_opportunistic_temp; for (int i = 0; i < opportunistic_iterations; ++i) { + float temp_ratio = temp / initial_temp; + float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); + // Select a neighborhood of slots int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); int neighborhood_size = std::min(5, num_slots); // Example neighborhood size @@ -380,7 +408,7 @@ std::tuple>, std::vector, float, ui float current_bpw = calculate_bpw(solution[j]); float target_bpw = neighborhood_bpw_avg; - // Error-weighted adjustment + // Error-weighted adjustment with bias towards higher BPW float avg_error = 0; for (int k = start_slot; k <= end_slot; ++k) { avg_error += std::get<1>(solution[k]); @@ -388,10 +416,10 @@ std::tuple>, std::vector, float, ui avg_error /= (end_slot - start_slot + 1); float error_ratio = std::get<1>(solution[j]) / avg_error; - float adjustment = 0.125f; + float adjustment = 0.25f + 0.25f * error_ratio; // Increased adjustment with bias - // Adjust BPW towards the target, weighted by error - if (current_bpw < target_bpw) { + // Adjust BPW towards the target, weighted by error, with a bias towards higher BPW + if (current_bpw < target_bpw + adjustment) { // Bias towards higher BPW // Search for a higher BPW option for (int n = 0; n < slots[j].size(); ++n) { auto new_option = slots[j][n]; @@ -408,7 +436,7 @@ std::tuple>, std::vector, float, ui } } else if (current_bpw > target_bpw) { // Search for a lower BPW option - for (int n = 0; n < slots[j].size(); ++n) { + for (int n = slots[j].size() - 1; n >= 0; --n) { // Iterate in reverse order auto new_option = slots[j][n]; if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) { if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) @@ -440,13 +468,16 @@ std::tuple>, std::vector, float, ui error_factor = 0.1f; // Reduce the weight of sum_log_err } - if (new_cost <= max_cost && calculate_bpw(new_solution[i]) >= min_bpw_limit) { + if (new_cost <= max_cost) { if (delta_sum_log_err * error_factor < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err * error_factor / local_temp)) { accept = true; - } - // Give high priority to the solution that has high minimum BPW - if (calculate_bpw(new_solution[i]) < min_bpw_limit) { - accept = false; + // Further penalize if below min_bpw_limit + for (int j = 0; j < num_slots; ++j) { + if (calculate_bpw(new_solution[j]) < min_bpw_limit) { + accept = false; + break; + } + } } } @@ -472,6 +503,42 @@ std::tuple>, std::vector, float, ui solution_idx = best_solution_idx; current_sum_log_err = best_sum_log_err; + // --- BPW Smoothing (Post-processing) --- + for (int i = 1; i < num_slots - 1; ++i) { + float current_bpw = calculate_bpw(solution[i]); + float prev_bpw = calculate_bpw(solution[i - 1]); + float next_bpw = calculate_bpw(solution[i + 1]); + float avg_neighbor_bpw = (prev_bpw + next_bpw) / 2.0f; + + if (current_bpw < avg_neighbor_bpw - 0.5f) { // Significant difference + // Find a higher BPW option for the current slot + for (int n = 0; n < slots[i].size(); ++n) { + auto new_option = slots[i][n]; + if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= avg_neighbor_bpw) { + if (current_cost - std::get<0>(solution[i]) + std::get<0>(new_option) <= max_cost) { + // Check if the new option doesn't significantly increase max_err + float new_max_err = 0; + for (int j = 0; j < num_slots; ++j) { + if (j == i) { + new_max_err = std::max(new_max_err, std::get<1>(new_option)); + } else { + new_max_err = std::max(new_max_err, std::get<1>(solution[j])); + } + } + + if (new_max_err < current_max_exp_error * 1.1f) { // Allow a small increase in max_err + current_cost = current_cost - std::get<0>(solution[i]) + std::get<0>(new_option); + solution[i] = new_option; + solution_idx[i] = n; + current_max_exp_error = new_max_err; + break; + } + } + } + } + } + } + // --- Final Cost Check and Rollback (if necessary) --- if (current_cost > max_cost) { std::vector> error_indices(num_slots); From 556d0e49ed111dae671a2c14817879b4ae4b60b9 Mon Sep 17 00:00:00 2001 From: imoc Date: Thu, 9 Jan 2025 09:34:49 +0800 Subject: [PATCH 11/17] improvement v3-5-2 72B.E: -- sum(log(err)): -880.320887 -- max(err): 0.019905 calibration perplexity (quant): 11.1816 --- exllamav2/exllamav2_ext/ext_quant.cpp | 225 ++++++++++++++++++-------- 1 file changed, 154 insertions(+), 71 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index 670fbd9f..c8753fb3 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -175,11 +175,16 @@ std::tuple>, std::vector, float, ui { // --- Internal Parameters --- const int redistribution_iterations = 25; - const float bpw_penalty_scale = 0.05f; // Increased penalty - const float min_bpw_base = 2.8f; // Absolute minimum BPW - const int opportunistic_iterations = 10000; - const float initial_opportunistic_temp = 0.01f; + const float bpw_penalty_scale = 0.1f; // Further increased BPW penalty + const float min_bpw_base = 3.0f; + const int opportunistic_iterations = 15000; // Increased iterations + const float initial_opportunistic_temp = 0.05f; // Higher initial temperature for opportunistic optimization const float low_error_threshold = 0.0009f; + const float targeted_redistribution_bpw_threshold = 3.3f; + const float targeted_redistribution_max_err_increase_initial = 1.2f; // Even more initial tolerance for error increase + const float targeted_redistribution_max_err_increase_final = 1.02f; + const float high_bpw_donor_threshold = 5.0f; + const int num_options_to_explore_per_layer = 3; // Explore multiple higher-bpw options in targeted redistribution // --- Dynamic Minimum BPW --- auto calculate_dynamic_min_bpw = [&](float target_bpw, float temp_ratio) { @@ -216,7 +221,7 @@ std::tuple>, std::vector, float, ui uint64_t current_cost = 0; float current_max_exp_error = 0; - float temp = initial_temp; + float temp = initial_temp * 2; // Higher initial temperature int iterations_outer = static_cast(std::log(min_temp / temp) / std::log(cooling_factor)); float target_bpw = max_cost * 8.0f / 1024.0f / num_slots; @@ -229,7 +234,7 @@ std::tuple>, std::vector, float, ui for (int j = 0; j < iterations_outer; ++j) { - float temp_ratio = temp / initial_temp; + float temp_ratio = temp / (initial_temp * 2); float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); for (int k = 0; k < iterations; ++k) @@ -252,12 +257,12 @@ std::tuple>, std::vector, float, ui new_max_exp_error = std::max(current_max_exp_error, std::get<1>(new_option)); } - // BPW Penalty (Dynamic and Temperature-Dependent) + // BPW Penalty (Dynamic, Temperature-Dependent, and Non-Linear) float bpw_new = calculate_bpw(new_option); float bpw_penalty = 0.0f; - if (bpw_new < min_bpw_limit) { - bpw_penalty = (min_bpw_limit - bpw_new) * bpw_penalty_scale * (1 + temp_ratio); // Stronger penalty at higher temp + bpw_penalty = (min_bpw_limit - bpw_new) * bpw_penalty_scale * (1 + temp_ratio); + bpw_penalty = bpw_penalty * bpw_penalty; // Exponential penalty } if (current_cost + delta_cost <= max_cost || (delta_cost < 0 && current_cost > max_cost)) @@ -278,7 +283,7 @@ std::tuple>, std::vector, float, ui // --- Post-processing: Bit Redistribution --- for (int r = 0; r < redistribution_iterations; ++r) { - float temp_ratio = temp / initial_temp; + float temp_ratio = temp / (initial_temp * 2); float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); // Calculate BPW statistics and dynamic bpw_threshold @@ -379,74 +384,66 @@ std::tuple>, std::vector, float, ui float best_sum_log_err = current_sum_log_err; std::vector> best_solution = solution; std::vector best_solution_idx = solution_idx; - float local_temp = initial_opportunistic_temp; + for (int i = 0; i < opportunistic_iterations; ++i) { - float temp_ratio = temp / initial_temp; + float temp_ratio = temp / (initial_temp * 2); float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); - // Select a neighborhood of slots - int center_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); - int neighborhood_size = std::min(5, num_slots); // Example neighborhood size - int start_slot = std::max(0, center_slot - neighborhood_size / 2); - int end_slot = std::min(num_slots - 1, center_slot + neighborhood_size / 2); + // Select a slot to adjust + int target_slot = std::uniform_int_distribution<>(0, num_slots - 1)(gen); - // Calculate average BPW in the neighborhood - float neighborhood_bpw_sum = 0; - for (int j = start_slot; j <= end_slot; ++j) { - neighborhood_bpw_sum += calculate_bpw(solution[j]); + // Calculate the global average BPW + float global_bpw_sum = 0; + for (int j = 0; j < num_slots; ++j) { + global_bpw_sum += calculate_bpw(solution[j]); } - float neighborhood_bpw_avg = neighborhood_bpw_sum / (end_slot - start_slot + 1); + float global_bpw_avg = global_bpw_sum / num_slots; - // Adjust BPWs within the neighborhood + // Adjust BPW of the target slot std::vector> new_solution = solution; std::vector new_solution_idx = solution_idx; float new_sum_log_err = current_sum_log_err; uint64_t new_cost = current_cost; - for (int j = start_slot; j <= end_slot; ++j) { - float current_bpw = calculate_bpw(solution[j]); - float target_bpw = neighborhood_bpw_avg; + float current_bpw = calculate_bpw(solution[target_slot]); - // Error-weighted adjustment with bias towards higher BPW - float avg_error = 0; - for (int k = start_slot; k <= end_slot; ++k) { - avg_error += std::get<1>(solution[k]); - } - avg_error /= (end_slot - start_slot + 1); - float error_ratio = std::get<1>(solution[j]) / avg_error; - - float adjustment = 0.25f + 0.25f * error_ratio; // Increased adjustment with bias - - // Adjust BPW towards the target, weighted by error, with a bias towards higher BPW - if (current_bpw < target_bpw + adjustment) { // Bias towards higher BPW - // Search for a higher BPW option - for (int n = 0; n < slots[j].size(); ++n) { - auto new_option = slots[j][n]; - if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) { - if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) - { - new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); - new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); - new_solution[j] = new_option; - new_solution_idx[j] = n; - break; - } + // Adjust BPW towards the global average, weighted by error + float avg_error = 0; + for (int k = 0; k < num_slots; ++k) { + avg_error += std::get<1>(solution[k]); + } + avg_error /= num_slots; + float error_ratio = std::get<1>(solution[target_slot]) / avg_error; + + float adjustment = 0.25f + 0.25f * error_ratio; + + // Adjust BPW towards the target, weighted by error, with a bias towards higher BPW + if (current_bpw < global_bpw_avg + adjustment) { + // Search for a higher BPW option + for (int n = 0; n < slots[target_slot].size(); ++n) { + auto new_option = slots[target_slot][n]; + if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) { + if (new_cost - std::get<0>(solution[target_slot]) + std::get<0>(new_option) <= max_cost) { + new_cost = new_cost - std::get<0>(solution[target_slot]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[target_slot])) + log(std::get<1>(new_option)); + new_solution[target_slot] = new_option; + new_solution_idx[target_slot] = n; + break; } } - } else if (current_bpw > target_bpw) { - // Search for a lower BPW option - for (int n = slots[j].size() - 1; n >= 0; --n) { // Iterate in reverse order - auto new_option = slots[j][n]; - if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) { - if (new_cost - std::get<0>(solution[j]) + std::get<0>(new_option) <= max_cost) - { - new_cost = new_cost - std::get<0>(solution[j]) + std::get<0>(new_option); - new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[j])) + log(std::get<1>(new_option)); - new_solution[j] = new_option; - new_solution_idx[j] = n; - break; - } + } + } else if (current_bpw > global_bpw_avg) { + // Search for a lower BPW option + for (int n = slots[target_slot].size() - 1; n >= 0; --n) { + auto new_option = slots[target_slot][n]; + if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) { + if (new_cost - std::get<0>(solution[target_slot]) + std::get<0>(new_option) <= max_cost) { + new_cost = new_cost - std::get<0>(solution[target_slot]) + std::get<0>(new_option); + new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[target_slot])) + log(std::get<1>(new_option)); + new_solution[target_slot] = new_option; + new_solution_idx[target_slot] = n; + break; } } } @@ -465,9 +462,10 @@ std::tuple>, std::vector, float, ui // Dampen penalty for low errors float error_factor = 1.0f; if (current_max_exp_error < low_error_threshold) { - error_factor = 0.1f; // Reduce the weight of sum_log_err + error_factor = 0.1f; } + // Introduce a small probability of accepting a worse solution (simulated annealing-like) if (new_cost <= max_cost) { if (delta_sum_log_err * error_factor < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err * error_factor / local_temp)) { accept = true; @@ -539,20 +537,105 @@ std::tuple>, std::vector, float, ui } } + // --- Targeted Bit Redistribution (Post-processing) --- + for (int iter = 0; iter < num_slots * 2; ++iter) { // Increased passes + // Create a global pool of donor indices + std::vector donor_indices; + for (int j = 0; j < num_slots; ++j) { + if (calculate_bpw(solution[j]) > high_bpw_donor_threshold) { + donor_indices.push_back(j); + } + } + + if (donor_indices.empty()) continue; // Skip if no donors + + for (int i = 0; i < num_slots; ++i) { + float current_bpw = calculate_bpw(solution[i]); + if (current_bpw < targeted_redistribution_bpw_threshold) { + // Randomly select a donor from the global pool + int donor_idx = donor_indices[std::uniform_int_distribution<>(0, donor_indices.size() - 1)(gen)]; + + // Explore multiple higher BPW options for the current slot + std::vector higher_bpw_options; + for (int n = 0; n < slots[i].size(); ++n) { + if (calculate_bpw(slots[i][n]) > current_bpw && calculate_bpw(slots[i][n]) >= targeted_redistribution_bpw_threshold) { + higher_bpw_options.push_back(n); + } + } + + std::shuffle(higher_bpw_options.begin(), higher_bpw_options.end(), gen); + int options_to_explore = std::min((int)higher_bpw_options.size(), num_options_to_explore_per_layer); + + for (int option_idx = 0; option_idx < options_to_explore; ++option_idx) { + int best_new_idx = higher_bpw_options[option_idx]; + auto new_option = slots[i][best_new_idx]; + + // Find a lower BPW option for the donor slot + int best_donor_new_idx = -1; + float best_donor_new_error = 1e10f; + for (int n = 0; n < slots[donor_idx].size(); ++n) { + if (calculate_bpw(slots[donor_idx][n]) < calculate_bpw(solution[donor_idx])) { + float error_factor = 1.0f + std::get<1>(slots[donor_idx][n]); + if (error_factor * std::get<1>(slots[donor_idx][n]) < best_donor_new_error) { + best_donor_new_error = error_factor * std::get<1>(slots[donor_idx][n]); + best_donor_new_idx = n; + } + } + } + + if (best_donor_new_idx != -1) { + auto donor_new_option = slots[donor_idx][best_donor_new_idx]; + uint64_t new_cost = current_cost - std::get<0>(solution[i]) - std::get<0>(solution[donor_idx]) + + std::get<0>(new_option) + std::get<0>(donor_new_option); + + if (new_cost <= max_cost) { + float new_max_err = std::get<1>(new_option); + for (int j = 0; j < num_slots; ++j) { + if (j == i) continue; + if (j == donor_idx) { + new_max_err = std::max(new_max_err, std::get<1>(donor_new_option)); + } else { + new_max_err = std::max(new_max_err, std::get<1>(solution[j])); + } + } + + // Adaptive max_err_increase + float max_err_increase = targeted_redistribution_max_err_increase_initial - + (targeted_redistribution_max_err_increase_initial - targeted_redistribution_max_err_increase_final) * + (static_cast(iter) / (num_slots * 2)); + + if (new_max_err < current_max_exp_error * max_err_increase) { + current_cost = new_cost; + solution[i] = new_option; + solution_idx[i] = best_new_idx; + solution[donor_idx] = donor_new_option; + solution_idx[donor_idx] = best_donor_new_idx; + current_max_exp_error = new_max_err; + break; // Move to the next low-bpw layer after a successful redistribution + } + } + } + } + } + } + } + // --- Final Cost Check and Rollback (if necessary) --- if (current_cost > max_cost) { - std::vector> error_indices(num_slots); + std::vector> bpw_error_indices(num_slots); for (int i = 0; i < num_slots; ++i) { - error_indices[i] = {std::get<1>(solution[i]), i}; + float bpw = calculate_bpw(solution[i]); + float error = std::get<1>(solution[i]); + float penalty = (bpw < targeted_redistribution_bpw_threshold) ? 1000.0f : 0.0f; // High penalty for low BPW + bpw_error_indices[i] = {error + penalty, bpw, i}; } - std::sort(error_indices.begin(), error_indices.end()); + std::sort(bpw_error_indices.begin(), bpw_error_indices.end(), std::greater>()); - for (const auto& pair : error_indices) { - int i = pair.second; + for (const auto& tuple : bpw_error_indices) { + int i = std::get<2>(tuple); for (int n = slots[i].size() - 1; n >= 0; --n) { if (calculate_bpw(slots[i][n]) < calculate_bpw(solution[i])) { - if (current_cost - std::get<0>(solution[i]) + std::get<0>(slots[i][n]) <= max_cost) - { + if (current_cost - std::get<0>(solution[i]) + std::get<0>(slots[i][n]) <= max_cost) { uint64_t delta_cost = std::get<0>(slots[i][n]) - std::get<0>(solution[i]); current_cost += delta_cost; solution[i] = slots[i][n]; From 02cfba80cc5e47f17fff51b2609a908393c72572 Mon Sep 17 00:00:00 2001 From: imoc Date: Thu, 9 Jan 2025 16:04:40 +0800 Subject: [PATCH 12/17] improvement v3-5-3 72B.E: -- sum(log(err)): -878.606786 -- max(err): 0.018033 --- exllamav2/exllamav2_ext/ext_quant.cpp | 44 ++++++++++++--------------- 1 file changed, 20 insertions(+), 24 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index c8753fb3..09c84b85 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -174,17 +174,17 @@ std::tuple>, std::vector, float, ui ) { // --- Internal Parameters --- - const int redistribution_iterations = 25; - const float bpw_penalty_scale = 0.1f; // Further increased BPW penalty - const float min_bpw_base = 3.0f; - const int opportunistic_iterations = 15000; // Increased iterations - const float initial_opportunistic_temp = 0.05f; // Higher initial temperature for opportunistic optimization - const float low_error_threshold = 0.0009f; - const float targeted_redistribution_bpw_threshold = 3.3f; - const float targeted_redistribution_max_err_increase_initial = 1.2f; // Even more initial tolerance for error increase - const float targeted_redistribution_max_err_increase_final = 1.02f; + const int redistribution_iterations = 50; // Increased + const float bpw_penalty_scale = 0.5f; // Further increased BPW penalty + const float min_bpw_base = 3.5f; // Increased baseline + const int opportunistic_iterations = 20000; + const float initial_opportunistic_temp = 0.05f; + const float low_error_threshold = 0.001f; + const float targeted_redistribution_bpw_threshold = 3.5f; // Increased threshold + const float targeted_redistribution_max_err_increase_initial = 1.1f; + const float targeted_redistribution_max_err_increase_final = 1.01f; const float high_bpw_donor_threshold = 5.0f; - const int num_options_to_explore_per_layer = 3; // Explore multiple higher-bpw options in targeted redistribution + const int num_options_to_explore_per_layer = 8; // Explore more options // --- Dynamic Minimum BPW --- auto calculate_dynamic_min_bpw = [&](float target_bpw, float temp_ratio) { @@ -261,8 +261,7 @@ std::tuple>, std::vector, float, ui float bpw_new = calculate_bpw(new_option); float bpw_penalty = 0.0f; if (bpw_new < min_bpw_limit) { - bpw_penalty = (min_bpw_limit - bpw_new) * bpw_penalty_scale * (1 + temp_ratio); - bpw_penalty = bpw_penalty * bpw_penalty; // Exponential penalty + bpw_penalty = powf((min_bpw_limit - bpw_new), 2.0f) * bpw_penalty_scale * (1 + temp_ratio); } if (current_cost + delta_cost <= max_cost || (delta_cost < 0 && current_cost > max_cost)) @@ -359,9 +358,9 @@ std::tuple>, std::vector, float, ui auto [current_bpw_mean, current_bpw_stddev] = calculate_bpw_stats(solution); auto [new_bpw_mean, new_bpw_stddev] = calculate_bpw_stats({new_low_option, new_high_option}); - float bpw_penalty = bpw_penalty_scale * (new_bpw_stddev - current_bpw_stddev); + float bpw_change_penalty = bpw_penalty_scale * std::max(0.0f, new_bpw_stddev - current_bpw_stddev); // Penalize increased variance - if (new_max_exp_error + bpw_penalty < current_max_exp_error) { + if (new_max_exp_error + bpw_change_penalty < current_max_exp_error) { solution[low_idx] = new_low_option; solution_idx[low_idx] = best_low_new_idx; solution[high_idx] = new_high_option; @@ -467,15 +466,12 @@ std::tuple>, std::vector, float, ui // Introduce a small probability of accepting a worse solution (simulated annealing-like) if (new_cost <= max_cost) { - if (delta_sum_log_err * error_factor < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err * error_factor / local_temp)) { + float bpw_penalty_opp = 0.0f; + if (calculate_bpw(new_solution[target_slot]) < min_bpw_limit) { + bpw_penalty_opp = powf((min_bpw_limit - calculate_bpw(new_solution[target_slot])), 2.0f) * bpw_penalty_scale * (1 + temp_ratio); + } + if (delta_sum_log_err * error_factor + bpw_penalty_opp < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-(delta_sum_log_err * error_factor + bpw_penalty_opp) / local_temp)) { accept = true; - // Further penalize if below min_bpw_limit - for (int j = 0; j < num_slots; ++j) { - if (calculate_bpw(new_solution[j]) < min_bpw_limit) { - accept = false; - break; - } - } } } @@ -538,7 +534,7 @@ std::tuple>, std::vector, float, ui } // --- Targeted Bit Redistribution (Post-processing) --- - for (int iter = 0; iter < num_slots * 2; ++iter) { // Increased passes + for (int iter = 0; iter < num_slots * 3; ++iter) { // Create a global pool of donor indices std::vector donor_indices; for (int j = 0; j < num_slots; ++j) { @@ -602,7 +598,7 @@ std::tuple>, std::vector, float, ui // Adaptive max_err_increase float max_err_increase = targeted_redistribution_max_err_increase_initial - (targeted_redistribution_max_err_increase_initial - targeted_redistribution_max_err_increase_final) * - (static_cast(iter) / (num_slots * 2)); + (static_cast(iter) / (num_slots * 3)); if (new_max_err < current_max_exp_error * max_err_increase) { current_cost = new_cost; From a0f0e90917a3cf08f20b58033c7f068fda0687e3 Mon Sep 17 00:00:00 2001 From: imoc Date: Thu, 9 Jan 2025 17:03:31 +0800 Subject: [PATCH 13/17] improvement v3-5-4 -- sum(log(err)): -877.778402 -- max(err): 0.017617 --- exllamav2/exllamav2_ext/ext_quant.cpp | 255 +++++++++++++++----------- 1 file changed, 146 insertions(+), 109 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index 09c84b85..31bd005f 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -173,22 +173,25 @@ std::tuple>, std::vector, float, ui float norm ) { - // --- Internal Parameters --- - const int redistribution_iterations = 50; // Increased - const float bpw_penalty_scale = 0.5f; // Further increased BPW penalty - const float min_bpw_base = 3.5f; // Increased baseline - const int opportunistic_iterations = 20000; - const float initial_opportunistic_temp = 0.05f; + // --- Enhanced Parameters --- + const int redistribution_iterations = 50; // Increased iterations for more thorough redistribution + const float bpw_penalty_scale = 0.5f; // Stronger penalty for low BPW + const float min_bpw_base = 3.5f; // Base minimum BPW + const int opportunistic_iterations = 30000; // More iterations for opportunistic optimization + const float initial_opportunistic_temp = 0.1f; // Higher initial temperature const float low_error_threshold = 0.001f; - const float targeted_redistribution_bpw_threshold = 3.5f; // Increased threshold - const float targeted_redistribution_max_err_increase_initial = 1.1f; - const float targeted_redistribution_max_err_increase_final = 1.01f; - const float high_bpw_donor_threshold = 5.0f; - const int num_options_to_explore_per_layer = 8; // Explore more options + const float error_floor = 0.0001f; // Minimum acceptable error + const float targeted_redistribution_bpw_threshold = 3.5f; // Higher threshold for targeted redistribution + const float targeted_redistribution_max_err_increase_initial = 1.3f; // More initial tolerance for error increase in targeted redistribution + const float targeted_redistribution_max_err_increase_final = 1.05f; // Tighter final tolerance + const float high_bpw_donor_threshold = 5.5f; // Threshold for identifying high-BPW donor layers + const int num_options_to_explore_per_layer = 8; // Explore more options in targeted redistribution + const int bpw_smoothing_passes = 5; // Multiple passes for BPW smoothing + const float bpw_smoothing_threshold = 0.75f; // Larger difference for triggering BPW smoothing // --- Dynamic Minimum BPW --- auto calculate_dynamic_min_bpw = [&](float target_bpw, float temp_ratio) { - float scaled_min_bpw = min_bpw_base + 0.5f * (target_bpw - min_bpw_base); + float scaled_min_bpw = min_bpw_base + 0.75f * (target_bpw - min_bpw_base); return min_bpw_base + temp_ratio * (scaled_min_bpw - min_bpw_base); }; @@ -210,7 +213,7 @@ std::tuple>, std::vector, float, ui return std::make_pair(bpw_mean, std::sqrt(std::max(0.0f, bpw_variance))); }; - // --- Original Simulated Annealing --- + // --- Simulated Annealing --- int num_slots = slots.size(); std::random_device rd; @@ -221,24 +224,34 @@ std::tuple>, std::vector, float, ui uint64_t current_cost = 0; float current_max_exp_error = 0; - float temp = initial_temp * 2; // Higher initial temperature + float temp = initial_temp * 2.5f; // Higher initial temperature int iterations_outer = static_cast(std::log(min_temp / temp) / std::log(cooling_factor)); float target_bpw = max_cost * 8.0f / 1024.0f / num_slots; - for (int i = 0; i < num_slots; ++i) - { - solution[i] = slots[i][0]; - current_cost += std::get<0>(slots[i][0]); - current_max_exp_error = std::max(current_max_exp_error, std::get<1>(slots[i][0])); + // --- Initialization (favor higher BPW options) --- + for (int i = 0; i < num_slots; ++i) { + int best_idx = 0; + float best_bpw_diff = 1e10f; + for (int j = 0; j < slots[i].size(); ++j) { + float bpw = calculate_bpw(slots[i][j]); + if (bpw >= target_bpw) { + float bpw_diff = bpw - target_bpw; + if (bpw_diff < best_bpw_diff) { + best_bpw_diff = bpw_diff; + best_idx = j; + } + } + } + solution[i] = slots[i][best_idx]; + current_cost += std::get<0>(slots[i][best_idx]); + current_max_exp_error = std::max(current_max_exp_error, std::get<1>(slots[i][best_idx])); } - for (int j = 0; j < iterations_outer; ++j) - { - float temp_ratio = temp / (initial_temp * 2); + for (int j = 0; j < iterations_outer; ++j) { + float temp_ratio = temp / (initial_temp * 2.5f); float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); - for (int k = 0; k < iterations; ++k) - { + for (int k = 0; k < iterations; ++k) { int i = std::uniform_int_distribution<>(0, num_slots - 1)(gen); int n = std::uniform_int_distribution<>(0, slots[i].size() - 1)(gen); auto new_option = slots[i][n]; @@ -257,18 +270,17 @@ std::tuple>, std::vector, float, ui new_max_exp_error = std::max(current_max_exp_error, std::get<1>(new_option)); } - // BPW Penalty (Dynamic, Temperature-Dependent, and Non-Linear) + // Enhanced BPW Penalty (Dynamic, Temperature-Dependent, and Non-Linear) float bpw_new = calculate_bpw(new_option); float bpw_penalty = 0.0f; if (bpw_new < min_bpw_limit) { - bpw_penalty = powf((min_bpw_limit - bpw_new), 2.0f) * bpw_penalty_scale * (1 + temp_ratio); + bpw_penalty = (min_bpw_limit - bpw_new) * bpw_penalty_scale * (1 + temp_ratio * 2); // Stronger temperature scaling + bpw_penalty = bpw_penalty * bpw_penalty * bpw_penalty; // Cubic penalty for more aggressive discouragement } - if (current_cost + delta_cost <= max_cost || (delta_cost < 0 && current_cost > max_cost)) - { + if (current_cost + delta_cost <= max_cost || (delta_cost < 0 && current_cost > max_cost)) { if (delta_e + bpw_penalty < 0 || - std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-(delta_e + bpw_penalty) / temp)) - { + std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-(delta_e + bpw_penalty) / temp)) { solution[i] = new_option; solution_idx[i] = n; current_cost += delta_cost; @@ -279,38 +291,36 @@ std::tuple>, std::vector, float, ui temp *= cooling_factor; } - // --- Post-processing: Bit Redistribution --- - + // --- Enhanced Bit Redistribution --- for (int r = 0; r < redistribution_iterations; ++r) { - float temp_ratio = temp / (initial_temp * 2); + float temp_ratio = temp / (initial_temp * 2.5f); float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); // Calculate BPW statistics and dynamic bpw_threshold auto [bpw_mean, bpw_stddev] = calculate_bpw_stats(solution); - float bpw_threshold = std::max(min_bpw_limit, bpw_mean - 0.5f * bpw_stddev); + float bpw_threshold = std::max(min_bpw_limit, bpw_mean - bpw_stddev); // More dynamic threshold std::vector low_bpw_indices; std::vector high_bpw_indices; + std::vector high_bpw_errors; for (int i = 0; i < num_slots; ++i) { float bpw = calculate_bpw(solution[i]); if (bpw < bpw_threshold) { low_bpw_indices.push_back(i); - } else { + } else if (bpw > high_bpw_donor_threshold) { high_bpw_indices.push_back(i); + high_bpw_errors.push_back(std::get<1>(solution[i])); } } + if (high_bpw_indices.empty()) continue; + + // Error-weighted selection of high_idx (donor) with a bias towards lower error + std::discrete_distribution high_idx_dist(high_bpw_errors.begin(), high_bpw_errors.end()); + bool improved = false; for (int low_idx : low_bpw_indices) { - if (high_bpw_indices.empty()) break; - - // Error-weighted selection of high_idx - std::vector high_bpw_errors; - for (int high_idx : high_bpw_indices) { - high_bpw_errors.push_back(std::get<1>(solution[high_idx])); - } - std::discrete_distribution high_idx_dist(high_bpw_errors.begin(), high_bpw_errors.end()); int high_idx = high_bpw_indices[high_idx_dist(gen)]; // Find a higher BPW option for the low-BPW slot, with bias towards lower error @@ -318,9 +328,9 @@ std::tuple>, std::vector, float, ui float best_low_new_error = 1e10f; for (int n = 0; n < slots[low_idx].size(); ++n) { if (calculate_bpw(slots[low_idx][n]) > calculate_bpw(solution[low_idx])) { - float error_factor = 1.0f + std::get<1>(slots[low_idx][n]); - if (error_factor * std::get<1>(slots[low_idx][n]) < best_low_new_error) { - best_low_new_error = error_factor * std::get<1>(slots[low_idx][n]); + if (std::get<1>(slots[low_idx][n]) < best_low_new_error) + { + best_low_new_error = std::get<1>(slots[low_idx][n]); best_low_new_idx = n; } } @@ -331,9 +341,8 @@ std::tuple>, std::vector, float, ui float best_high_new_error = 1e10f; for (int n = 0; n < slots[high_idx].size(); ++n) { if (calculate_bpw(slots[high_idx][n]) < calculate_bpw(solution[high_idx])) { - float error_factor = 1.0f + std::get<1>(slots[high_idx][n]); - if (error_factor * std::get<1>(slots[high_idx][n]) < best_high_new_error) { - best_high_new_error = error_factor * std::get<1>(slots[high_idx][n]); + if (std::get<1>(slots[high_idx][n]) < best_high_new_error) { + best_high_new_error = std::get<1>(slots[high_idx][n]); best_high_new_idx = n; } } @@ -343,7 +352,8 @@ std::tuple>, std::vector, float, ui auto new_low_option = slots[low_idx][best_low_new_idx]; auto new_high_option = slots[high_idx][best_high_new_idx]; - uint64_t new_cost = current_cost - std::get<0>(solution[low_idx]) - std::get<0>(solution[high_idx]) + std::get<0>(new_low_option) + std::get<0>(new_high_option); + uint64_t new_cost = current_cost - std::get<0>(solution[low_idx]) - std::get<0>(solution[high_idx]) + + std::get<0>(new_low_option) + std::get<0>(new_high_option); if (new_cost <= max_cost) { float new_max_exp_error = std::get<1>(new_low_option); @@ -356,11 +366,14 @@ std::tuple>, std::vector, float, ui } } + // Consider error floor + if (std::get<1>(new_low_option) < error_floor || std::get<1>(new_high_option) < error_floor) continue; + auto [current_bpw_mean, current_bpw_stddev] = calculate_bpw_stats(solution); auto [new_bpw_mean, new_bpw_stddev] = calculate_bpw_stats({new_low_option, new_high_option}); - float bpw_change_penalty = bpw_penalty_scale * std::max(0.0f, new_bpw_stddev - current_bpw_stddev); // Penalize increased variance + float bpw_penalty = bpw_penalty_scale * (new_bpw_stddev - current_bpw_stddev) * (1 + temp_ratio); - if (new_max_exp_error + bpw_change_penalty < current_max_exp_error) { + if (new_max_exp_error + bpw_penalty < current_max_exp_error) { solution[low_idx] = new_low_option; solution_idx[low_idx] = best_low_new_idx; solution[high_idx] = new_high_option; @@ -374,7 +387,7 @@ std::tuple>, std::vector, float, ui } } - // --- Opportunistic Optimization with Simulated Annealing --- + // --- Enhanced Opportunistic Optimization with Simulated Annealing --- float current_sum_log_err = 0; for (int i = 0; i < num_slots; ++i) { current_sum_log_err += log(std::get<1>(solution[i])); @@ -386,7 +399,7 @@ std::tuple>, std::vector, float, ui float local_temp = initial_opportunistic_temp; for (int i = 0; i < opportunistic_iterations; ++i) { - float temp_ratio = temp / (initial_temp * 2); + float temp_ratio = temp / (initial_temp * 2.5f); float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); // Select a slot to adjust @@ -407,23 +420,28 @@ std::tuple>, std::vector, float, ui float current_bpw = calculate_bpw(solution[target_slot]); - // Adjust BPW towards the global average, weighted by error + // Calculate average error float avg_error = 0; for (int k = 0; k < num_slots; ++k) { avg_error += std::get<1>(solution[k]); } avg_error /= num_slots; + + // Calculate error ratio for the target slot float error_ratio = std::get<1>(solution[target_slot]) / avg_error; - float adjustment = 0.25f + 0.25f * error_ratio; + // Enhanced adjustment factor, more sensitive to error ratio + float adjustment = 0.5f + 0.5f * error_ratio; // Adjust BPW towards the target, weighted by error, with a bias towards higher BPW if (current_bpw < global_bpw_avg + adjustment) { // Search for a higher BPW option for (int n = 0; n < slots[target_slot].size(); ++n) { auto new_option = slots[target_slot][n]; - if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= current_bpw + adjustment) { + float new_option_bpw = calculate_bpw(new_option); + if (new_option_bpw > current_bpw && new_option_bpw <= current_bpw + adjustment) { if (new_cost - std::get<0>(solution[target_slot]) + std::get<0>(new_option) <= max_cost) { + if (std::get<1>(new_option) < error_floor) continue; new_cost = new_cost - std::get<0>(solution[target_slot]) + std::get<0>(new_option); new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[target_slot])) + log(std::get<1>(new_option)); new_solution[target_slot] = new_option; @@ -436,8 +454,10 @@ std::tuple>, std::vector, float, ui // Search for a lower BPW option for (int n = slots[target_slot].size() - 1; n >= 0; --n) { auto new_option = slots[target_slot][n]; - if (calculate_bpw(new_option) < current_bpw && calculate_bpw(new_option) >= current_bpw - adjustment) { + float new_option_bpw = calculate_bpw(new_option); + if (new_option_bpw < current_bpw && new_option_bpw >= current_bpw - adjustment) { if (new_cost - std::get<0>(solution[target_slot]) + std::get<0>(new_option) <= max_cost) { + if (std::get<1>(new_option) < error_floor) continue; new_cost = new_cost - std::get<0>(solution[target_slot]) + std::get<0>(new_option); new_sum_log_err = new_sum_log_err - log(std::get<1>(solution[target_slot])) + log(std::get<1>(new_option)); new_solution[target_slot] = new_option; @@ -454,24 +474,27 @@ std::tuple>, std::vector, float, ui new_max_exp_error = std::max(new_max_exp_error, std::get<1>(new_solution[j])); } - // Acceptance criterion with a small probability of accepting worse solutions + // Acceptance criterion with error equalization focus bool accept = false; float delta_sum_log_err = new_sum_log_err - current_sum_log_err; - // Dampen penalty for low errors + // Dampen penalty for low errors, but less aggressively float error_factor = 1.0f; if (current_max_exp_error < low_error_threshold) { - error_factor = 0.1f; + error_factor = 0.25f; // Less dampening } - // Introduce a small probability of accepting a worse solution (simulated annealing-like) + // Introduce a probability of accepting a worse solution (simulated annealing-like) if (new_cost <= max_cost) { - float bpw_penalty_opp = 0.0f; - if (calculate_bpw(new_solution[target_slot]) < min_bpw_limit) { - bpw_penalty_opp = powf((min_bpw_limit - calculate_bpw(new_solution[target_slot])), 2.0f) * bpw_penalty_scale * (1 + temp_ratio); - } - if (delta_sum_log_err * error_factor + bpw_penalty_opp < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-(delta_sum_log_err * error_factor + bpw_penalty_opp) / local_temp)) { + if (delta_sum_log_err * error_factor < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err * error_factor / local_temp)) { accept = true; + // Further penalize if below min_bpw_limit, more aggressively + for (int j = 0; j < num_slots; ++j) { + if (calculate_bpw(new_solution[j]) < min_bpw_limit) { + accept = false; + break; + } + } } } @@ -489,7 +512,7 @@ std::tuple>, std::vector, float, ui } } - local_temp *= 0.95f; + local_temp *= 0.95f; // Faster cooling } // Use the best solution found during opportunistic optimization @@ -497,35 +520,38 @@ std::tuple>, std::vector, float, ui solution_idx = best_solution_idx; current_sum_log_err = best_sum_log_err; - // --- BPW Smoothing (Post-processing) --- - for (int i = 1; i < num_slots - 1; ++i) { - float current_bpw = calculate_bpw(solution[i]); - float prev_bpw = calculate_bpw(solution[i - 1]); - float next_bpw = calculate_bpw(solution[i + 1]); - float avg_neighbor_bpw = (prev_bpw + next_bpw) / 2.0f; - - if (current_bpw < avg_neighbor_bpw - 0.5f) { // Significant difference - // Find a higher BPW option for the current slot - for (int n = 0; n < slots[i].size(); ++n) { - auto new_option = slots[i][n]; - if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= avg_neighbor_bpw) { - if (current_cost - std::get<0>(solution[i]) + std::get<0>(new_option) <= max_cost) { - // Check if the new option doesn't significantly increase max_err - float new_max_err = 0; - for (int j = 0; j < num_slots; ++j) { - if (j == i) { - new_max_err = std::max(new_max_err, std::get<1>(new_option)); - } else { - new_max_err = std::max(new_max_err, std::get<1>(solution[j])); + // --- Enhanced BPW Smoothing (Post-processing) --- + for (int pass = 0; pass < bpw_smoothing_passes; ++pass) { + for (int i = 1; i < num_slots - 1; ++i) { + float current_bpw = calculate_bpw(solution[i]); + float prev_bpw = calculate_bpw(solution[i - 1]); + float next_bpw = calculate_bpw(solution[i + 1]); + float avg_neighbor_bpw = (prev_bpw + next_bpw) / 2.0f; + + if (current_bpw < avg_neighbor_bpw - bpw_smoothing_threshold) { // Larger difference + // Find a higher BPW option for the current slot + for (int n = 0; n < slots[i].size(); ++n) { + auto new_option = slots[i][n]; + if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= avg_neighbor_bpw) { + if (current_cost - std::get<0>(solution[i]) + std::get<0>(new_option) <= max_cost) { + // Check if the new option doesn't significantly increase max_err and is not below error floor + if (std::get<1>(new_option) < error_floor) continue; + float new_max_err = 0; + for (int j = 0; j < num_slots; ++j) { + if (j == i) { + new_max_err = std::max(new_max_err, std::get<1>(new_option)); + } else { + new_max_err = std::max(new_max_err, std::get<1>(solution[j])); + } } - } - if (new_max_err < current_max_exp_error * 1.1f) { // Allow a small increase in max_err - current_cost = current_cost - std::get<0>(solution[i]) + std::get<0>(new_option); - solution[i] = new_option; - solution_idx[i] = n; - current_max_exp_error = new_max_err; - break; + if (new_max_err < current_max_exp_error * 1.2f) { // Allow a larger increase in max_err + current_cost = current_cost - std::get<0>(solution[i]) + std::get<0>(new_option); + solution[i] = new_option; + solution_idx[i] = n; + current_max_exp_error = new_max_err; + break; + } } } } @@ -533,28 +559,33 @@ std::tuple>, std::vector, float, ui } } - // --- Targeted Bit Redistribution (Post-processing) --- - for (int iter = 0; iter < num_slots * 3; ++iter) { - // Create a global pool of donor indices + // --- Enhanced Targeted Bit Redistribution (Post-processing) --- + for (int iter = 0; iter < num_slots * 3; ++iter) { // Increased passes + // Create a global pool of donor indices, considering both high BPW and relatively low error std::vector donor_indices; + std::vector donor_errors; for (int j = 0; j < num_slots; ++j) { - if (calculate_bpw(solution[j]) > high_bpw_donor_threshold) { + if (calculate_bpw(solution[j]) > high_bpw_donor_threshold && std::get<1>(solution[j]) < low_error_threshold) { donor_indices.push_back(j); + donor_errors.push_back(std::get<1>(solution[j])); } } - if (donor_indices.empty()) continue; // Skip if no donors + if (donor_indices.empty()) continue; // Skip if no suitable donors + + // Error-weighted selection of donor + std::discrete_distribution donor_dist(donor_errors.begin(), donor_errors.end()); for (int i = 0; i < num_slots; ++i) { float current_bpw = calculate_bpw(solution[i]); if (current_bpw < targeted_redistribution_bpw_threshold) { - // Randomly select a donor from the global pool - int donor_idx = donor_indices[std::uniform_int_distribution<>(0, donor_indices.size() - 1)(gen)]; + // Randomly select a donor from the global pool, weighted by error + int donor_idx = donor_indices[donor_dist(gen)]; // Explore multiple higher BPW options for the current slot std::vector higher_bpw_options; for (int n = 0; n < slots[i].size(); ++n) { - if (calculate_bpw(slots[i][n]) > current_bpw && calculate_bpw(slots[i][n]) >= targeted_redistribution_bpw_threshold) { + if (calculate_bpw(slots[i][n]) > current_bpw) { higher_bpw_options.push_back(n); } } @@ -566,14 +597,16 @@ std::tuple>, std::vector, float, ui int best_new_idx = higher_bpw_options[option_idx]; auto new_option = slots[i][best_new_idx]; - // Find a lower BPW option for the donor slot + // Consider error floor + if (std::get<1>(new_option) < error_floor) continue; + + // Find a lower BPW option for the donor slot, with bias towards lower error int best_donor_new_idx = -1; float best_donor_new_error = 1e10f; for (int n = 0; n < slots[donor_idx].size(); ++n) { if (calculate_bpw(slots[donor_idx][n]) < calculate_bpw(solution[donor_idx])) { - float error_factor = 1.0f + std::get<1>(slots[donor_idx][n]); - if (error_factor * std::get<1>(slots[donor_idx][n]) < best_donor_new_error) { - best_donor_new_error = error_factor * std::get<1>(slots[donor_idx][n]); + if (std::get<1>(slots[donor_idx][n]) < best_donor_new_error) { + best_donor_new_error = std::get<1>(slots[donor_idx][n]); best_donor_new_idx = n; } } @@ -581,6 +614,10 @@ std::tuple>, std::vector, float, ui if (best_donor_new_idx != -1) { auto donor_new_option = slots[donor_idx][best_donor_new_idx]; + + // Consider error floor + if (std::get<1>(donor_new_option) < error_floor) continue; + uint64_t new_cost = current_cost - std::get<0>(solution[i]) - std::get<0>(solution[donor_idx]) + std::get<0>(new_option) + std::get<0>(donor_new_option); @@ -595,7 +632,7 @@ std::tuple>, std::vector, float, ui } } - // Adaptive max_err_increase + // Adaptive max_err_increase based on iteration number float max_err_increase = targeted_redistribution_max_err_increase_initial - (targeted_redistribution_max_err_increase_initial - targeted_redistribution_max_err_increase_final) * (static_cast(iter) / (num_slots * 3)); @@ -631,7 +668,7 @@ std::tuple>, std::vector, float, ui int i = std::get<2>(tuple); for (int n = slots[i].size() - 1; n >= 0; --n) { if (calculate_bpw(slots[i][n]) < calculate_bpw(solution[i])) { - if (current_cost - std::get<0>(solution[i]) + std::get<0>(slots[i][n]) <= max_cost) { +if (current_cost - std::get<0>(solution[i]) + std::get<0>(slots[i][n]) <= max_cost) { uint64_t delta_cost = std::get<0>(slots[i][n]) - std::get<0>(solution[i]); current_cost += delta_cost; solution[i] = slots[i][n]; From 305f312001fa5b3f068103a9d4b29cd0a97b02fb Mon Sep 17 00:00:00 2001 From: imoc Date: Thu, 9 Jan 2025 17:35:25 +0800 Subject: [PATCH 14/17] improvement v3-5-5 -- sum(log(err)): -878.246132 -- max(err): 0.017426 --- exllamav2/exllamav2_ext/ext_quant.cpp | 125 ++++++++++++++------------ 1 file changed, 68 insertions(+), 57 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index 31bd005f..b1f7211b 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -174,20 +174,21 @@ std::tuple>, std::vector, float, ui ) { // --- Enhanced Parameters --- - const int redistribution_iterations = 50; // Increased iterations for more thorough redistribution - const float bpw_penalty_scale = 0.5f; // Stronger penalty for low BPW - const float min_bpw_base = 3.5f; // Base minimum BPW - const int opportunistic_iterations = 30000; // More iterations for opportunistic optimization - const float initial_opportunistic_temp = 0.1f; // Higher initial temperature - const float low_error_threshold = 0.001f; - const float error_floor = 0.0001f; // Minimum acceptable error - const float targeted_redistribution_bpw_threshold = 3.5f; // Higher threshold for targeted redistribution - const float targeted_redistribution_max_err_increase_initial = 1.3f; // More initial tolerance for error increase in targeted redistribution - const float targeted_redistribution_max_err_increase_final = 1.05f; // Tighter final tolerance - const float high_bpw_donor_threshold = 5.5f; // Threshold for identifying high-BPW donor layers - const int num_options_to_explore_per_layer = 8; // Explore more options in targeted redistribution - const int bpw_smoothing_passes = 5; // Multiple passes for BPW smoothing - const float bpw_smoothing_threshold = 0.75f; // Larger difference for triggering BPW smoothing + const int redistribution_iterations = 50; + const float bpw_penalty_scale = 0.6f; // Stronger penalty for low BPW + const float min_bpw_base = 3.3f; // Higher base minimum BPW, we want higher bpw + const int opportunistic_iterations = 30000; + const float initial_opportunistic_temp = 0.12f; + const float low_error_threshold = 0.002f; + const float error_floor = 0.0005f; + const float targeted_redistribution_bpw_threshold = 3.6f; + const float targeted_redistribution_max_err_increase_initial = 1.5f; // Increased initial tolerance + const float targeted_redistribution_max_err_increase_final = 1.1f; // Slightly increased final tolerance + const float high_bpw_donor_threshold = 5.5f; + const int num_options_to_explore_per_layer = 8; + const int bpw_smoothing_passes = 8; + const float bpw_smoothing_threshold = 0.75f; + const float bpw_uniformity_factor = 1.8f; // Control trade-off between BPW uniformity and error, higher value will make bpw more uniform // --- Dynamic Minimum BPW --- auto calculate_dynamic_min_bpw = [&](float target_bpw, float temp_ratio) { @@ -224,22 +225,22 @@ std::tuple>, std::vector, float, ui uint64_t current_cost = 0; float current_max_exp_error = 0; - float temp = initial_temp * 2.5f; // Higher initial temperature + float temp = initial_temp * 2.5f; int iterations_outer = static_cast(std::log(min_temp / temp) / std::log(cooling_factor)); float target_bpw = max_cost * 8.0f / 1024.0f / num_slots; - // --- Initialization (favor higher BPW options) --- + // --- Balanced Initialization --- for (int i = 0; i < num_slots; ++i) { int best_idx = 0; - float best_bpw_diff = 1e10f; + float best_score = -1e10f; // Lower score is better for (int j = 0; j < slots[i].size(); ++j) { float bpw = calculate_bpw(slots[i][j]); - if (bpw >= target_bpw) { - float bpw_diff = bpw - target_bpw; - if (bpw_diff < best_bpw_diff) { - best_bpw_diff = bpw_diff; - best_idx = j; - } + float error = std::get<1>(slots[i][j]); + // Favor options with BPW close to target and relatively high error + float score = -std::abs(bpw - target_bpw) + error * bpw_uniformity_factor; + if (score > best_score) { + best_score = score; + best_idx = j; } } solution[i] = slots[i][best_idx]; @@ -270,12 +271,14 @@ std::tuple>, std::vector, float, ui new_max_exp_error = std::max(current_max_exp_error, std::get<1>(new_option)); } - // Enhanced BPW Penalty (Dynamic, Temperature-Dependent, and Non-Linear) + // Enhanced Layer-Specific BPW Penalty float bpw_new = calculate_bpw(new_option); float bpw_penalty = 0.0f; if (bpw_new < min_bpw_limit) { - bpw_penalty = (min_bpw_limit - bpw_new) * bpw_penalty_scale * (1 + temp_ratio * 2); // Stronger temperature scaling - bpw_penalty = bpw_penalty * bpw_penalty * bpw_penalty; // Cubic penalty for more aggressive discouragement + // Stronger penalty for earlier layers + float layer_penalty_factor = std::max(0.0f, 1.0f - static_cast(i) / num_slots); + bpw_penalty = (min_bpw_limit - bpw_new) * bpw_penalty_scale * (1 + temp_ratio * 2) * (1 + layer_penalty_factor * bpw_uniformity_factor); + bpw_penalty = bpw_penalty * bpw_penalty * bpw_penalty; } if (current_cost + delta_cost <= max_cost || (delta_cost < 0 && current_cost > max_cost)) { @@ -291,24 +294,35 @@ std::tuple>, std::vector, float, ui temp *= cooling_factor; } - // --- Enhanced Bit Redistribution --- + // --- Enhanced Bit Redistribution with Early Layer Prioritization --- for (int r = 0; r < redistribution_iterations; ++r) { float temp_ratio = temp / (initial_temp * 2.5f); float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); // Calculate BPW statistics and dynamic bpw_threshold auto [bpw_mean, bpw_stddev] = calculate_bpw_stats(solution); - float bpw_threshold = std::max(min_bpw_limit, bpw_mean - bpw_stddev); // More dynamic threshold + float bpw_threshold = std::max(min_bpw_limit, bpw_mean - bpw_stddev); std::vector low_bpw_indices; std::vector high_bpw_indices; std::vector high_bpw_errors; + // Prioritize early layers + for (int i = 0; i < num_slots / 2; ++i) { + if (calculate_bpw(solution[i]) < bpw_threshold) { + low_bpw_indices.push_back(i); + } + } + // Then consider other layers + for (int i = num_slots / 2; i < num_slots; ++i) { + if (calculate_bpw(solution[i]) < bpw_threshold) { + low_bpw_indices.push_back(i); + } + } + for (int i = 0; i < num_slots; ++i) { float bpw = calculate_bpw(solution[i]); - if (bpw < bpw_threshold) { - low_bpw_indices.push_back(i); - } else if (bpw > high_bpw_donor_threshold) { + if (bpw > high_bpw_donor_threshold) { high_bpw_indices.push_back(i); high_bpw_errors.push_back(std::get<1>(solution[i])); } @@ -316,14 +330,13 @@ std::tuple>, std::vector, float, ui if (high_bpw_indices.empty()) continue; - // Error-weighted selection of high_idx (donor) with a bias towards lower error std::discrete_distribution high_idx_dist(high_bpw_errors.begin(), high_bpw_errors.end()); bool improved = false; for (int low_idx : low_bpw_indices) { int high_idx = high_bpw_indices[high_idx_dist(gen)]; - // Find a higher BPW option for the low-BPW slot, with bias towards lower error + // Find a higher BPW option for the low-BPW slot int best_low_new_idx = -1; float best_low_new_error = 1e10f; for (int n = 0; n < slots[low_idx].size(); ++n) { @@ -336,7 +349,7 @@ std::tuple>, std::vector, float, ui } } - // Find a lower BPW option for the high-BPW slot, with bias towards lower error + // Find a lower BPW option for the high-BPW slot int best_high_new_idx = -1; float best_high_new_error = 1e10f; for (int n = 0; n < slots[high_idx].size(); ++n) { @@ -366,14 +379,17 @@ std::tuple>, std::vector, float, ui } } - // Consider error floor if (std::get<1>(new_low_option) < error_floor || std::get<1>(new_high_option) < error_floor) continue; auto [current_bpw_mean, current_bpw_stddev] = calculate_bpw_stats(solution); auto [new_bpw_mean, new_bpw_stddev] = calculate_bpw_stats({new_low_option, new_high_option}); - float bpw_penalty = bpw_penalty_scale * (new_bpw_stddev - current_bpw_stddev) * (1 + temp_ratio); + // Penalty is less relevant here, we are aiming for higher bpw for the low bpw layers anyway + // float bpw_penalty = bpw_penalty_scale * (new_bpw_stddev - current_bpw_stddev) * (1 + temp_ratio); + + // Relaxed max_err constraint for early layers + float max_err_increase = (low_idx < num_slots / 2) ? 1.0f + (targeted_redistribution_max_err_increase_initial - 1.0f) * bpw_uniformity_factor : targeted_redistribution_max_err_increase_initial; - if (new_max_exp_error + bpw_penalty < current_max_exp_error) { + if (new_max_exp_error < current_max_exp_error * max_err_increase) { solution[low_idx] = new_low_option; solution_idx[low_idx] = best_low_new_idx; solution[high_idx] = new_high_option; @@ -481,14 +497,12 @@ std::tuple>, std::vector, float, ui // Dampen penalty for low errors, but less aggressively float error_factor = 1.0f; if (current_max_exp_error < low_error_threshold) { - error_factor = 0.25f; // Less dampening + error_factor = 0.25f; } - // Introduce a probability of accepting a worse solution (simulated annealing-like) if (new_cost <= max_cost) { if (delta_sum_log_err * error_factor < 0 || std::uniform_real_distribution<>(0, 1)(gen) < std::exp(-delta_sum_log_err * error_factor / local_temp)) { accept = true; - // Further penalize if below min_bpw_limit, more aggressively for (int j = 0; j < num_slots; ++j) { if (calculate_bpw(new_solution[j]) < min_bpw_limit) { accept = false; @@ -512,7 +526,7 @@ std::tuple>, std::vector, float, ui } } - local_temp *= 0.95f; // Faster cooling + local_temp *= 0.95f; } // Use the best solution found during opportunistic optimization @@ -528,13 +542,12 @@ std::tuple>, std::vector, float, ui float next_bpw = calculate_bpw(solution[i + 1]); float avg_neighbor_bpw = (prev_bpw + next_bpw) / 2.0f; - if (current_bpw < avg_neighbor_bpw - bpw_smoothing_threshold) { // Larger difference + if (current_bpw < avg_neighbor_bpw - bpw_smoothing_threshold) { // Find a higher BPW option for the current slot for (int n = 0; n < slots[i].size(); ++n) { auto new_option = slots[i][n]; if (calculate_bpw(new_option) > current_bpw && calculate_bpw(new_option) <= avg_neighbor_bpw) { if (current_cost - std::get<0>(solution[i]) + std::get<0>(new_option) <= max_cost) { - // Check if the new option doesn't significantly increase max_err and is not below error floor if (std::get<1>(new_option) < error_floor) continue; float new_max_err = 0; for (int j = 0; j < num_slots; ++j) { @@ -545,7 +558,7 @@ std::tuple>, std::vector, float, ui } } - if (new_max_err < current_max_exp_error * 1.2f) { // Allow a larger increase in max_err + if (new_max_err < current_max_exp_error * 1.2f) { current_cost = current_cost - std::get<0>(solution[i]) + std::get<0>(new_option); solution[i] = new_option; solution_idx[i] = n; @@ -560,8 +573,8 @@ std::tuple>, std::vector, float, ui } // --- Enhanced Targeted Bit Redistribution (Post-processing) --- - for (int iter = 0; iter < num_slots * 3; ++iter) { // Increased passes - // Create a global pool of donor indices, considering both high BPW and relatively low error + for (int iter = 0; iter < num_slots * 3; ++iter) { + // Create a global pool of donor indices std::vector donor_indices; std::vector donor_errors; for (int j = 0; j < num_slots; ++j) { @@ -571,18 +584,15 @@ std::tuple>, std::vector, float, ui } } - if (donor_indices.empty()) continue; // Skip if no suitable donors + if (donor_indices.empty()) continue; - // Error-weighted selection of donor std::discrete_distribution donor_dist(donor_errors.begin(), donor_errors.end()); for (int i = 0; i < num_slots; ++i) { float current_bpw = calculate_bpw(solution[i]); if (current_bpw < targeted_redistribution_bpw_threshold) { - // Randomly select a donor from the global pool, weighted by error int donor_idx = donor_indices[donor_dist(gen)]; - // Explore multiple higher BPW options for the current slot std::vector higher_bpw_options; for (int n = 0; n < slots[i].size(); ++n) { if (calculate_bpw(slots[i][n]) > current_bpw) { @@ -597,10 +607,8 @@ std::tuple>, std::vector, float, ui int best_new_idx = higher_bpw_options[option_idx]; auto new_option = slots[i][best_new_idx]; - // Consider error floor if (std::get<1>(new_option) < error_floor) continue; - // Find a lower BPW option for the donor slot, with bias towards lower error int best_donor_new_idx = -1; float best_donor_new_error = 1e10f; for (int n = 0; n < slots[donor_idx].size(); ++n) { @@ -615,7 +623,6 @@ std::tuple>, std::vector, float, ui if (best_donor_new_idx != -1) { auto donor_new_option = slots[donor_idx][best_donor_new_idx]; - // Consider error floor if (std::get<1>(donor_new_option) < error_floor) continue; uint64_t new_cost = current_cost - std::get<0>(solution[i]) - std::get<0>(solution[donor_idx]) @@ -632,11 +639,15 @@ std::tuple>, std::vector, float, ui } } - // Adaptive max_err_increase based on iteration number float max_err_increase = targeted_redistribution_max_err_increase_initial - (targeted_redistribution_max_err_increase_initial - targeted_redistribution_max_err_increase_final) * (static_cast(iter) / (num_slots * 3)); + // Relaxed constraint for early layers + if (i < num_slots / 2) { + max_err_increase *= bpw_uniformity_factor; + } + if (new_max_err < current_max_exp_error * max_err_increase) { current_cost = new_cost; solution[i] = new_option; @@ -644,7 +655,7 @@ std::tuple>, std::vector, float, ui solution[donor_idx] = donor_new_option; solution_idx[donor_idx] = best_donor_new_idx; current_max_exp_error = new_max_err; - break; // Move to the next low-bpw layer after a successful redistribution + break; } } } @@ -659,7 +670,7 @@ std::tuple>, std::vector, float, ui for (int i = 0; i < num_slots; ++i) { float bpw = calculate_bpw(solution[i]); float error = std::get<1>(solution[i]); - float penalty = (bpw < targeted_redistribution_bpw_threshold) ? 1000.0f : 0.0f; // High penalty for low BPW + float penalty = (bpw < targeted_redistribution_bpw_threshold) ? 1000.0f : 0.0f; bpw_error_indices[i] = {error + penalty, bpw, i}; } std::sort(bpw_error_indices.begin(), bpw_error_indices.end(), std::greater>()); @@ -668,7 +679,7 @@ std::tuple>, std::vector, float, ui int i = std::get<2>(tuple); for (int n = slots[i].size() - 1; n >= 0; --n) { if (calculate_bpw(slots[i][n]) < calculate_bpw(solution[i])) { -if (current_cost - std::get<0>(solution[i]) + std::get<0>(slots[i][n]) <= max_cost) { + if (current_cost - std::get<0>(solution[i]) + std::get<0>(slots[i][n]) <= max_cost) { uint64_t delta_cost = std::get<0>(slots[i][n]) - std::get<0>(solution[i]); current_cost += delta_cost; solution[i] = slots[i][n]; From bb120fc520ad217f4d111a47e63096713ad92179 Mon Sep 17 00:00:00 2001 From: imoc Date: Fri, 10 Jan 2025 15:42:25 +0800 Subject: [PATCH 15/17] improvement v3-5-6 remove layer position --- exllamav2/exllamav2_ext/ext_quant.cpp | 63 +++++++-------------------- 1 file changed, 16 insertions(+), 47 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index b1f7211b..dc476516 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -182,13 +182,12 @@ std::tuple>, std::vector, float, ui const float low_error_threshold = 0.002f; const float error_floor = 0.0005f; const float targeted_redistribution_bpw_threshold = 3.6f; - const float targeted_redistribution_max_err_increase_initial = 1.5f; // Increased initial tolerance - const float targeted_redistribution_max_err_increase_final = 1.1f; // Slightly increased final tolerance + const float targeted_redistribution_max_err_increase = 1.5f; // Increased tolerance for error increase in targeted redistribution const float high_bpw_donor_threshold = 5.5f; const int num_options_to_explore_per_layer = 8; const int bpw_smoothing_passes = 8; const float bpw_smoothing_threshold = 0.75f; - const float bpw_uniformity_factor = 1.8f; // Control trade-off between BPW uniformity and error, higher value will make bpw more uniform + const float bpw_balance_factor = 1.8f; // Control trade-off between BPW uniformity and error // --- Dynamic Minimum BPW --- auto calculate_dynamic_min_bpw = [&](float target_bpw, float temp_ratio) { @@ -232,12 +231,12 @@ std::tuple>, std::vector, float, ui // --- Balanced Initialization --- for (int i = 0; i < num_slots; ++i) { int best_idx = 0; - float best_score = -1e10f; // Lower score is better + float best_score = -1e10f; for (int j = 0; j < slots[i].size(); ++j) { float bpw = calculate_bpw(slots[i][j]); float error = std::get<1>(slots[i][j]); // Favor options with BPW close to target and relatively high error - float score = -std::abs(bpw - target_bpw) + error * bpw_uniformity_factor; + float score = -std::abs(bpw - target_bpw) + error * bpw_balance_factor; if (score > best_score) { best_score = score; best_idx = j; @@ -271,14 +270,13 @@ std::tuple>, std::vector, float, ui new_max_exp_error = std::max(current_max_exp_error, std::get<1>(new_option)); } - // Enhanced Layer-Specific BPW Penalty + // Enhanced Dynamic BPW Penalty (applied uniformly to all layers) float bpw_new = calculate_bpw(new_option); float bpw_penalty = 0.0f; if (bpw_new < min_bpw_limit) { - // Stronger penalty for earlier layers - float layer_penalty_factor = std::max(0.0f, 1.0f - static_cast(i) / num_slots); - bpw_penalty = (min_bpw_limit - bpw_new) * bpw_penalty_scale * (1 + temp_ratio * 2) * (1 + layer_penalty_factor * bpw_uniformity_factor); - bpw_penalty = bpw_penalty * bpw_penalty * bpw_penalty; + // Clear formula for BPW penalty + bpw_penalty = (min_bpw_limit - bpw_new) * bpw_penalty_scale * (1 + temp_ratio) * bpw_balance_factor; + bpw_penalty = bpw_penalty * bpw_penalty; // Squared penalty } if (current_cost + delta_cost <= max_cost || (delta_cost < 0 && current_cost > max_cost)) { @@ -294,34 +292,24 @@ std::tuple>, std::vector, float, ui temp *= cooling_factor; } - // --- Enhanced Bit Redistribution with Early Layer Prioritization --- + // --- Error-Weighted Bit Redistribution --- for (int r = 0; r < redistribution_iterations; ++r) { float temp_ratio = temp / (initial_temp * 2.5f); float min_bpw_limit = calculate_dynamic_min_bpw(target_bpw, temp_ratio); // Calculate BPW statistics and dynamic bpw_threshold auto [bpw_mean, bpw_stddev] = calculate_bpw_stats(solution); - float bpw_threshold = std::max(min_bpw_limit, bpw_mean - bpw_stddev); + float bpw_threshold = std::max(min_bpw_limit, bpw_mean - bpw_stddev * bpw_balance_factor); std::vector low_bpw_indices; std::vector high_bpw_indices; std::vector high_bpw_errors; - // Prioritize early layers - for (int i = 0; i < num_slots / 2; ++i) { - if (calculate_bpw(solution[i]) < bpw_threshold) { - low_bpw_indices.push_back(i); - } - } - // Then consider other layers - for (int i = num_slots / 2; i < num_slots; ++i) { - if (calculate_bpw(solution[i]) < bpw_threshold) { - low_bpw_indices.push_back(i); - } - } - for (int i = 0; i < num_slots; ++i) { float bpw = calculate_bpw(solution[i]); + if (bpw < bpw_threshold) { + low_bpw_indices.push_back(i); + } if (bpw > high_bpw_donor_threshold) { high_bpw_indices.push_back(i); high_bpw_errors.push_back(std::get<1>(solution[i])); @@ -336,7 +324,6 @@ std::tuple>, std::vector, float, ui for (int low_idx : low_bpw_indices) { int high_idx = high_bpw_indices[high_idx_dist(gen)]; - // Find a higher BPW option for the low-BPW slot int best_low_new_idx = -1; float best_low_new_error = 1e10f; for (int n = 0; n < slots[low_idx].size(); ++n) { @@ -349,7 +336,6 @@ std::tuple>, std::vector, float, ui } } - // Find a lower BPW option for the high-BPW slot int best_high_new_idx = -1; float best_high_new_error = 1e10f; for (int n = 0; n < slots[high_idx].size(); ++n) { @@ -381,15 +367,7 @@ std::tuple>, std::vector, float, ui if (std::get<1>(new_low_option) < error_floor || std::get<1>(new_high_option) < error_floor) continue; - auto [current_bpw_mean, current_bpw_stddev] = calculate_bpw_stats(solution); - auto [new_bpw_mean, new_bpw_stddev] = calculate_bpw_stats({new_low_option, new_high_option}); - // Penalty is less relevant here, we are aiming for higher bpw for the low bpw layers anyway - // float bpw_penalty = bpw_penalty_scale * (new_bpw_stddev - current_bpw_stddev) * (1 + temp_ratio); - - // Relaxed max_err constraint for early layers - float max_err_increase = (low_idx < num_slots / 2) ? 1.0f + (targeted_redistribution_max_err_increase_initial - 1.0f) * bpw_uniformity_factor : targeted_redistribution_max_err_increase_initial; - - if (new_max_exp_error < current_max_exp_error * max_err_increase) { + if (new_max_exp_error < current_max_exp_error * (1 + 0.1f * bpw_balance_factor)) { solution[low_idx] = new_low_option; solution_idx[low_idx] = best_low_new_idx; solution[high_idx] = new_high_option; @@ -558,7 +536,7 @@ std::tuple>, std::vector, float, ui } } - if (new_max_err < current_max_exp_error * 1.2f) { + if (new_max_err < current_max_exp_error * (1 + 0.1f * bpw_balance_factor)) { current_cost = current_cost - std::get<0>(solution[i]) + std::get<0>(new_option); solution[i] = new_option; solution_idx[i] = n; @@ -639,16 +617,7 @@ std::tuple>, std::vector, float, ui } } - float max_err_increase = targeted_redistribution_max_err_increase_initial - - (targeted_redistribution_max_err_increase_initial - targeted_redistribution_max_err_increase_final) * - (static_cast(iter) / (num_slots * 3)); - - // Relaxed constraint for early layers - if (i < num_slots / 2) { - max_err_increase *= bpw_uniformity_factor; - } - - if (new_max_err < current_max_exp_error * max_err_increase) { + if (new_max_err < current_max_exp_error * targeted_redistribution_max_err_increase) { current_cost = new_cost; solution[i] = new_option; solution_idx[i] = best_new_idx; From 165c90924137e86f15024e653085aaa379f7152f Mon Sep 17 00:00:00 2001 From: imoc Date: Fri, 10 Jan 2025 16:32:58 +0800 Subject: [PATCH 16/17] improvement v3-5-f parameterize on a scale of higher min(bpw) or lower exp. error --- exllamav2/exllamav2_ext/ext_quant.cpp | 91 ++++++++++++++++++++++----- 1 file changed, 76 insertions(+), 15 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index dc476516..287708b8 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -173,21 +173,82 @@ std::tuple>, std::vector, float, ui float norm ) { - // --- Enhanced Parameters --- - const int redistribution_iterations = 50; - const float bpw_penalty_scale = 0.6f; // Stronger penalty for low BPW - const float min_bpw_base = 3.3f; // Higher base minimum BPW, we want higher bpw - const int opportunistic_iterations = 30000; - const float initial_opportunistic_temp = 0.12f; - const float low_error_threshold = 0.002f; - const float error_floor = 0.0005f; - const float targeted_redistribution_bpw_threshold = 3.6f; - const float targeted_redistribution_max_err_increase = 1.5f; // Increased tolerance for error increase in targeted redistribution - const float high_bpw_donor_threshold = 5.5f; - const int num_options_to_explore_per_layer = 8; - const int bpw_smoothing_passes = 8; - const float bpw_smoothing_threshold = 0.75f; - const float bpw_balance_factor = 1.8f; // Control trade-off between BPW uniformity and error + // --- Mode-Specific Parameters --- + enum Mode { MODE_BALANCED, MODE_UNIFORM, MODE_AGGRESSIVE, MODE_3_5_2, MODE_3_5_6, MODE_CUSTOM }; + // --- Mode Selection --- + Mode mode = MODE_3_5_2; // Default mode, Can be changed into other mode or MODE_CUSTOM + + // Define a struct to hold parameters for different modes + struct ModeParams { + float bpw_penalty_scale; + float min_bpw_base; + float opportunistic_temp; + float error_floor; + float targeted_redistribution_max_err_increase; + float high_bpw_donor_threshold; + float bpw_balance_factor; + float low_error_threshold; + int redistribution_iterations; + int opportunistic_iterations; + int num_options_to_explore_per_layer; + int bpw_smoothing_passes; + float bpw_smoothing_threshold; + float targeted_redistribution_bpw_threshold; + }; + + // Define the parameter sets for each mode + const std::vector mode_params = { + // MODE_BALANCED: Balanced trade-off between BPW uniformity and error + {0.6f, 3.2f, 0.1f, 0.0001f, 1.3f, 5.5f, 1.5f, 0.001f, 60, 30000, 8, 8, 0.75f, 3.5f}, + + // MODE_UNIFORM: Strong emphasis on BPW uniformity + {0.8f, 3.5f, 0.12f, 0.0005f, 1.5f, 6.0f, 3.0f, 0.001f, 80, 40000, 8, 10, 0.8f, 3.7f}, + + // MODE_AGGRESSIVE: Aggressively avoids low BPW, potentially higher error + {1.0f, 3.8f, 0.15f, 0.001f, 1.6f, 6.5f, 4.0f, 0.001f, 100, 50000, 8, 12, 0.9f, 3.9f}, + + // MODE_3_5_2: Approximates the behavior of Version 3-5-2 + {0.1f, 3.0f, 0.05f, 0.0f, 1.2f, 5.0f, 0.1f, 0.0009f, 25, 15000, 3, 5, 0.5f, 3.3f}, + + // MODE_3_5_6: Replicates the behavior of Version 3-5-6 + {0.6f, 3.3f, 0.12f, 0.0005f, 1.5f, 5.5f, 1.8f, 0.002f, 50, 30000, 8, 8, 0.75f, 3.6f}, + + // MODE_CUSTOM: User-defined parameters, will be overwritten if using custom mode + {0.6f, 3.2f, 0.1f, 0.0001f, 1.3f, 5.5f, 1.5f, 0.001f, 60, 30000, 8, 8, 0.75f, 3.5f} + }; + + ModeParams params; + if (mode == MODE_CUSTOM) + { + params = {0.7f, 3.3f, 0.11f, 0.0002f, 1.35f, 5.7f, 2.0f, 0.001f, 70, 35000, 8, 9, 0.8f, 3.6f}; // Example custom parameters, you should change this + } else { + params = mode_params[mode]; + } + + // --- Parameter Application --- + // (Consolidated parameters are grouped together) + + // Penalty-related parameters + const float bpw_penalty_scale = params.bpw_penalty_scale; + const float min_bpw_base = params.min_bpw_base; + const float bpw_balance_factor = params.bpw_balance_factor; + + // Redistribution-related parameters + const int redistribution_iterations = params.redistribution_iterations; + const float targeted_redistribution_bpw_threshold = params.targeted_redistribution_bpw_threshold; + const float targeted_redistribution_max_err_increase = params.targeted_redistribution_max_err_increase; + const float high_bpw_donor_threshold = params.high_bpw_donor_threshold; + const int num_options_to_explore_per_layer = params.num_options_to_explore_per_layer; + + // Opportunistic optimization parameters + const int opportunistic_iterations = params.opportunistic_iterations; + const float initial_opportunistic_temp = params.opportunistic_temp; + const float low_error_threshold = params.low_error_threshold; + + // Other parameters + const float error_floor = params.error_floor; + const int bpw_smoothing_passes = params.bpw_smoothing_passes; + const float bpw_smoothing_threshold = params.bpw_smoothing_threshold; // --- Dynamic Minimum BPW --- auto calculate_dynamic_min_bpw = [&](float target_bpw, float temp_ratio) { From baaa786a920f1b6cd900345638e359e0908f37c8 Mon Sep 17 00:00:00 2001 From: imoc Date: Mon, 13 Jan 2025 16:21:48 +0800 Subject: [PATCH 17/17] update modes for v3-5-f --- exllamav2/exllamav2_ext/ext_quant.cpp | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/exllamav2/exllamav2_ext/ext_quant.cpp b/exllamav2/exllamav2_ext/ext_quant.cpp index 287708b8..961ca6a9 100644 --- a/exllamav2/exllamav2_ext/ext_quant.cpp +++ b/exllamav2/exllamav2_ext/ext_quant.cpp @@ -174,9 +174,9 @@ std::tuple>, std::vector, float, ui ) { // --- Mode-Specific Parameters --- - enum Mode { MODE_BALANCED, MODE_UNIFORM, MODE_AGGRESSIVE, MODE_3_5_2, MODE_3_5_6, MODE_CUSTOM }; + enum Mode { MODE_RELAXED, MODE_BALANCED, MODE_UNIFORM, MODE_AGGRESSIVE, MODE_3_5_2, MODE_3_5_6, MODE_CUSTOM }; // --- Mode Selection --- - Mode mode = MODE_3_5_2; // Default mode, Can be changed into other mode or MODE_CUSTOM + Mode mode = MODE_UNIFORM; // Default mode, Can be changed into other mode or MODE_CUSTOM // Define a struct to hold parameters for different modes struct ModeParams { @@ -198,23 +198,20 @@ std::tuple>, std::vector, float, ui // Define the parameter sets for each mode const std::vector mode_params = { + // MODE_RELAXED: Minize error first + {0.1f, 3.0f, 0.05f, 0.0f, 1.2f, 5.0f, 0.1f, 0.0009f, 25, 15000, 3, 5, 0.5f, 3.3f}, + // MODE_BALANCED: Balanced trade-off between BPW uniformity and error - {0.6f, 3.2f, 0.1f, 0.0001f, 1.3f, 5.5f, 1.5f, 0.001f, 60, 30000, 8, 8, 0.75f, 3.5f}, + {0.6f, 3.3f, 0.12f, 0.0005f, 1.5f, 5.5f, 1.8f, 0.002f, 50, 30000, 8, 8, 0.75f, 3.6f}, // MODE_UNIFORM: Strong emphasis on BPW uniformity - {0.8f, 3.5f, 0.12f, 0.0005f, 1.5f, 6.0f, 3.0f, 0.001f, 80, 40000, 8, 10, 0.8f, 3.7f}, + {0.8f, 3.5f, 0.12f, 0.0005f, 1.6f, 6.0f, 3.0f, 0.001f, 80, 40000, 8, 10, 0.8f, 3.7f}, // MODE_AGGRESSIVE: Aggressively avoids low BPW, potentially higher error - {1.0f, 3.8f, 0.15f, 0.001f, 1.6f, 6.5f, 4.0f, 0.001f, 100, 50000, 8, 12, 0.9f, 3.9f}, - - // MODE_3_5_2: Approximates the behavior of Version 3-5-2 - {0.1f, 3.0f, 0.05f, 0.0f, 1.2f, 5.0f, 0.1f, 0.0009f, 25, 15000, 3, 5, 0.5f, 3.3f}, - - // MODE_3_5_6: Replicates the behavior of Version 3-5-6 - {0.6f, 3.3f, 0.12f, 0.0005f, 1.5f, 5.5f, 1.8f, 0.002f, 50, 30000, 8, 8, 0.75f, 3.6f}, + {1.0f, 3.8f, 0.15f, 0.001f, 1.7f, 6.5f, 4.0f, 0.001f, 100, 50000, 8, 12, 0.9f, 3.9f}, // MODE_CUSTOM: User-defined parameters, will be overwritten if using custom mode - {0.6f, 3.2f, 0.1f, 0.0001f, 1.3f, 5.5f, 1.5f, 0.001f, 60, 30000, 8, 8, 0.75f, 3.5f} + {0.8f, 5.0f, 0.12f, 0.0005f, 1.5f, 6.0f, 3.0f, 0.001f, 80, 40000, 8, 10, 0.8f, 5.5f}, }; ModeParams params;