mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	CUDA: fix --split-mode row for MMQ (#13323)
This commit is contained in:
		@@ -128,7 +128,7 @@ void ggml_cuda_mul_mat_q(
 | 
			
		||||
 | 
			
		||||
        const mmq_args args = {
 | 
			
		||||
            src0_d, src0->type, (const int *) src1_q8_1.ptr, nullptr, nullptr, dst_d,
 | 
			
		||||
            ne00, ne01, ne1, s01, s1,
 | 
			
		||||
            ne00, ne01, ne1, s01, ne11, s1,
 | 
			
		||||
            ne02, ne12, s02, s12, s2,
 | 
			
		||||
            ne03, ne13, s03, s13, s3,
 | 
			
		||||
            use_stream_k};
 | 
			
		||||
@@ -212,7 +212,7 @@ void ggml_cuda_mul_mat_q(
 | 
			
		||||
    // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
 | 
			
		||||
    const mmq_args args = {
 | 
			
		||||
        src0_d, src0->type, (const int *) src1_q8_1.ptr, ids_dst_dev, expert_bounds_dev, dst_d,
 | 
			
		||||
        ne00, ne01, ne_get_rows, s01, s1,
 | 
			
		||||
        ne00, ne01, ne_get_rows, s01, ne_get_rows, s1,
 | 
			
		||||
        ne02, ne02, s02, s12, s2,
 | 
			
		||||
        ne03, ne13, s03, s13, s3,
 | 
			
		||||
        use_stream_k};
 | 
			
		||||
@@ -251,7 +251,7 @@ void ggml_cuda_op_mul_mat_q(
 | 
			
		||||
        ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
 | 
			
		||||
    const mmq_args args = {
 | 
			
		||||
        src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
 | 
			
		||||
        ne00, row_diff, src1_ncols, stride01, nrows_dst,
 | 
			
		||||
        ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
 | 
			
		||||
        1, 1, 0, 0, 0,
 | 
			
		||||
        1, 1, 0, 0, 0,
 | 
			
		||||
        use_stream_k};
 | 
			
		||||
 
 | 
			
		||||
@@ -2522,7 +2522,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check, bool fixup>
 | 
			
		||||
static __device__ __forceinline__ void mul_mat_q_process_tile(
 | 
			
		||||
        const char * __restrict__ x, const int offset_x, const int * __restrict__ y,
 | 
			
		||||
        const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup,
 | 
			
		||||
        const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
 | 
			
		||||
        const int nrows_x, const int stride_row_x, const int ncols_y, const int stride_col_dst,
 | 
			
		||||
        const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) {
 | 
			
		||||
 | 
			
		||||
    constexpr int              qk         = ggml_cuda_type_traits<type>::qk;
 | 
			
		||||
@@ -2606,7 +2606,7 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
 | 
			
		||||
static __global__ void mul_mat_q(
 | 
			
		||||
        const char * __restrict__ x, const int * __restrict__ y, const int32_t * __restrict__ ids_dst,
 | 
			
		||||
        const int32_t * __restrict__ expert_bounds, float * __restrict__ dst, float * __restrict__ tmp_fixup,
 | 
			
		||||
        const int ncols_x, const int nrows_x, const int ncols_y, const int stride_row_x, const int stride_col_dst,
 | 
			
		||||
        const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_row_x, const int ncols_y, const int stride_col_dst,
 | 
			
		||||
        const int channel_ratio, const int nchannels_y, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
 | 
			
		||||
        const int sample_ratio, const int nsamples_y, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
 | 
			
		||||
 | 
			
		||||
@@ -2619,8 +2619,8 @@ static __global__ void mul_mat_q(
 | 
			
		||||
    constexpr int qk    = ggml_cuda_type_traits<type>::qk;
 | 
			
		||||
    constexpr int mmq_y = get_mmq_y_device();
 | 
			
		||||
 | 
			
		||||
    const int ntx = (ncols_y + mmq_x - 1) / mmq_x; // Number of tiles x
 | 
			
		||||
    const int nty = (nrows_x + mmq_y - 1) / mmq_y; // Number of tiles y
 | 
			
		||||
    const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; // Number of tiles x
 | 
			
		||||
    const int nty = (nrows_x   + mmq_y - 1) / mmq_y; // Number of tiles y
 | 
			
		||||
 | 
			
		||||
    // Initialize the ids for writing back data with just the index.
 | 
			
		||||
    // For regular matrix multiplications this is never changed.
 | 
			
		||||
@@ -2648,8 +2648,8 @@ static __global__ void mul_mat_q(
 | 
			
		||||
 | 
			
		||||
        // Defaults for regular matrix multiplication:
 | 
			
		||||
        int col_low    = 0;
 | 
			
		||||
        int col_high   = ncols_y;
 | 
			
		||||
        int col_diff   = ncols_y;
 | 
			
		||||
        int col_high   = ncols_dst;
 | 
			
		||||
        int col_diff   = ncols_dst;
 | 
			
		||||
        int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;
 | 
			
		||||
        int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
 | 
			
		||||
 | 
			
		||||
@@ -2689,7 +2689,7 @@ static __global__ void mul_mat_q(
 | 
			
		||||
 | 
			
		||||
        constexpr bool fixup = false;
 | 
			
		||||
        mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
 | 
			
		||||
            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
 | 
			
		||||
            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
 | 
			
		||||
             tile_x_max_i, tile_y_max_j, 0, ncols_x/qk);
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
@@ -2720,8 +2720,8 @@ static __global__ void mul_mat_q(
 | 
			
		||||
 | 
			
		||||
        // Defaults for regular matrix multiplication:
 | 
			
		||||
        int col_low    = 0;
 | 
			
		||||
        int col_high   = ncols_y;
 | 
			
		||||
        int col_diff   = ncols_y;
 | 
			
		||||
        int col_high   = ncols_dst;
 | 
			
		||||
        int col_diff   = ncols_dst;
 | 
			
		||||
        int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;
 | 
			
		||||
        int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
 | 
			
		||||
 | 
			
		||||
@@ -2767,7 +2767,7 @@ static __global__ void mul_mat_q(
 | 
			
		||||
 | 
			
		||||
        constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
 | 
			
		||||
        mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
 | 
			
		||||
            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
 | 
			
		||||
            (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
 | 
			
		||||
             tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
 | 
			
		||||
 | 
			
		||||
        kbc += blocks_per_ne00;
 | 
			
		||||
@@ -2792,8 +2792,8 @@ static __global__ void mul_mat_q(
 | 
			
		||||
 | 
			
		||||
    // Defaults for regular matrix multiplication:
 | 
			
		||||
    int col_low    = 0;
 | 
			
		||||
    int col_high   = ncols_y;
 | 
			
		||||
    int col_diff   = ncols_y;
 | 
			
		||||
    int col_high   = ncols_dst;
 | 
			
		||||
    int col_diff   = ncols_dst;
 | 
			
		||||
    int offset_y   = wt*stride_sample_y   + zt*stride_channel_y;
 | 
			
		||||
    int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst;
 | 
			
		||||
 | 
			
		||||
@@ -2834,7 +2834,7 @@ static __global__ void mul_mat_q(
 | 
			
		||||
 | 
			
		||||
    constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
 | 
			
		||||
    mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
 | 
			
		||||
        (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, ncols_y, stride_row_x, stride_col_dst,
 | 
			
		||||
        (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, nrows_x, stride_row_x, ncols_y, stride_col_dst,
 | 
			
		||||
         tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -2842,7 +2842,7 @@ static __global__ void mul_mat_q(
 | 
			
		||||
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
 | 
			
		||||
static __global__ void mul_mat_q_stream_k_fixup(
 | 
			
		||||
        const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile,
 | 
			
		||||
        const int ncols_x, const int nrows_x, const int ncols_y, const int stride_col_dst,
 | 
			
		||||
        const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst,
 | 
			
		||||
        const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst) {
 | 
			
		||||
    constexpr int     mmq_y           = get_mmq_y_device();
 | 
			
		||||
    constexpr int     qk              = ggml_cuda_type_traits<type>::qk;
 | 
			
		||||
@@ -2851,8 +2851,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
 | 
			
		||||
 | 
			
		||||
    float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f};
 | 
			
		||||
 | 
			
		||||
    const int ntx  = (ncols_y + mmq_x - 1) / mmq_x;
 | 
			
		||||
    const int nty  = (nrows_x + mmq_y - 1) / mmq_y;
 | 
			
		||||
    const int ntx  = (ncols_dst + mmq_x - 1) / mmq_x;
 | 
			
		||||
    const int nty  = (nrows_x   + mmq_y - 1) / mmq_y;
 | 
			
		||||
 | 
			
		||||
    const int bidx0 = blockIdx.x;
 | 
			
		||||
 | 
			
		||||
@@ -2925,8 +2925,8 @@ static __global__ void mul_mat_q_stream_k_fixup(
 | 
			
		||||
        const int offset_dst = wt*stride_sample_dst + zt*stride_channel_dst + jt*mmq_x*stride_col_dst + it*mmq_y;
 | 
			
		||||
        dst += offset_dst;
 | 
			
		||||
 | 
			
		||||
        const int i_max = nrows_x - it*mmq_y - 1;
 | 
			
		||||
        const int j_max = ncols_y - jt*mmq_x - 1;
 | 
			
		||||
        const int i_max = nrows_x   - it*mmq_y - 1;
 | 
			
		||||
        const int j_max = ncols_dst - jt*mmq_x - 1;
 | 
			
		||||
 | 
			
		||||
#pragma unroll
 | 
			
		||||
        for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
 | 
			
		||||
@@ -2989,7 +2989,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
 | 
			
		||||
 | 
			
		||||
struct mmq_args {
 | 
			
		||||
    const char * x; ggml_type type_x; const int * y; const int32_t * ids_dst; const int32_t * expert_bounds; float * dst;
 | 
			
		||||
    int64_t ncols_x; int64_t nrows_x; int64_t ncols_y; int64_t stride_row_x; int64_t nrows_dst;
 | 
			
		||||
    int64_t ncols_x; int64_t nrows_x; int64_t ncols_dst; int64_t stride_row_x; int64_t ncols_y; int64_t nrows_dst;
 | 
			
		||||
    int64_t nchannels_x; int64_t nchannels_y; int64_t stride_channel_x; int64_t stride_channel_y; int64_t stride_channel_dst;
 | 
			
		||||
    int64_t nsamples_x; int64_t nsamples_y; int64_t stride_sample_x; int64_t stride_sample_y; int64_t stride_sample_dst;
 | 
			
		||||
    bool use_stream_k;
 | 
			
		||||
@@ -3025,8 +3025,8 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 | 
			
		||||
    }
 | 
			
		||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
 | 
			
		||||
 | 
			
		||||
    const int nty  = (args.nrows_x + mmq_y - 1) / mmq_y;
 | 
			
		||||
    const int ntx  = (args.ncols_y + mmq_x - 1) / mmq_x;
 | 
			
		||||
    const int nty  = (args.nrows_x   + mmq_y - 1) / mmq_y;
 | 
			
		||||
    const int ntx  = (args.ncols_dst + mmq_x - 1) / mmq_x;
 | 
			
		||||
    const int ntzw = args.nchannels_y * args.nsamples_y;
 | 
			
		||||
    const dim3 block_nums_xy_tiling(nty, ntx, ntzw);
 | 
			
		||||
 | 
			
		||||
@@ -3040,14 +3040,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 | 
			
		||||
            constexpr bool need_check = false;
 | 
			
		||||
            mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
 | 
			
		||||
                (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
 | 
			
		||||
                 args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
 | 
			
		||||
                 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
 | 
			
		||||
                 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
 | 
			
		||||
                 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
 | 
			
		||||
        } else {
 | 
			
		||||
            constexpr bool need_check = true;
 | 
			
		||||
            mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_xy_tiling, block_dims, nbytes_shared, stream>>>
 | 
			
		||||
                (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr,
 | 
			
		||||
                 args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
 | 
			
		||||
                 args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
 | 
			
		||||
                 channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
 | 
			
		||||
                 sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
 | 
			
		||||
        }
 | 
			
		||||
@@ -3068,7 +3068,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 | 
			
		||||
 | 
			
		||||
        mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
 | 
			
		||||
            (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
 | 
			
		||||
             args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
 | 
			
		||||
             args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
 | 
			
		||||
             channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
 | 
			
		||||
             sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
 | 
			
		||||
 | 
			
		||||
@@ -3077,14 +3077,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
 | 
			
		||||
            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
 | 
			
		||||
            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
 | 
			
		||||
             args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
 | 
			
		||||
    } else {
 | 
			
		||||
        constexpr bool need_check = true;
 | 
			
		||||
 | 
			
		||||
        mul_mat_q<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, nbytes_shared, stream>>>
 | 
			
		||||
            (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr,
 | 
			
		||||
             args.ncols_x, args.nrows_x, args.ncols_y, args.stride_row_x, args.nrows_dst,
 | 
			
		||||
             args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst,
 | 
			
		||||
             channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst,
 | 
			
		||||
             sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst);
 | 
			
		||||
 | 
			
		||||
@@ -3093,7 +3093,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        mul_mat_q_stream_k_fixup<type, mmq_x, MMQ_NWARPS, need_check><<<block_nums_stream_k, block_dims, 0, stream>>>
 | 
			
		||||
            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_y,
 | 
			
		||||
            (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst,
 | 
			
		||||
             args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user