mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	CUDA: fastdiv, launch bounds for mmvq + q8_1 quant (#15802)
* CUDA: fastdiv, launch bounds for mmvq + q8_1 quant
This commit is contained in:
		@@ -570,6 +570,8 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
 | 
			
		||||
//
 | 
			
		||||
// n/d = (mulhi(n, mp) + n) >> L;
 | 
			
		||||
static const uint3 init_fastdiv_values(uint32_t d) {
 | 
			
		||||
    GGML_ASSERT(d != 0);
 | 
			
		||||
 | 
			
		||||
    // compute L = ceil(log2(d));
 | 
			
		||||
    uint32_t L = 0;
 | 
			
		||||
    while (L < 32 && (uint32_t{ 1 } << L) < d) {
 | 
			
		||||
 
 | 
			
		||||
@@ -141,9 +141,10 @@ template <ggml_type type, int ncols_dst>
 | 
			
		||||
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
 | 
			
		||||
static __global__ void mul_mat_vec_q(
 | 
			
		||||
        const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
 | 
			
		||||
        const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
 | 
			
		||||
        const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
 | 
			
		||||
        const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
 | 
			
		||||
        const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
 | 
			
		||||
        const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
 | 
			
		||||
        const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
 | 
			
		||||
        const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
 | 
			
		||||
 | 
			
		||||
    constexpr int qk  = ggml_cuda_type_traits<type>::qk;
 | 
			
		||||
    constexpr int qi  = ggml_cuda_type_traits<type>::qi;
 | 
			
		||||
@@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q(
 | 
			
		||||
    constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
 | 
			
		||||
 | 
			
		||||
    // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
 | 
			
		||||
    const int channel_dst = blockIdx.y;
 | 
			
		||||
    const int channel_x   = ncols_dst == 1 && ids ? ids[channel_dst]          : channel_dst / channel_ratio;
 | 
			
		||||
    const int channel_y   = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
 | 
			
		||||
    const int sample_dst  = blockIdx.z;
 | 
			
		||||
    const int sample_x    = sample_dst / sample_ratio;
 | 
			
		||||
    const int sample_y    = sample_dst;
 | 
			
		||||
    const uint32_t channel_dst = blockIdx.y;
 | 
			
		||||
    const uint32_t channel_x   = ncols_dst == 1 && ids ? ids[channel_dst]                     : fastdiv(channel_dst, channel_ratio);
 | 
			
		||||
    const uint32_t channel_y   = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
 | 
			
		||||
    const uint32_t sample_dst  = blockIdx.z;
 | 
			
		||||
    const uint32_t sample_x    = fastdiv(sample_dst, sample_ratio);
 | 
			
		||||
    const uint32_t sample_y    = sample_dst;
 | 
			
		||||
 | 
			
		||||
    // partial sum for each thread
 | 
			
		||||
    float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
 | 
			
		||||
@@ -247,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst(
 | 
			
		||||
    GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
 | 
			
		||||
    GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
 | 
			
		||||
 | 
			
		||||
    const int channel_ratio = nchannels_dst / nchannels_x;
 | 
			
		||||
    const int sample_ratio  = nsamples_dst  / nsamples_x;
 | 
			
		||||
    const uint3 nchannels_y_fd   = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
 | 
			
		||||
    const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0)              : init_fastdiv_values(nchannels_dst / nchannels_x);
 | 
			
		||||
    const uint3 sample_ratio_fd  = init_fastdiv_values(nsamples_dst  / nsamples_x);
 | 
			
		||||
 | 
			
		||||
    const int device = ggml_cuda_get_device();
 | 
			
		||||
    const int warp_size = ggml_cuda_info().devices[device].warp_size;
 | 
			
		||||
@@ -256,86 +258,70 @@ static void mul_mat_vec_q_switch_ncols_dst(
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(!ids || ncols_dst == 1);
 | 
			
		||||
    switch (ncols_dst) {
 | 
			
		||||
        case 1:
 | 
			
		||||
        {
 | 
			
		||||
        case 1: {
 | 
			
		||||
            constexpr int c_ncols_dst = 1;
 | 
			
		||||
            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
 | 
			
		||||
            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
        case 2:
 | 
			
		||||
        {
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
        } break;
 | 
			
		||||
        case 2: {
 | 
			
		||||
            constexpr int c_ncols_dst = 2;
 | 
			
		||||
            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
 | 
			
		||||
            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
        case 3:
 | 
			
		||||
        {
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
        } break;
 | 
			
		||||
        case 3: {
 | 
			
		||||
            constexpr int c_ncols_dst = 3;
 | 
			
		||||
            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
 | 
			
		||||
            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
        case 4:
 | 
			
		||||
        {
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
        } break;
 | 
			
		||||
        case 4: {
 | 
			
		||||
            constexpr int c_ncols_dst = 4;
 | 
			
		||||
            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
 | 
			
		||||
            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
        case 5:
 | 
			
		||||
        {
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
        } break;
 | 
			
		||||
        case 5: {
 | 
			
		||||
            constexpr int c_ncols_dst = 5;
 | 
			
		||||
            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
 | 
			
		||||
            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
        case 6:
 | 
			
		||||
        {
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
        } break;
 | 
			
		||||
        case 6: {
 | 
			
		||||
            constexpr int c_ncols_dst = 6;
 | 
			
		||||
            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
 | 
			
		||||
            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
        case 7:
 | 
			
		||||
        {
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
        } break;
 | 
			
		||||
        case 7: {
 | 
			
		||||
            constexpr int c_ncols_dst = 7;
 | 
			
		||||
            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
 | 
			
		||||
            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
        case 8:
 | 
			
		||||
        {
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
        } break;
 | 
			
		||||
        case 8: {
 | 
			
		||||
            constexpr int c_ncols_dst = 8;
 | 
			
		||||
            std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
 | 
			
		||||
            mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
            break;
 | 
			
		||||
        }
 | 
			
		||||
                (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
 | 
			
		||||
                 channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
 | 
			
		||||
                 sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
 | 
			
		||||
        } break;
 | 
			
		||||
        default:
 | 
			
		||||
            GGML_ABORT("fatal error");
 | 
			
		||||
            break;
 | 
			
		||||
 
 | 
			
		||||
@@ -1,26 +1,27 @@
 | 
			
		||||
#include "quantize.cuh"
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
 | 
			
		||||
__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1)
 | 
			
		||||
static __global__ void quantize_q8_1(
 | 
			
		||||
        const float * __restrict__ x, void * __restrict__ vy,
 | 
			
		||||
        const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
 | 
			
		||||
        const int64_t ne0, const int ne1, const int ne2) {
 | 
			
		||||
        const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
 | 
			
		||||
    const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
 | 
			
		||||
 | 
			
		||||
    if (i0 >= ne0) {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const int64_t i3 = fastdiv(blockIdx.z, ne2);
 | 
			
		||||
    const int64_t i2 = blockIdx.z - i3*ne2.z;
 | 
			
		||||
    const int64_t i1 = blockIdx.y;
 | 
			
		||||
    const int64_t i2 = blockIdx.z % ne2;
 | 
			
		||||
    const int64_t i3 = blockIdx.z / ne2;
 | 
			
		||||
 | 
			
		||||
    const int64_t & i00 = i0;
 | 
			
		||||
    const int64_t & i01 = i1;
 | 
			
		||||
    const int64_t & i02 = i2;
 | 
			
		||||
    const int64_t & i03 = i3;
 | 
			
		||||
 | 
			
		||||
    const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0;
 | 
			
		||||
    const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0;
 | 
			
		||||
 | 
			
		||||
    block_q8_1 * y = (block_q8_1 *) vy;
 | 
			
		||||
 | 
			
		||||
@@ -31,10 +32,10 @@ static __global__ void quantize_q8_1(
 | 
			
		||||
    float amax = fabsf(xi);
 | 
			
		||||
    float sum = xi;
 | 
			
		||||
 | 
			
		||||
    amax = warp_reduce_max(amax);
 | 
			
		||||
    sum  = warp_reduce_sum(sum);
 | 
			
		||||
    amax = warp_reduce_max<QK8_1>(amax);
 | 
			
		||||
    sum  = warp_reduce_sum<QK8_1>(sum);
 | 
			
		||||
 | 
			
		||||
    const float  d = amax / 127;
 | 
			
		||||
    const float  d = amax / 127.0f;
 | 
			
		||||
    const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
 | 
			
		||||
 | 
			
		||||
    y[ib].qs[iqs] = q;
 | 
			
		||||
@@ -43,8 +44,7 @@ static __global__ void quantize_q8_1(
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    reinterpret_cast<half&>(y[ib].ds.x) = d;
 | 
			
		||||
    reinterpret_cast<half&>(y[ib].ds.y) = sum;
 | 
			
		||||
    y[ib].ds = make_half2(d, sum);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <mmq_q8_1_ds_layout ds_layout>
 | 
			
		||||
@@ -152,10 +152,12 @@ void quantize_row_q8_1_cuda(
 | 
			
		||||
    GGML_ASSERT(!ids);
 | 
			
		||||
    GGML_ASSERT(ne0 % QK8_1 == 0);
 | 
			
		||||
 | 
			
		||||
    const uint3 ne2_fastdiv = init_fastdiv_values(ne2);
 | 
			
		||||
 | 
			
		||||
    const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
 | 
			
		||||
    const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
 | 
			
		||||
    const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
 | 
			
		||||
    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
 | 
			
		||||
    quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
 | 
			
		||||
    GGML_UNUSED(type_src0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user