mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	CUDA: use fastdiv in set-rows (#16834)
* CUDA: use fastdiv in set-rows * add assert about value fitting in u32
This commit is contained in:
		@@ -625,8 +625,11 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
 | 
				
			|||||||
// and a shift:
 | 
					// and a shift:
 | 
				
			||||||
//
 | 
					//
 | 
				
			||||||
// n/d = (mulhi(n, mp) + n) >> L;
 | 
					// n/d = (mulhi(n, mp) + n) >> L;
 | 
				
			||||||
static const uint3 init_fastdiv_values(uint32_t d) {
 | 
					static const uint3 init_fastdiv_values(uint64_t d_64) {
 | 
				
			||||||
    GGML_ASSERT(d != 0);
 | 
					    GGML_ASSERT(d_64 != 0);
 | 
				
			||||||
 | 
					    GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    uint32_t d = (uint32_t)d_64;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // compute L = ceil(log2(d));
 | 
					    // compute L = ceil(log2(d));
 | 
				
			||||||
    uint32_t L = 0;
 | 
					    uint32_t L = 0;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -4,30 +4,53 @@
 | 
				
			|||||||
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
 | 
					typedef void (*set_rows_kernel_t)(const char * src, char * dst);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Generic quantized set_rows kernel template
 | 
					// Generic quantized set_rows kernel template
 | 
				
			||||||
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
 | 
					template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
 | 
				
			||||||
static __global__ void k_set_rows_quant(
 | 
					static __global__ void k_set_rows_quant(const float * __restrict__ src0,
 | 
				
			||||||
        const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
 | 
					                                        const idx_t * __restrict__ src1,
 | 
				
			||||||
        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
 | 
					                                        block_type * __restrict__ dst,
 | 
				
			||||||
        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
 | 
					                                        const int64_t ne_total,
 | 
				
			||||||
        const int64_t s01, const int64_t s02, const int64_t s03,
 | 
					                                        const int64_t ne10,
 | 
				
			||||||
        const int64_t s10, const int64_t s11, const int64_t s12,
 | 
					                                        const int64_t ne11,
 | 
				
			||||||
        const int64_t s1, const int64_t s2, const int64_t s3) {
 | 
					                                        const int64_t ne12,
 | 
				
			||||||
 | 
					                                        const int64_t ne13,
 | 
				
			||||||
 | 
					                                        const int64_t s01,
 | 
				
			||||||
 | 
					                                        const int64_t s02,
 | 
				
			||||||
 | 
					                                        const int64_t s03,
 | 
				
			||||||
 | 
					                                        const int64_t s10,
 | 
				
			||||||
 | 
					                                        const int64_t s11,
 | 
				
			||||||
 | 
					                                        const int64_t s12,
 | 
				
			||||||
 | 
					                                        const int64_t s1,
 | 
				
			||||||
 | 
					                                        const int64_t s2,
 | 
				
			||||||
 | 
					                                        const int64_t s3,
 | 
				
			||||||
 | 
					                                        const uint3   ne00,
 | 
				
			||||||
 | 
					                                        const uint3   ne01,
 | 
				
			||||||
 | 
					                                        const uint3   ne02,
 | 
				
			||||||
 | 
					                                        const uint3   ne11_fd,
 | 
				
			||||||
 | 
					                                        const uint3   ne12_fd) {
 | 
				
			||||||
    const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
 | 
					    const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
 | 
				
			||||||
    const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (i >= ne_total) {
 | 
					    if (i >= ne_total) {
 | 
				
			||||||
        return;
 | 
					        return;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int64_t i_base = i * qk;
 | 
					    const int64_t i_base = i * qk;
 | 
				
			||||||
    const int64_t i03 = i_base / (ne00 * ne01 * ne02);
 | 
					    uint32_t      tmp    = (uint32_t) i_base;
 | 
				
			||||||
    const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
 | 
					    uint2         div_mod;
 | 
				
			||||||
    const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
 | 
					 | 
				
			||||||
    const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int64_t i12 = i03 % ne12;
 | 
					    div_mod           = fast_div_modulo(tmp, ne00);
 | 
				
			||||||
    const int64_t i11 = i02 % ne11;
 | 
					    const int64_t i00 = div_mod.y;
 | 
				
			||||||
 | 
					    tmp               = div_mod.x;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    div_mod           = fast_div_modulo(tmp, ne01);
 | 
				
			||||||
 | 
					    const int64_t i01 = div_mod.y;
 | 
				
			||||||
 | 
					    tmp               = div_mod.x;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    div_mod           = fast_div_modulo(tmp, ne02);
 | 
				
			||||||
 | 
					    const int64_t i02 = div_mod.y;
 | 
				
			||||||
 | 
					    const int64_t i03 = div_mod.x;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
 | 
				
			||||||
 | 
					    const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
 | 
				
			||||||
    const int64_t i10 = i01;
 | 
					    const int64_t i10 = i01;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
 | 
					    const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
 | 
				
			||||||
@@ -41,6 +64,8 @@ static __global__ void k_set_rows_quant(
 | 
				
			|||||||
    quantize_func(src_block, dst_block);
 | 
					    quantize_func(src_block, dst_block);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    GGML_UNUSED(ne10);
 | 
					    GGML_UNUSED(ne10);
 | 
				
			||||||
 | 
					    GGML_UNUSED(ne11);
 | 
				
			||||||
 | 
					    GGML_UNUSED(ne12);
 | 
				
			||||||
    GGML_UNUSED(ne13);
 | 
					    GGML_UNUSED(ne13);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -71,40 +96,65 @@ static void set_rows_cuda_quant(
 | 
				
			|||||||
    const int64_t s2  = nb2;
 | 
					    const int64_t s2  = nb2;
 | 
				
			||||||
    const int64_t s3  = nb3;
 | 
					    const int64_t s3  = nb3;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (ne_total > 0) {
 | 
					    if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
 | 
				
			||||||
 | 
					        const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
 | 
				
			||||||
 | 
					        const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
 | 
				
			||||||
 | 
					        const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
 | 
				
			||||||
 | 
					        const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
 | 
				
			||||||
 | 
					        const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
 | 
					        k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
 | 
				
			||||||
            src0_d, src1_d, dst_d,
 | 
					            src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
 | 
				
			||||||
            ne00, ne01, ne02, ne03,
 | 
					            ne01_fd, ne02_fd, ne11_fd, ne12_fd);
 | 
				
			||||||
            ne10, ne11, ne12, ne13,
 | 
					 | 
				
			||||||
            s01, s02, s03,
 | 
					 | 
				
			||||||
            s10, s11, s12,
 | 
					 | 
				
			||||||
            s1, s2, s3);
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template<typename src_t, typename idx_t, typename dst_t>
 | 
					template <typename src_t, typename idx_t, typename dst_t>
 | 
				
			||||||
static __global__ void k_set_rows(
 | 
					static __global__ void k_set_rows(const src_t * __restrict__ src0,
 | 
				
			||||||
        const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
 | 
					                                  const idx_t * __restrict__ src1,
 | 
				
			||||||
        const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
 | 
					                                  dst_t * __restrict__ dst,
 | 
				
			||||||
        const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
 | 
					                                  const int64_t ne_total,
 | 
				
			||||||
        const int64_t s01, const int64_t s02, const int64_t s03,
 | 
					                                  const int64_t ne10,
 | 
				
			||||||
        const int64_t s10, const int64_t s11, const int64_t s12,
 | 
					                                  const int64_t ne11,
 | 
				
			||||||
        const int64_t s1, const int64_t s2, const int64_t s3) {
 | 
					                                  const int64_t ne12,
 | 
				
			||||||
 | 
					                                  const int64_t ne13,
 | 
				
			||||||
 | 
					                                  const int64_t s01,
 | 
				
			||||||
 | 
					                                  const int64_t s02,
 | 
				
			||||||
 | 
					                                  const int64_t s03,
 | 
				
			||||||
 | 
					                                  const int64_t s10,
 | 
				
			||||||
 | 
					                                  const int64_t s11,
 | 
				
			||||||
 | 
					                                  const int64_t s12,
 | 
				
			||||||
 | 
					                                  const int64_t s1,
 | 
				
			||||||
 | 
					                                  const int64_t s2,
 | 
				
			||||||
 | 
					                                  const int64_t s3,
 | 
				
			||||||
 | 
					                                  const uint3   ne00,
 | 
				
			||||||
 | 
					                                  const uint3   ne01,
 | 
				
			||||||
 | 
					                                  const uint3   ne02,
 | 
				
			||||||
 | 
					                                  const uint3   ne11_fd,
 | 
				
			||||||
 | 
					                                  const uint3   ne12_fd) {
 | 
				
			||||||
    const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
 | 
					    const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
 | 
				
			||||||
    const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (i >= ne_total) {
 | 
					    if (i >= ne_total) {
 | 
				
			||||||
        return;
 | 
					        return;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int64_t i03 = i / (ne00 * ne01 * ne02);
 | 
					    uint32_t tmp = (uint32_t) i;
 | 
				
			||||||
    const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
 | 
					    uint2    div_mod;
 | 
				
			||||||
    const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
 | 
					 | 
				
			||||||
    const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int64_t i12 = i03 % ne12;
 | 
					    div_mod           = fast_div_modulo(tmp, ne00);
 | 
				
			||||||
    const int64_t i11 = i02 % ne11;
 | 
					    const int64_t i00 = div_mod.y;
 | 
				
			||||||
 | 
					    tmp               = div_mod.x;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    div_mod           = fast_div_modulo(tmp, ne01);
 | 
				
			||||||
 | 
					    const int64_t i01 = div_mod.y;
 | 
				
			||||||
 | 
					    tmp               = div_mod.x;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    div_mod           = fast_div_modulo(tmp, ne02);
 | 
				
			||||||
 | 
					    const int64_t i02 = div_mod.y;
 | 
				
			||||||
 | 
					    const int64_t i03 = div_mod.x;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
 | 
				
			||||||
 | 
					    const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
 | 
				
			||||||
    const int64_t i10 = i01;
 | 
					    const int64_t i10 = i01;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
 | 
					    const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
 | 
				
			||||||
@@ -115,6 +165,8 @@ static __global__ void k_set_rows(
 | 
				
			|||||||
    dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
 | 
					    dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    GGML_UNUSED(ne10);
 | 
					    GGML_UNUSED(ne10);
 | 
				
			||||||
 | 
					    GGML_UNUSED(ne11);
 | 
				
			||||||
 | 
					    GGML_UNUSED(ne12);
 | 
				
			||||||
    GGML_UNUSED(ne13);
 | 
					    GGML_UNUSED(ne13);
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -144,14 +196,16 @@ static void set_rows_cuda(
 | 
				
			|||||||
    const int64_t s2  = nb2/sizeof(dst_t);
 | 
					    const int64_t s2  = nb2/sizeof(dst_t);
 | 
				
			||||||
    const int64_t s3  = nb3/sizeof(dst_t);
 | 
					    const int64_t s3  = nb3/sizeof(dst_t);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (ne_total > 0) {
 | 
					    if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
 | 
				
			||||||
        k_set_rows<<<grid_size, block_size, 0, stream>>>(
 | 
					        const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
 | 
				
			||||||
            src0_d, src1_d, dst_d,
 | 
					        const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
 | 
				
			||||||
            ne00, ne01, ne02, ne03,
 | 
					        const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
 | 
				
			||||||
            ne10, ne11, ne12, ne13,
 | 
					        const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
 | 
				
			||||||
            s01, s02, s03,
 | 
					        const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
 | 
				
			||||||
            s10, s11, s12,
 | 
					
 | 
				
			||||||
            s1, s2, s3);
 | 
					        k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
 | 
				
			||||||
 | 
					                                                         s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
 | 
				
			||||||
 | 
					                                                         ne11_fd, ne12_fd);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user