@@ -1342,9 +1342,6 @@ __device__ inline void MarlinMoESingle(
1342
1342
1343
1343
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
1344
1344
const int threads, // number of threads in a threadblock
1345
- const int thread_m_blocks, // number of 16x16 blocks in the m
1346
- // dimension (batchsize) of the
1347
- // threadblock
1348
1345
const int thread_n_blocks, // same for n dimension (output)
1349
1346
const int thread_k_blocks, // same for k dimension (reduction)
1350
1347
const int stages, // number of stages for the async global->shared
@@ -1459,9 +1456,6 @@ __global__ void compute_expert_offsets(int const* __restrict__ topk_ids,
1459
1456
1460
1457
template <const vllm::ScalarTypeId w_type_id, // weight ScalarType id
1461
1458
const int threads, // number of threads in a threadblock
1462
- const int thread_m_blocks, // number of 16x16 blocks in the m
1463
- // dimension (batchsize) of the
1464
- // threadblock
1465
1459
const int thread_n_blocks, // same for n dimension (output)
1466
1460
const int thread_k_blocks, // same for k dimension (reduction)
1467
1461
const int stages, // number of stages for the async global->shared
@@ -1515,26 +1509,24 @@ const int STAGES = 4; // 4 pipeline stages fit into shared memory
1515
1509
static constexpr int min_thread_n = 64 ;
1516
1510
static constexpr int min_thread_k = 64 ;
1517
1511
1518
- #define __CALL_IF_MOE (W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
1519
- THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, \
1520
- NUM_THREADS) \
1521
- else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
1522
- thread_n_blocks == THREAD_N_BLOCKS && \
1523
- thread_k_blocks == THREAD_K_BLOCKS && \
1524
- has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
1525
- num_threads == NUM_THREADS) { \
1526
- cudaFuncSetAttribute ( \
1527
- MarlinMoE<W_TYPE.id (), NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
1528
- THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
1529
- cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
1530
- MarlinMoE<W_TYPE.id (), NUM_THREADS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
1531
- THREAD_K_BLOCKS, STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
1532
- <<<blocks, NUM_THREADS, max_shared_mem, stream>>> ( \
1533
- A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
1534
- g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
1535
- num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
1536
- replicate_input, apply_weights, m_block, max_par, \
1537
- exec_cfg.max_m_blocks ); \
1512
+ #define __CALL_IF_MOE (W_TYPE, THREAD_N_BLOCKS, THREAD_K_BLOCKS, HAS_ACT_ORDER, \
1513
+ GROUP_BLOCKS, NUM_THREADS) \
1514
+ else if (q_type == W_TYPE && thread_n_blocks == THREAD_N_BLOCKS && \
1515
+ thread_k_blocks == THREAD_K_BLOCKS && \
1516
+ has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
1517
+ num_threads == NUM_THREADS) { \
1518
+ cudaFuncSetAttribute ( \
1519
+ MarlinMoE<W_TYPE.id (), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
1520
+ STAGES, HAS_ACT_ORDER, GROUP_BLOCKS>, \
1521
+ cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
1522
+ MarlinMoE<W_TYPE.id (), NUM_THREADS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
1523
+ STAGES, HAS_ACT_ORDER, GROUP_BLOCKS> \
1524
+ <<<blocks, NUM_THREADS, max_shared_mem, stream>>> ( \
1525
+ A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \
1526
+ g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \
1527
+ num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \
1528
+ replicate_input, apply_weights, m_block, max_par, \
1529
+ exec_cfg.max_m_blocks ); \
1538
1530
}
1539
1531
1540
1532
typedef struct {
@@ -1711,31 +1703,16 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
1711
1703
return exec_config_t {0 , {-1 , -1 , -1 }};
1712
1704
}
1713
1705
1714
- #define CALL_IF_MOE (W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS ) \
1715
- __CALL_IF_MOE (W_TYPE, 1 , N_BLOCKS, K_BLOCKS, true , 0 , NUM_THREADS) \
1716
- __CALL_IF_MOE (W_TYPE, 2 , N_BLOCKS, K_BLOCKS, true , 0 , NUM_THREADS) \
1717
- __CALL_IF_MOE (W_TYPE, 3 , N_BLOCKS, K_BLOCKS, true , 0 , NUM_THREADS) \
1718
- __CALL_IF_MOE (W_TYPE, 4 , N_BLOCKS, K_BLOCKS, true , 0 , NUM_THREADS) \
1719
- \
1720
- __CALL_IF_MOE (W_TYPE, 1 , N_BLOCKS, K_BLOCKS, false , -1 , NUM_THREADS) \
1721
- __CALL_IF_MOE (W_TYPE, 1 , N_BLOCKS, K_BLOCKS, false , 2 , NUM_THREADS) \
1722
- __CALL_IF_MOE (W_TYPE, 1 , N_BLOCKS, K_BLOCKS, false , 4 , NUM_THREADS) \
1723
- __CALL_IF_MOE (W_TYPE, 1 , N_BLOCKS, K_BLOCKS, false , 8 , NUM_THREADS) \
1724
- \
1725
- __CALL_IF_MOE (W_TYPE, 2 , N_BLOCKS, K_BLOCKS, false , -1 , NUM_THREADS) \
1726
- __CALL_IF_MOE (W_TYPE, 2 , N_BLOCKS, K_BLOCKS, false , 2 , NUM_THREADS) \
1727
- __CALL_IF_MOE (W_TYPE, 2 , N_BLOCKS, K_BLOCKS, false , 4 , NUM_THREADS) \
1728
- __CALL_IF_MOE (W_TYPE, 2 , N_BLOCKS, K_BLOCKS, false , 8 , NUM_THREADS) \
1729
- \
1730
- __CALL_IF_MOE (W_TYPE, 3 , N_BLOCKS, K_BLOCKS, false , -1 , NUM_THREADS) \
1731
- __CALL_IF_MOE (W_TYPE, 3 , N_BLOCKS, K_BLOCKS, false , 2 , NUM_THREADS) \
1732
- __CALL_IF_MOE (W_TYPE, 3 , N_BLOCKS, K_BLOCKS, false , 4 , NUM_THREADS) \
1733
- __CALL_IF_MOE (W_TYPE, 3 , N_BLOCKS, K_BLOCKS, false , 8 , NUM_THREADS) \
1734
- \
1735
- __CALL_IF_MOE (W_TYPE, 4 , N_BLOCKS, K_BLOCKS, false , -1 , NUM_THREADS) \
1736
- __CALL_IF_MOE (W_TYPE, 4 , N_BLOCKS, K_BLOCKS, false , 2 , NUM_THREADS) \
1737
- __CALL_IF_MOE (W_TYPE, 4 , N_BLOCKS, K_BLOCKS, false , 4 , NUM_THREADS) \
1738
- __CALL_IF_MOE (W_TYPE, 4 , N_BLOCKS, K_BLOCKS, false , 8 , NUM_THREADS)
1706
+ #define CALL_IF_MOE (W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS ) \
1707
+ __CALL_IF_MOE (W_TYPE, N_BLOCKS, K_BLOCKS, true , 0 , NUM_THREADS) \
1708
+ __CALL_IF_MOE (W_TYPE, N_BLOCKS, K_BLOCKS, true , 0 , NUM_THREADS) \
1709
+ __CALL_IF_MOE (W_TYPE, N_BLOCKS, K_BLOCKS, true , 0 , NUM_THREADS) \
1710
+ __CALL_IF_MOE (W_TYPE, N_BLOCKS, K_BLOCKS, true , 0 , NUM_THREADS) \
1711
+ \
1712
+ __CALL_IF_MOE (W_TYPE, N_BLOCKS, K_BLOCKS, false , -1 , NUM_THREADS) \
1713
+ __CALL_IF_MOE (W_TYPE, N_BLOCKS, K_BLOCKS, false , 2 , NUM_THREADS) \
1714
+ __CALL_IF_MOE (W_TYPE, N_BLOCKS, K_BLOCKS, false , 4 , NUM_THREADS) \
1715
+ __CALL_IF_MOE (W_TYPE, N_BLOCKS, K_BLOCKS, false , 8 , NUM_THREADS)
1739
1716
1740
1717
void marlin_mm_moe_f16i4 (const void * A, const void * B, void * C,
1741
1718
const void * sorted_ids, const void * topk_weights,
0 commit comments