cuda: use simpler loop in get_rows

This commit is contained in:
leejet
2025-08-31 00:21:24 +08:00
parent 131ae2d585
commit 0d5eb51252

View File

@@ -11,19 +11,14 @@ static __global__ void k_get_rows(
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/, const size_t block_num_y) { const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
for (int64_t iy = blockIdx.y; iy < block_num_y; iy+=MAX_GRIDDIM_Y) { for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i00 = (iy * blockDim.x + threadIdx.x)*2;
const int i10 = blockIdx.x; const int i10 = blockIdx.x;
const int i11 = blockIdx.z / ne12; const int i11 = blockIdx.z / ne12;
const int i12 = blockIdx.z % ne12; const int i12 = blockIdx.z % ne12;
if (i00 >= ne00) {
return;
}
const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
@@ -35,7 +30,7 @@ static __global__ void k_get_rows(
const int y_offset = qr == 1 ? 1 : qk/2; const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize // dequantize
float2 v; dfloat2 v;
dequantize_kernel(src0_row, ib, iqs, v); dequantize_kernel(src0_row, ib, iqs, v);
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x); dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
@@ -50,12 +45,10 @@ static __global__ void k_get_rows_float(
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ /*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/, const size_t block_num_y) { const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
for (int64_t iy = blockIdx.y; iy < block_num_y; iy+=MAX_GRIDDIM_Y) {
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i00 = iy * blockDim.x + threadIdx.x;
const int i10 = blockIdx.x; const int i10 = blockIdx.x;
const int i11 = blockIdx.z / ne12; const int i11 = blockIdx.z / ne12;
const int i12 = blockIdx.z % ne12; const int i12 = blockIdx.z % ne12;
@@ -126,7 +119,7 @@ static void get_rows_cuda_q(
/*ne10, ne11,*/ ne12, /*ne13,*/ /*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3, /* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03, /* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/, block_num_y); s10, s11, s12/*, s13*/);
} }
template<typename src0_t, typename dst_t> template<typename src0_t, typename dst_t>
@@ -157,7 +150,7 @@ static void get_rows_cuda_float(
/*ne10, ne11,*/ ne12, /*ne13,*/ /*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3, /* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03, /* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/, block_num_y); s10, s11, s12/*, s13*/);
} }
template <typename dst_t> template <typename dst_t>