mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
CUDA: use fastdiv in set-rows (#16834)
* CUDA: use fastdiv in set-rows * add assert about value fitting in u32
This commit is contained in:
@@ -625,8 +625,11 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
|
||||
// and a shift:
|
||||
//
|
||||
// n/d = (mulhi(n, mp) + n) >> L;
|
||||
static const uint3 init_fastdiv_values(uint32_t d) {
|
||||
GGML_ASSERT(d != 0);
|
||||
static const uint3 init_fastdiv_values(uint64_t d_64) {
|
||||
GGML_ASSERT(d_64 != 0);
|
||||
GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
|
||||
|
||||
uint32_t d = (uint32_t)d_64;
|
||||
|
||||
// compute L = ceil(log2(d));
|
||||
uint32_t L = 0;
|
||||
|
||||
@@ -4,30 +4,53 @@
|
||||
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
|
||||
|
||||
// Generic quantized set_rows kernel template
|
||||
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
|
||||
static __global__ void k_set_rows_quant(
|
||||
const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s10, const int64_t s11, const int64_t s12,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
|
||||
template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
|
||||
static __global__ void k_set_rows_quant(const float * __restrict__ src0,
|
||||
const idx_t * __restrict__ src1,
|
||||
block_type * __restrict__ dst,
|
||||
const int64_t ne_total,
|
||||
const int64_t ne10,
|
||||
const int64_t ne11,
|
||||
const int64_t ne12,
|
||||
const int64_t ne13,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
const int64_t s10,
|
||||
const int64_t s11,
|
||||
const int64_t s12,
|
||||
const int64_t s1,
|
||||
const int64_t s2,
|
||||
const int64_t s3,
|
||||
const uint3 ne00,
|
||||
const uint3 ne01,
|
||||
const uint3 ne02,
|
||||
const uint3 ne11_fd,
|
||||
const uint3 ne12_fd) {
|
||||
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
||||
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
|
||||
|
||||
if (i >= ne_total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i_base = i * qk;
|
||||
const int64_t i03 = i_base / (ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
|
||||
const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
|
||||
uint32_t tmp = (uint32_t) i_base;
|
||||
uint2 div_mod;
|
||||
|
||||
const int64_t i12 = i03 % ne12;
|
||||
const int64_t i11 = i02 % ne11;
|
||||
div_mod = fast_div_modulo(tmp, ne00);
|
||||
const int64_t i00 = div_mod.y;
|
||||
tmp = div_mod.x;
|
||||
|
||||
div_mod = fast_div_modulo(tmp, ne01);
|
||||
const int64_t i01 = div_mod.y;
|
||||
tmp = div_mod.x;
|
||||
|
||||
div_mod = fast_div_modulo(tmp, ne02);
|
||||
const int64_t i02 = div_mod.y;
|
||||
const int64_t i03 = div_mod.x;
|
||||
|
||||
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
|
||||
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
|
||||
const int64_t i10 = i01;
|
||||
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
@@ -41,6 +64,8 @@ static __global__ void k_set_rows_quant(
|
||||
quantize_func(src_block, dst_block);
|
||||
|
||||
GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12);
|
||||
GGML_UNUSED(ne13);
|
||||
}
|
||||
|
||||
@@ -71,40 +96,65 @@ static void set_rows_cuda_quant(
|
||||
const int64_t s2 = nb2;
|
||||
const int64_t s3 = nb3;
|
||||
|
||||
if (ne_total > 0) {
|
||||
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
|
||||
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
|
||||
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
|
||||
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
||||
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
|
||||
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
|
||||
|
||||
k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
s01, s02, s03,
|
||||
s10, s11, s12,
|
||||
s1, s2, s3);
|
||||
src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
|
||||
ne01_fd, ne02_fd, ne11_fd, ne12_fd);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename src_t, typename idx_t, typename dst_t>
|
||||
static __global__ void k_set_rows(
|
||||
const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
|
||||
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
|
||||
const int64_t s01, const int64_t s02, const int64_t s03,
|
||||
const int64_t s10, const int64_t s11, const int64_t s12,
|
||||
const int64_t s1, const int64_t s2, const int64_t s3) {
|
||||
|
||||
template <typename src_t, typename idx_t, typename dst_t>
|
||||
static __global__ void k_set_rows(const src_t * __restrict__ src0,
|
||||
const idx_t * __restrict__ src1,
|
||||
dst_t * __restrict__ dst,
|
||||
const int64_t ne_total,
|
||||
const int64_t ne10,
|
||||
const int64_t ne11,
|
||||
const int64_t ne12,
|
||||
const int64_t ne13,
|
||||
const int64_t s01,
|
||||
const int64_t s02,
|
||||
const int64_t s03,
|
||||
const int64_t s10,
|
||||
const int64_t s11,
|
||||
const int64_t s12,
|
||||
const int64_t s1,
|
||||
const int64_t s2,
|
||||
const int64_t s3,
|
||||
const uint3 ne00,
|
||||
const uint3 ne01,
|
||||
const uint3 ne02,
|
||||
const uint3 ne11_fd,
|
||||
const uint3 ne12_fd) {
|
||||
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
|
||||
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
|
||||
|
||||
if (i >= ne_total) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int64_t i03 = i / (ne00 * ne01 * ne02);
|
||||
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01);
|
||||
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
|
||||
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00;
|
||||
uint32_t tmp = (uint32_t) i;
|
||||
uint2 div_mod;
|
||||
|
||||
const int64_t i12 = i03 % ne12;
|
||||
const int64_t i11 = i02 % ne11;
|
||||
div_mod = fast_div_modulo(tmp, ne00);
|
||||
const int64_t i00 = div_mod.y;
|
||||
tmp = div_mod.x;
|
||||
|
||||
div_mod = fast_div_modulo(tmp, ne01);
|
||||
const int64_t i01 = div_mod.y;
|
||||
tmp = div_mod.x;
|
||||
|
||||
div_mod = fast_div_modulo(tmp, ne02);
|
||||
const int64_t i02 = div_mod.y;
|
||||
const int64_t i03 = div_mod.x;
|
||||
|
||||
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
|
||||
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
|
||||
const int64_t i10 = i01;
|
||||
|
||||
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
|
||||
@@ -115,6 +165,8 @@ static __global__ void k_set_rows(
|
||||
dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
|
||||
|
||||
GGML_UNUSED(ne10);
|
||||
GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12);
|
||||
GGML_UNUSED(ne13);
|
||||
}
|
||||
|
||||
@@ -144,14 +196,16 @@ static void set_rows_cuda(
|
||||
const int64_t s2 = nb2/sizeof(dst_t);
|
||||
const int64_t s3 = nb3/sizeof(dst_t);
|
||||
|
||||
if (ne_total > 0) {
|
||||
k_set_rows<<<grid_size, block_size, 0, stream>>>(
|
||||
src0_d, src1_d, dst_d,
|
||||
ne00, ne01, ne02, ne03,
|
||||
ne10, ne11, ne12, ne13,
|
||||
s01, s02, s03,
|
||||
s10, s11, s12,
|
||||
s1, s2, s3);
|
||||
if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
|
||||
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
|
||||
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
|
||||
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
|
||||
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
|
||||
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
|
||||
|
||||
k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
|
||||
s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
|
||||
ne11_fd, ne12_fd);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user