opencl: support ne3 in get_rows (#15866)

This commit is contained in:
lhez
2025-09-30 09:55:13 -07:00
committed by GitHub
parent 364a7a6d4a
commit d1c84a662d
2 changed files with 59 additions and 28 deletions

View File

@@ -4222,15 +4222,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
GGML_ASSERT(dst); GGML_ASSERT(dst);
GGML_ASSERT(dst->extra); GGML_ASSERT(dst->extra);
const int ne00 = src0 ? src0->ne[0] : 0; const int ne00 = src0->ne[0];
const cl_ulong nb01 = src0 ? src0->nb[1] : 0; const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0 ? src0->nb[2] : 0; const cl_ulong nb02 = src0->nb[2];
const int ne10 = src1 ? src1->ne[0] : 0; const cl_ulong nb03 = src0->nb[3];
const cl_ulong nb10 = src1 ? src1->nb[0] : 0; const int ne10 = src1->ne[0];
const int ne11 = src1 ? src1->ne[1] : 0; const cl_ulong nb10 = src1->nb[0];
const cl_ulong nb11 = src1 ? src1->nb[1] : 0; const int ne11 = src1->ne[1];
const cl_ulong nb1 = dst ? dst->nb[1] : 0; const int ne12 = src1->ne[2];
const cl_ulong nb2 = dst ? dst->nb[2] : 0; const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@@ -4267,14 +4271,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10)); CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11)); CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3));
size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1}; size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12};
size_t local_work_size[] = {1, 1, 1}; size_t local_work_size[] = {64, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
} }

View File

@@ -69,11 +69,14 @@ kernel void kernel_get_rows_f32(
int ne00, int ne00,
ulong nb01, ulong nb01,
ulong nb02, ulong nb02,
ulong nb03,
int ne10, int ne10,
ulong nb10, ulong nb10,
ulong nb11, ulong nb11,
ulong nb12,
ulong nb1, ulong nb1,
ulong nb2 ulong nb2,
ulong nb3
) { ) {
src0 = (global void*)((global char*)src0 + offset0); src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1); src1 = (global int*)((global char*)src1 + offset1);
@@ -81,14 +84,19 @@ kernel void kernel_get_rows_f32(
int i10 = get_group_id(0); int i10 = get_group_id(0);
int i11 = get_group_id(1); int i11 = get_group_id(1);
int i12 = get_group_id(2);
int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
int i02 = i11; int i02 = i11;
int i03 = i12;
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = if (ind >= ne00) {
((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; return;
}
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
} }
} }
@@ -102,11 +110,14 @@ kernel void kernel_get_rows_f16(
int ne00, int ne00,
ulong nb01, ulong nb01,
ulong nb02, ulong nb02,
ulong nb03,
int ne10, int ne10,
ulong nb10, ulong nb10,
ulong nb11, ulong nb11,
ulong nb12,
ulong nb1, ulong nb1,
ulong nb2 ulong nb2,
ulong nb3
) { ) {
src0 = (global void*)((global char*)src0 + offset0); src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1); src1 = (global int*)((global char*)src1 + offset1);
@@ -114,14 +125,19 @@ kernel void kernel_get_rows_f16(
int i10 = get_group_id(0); int i10 = get_group_id(0);
int i11 = get_group_id(1); int i11 = get_group_id(1);
int i12 = get_group_id(2);
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
int i02 = i11; int i02 = i11;
int i03 = i12;
for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) {
((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = if (ind >= ne00) {
((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; return;
}
((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] =
((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind];
} }
} }
@@ -135,11 +151,14 @@ kernel void kernel_get_rows_q4_0(
int ne00, int ne00,
ulong nb01, ulong nb01,
ulong nb02, ulong nb02,
ulong nb03,
int ne10, int ne10,
ulong nb10, ulong nb10,
ulong nb11, ulong nb11,
ulong nb12,
ulong nb1, ulong nb1,
ulong nb2 ulong nb2,
ulong nb3
) { ) {
src0 = (global void*)((global char*)src0 + offset0); src0 = (global void*)((global char*)src0 + offset0);
src1 = (global int*)((global char*)src1 + offset1); src1 = (global int*)((global char*)src1 + offset1);
@@ -149,15 +168,20 @@ kernel void kernel_get_rows_q4_0(
int i10 = get_group_id(0); int i10 = get_group_id(0);
int i11 = get_group_id(1); int i11 = get_group_id(1);
int i12 = get_group_id(2);
int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0];
int i02 = i11; int i02 = i11;
int i03 = i12;
for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) {
float16 temp; float16 temp;
if (ind >= ne00) {
return;
}
dequantize_q4_0_f32( dequantize_q4_0_f32(
((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp);
*(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; *(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp;
} }
} }