mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CUDA: fastdiv, launch bounds for mmvq + q8_1 quant (#15802)
* CUDA: fastdiv, launch bounds for mmvq + q8_1 quant
This commit is contained in:
		| @@ -570,6 +570,8 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { | |||||||
| // | // | ||||||
| // n/d = (mulhi(n, mp) + n) >> L; | // n/d = (mulhi(n, mp) + n) >> L; | ||||||
| static const uint3 init_fastdiv_values(uint32_t d) { | static const uint3 init_fastdiv_values(uint32_t d) { | ||||||
|  |     GGML_ASSERT(d != 0); | ||||||
|  |  | ||||||
|     // compute L = ceil(log2(d)); |     // compute L = ceil(log2(d)); | ||||||
|     uint32_t L = 0; |     uint32_t L = 0; | ||||||
|     while (L < 32 && (uint32_t{ 1 } << L) < d) { |     while (L < 32 && (uint32_t{ 1 } << L) < d) { | ||||||
|   | |||||||
| @@ -141,9 +141,10 @@ template <ggml_type type, int ncols_dst> | |||||||
| __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) | __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) | ||||||
| static __global__ void mul_mat_vec_q( | static __global__ void mul_mat_vec_q( | ||||||
|         const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, |         const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, | ||||||
|         const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst, |         const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, | ||||||
|         const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, |         const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, | ||||||
|         const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { |         const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, | ||||||
|  |         const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { | ||||||
|  |  | ||||||
|     constexpr int qk  = ggml_cuda_type_traits<type>::qk; |     constexpr int qk  = ggml_cuda_type_traits<type>::qk; | ||||||
|     constexpr int qi  = ggml_cuda_type_traits<type>::qi; |     constexpr int qi  = ggml_cuda_type_traits<type>::qi; | ||||||
| @@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q( | |||||||
|     constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; |     constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; | ||||||
|  |  | ||||||
|     // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. |     // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. | ||||||
|     const int channel_dst = blockIdx.y; |     const uint32_t channel_dst = blockIdx.y; | ||||||
|     const int channel_x   = ncols_dst == 1 && ids ? ids[channel_dst]          : channel_dst / channel_ratio; |     const uint32_t channel_x   = ncols_dst == 1 && ids ? ids[channel_dst]                     : fastdiv(channel_dst, channel_ratio); | ||||||
|     const int channel_y   = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst; |     const uint32_t channel_y   = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; | ||||||
|     const int sample_dst  = blockIdx.z; |     const uint32_t sample_dst  = blockIdx.z; | ||||||
|     const int sample_x    = sample_dst / sample_ratio; |     const uint32_t sample_x    = fastdiv(sample_dst, sample_ratio); | ||||||
|     const int sample_y    = sample_dst; |     const uint32_t sample_y    = sample_dst; | ||||||
|  |  | ||||||
|     // partial sum for each thread |     // partial sum for each thread | ||||||
|     float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; |     float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; | ||||||
| @@ -247,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst( | |||||||
|     GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); |     GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); | ||||||
|     GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); |     GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); | ||||||
|  |  | ||||||
|     const int channel_ratio = nchannels_dst / nchannels_x; |     const uint3 nchannels_y_fd   = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0); | ||||||
|     const int sample_ratio  = nsamples_dst  / nsamples_x; |     const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0)              : init_fastdiv_values(nchannels_dst / nchannels_x); | ||||||
|  |     const uint3 sample_ratio_fd  = init_fastdiv_values(nsamples_dst  / nsamples_x); | ||||||
|  |  | ||||||
|     const int device = ggml_cuda_get_device(); |     const int device = ggml_cuda_get_device(); | ||||||
|     const int warp_size = ggml_cuda_info().devices[device].warp_size; |     const int warp_size = ggml_cuda_info().devices[device].warp_size; | ||||||
| @@ -256,86 +258,70 @@ static void mul_mat_vec_q_switch_ncols_dst( | |||||||
|  |  | ||||||
|     GGML_ASSERT(!ids || ncols_dst == 1); |     GGML_ASSERT(!ids || ncols_dst == 1); | ||||||
|     switch (ncols_dst) { |     switch (ncols_dst) { | ||||||
|         case 1: |         case 1: { | ||||||
|         { |  | ||||||
|             constexpr int c_ncols_dst = 1; |             constexpr int c_ncols_dst = 1; | ||||||
|             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); |             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); | ||||||
|             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> |             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> | ||||||
|                 (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, |                 (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, | ||||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); |                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||||
|             break; |         } break; | ||||||
|         } |         case 2: { | ||||||
|         case 2: |  | ||||||
|         { |  | ||||||
|             constexpr int c_ncols_dst = 2; |             constexpr int c_ncols_dst = 2; | ||||||
|             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); |             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); | ||||||
|             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> |             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> | ||||||
|                 (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, |                 (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, | ||||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); |                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||||
|             break; |         } break; | ||||||
|         } |         case 3: { | ||||||
|         case 3: |  | ||||||
|         { |  | ||||||
|             constexpr int c_ncols_dst = 3; |             constexpr int c_ncols_dst = 3; | ||||||
|             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); |             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); | ||||||
|             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> |             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> | ||||||
|                 (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, |                 (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, | ||||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); |                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||||
|             break; |         } break; | ||||||
|         } |         case 4: { | ||||||
|         case 4: |  | ||||||
|         { |  | ||||||
|             constexpr int c_ncols_dst = 4; |             constexpr int c_ncols_dst = 4; | ||||||
|             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); |             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); | ||||||
|             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> |             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> | ||||||
|                 (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, |                 (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, | ||||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); |                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||||
|             break; |         } break; | ||||||
|         } |         case 5: { | ||||||
|         case 5: |  | ||||||
|         { |  | ||||||
|             constexpr int c_ncols_dst = 5; |             constexpr int c_ncols_dst = 5; | ||||||
|             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); |             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); | ||||||
|             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> |             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> | ||||||
|                 (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, |                 (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, | ||||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); |                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||||
|             break; |         } break; | ||||||
|         } |         case 6: { | ||||||
|         case 6: |  | ||||||
|         { |  | ||||||
|             constexpr int c_ncols_dst = 6; |             constexpr int c_ncols_dst = 6; | ||||||
|             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); |             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); | ||||||
|             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> |             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> | ||||||
|                 (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, |                 (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, | ||||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); |                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||||
|             break; |         } break; | ||||||
|         } |         case 7: { | ||||||
|         case 7: |  | ||||||
|         { |  | ||||||
|             constexpr int c_ncols_dst = 7; |             constexpr int c_ncols_dst = 7; | ||||||
|             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); |             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); | ||||||
|             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> |             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> | ||||||
|                 (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, |                 (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, | ||||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); |                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||||
|             break; |         } break; | ||||||
|         } |         case 8: { | ||||||
|         case 8: |  | ||||||
|         { |  | ||||||
|             constexpr int c_ncols_dst = 8; |             constexpr int c_ncols_dst = 8; | ||||||
|             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); |             std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); | ||||||
|             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> |             mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> | ||||||
|                 (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, |                 (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, | ||||||
|                  channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, |                  channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, | ||||||
|                  sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); |                  sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); | ||||||
|             break; |         } break; | ||||||
|         } |  | ||||||
|         default: |         default: | ||||||
|             GGML_ABORT("fatal error"); |             GGML_ABORT("fatal error"); | ||||||
|             break; |             break; | ||||||
|   | |||||||
| @@ -1,26 +1,27 @@ | |||||||
| #include "quantize.cuh" | #include "quantize.cuh" | ||||||
| #include <cstdint> | #include <cstdint> | ||||||
|  |  | ||||||
|  | __launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1) | ||||||
| static __global__ void quantize_q8_1( | static __global__ void quantize_q8_1( | ||||||
|         const float * __restrict__ x, void * __restrict__ vy, |         const float * __restrict__ x, void * __restrict__ vy, | ||||||
|         const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, |         const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, | ||||||
|         const int64_t ne0, const int ne1, const int ne2) { |         const int64_t ne0, const uint32_t ne1, const uint3 ne2) { | ||||||
|     const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; |     const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; | ||||||
|  |  | ||||||
|     if (i0 >= ne0) { |     if (i0 >= ne0) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     const int64_t i3 = fastdiv(blockIdx.z, ne2); | ||||||
|  |     const int64_t i2 = blockIdx.z - i3*ne2.z; | ||||||
|     const int64_t i1 = blockIdx.y; |     const int64_t i1 = blockIdx.y; | ||||||
|     const int64_t i2 = blockIdx.z % ne2; |  | ||||||
|     const int64_t i3 = blockIdx.z / ne2; |  | ||||||
|  |  | ||||||
|     const int64_t & i00 = i0; |     const int64_t & i00 = i0; | ||||||
|     const int64_t & i01 = i1; |     const int64_t & i01 = i1; | ||||||
|     const int64_t & i02 = i2; |     const int64_t & i02 = i2; | ||||||
|     const int64_t & i03 = i3; |     const int64_t & i03 = i3; | ||||||
|  |  | ||||||
|     const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0; |     const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0; | ||||||
|  |  | ||||||
|     block_q8_1 * y = (block_q8_1 *) vy; |     block_q8_1 * y = (block_q8_1 *) vy; | ||||||
|  |  | ||||||
| @@ -31,10 +32,10 @@ static __global__ void quantize_q8_1( | |||||||
|     float amax = fabsf(xi); |     float amax = fabsf(xi); | ||||||
|     float sum = xi; |     float sum = xi; | ||||||
|  |  | ||||||
|     amax = warp_reduce_max(amax); |     amax = warp_reduce_max<QK8_1>(amax); | ||||||
|     sum  = warp_reduce_sum(sum); |     sum  = warp_reduce_sum<QK8_1>(sum); | ||||||
|  |  | ||||||
|     const float  d = amax / 127; |     const float  d = amax / 127.0f; | ||||||
|     const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); |     const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); | ||||||
|  |  | ||||||
|     y[ib].qs[iqs] = q; |     y[ib].qs[iqs] = q; | ||||||
| @@ -43,8 +44,7 @@ static __global__ void quantize_q8_1( | |||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     reinterpret_cast<half&>(y[ib].ds.x) = d; |     y[ib].ds = make_half2(d, sum); | ||||||
|     reinterpret_cast<half&>(y[ib].ds.y) = sum; |  | ||||||
| } | } | ||||||
|  |  | ||||||
| template <mmq_q8_1_ds_layout ds_layout> | template <mmq_q8_1_ds_layout ds_layout> | ||||||
| @@ -152,10 +152,12 @@ void quantize_row_q8_1_cuda( | |||||||
|     GGML_ASSERT(!ids); |     GGML_ASSERT(!ids); | ||||||
|     GGML_ASSERT(ne0 % QK8_1 == 0); |     GGML_ASSERT(ne0 % QK8_1 == 0); | ||||||
|  |  | ||||||
|  |     const uint3 ne2_fastdiv = init_fastdiv_values(ne2); | ||||||
|  |  | ||||||
|     const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; |     const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; | ||||||
|     const dim3 num_blocks(block_num_x, ne1, ne2*ne3); |     const dim3 num_blocks(block_num_x, ne1, ne2*ne3); | ||||||
|     const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); |     const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); | ||||||
|     quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2); |     quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv); | ||||||
|     GGML_UNUSED(type_src0); |     GGML_UNUSED(type_src0); | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler