CUDA: batched+noncont MMQ, refactor bs>1 MoE code (#13199)

This commit is contained in:
Johannes Gäßler
2025-04-30 23:12:59 +02:00
committed by GitHub
parent 6f67cf1f48
commit e1e8e0991f
9 changed files with 869 additions and 440 deletions

View File

@@ -13,9 +13,10 @@ using namespace ggml_cuda_mma;
#define MMQ_ITER_K 256
#define MMQ_NWARPS 8
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int & kbx0, const int & i_max, const int & stride);
typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00);
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max);
typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
typedef void (*vec_dot_mmq_t)(const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00);
typedef void (*mmq_write_back_t)(const float * __restrict__ sum, const int32_t * __restrict__ get_rows_to_sorted,
float * __restrict__ dst, const int stride, const int i_max, const int j_max);
enum mmq_q8_1_ds_layout {
MMQ_Q8_1_DS_LAYOUT_D4,
@@ -233,7 +234,7 @@ static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */
// ------------------------------------------------------------
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -289,7 +290,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y);
const int * x_qs = (const int *) x;
@@ -328,7 +329,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a(
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -384,7 +385,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y);
const int * x_qs = (const int *) x;
@@ -423,7 +424,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a(
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -495,7 +496,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -565,7 +566,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -621,7 +622,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y);
const int * x_qs = (const int *) x;
@@ -651,7 +652,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
template <int mmq_x, int mmq_y, int nwarps, mmq_q8_1_ds_layout ds_layout>
static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
typedef tile<16, 8, int> tile_A;
typedef tile< 8, 8, int> tile_B;
@@ -732,7 +733,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y);
const int * x_qs = (const int *) x;
@@ -762,7 +763,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
typedef tile<16, 8, int> tile_A;
typedef tile< 8, 8, int> tile_B;
@@ -839,7 +840,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16;
const int * x_qs = (const int *) x;
@@ -871,7 +872,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#ifdef NEW_MMA_AVAILABLE
typedef tile<16, 4, int> tile_A;
@@ -955,7 +956,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -1011,7 +1012,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y);
const int * x_qs = (const int *) x;
@@ -1074,7 +1075,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#ifdef NEW_MMA_AVAILABLE
typedef tile<16, 4, int> tile_A;
@@ -1201,7 +1202,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q3_K(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -1298,7 +1299,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y);
const int * x_qs = (const int *) x;
@@ -1340,7 +1341,7 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -1437,7 +1438,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y);
const int * x_qs = (const int *) x;
@@ -1469,7 +1470,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a(
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -1578,7 +1579,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y);
const int * x_qs = (const int *) x;
@@ -1610,7 +1611,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a(
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -1693,7 +1694,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y);
const int * x_qs = (const int *) x;
@@ -1726,7 +1727,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a(
template <int mmq_x, int mmq_y, int nwarps>
static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int & k00) {
const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
#ifdef NEW_MMA_AVAILABLE
typedef tile<16, 4, int> tile_A;
@@ -1835,7 +1836,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_nl(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -1893,7 +1894,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -1951,7 +1952,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xs(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -2007,7 +2008,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_s(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -2070,7 +2071,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_xxs(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -2126,7 +2127,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq3_s(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -2189,7 +2190,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq1_s(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -2245,7 +2246,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
}
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_xs(
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) {
#ifdef NEW_MMA_AVAILABLE
int * x_qs = (int *) x_tile;
@@ -2306,8 +2307,8 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_dp4a(
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst,
const int stride, const int i_max, const int j_max) {
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = j0 + threadIdx.y;
@@ -2324,15 +2325,15 @@ static __device__ __forceinline__ void mmq_write_back_dp4a(
continue;
}
dst[j*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
}
}
}
template<int mmq_x, int mmq_y, int nwarps, bool need_check>
static __device__ __forceinline__ void mmq_write_back_mma(
const float * __restrict__ sum, float * __restrict__ dst, const int & stride, const int & i_max, const int & j_max) {
const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst,
const int stride, const int i_max, const int j_max) {
typedef tile<16, 8, int> tile_C;
constexpr int granularity = mmq_get_granularity_device(mmq_x);
@@ -2362,7 +2363,7 @@ static __device__ __forceinline__ void mmq_write_back_mma(
continue;
}
dst[j*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
dst[ids_dst[j]*stride + i] = sum[(j0/tile_C::J + n)*tile_C::ne + l];
}
}
}
@@ -2518,17 +2519,18 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
};
template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
static __device__ void mul_mat_q_process_tile(
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int & ne00, const int & ne01, const int & stride01, const int & ne10, const int & ne11, const int & stride11, const int & ne0,
const int & it, const int & jt, const int & kb0_start, const int & kb0_stop) {
static __device__ __forceinline__ void mul_mat_q_process_tile(
const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int mmq_y = get_mmq_y_device();
constexpr load_tiles_mmq_t load_tiles = mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, type>::load_tiles;
extern __shared__ char data_mul_mat_q[];
int * tile_y = (int *) data_mul_mat_q;
extern __shared__ int data_mul_mat_q[];
int * tile_y = data_mul_mat_q + mmq_x;
int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE);
#ifdef NEW_MMA_AVAILABLE
@@ -2543,16 +2545,11 @@ static __device__ void mul_mat_q_process_tile(
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
const int tile_x_max_i = ne01 - it*mmq_y - 1;
const int tile_y_max_j = ne11 - jt*mmq_x - 1;
const int * y = (const int *) yc + jt*(mmq_x*sizeof(block_q8_1_mmq)/sizeof(int));
for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
load_tiles(x, tile_x, stride01*it*mmq_y + kb0, tile_x_max_i, stride01);
load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
{
const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
#pragma unroll
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2568,7 +2565,7 @@ static __device__ void mul_mat_q_process_tile(
__syncthreads();
{
const int * by0 = y + stride11*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
#pragma unroll
for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) {
int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x;
@@ -2585,12 +2582,10 @@ static __device__ void mul_mat_q_process_tile(
}
if (fixup) {
write_back(sum, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
write_back(sum, ids_dst, tmp_fixup + blockIdx.x*(mmq_x*mmq_y), mmq_y, mmq_y, mmq_x);
} else {
write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
write_back(sum, ids_dst, dst, stride_col_dst, tile_x_max_i, tile_y_max_j);
}
GGML_UNUSED(ne00); GGML_UNUSED(ne10);
}
@@ -2609,8 +2604,11 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
static __global__ void mul_mat_q(
const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int ne00, const int ne01, const int stride01, const int ne10, const int ne11, const int stride11, const int ne0) {
const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
const int ncols_x, const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
// Skip unused template specializations for faster compilation:
if (mmq_x > get_mmq_x_max_device() || mmq_x % mmq_get_granularity_device(mmq_x) != 0) {
@@ -2621,26 +2619,85 @@ static __global__ void mul_mat_q(
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int mmq_y = get_mmq_y_device();
const int ntx = (ncols_y + mmq_x - 1) / mmq_x; // Number of tiles x
const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
// Initialize the ids for writing back data with just the index.
// For regular matrix multiplications this is never changed.
// For MoE the correct indices are loaded from ids_dst.
extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory.
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
break;
}
ids_dst_shared[j] = j;
}
// On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
{
const int wt = blockIdx.z / nchannels_y;
const int zt = blockIdx.z - wt*nchannels_y;
const int jt = blockIdx.y;
const int it = blockIdx.x;
// Defaults for regular matrix multiplication:
int col_low = 0;
int col_high = ncols_y;
int col_diff = ncols_y;
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
if (ids_dst) {
col_low = expert_bounds[zt + 0];
col_high = expert_bounds[zt + 1];
col_diff = col_high - col_low;
offset_y = 0;
offset_dst = 0;
if (jt*mmq_x >= col_diff) {
return;
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
break;
}
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
}
}
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
offset_dst += it*mmq_y;
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
constexpr bool fixup = false;
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
blockIdx.x, blockIdx.y, 0, ne00/qk);
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
return;
}
#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
const int64_t blocks_per_ne00 = ne00 / qk;
const int64_t blocks_per_ne00 = ncols_x / qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
const int ntx = (ne11 + mmq_x - 1) / mmq_x; // Number of tiles x
const int nty = (ne01 + mmq_y - 1) / mmq_y; // Number of tiles y
// kbc == k block continuous, current index in continuous ijk space.
int64_t kbc = (int64_t) blockIdx.x *blocks_per_ne00*ntx*nty / gridDim.x;
int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*blocks_per_ne00*ntx*nty / gridDim.x;
int64_t kbc = (int64_t) blockIdx.x *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
int64_t kbc_stop = (int64_t)(blockIdx.x + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
kbc_stop -= (kbc_stop % blocks_per_ne00) % blocks_per_iter;
@@ -2649,13 +2706,64 @@ static __global__ void mul_mat_q(
int kb0_start = kbc % blocks_per_ne00;
int kb0_stop = min(blocks_per_ne00, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == blocks_per_ne00) {
const int jt = kbc / (blocks_per_ne00*nty); // j index of current tile.
const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00; // i index of current tile.
int tmp = kbc;
const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
const int zt = tmp / (ntx*blocks_per_ne00);
tmp -= zt * (ntx*blocks_per_ne00);
const int jt = tmp / blocks_per_ne00;
// Defaults for regular matrix multiplication:
int col_low = 0;
int col_high = ncols_y;
int col_diff = ncols_y;
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
if (ids_dst) {
col_low = expert_bounds[zt + 0];
col_high = expert_bounds[zt + 1];
col_diff = col_high - col_low;
offset_y = 0;
offset_dst = 0;
if (jt*mmq_x >= col_diff) {
kbc += blocks_per_ne00;
kbc -= kbc % blocks_per_ne00;
kb0_start = 0;
kb0_stop = min(blocks_per_ne00, kbc_stop - kbc);
continue;
}
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
break;
}
ids_dst_shared[j] = ids_dst[col_low + jt*mmq_x + j];
}
}
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
offset_dst += it*mmq_y;
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
it, jt, kb0_start, kb0_stop);
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
kbc += blocks_per_ne00;
kbc -= kbc % blocks_per_ne00;
@@ -2668,55 +2776,106 @@ static __global__ void mul_mat_q(
return;
}
const int jt = kbc / (blocks_per_ne00*nty);
const int it = (kbc - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
int tmp = kbc;
const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
const int zt = tmp / (ntx*blocks_per_ne00);
tmp -= zt * (ntx*blocks_per_ne00);
const int jt = tmp / blocks_per_ne00;
// Defaults for regular matrix multiplication:
int col_low = 0;
int col_high = ncols_y;
int col_diff = ncols_y;
int offset_y = wt*stride_sample_y + zt*stride_channel_y;
int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
if (ids_dst) {
col_low = expert_bounds[zt + 0];
col_high = expert_bounds[zt + 1];
col_diff = col_high - col_low;
offset_y = 0;
offset_dst = 0;
if (jt*mmq_x >= col_diff) {
return;
}
// The memory layout for the fixup buffer is always contiguous, therefore reset ids:
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) {
const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x;
if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) {
break;
}
ids_dst_shared[j] = j;
}
}
offset_y += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
offset_dst += it*mmq_y;
const int tile_x_max_i = nrows_x - it*mmq_y - 1;
const int tile_y_max_j = col_diff - jt*mmq_x - 1;
const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x;
constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
(x, yc, dst, tmp_fixup, ne00, ne01, stride01, ne10, ne11, stride11, ne0,
it, jt, kb0_start, kb0_stop);
(x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
}
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
static __global__ void mul_mat_q_stream_k_fixup(
float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ne00, const int ne01, const int ne11, const int ne0, const int block_num_mmq) {
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
const int ncols_x, const int nrows_x, const int ncols_y, const int stride_col_dst,
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int blocks_per_iter = MMQ_ITER_K / qk;
const int64_t blocks_per_ne00 = ne00 / qk;
const int64_t blocks_per_ne00 = ncols_x / qk;
float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
const int ntx = (ne11 + mmq_x - 1) / mmq_x;
const int nty = (ne01 + mmq_y - 1) / mmq_y;
const int ntx = (ncols_y + mmq_x - 1) / mmq_x;
const int nty = (nrows_x + mmq_y - 1) / mmq_y;
const int bidx0 = blockIdx.x;
// kbc == k block continuous, current index in continuous ijk space.
int64_t kbc0 = (int64_t) bidx0 *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
int64_t kbc0_stop = (int64_t)(bidx0 + 1)*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
kbc0 -= (kbc0 % blocks_per_ne00) % blocks_per_iter;
kbc0_stop -= (kbc0_stop % blocks_per_ne00) % blocks_per_iter;
const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % blocks_per_ne00 == 0;
const bool did_not_write_last = kbc0/blocks_per_ne00 == kbc0_stop/blocks_per_ne00 && kbc0_stop % blocks_per_ne00 != 0;
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
return;
}
bool any_fixup = false;
const int bidx_start = ((blockIdx.y*nty + blockIdx.x) * block_num_mmq) / (gridDim.y*gridDim.x);
const int bidx_stop = ((blockIdx.y*nty + blockIdx.x + 1) * block_num_mmq + gridDim.y*gridDim.x - 1) / (gridDim.y*gridDim.x);
// Iterate over previous blocks and sum up partial sums written to fixup buffer.
// All CUDA blocks that get here must have a previous block that needs a fixup.
int64_t bidx = bidx0 - 1;
int64_t kbc_stop = kbc0;
while(true) {
int64_t kbc = bidx*nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
kbc -= (kbc % blocks_per_ne00) % blocks_per_iter;
int64_t kbc_0;
int64_t kbc_stop_0 = (int64_t) bidx_start*blocks_per_ne00*ntx*nty / block_num_mmq;
for (int bidx = bidx_start; bidx < bidx_stop; ++bidx) {
kbc_0 = kbc_stop_0;
kbc_stop_0 = (int64_t) (bidx + 1)*blocks_per_ne00*ntx*nty / block_num_mmq;
const int64_t kbc = kbc_0 - (kbc_0 % blocks_per_ne00) % blocks_per_iter;
const int64_t kbc_stop = kbc_stop_0 - (kbc_stop_0 % blocks_per_ne00) % blocks_per_iter;
// Skip fixup tile if the MMQ CUDA block never wrote anything to it:
if (kbc == kbc_stop || kbc_stop % blocks_per_ne00 == 0) {
continue;
}
const int jt = kbc_stop / (blocks_per_ne00*nty);
const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
// Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
if ((unsigned)it != blockIdx.x || (unsigned)jt != blockIdx.y) {
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
continue;
}
@@ -2733,16 +2892,71 @@ static __global__ void mul_mat_q_stream_k_fixup(
sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i];
}
}
// If this block started in a previous tile we are done and don't need to combine additional partial results.
if (kbc % blocks_per_ne00 == 0 || kbc/blocks_per_ne00 < kbc0/blocks_per_ne00) {
break;
}
bidx--;
kbc_stop = kbc;
}
if (!any_fixup) {
return;
}
dst += blockIdx.y*mmq_x*ne0 + blockIdx.x*mmq_y;
int tmp = kbc0;
const int it = tmp / (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
tmp -= it * (nsamples_y*nchannels_y*ntx*blocks_per_ne00);
const int wt = tmp / (nchannels_y*ntx*blocks_per_ne00);
tmp -= wt * (nchannels_y*ntx*blocks_per_ne00);
const int zt = tmp / (ntx*blocks_per_ne00);
tmp -= zt * (ntx*blocks_per_ne00);
const int jt = tmp / blocks_per_ne00;
const int i_max = ne01 - blockIdx.x*mmq_y - 1;
const int j_max = ne11 - blockIdx.y*mmq_x - 1;
if (!ids_dst) {
const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
dst += offset_dst;
const int i_max = nrows_x - it*mmq_y - 1;
const int j_max = ncols_y - jt*mmq_x - 1;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j > j_max) {
return;
}
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (need_check && i > i_max) {
continue;
}
dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
}
}
return;
}
__shared__ int ids_dst_shared[mmq_x];
const int col_low = expert_bounds[zt + 0];
const int col_high = expert_bounds[zt + 1];
const int col_diff = col_high - col_low;
for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) {
ids_dst_shared[j] = ids_dst[col_low + j];
}
const int offset_dst = it*mmq_y;
dst += offset_dst;
const int i_max = nrows_x - it*mmq_y - 1;
const int j_max = col_diff - jt*mmq_x - 1;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
@@ -2760,26 +2974,27 @@ static __global__ void mul_mat_q_stream_k_fixup(
continue;
}
dst[j*ne0 + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE];
}
}
}
struct mmq_args {
const char * x; const char * y; float * dst;
int64_t ne00; int64_t ne01; int64_t stride01;
int64_t ne10; int64_t ne11; int64_t stride11;
int64_t ne0;
const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
int64_t ncols_x; int64_t nrows_x; int64_t ncols_y; int64_t stride_row_x; int64_t nrows_dst;
int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
bool use_stream_k;
};
template<ggml_type type>
static int mmq_get_shmem(const int mmq_x, const int mmq_y, const int cc) {
static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) {
const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y);
const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
const int shmem_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const int shmem_y = mmq_x*sizeof(block_q8_1_mmq);
return shmem_x + GGML_PAD(shmem_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
const size_t nbs_ids = mmq_x*sizeof(int);
const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int));
}
template <ggml_type type, int mmq_x>
@@ -2791,86 +3006,114 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1);
const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
const int nbytes_shared = mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc);
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shmem_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
shmem_limit_raised[id] = true;
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shared_memory_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes_shared));
shared_memory_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
const dim3 block_nums_xy_tiling(nty, ntx, 1);
const int nty = (args.nrows_x + mmq_y - 1) / mmq_y;
const int ntx = (args.ncols_y + mmq_x - 1) / mmq_x;
const int ntzw = args.nchannels_y * args.nsamples_y;
const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
GGML_ASSERT(args.nchannels_y % args.nchannels_x == 0);
GGML_ASSERT(args.nsamples_y % args.nsamples_x == 0);
const int channel_ratio = args.nchannels_y / args.nchannels_x;
const int sample_ratio = args.nsamples_y / args.nsamples_x;
if (!args.use_stream_k) {
if (args.ne01 % mmq_y == 0) {
if (args.nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
} else {
constexpr bool need_check = true;
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, nullptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
}
return;
}
const dim3 block_nums_mmq(nsm, 1, 1);
const dim3 block_nums_stream_k(nsm, 1, 1);
const bool fixup_needed = ntx*nty*ntzw % nsm != 0;
ggml_cuda_pool & pool = ctx.pool(id);
ggml_cuda_pool_alloc<float> tmp_fixup(pool, block_nums_mmq.x * mmq_x*mmq_y);
ggml_cuda_pool_alloc<float> tmp_fixup(pool);
if (fixup_needed) {
tmp_fixup.alloc(block_nums_stream_k.x * mmq_x*mmq_y);
}
if (args.ne01 % mmq_y == 0) {
if (args.nrows_x % mmq_y == 0) {
constexpr bool need_check = false;
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
(args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
if (!fixup_needed) {
return;
}
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
} else {
constexpr bool need_check = true;
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_mmq, block_dims, shmem, stream>>>
(args.x, args.y, args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.stride01, args.ne10, args.ne11, args.stride11, args.ne0);
mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
(args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, 0, stream>>>
(args.dst, tmp_fixup.ptr, args.ne00, args.ne01, args.ne11, args.ne0, block_nums_mmq.x);
if (!fixup_needed) {
return;
}
mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
(args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
}
}
template <ggml_type type>
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int smpbo = ggml_cuda_info().devices[id].smpbo;
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
const int mmq_x_max = get_mmq_x_max_host(cc);
const int mmq_y = get_mmq_y_host(cc);
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
int mmq_x_best = 0;
int nparts_best = INT_MAX;
int ntiles_x_best = INT_MAX;
for (int mmq_x = 8; mmq_x <= mmq_x_max && nparts_best > 1; mmq_x += 8) {
for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) {
const int granularity = mmq_get_granularity_host(mmq_x, cc);
if (mmq_x % granularity != 0 || mmq_get_shmem<type>(mmq_x, mmq_y, cc) > smpbo) {
if (mmq_x % granularity != 0 || mmq_get_nbytes_shared<type>(mmq_x, mmq_y, cc) > smpbo) {
continue;
}
const int ntiles_x = (args.ne11 + mmq_x - 1) / mmq_x;
const int nwaves_xy_tiling = ntiles_x*block_num_y;
const int nparts = use_stream_k ? ntiles_x : nwaves_xy_tiling;
const int ntiles_x = (args.ncols_y + mmq_x - 1) / mmq_x;
if (nparts < nparts_best) {
mmq_x_best = mmq_x;
nparts_best = nparts;
if (ntiles_x < ntiles_x_best) {
mmq_x_best = mmq_x;
ntiles_x_best = ntiles_x;
}
}
@@ -2954,6 +3197,9 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
// -------------------------------------------------------------------------------------------------------------------------
void ggml_cuda_mul_mat_q(
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_op_mul_mat_q(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,