mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
CUDA: fix GET_ROWS for large tensors (#15882)
This commit is contained in:
@@ -2,39 +2,39 @@
|
|||||||
#include "dequantize.cuh"
|
#include "dequantize.cuh"
|
||||||
#include "convert.cuh"
|
#include "convert.cuh"
|
||||||
|
|
||||||
#define MAX_GRIDDIM_Y 65535
|
|
||||||
|
|
||||||
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
|
||||||
static __global__ void k_get_rows(
|
static __global__ void k_get_rows(
|
||||||
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||||
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
||||||
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
|
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
|
||||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||||
|
|
||||||
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
|
||||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||||
const int i10 = blockIdx.x;
|
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||||
const int i11 = blockIdx.z / ne12;
|
const int i10 = blockIdx.x;
|
||||||
const int i12 = blockIdx.z % ne12;
|
const int i11 = z / ne12; // TODO fastdiv
|
||||||
|
const int i12 = z % ne12;
|
||||||
|
|
||||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||||
|
|
||||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||||
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
|
||||||
|
|
||||||
const int ib = i00/qk; // block index
|
const int ib = i00/qk; // block index
|
||||||
const int iqs = (i00%qk)/qr; // quant index
|
const int iqs = (i00%qk)/qr; // quant index
|
||||||
const int iybs = i00 - i00%qk; // dst block start index
|
const int iybs = i00 - i00%qk; // dst block start index
|
||||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||||
|
|
||||||
// dequantize
|
// dequantize
|
||||||
float2 v;
|
float2 v;
|
||||||
dequantize_kernel(src0_row, ib, iqs, v);
|
dequantize_kernel(src0_row, ib, iqs, v);
|
||||||
|
|
||||||
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
|
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
|
||||||
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -42,27 +42,29 @@ template<typename src0_t, typename dst_t>
|
|||||||
static __global__ void k_get_rows_float(
|
static __global__ void k_get_rows_float(
|
||||||
const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
|
||||||
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
|
||||||
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
|
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
|
||||||
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
|
||||||
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
|
||||||
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
|
||||||
|
|
||||||
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
|
||||||
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
|
||||||
const int i10 = blockIdx.x;
|
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
|
||||||
const int i11 = blockIdx.z / ne12;
|
const int i10 = blockIdx.x;
|
||||||
const int i12 = blockIdx.z % ne12;
|
const int i11 = z / ne12; // TODO fastdiv
|
||||||
|
const int i12 = z % ne12;
|
||||||
|
|
||||||
if (i00 >= ne00) {
|
if (i00 >= ne00) {
|
||||||
return;
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
||||||
|
|
||||||
|
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
||||||
|
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
||||||
|
|
||||||
|
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
|
|
||||||
|
|
||||||
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
|
|
||||||
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
|
|
||||||
|
|
||||||
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -98,7 +100,7 @@ static void get_rows_cuda_q(
|
|||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||||
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
|
||||||
const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
|
const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
|
||||||
|
|
||||||
// strides in elements
|
// strides in elements
|
||||||
// const size_t s0 = nb0 / sizeof(dst_t);
|
// const size_t s0 = nb0 / sizeof(dst_t);
|
||||||
@@ -116,7 +118,7 @@ static void get_rows_cuda_q(
|
|||||||
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
|
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
|
||||||
src0_d, src1_d, dst_d,
|
src0_d, src1_d, dst_d,
|
||||||
ne00, /*ne01, ne02, ne03,*/
|
ne00, /*ne01, ne02, ne03,*/
|
||||||
/*ne10, ne11,*/ ne12, /*ne13,*/
|
/*ne10,*/ ne11, ne12, /*ne13,*/
|
||||||
/* s0,*/ s1, s2, s3,
|
/* s0,*/ s1, s2, s3,
|
||||||
/* nb00,*/ nb01, nb02, nb03,
|
/* nb00,*/ nb01, nb02, nb03,
|
||||||
s10, s11, s12/*, s13*/);
|
s10, s11, s12/*, s13*/);
|
||||||
@@ -131,7 +133,7 @@ static void get_rows_cuda_float(
|
|||||||
cudaStream_t stream) {
|
cudaStream_t stream) {
|
||||||
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
|
||||||
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
|
||||||
const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
|
const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
|
||||||
|
|
||||||
// strides in elements
|
// strides in elements
|
||||||
// const size_t s0 = nb0 / sizeof(dst_t);
|
// const size_t s0 = nb0 / sizeof(dst_t);
|
||||||
@@ -147,7 +149,7 @@ static void get_rows_cuda_float(
|
|||||||
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
|
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
|
||||||
src0_d, src1_d, dst_d,
|
src0_d, src1_d, dst_d,
|
||||||
ne00, /*ne01, ne02, ne03,*/
|
ne00, /*ne01, ne02, ne03,*/
|
||||||
/*ne10, ne11,*/ ne12, /*ne13,*/
|
/*ne10,*/ ne11, ne12, /*ne13,*/
|
||||||
/* s0,*/ s1, s2, s3,
|
/* s0,*/ s1, s2, s3,
|
||||||
/* nb00,*/ nb01, nb02, nb03,
|
/* nb00,*/ nb01, nb02, nb03,
|
||||||
s10, s11, s12/*, s13*/);
|
s10, s11, s12/*, s13*/);
|
||||||
|
|||||||
@@ -3393,10 +3393,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
{
|
{
|
||||||
// FIXME: https://github.com/ggml-org/llama.cpp/pull/15868
|
|
||||||
if (op->src[1]->ne[1]*op->src[1]->ne[2] > 65535) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
switch (op->src[0]->type) {
|
switch (op->src[0]->type) {
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
|
|||||||
Reference in New Issue
Block a user