mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-04 09:32:00 +00:00
cuda: use simpler loop in get_rows
This commit is contained in:
@@ -11,19 +11,14 @@ static __global__ void k_get_rows(
|
|||||||
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
|
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
|
||||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/, const size_t block_num_y) {
|
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||||
|
|
||||||
for (int64_t iy = blockIdx.y; iy < block_num_y; iy+=MAX_GRIDDIM_Y) {
|
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||||
const int i00 = (iy * blockDim.x + threadIdx.x)*2;
|
|
||||||
const int i10 = blockIdx.x;
|
const int i10 = blockIdx.x;
|
||||||
const int i11 = blockIdx.z / ne12;
|
const int i11 = blockIdx.z / ne12;
|
||||||
const int i12 = blockIdx.z % ne12;
|
const int i12 = blockIdx.z % ne12;
|
||||||
|
|
||||||
if (i00 >= ne00) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||||
|
|
||||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||||
@@ -35,7 +30,7 @@ static __global__ void k_get_rows(
|
|||||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||||
|
|
||||||
// dequantize
|
// dequantize
|
||||||
float2 v;
|
dfloat2 v;
|
||||||
dequantize_kernel(src0_row, ib, iqs, v);
|
dequantize_kernel(src0_row, ib, iqs, v);
|
||||||
|
|
||||||
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
|
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||||
@@ -50,12 +45,10 @@ static __global__ void k_get_rows_float(
|
|||||||
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
|
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
|
||||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/, const size_t block_num_y) {
|
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||||
|
|
||||||
for (int64_t iy = blockIdx.y; iy < block_num_y; iy+=MAX_GRIDDIM_Y) {
|
|
||||||
|
|
||||||
|
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||||
const int i00 = iy * blockDim.x + threadIdx.x;
|
|
||||||
const int i10 = blockIdx.x;
|
const int i10 = blockIdx.x;
|
||||||
const int i11 = blockIdx.z / ne12;
|
const int i11 = blockIdx.z / ne12;
|
||||||
const int i12 = blockIdx.z % ne12;
|
const int i12 = blockIdx.z % ne12;
|
||||||
@@ -126,7 +119,7 @@ static void get_rows_cuda_q(
|
|||||||
/*ne10, ne11,*/ ne12, /*ne13,*/
|
/*ne10, ne11,*/ ne12, /*ne13,*/
|
||||||
/* s0,*/ s1, s2, s3,
|
/* s0,*/ s1, s2, s3,
|
||||||
/* nb00,*/ nb01, nb02, nb03,
|
/* nb00,*/ nb01, nb02, nb03,
|
||||||
s10, s11, s12/*, s13*/, block_num_y);
|
s10, s11, s12/*, s13*/);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename src0_t, typename dst_t>
|
template<typename src0_t, typename dst_t>
|
||||||
@@ -157,7 +150,7 @@ static void get_rows_cuda_float(
|
|||||||
/*ne10, ne11,*/ ne12, /*ne13,*/
|
/*ne10, ne11,*/ ne12, /*ne13,*/
|
||||||
/* s0,*/ s1, s2, s3,
|
/* s0,*/ s1, s2, s3,
|
||||||
/* nb00,*/ nb01, nb02, nb03,
|
/* nb00,*/ nb01, nb02, nb03,
|
||||||
s10, s11, s12/*, s13*/, block_num_y);
|
s10, s11, s12/*, s13*/);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename dst_t>
|
template <typename dst_t>
|
||||||
|
|||||||
Reference in New Issue
Block a user