opencl: support ne3 in get_rows (#15866)

This commit is contained in:
lhez
2025-09-30 09:55:13 -07:00
committed by GitHub
parent 364a7a6d4a
commit d1c84a662d
2 changed files with 59 additions and 28 deletions

View File

@@ -4222,15 +4222,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
const int ne00 = src0 ? src0->ne[0] : 0;
const cl_ulong nb01 = src0 ? src0->nb[1] : 0;
const cl_ulong nb02 = src0 ? src0->nb[2] : 0;
const int ne10 = src1 ? src1->ne[0] : 0;
const cl_ulong nb10 = src1 ? src1->nb[0] : 0;
const int ne11 = src1 ? src1->ne[1] : 0;
const cl_ulong nb11 = src1 ? src1->nb[1] : 0;
const cl_ulong nb1 = dst ? dst->nb[1] : 0;
const cl_ulong nb2 = dst ? dst->nb[2] : 0;
const int ne00 = src0->ne[0];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int ne10 = src1->ne[0];
const cl_ulong nb10 = src1->nb[0];
const int ne11 = src1->ne[1];
const int ne12 = src1->ne[2];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@@ -4267,14 +4271,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));
size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1};
size_t local_work_size[] = {1, 1, 1};
size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12};
size_t local_work_size[] = {64, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}

View File

@@ -69,11 +69,14 @@ kernel void kernel_get_rows_f32(
int ne00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb1,
ulong nb2
ulong nb2,
ulong nb3
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
@@ -81,14 +84,19 @@ kernel void kernel_get_rows_f32(
int i10 = get_group_id(0);
int i11 = get_group_id(1);
int i12 = get_group_id(2);
int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
int i02 = i11;
int i03 = i12;
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
if (ind >= ne00) {
return;
}
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
}
}
@@ -102,11 +110,14 @@ kernel void kernel_get_rows_f16(
int ne00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb1,
ulong nb2
ulong nb2,
ulong nb3
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
@@ -114,14 +125,19 @@ kernel void kernel_get_rows_f16(
int i10 = get_group_id(0);
int i11 = get_group_id(1);
int i12 = get_group_id(2);
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
int i02 = i11;
int i03 = i12;
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] =
((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind];
if (ind >= ne00) {
return;
}
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
}
}
@@ -135,11 +151,14 @@ kernel void kernel_get_rows_q4_0(
int ne00,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
ulong nb10,
ulong nb11,
ulong nb12,
ulong nb1,
ulong nb2
ulong nb2,
ulong nb3
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1);
@@ -149,15 +168,20 @@ kernel void kernel_get_rows_q4_0(
int i10 = get_group_id(0);
int i11 = get_group_id(1);
int i12 = get_group_id(2);
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0];
int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
int i02 = i11;
int i03 = i12;
for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
float16 temp;
if (ind >= ne00) {
return;
}
dequantize_q4_0_f32(
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp);
*(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp);
*(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp;
}
}