opencl: add fastdiv and use it in set_rows, ported from cuda (#17090)

* opencl: add fastdiv for mm q8_0

* opencl: use uint4 for fastdiv vals

* opencl: use fastdiv for set_rows

* opencl: do not use fastdiv for q8_0 mm
This commit is contained in:
lhez
2025-11-10 15:00:13 -08:00
committed by GitHub
parent 7bef684118
commit ece0f5c177
2 changed files with 71 additions and 18 deletions

View File

@@ -53,6 +53,37 @@
bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor); bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor);
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
// and a shift:
//
// n/d = (mulhi(n, mp) + n) >> L;
struct fastdiv_vals {
uint32_t mp;
uint32_t L;
uint32_t d;
uint32_t pad;
};
static_assert(sizeof(fastdiv_vals) == 16, "fastdiv_vals size incorrect");
static fastdiv_vals init_fastdiv_values(uint64_t d_64) {
GGML_ASSERT(d_64 != 0);
GGML_ASSERT(d_64 <= std::numeric_limits<uint32_t>::max());
uint32_t d = (uint32_t)d_64;
// compute L = ceil(log2(d));
uint32_t L = 0;
while (L < 32 && (uint32_t{ 1 } << L) < d) {
L++;
}
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
// pack divisor as well to reduce error surface
return { mp, L, d, 0 };
}
enum GPU_FAMILY { enum GPU_FAMILY {
ADRENO, ADRENO,
INTEL, INTEL,
@@ -4464,6 +4495,9 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c
GGML_ABORT("not implemented"); GGML_ABORT("not implemented");
} }
fastdiv_vals ne11_ = init_fastdiv_values(ne11);
fastdiv_vals ne12_ = init_fastdiv_values(ne12);
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
@@ -4474,8 +4508,8 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11)); CL_CHECK(clSetKernelArg(kernel, 10, sizeof(fastdiv_vals), &ne11_));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); CL_CHECK(clSetKernelArg(kernel, 11, sizeof(fastdiv_vals), &ne12_));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10)); CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb10));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb12));

View File

@@ -1,5 +1,16 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_khr_fp16 : enable
// v = { mp, L, d }
inline uint fastdiv(uint n, uint4 v) {
uint msbs;
msbs = mul_hi(n, v.s0);
return (msbs + n) >> v.s1;
}
inline uint fastmod(uint n, uint4 v) {
uint q = fastdiv(n, v);
return n - q * v.s2;
}
kernel void kernel_set_rows_f32_i64( kernel void kernel_set_rows_f32_i64(
global char * src0, global char * src0,
ulong offset0, ulong offset0,
@@ -11,8 +22,8 @@ kernel void kernel_set_rows_f32_i64(
ulong nb01, ulong nb01,
ulong nb02, ulong nb02,
ulong nb03, ulong nb03,
int ne11, uint4 ne11,
int ne12, uint4 ne12,
ulong nb10, ulong nb10,
ulong nb11, ulong nb11,
ulong nb12, ulong nb12,
@@ -33,8 +44,10 @@ kernel void kernel_set_rows_f32_i64(
return; return;
} }
int i12 = i03%ne12; //int i12 = i03%ne12;
int i11 = i02%ne11; //int i11 = i02%ne11;
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01; int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0]; long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
@@ -58,8 +71,8 @@ kernel void kernel_set_rows_f16_i64(
ulong nb01, ulong nb01,
ulong nb02, ulong nb02,
ulong nb03, ulong nb03,
int ne11, uint4 ne11,
int ne12, uint4 ne12,
ulong nb10, ulong nb10,
ulong nb11, ulong nb11,
ulong nb12, ulong nb12,
@@ -80,8 +93,10 @@ kernel void kernel_set_rows_f16_i64(
return; return;
} }
int i12 = i03%ne12; //int i12 = i03%ne12;
int i11 = i02%ne11; //int i11 = i02%ne11;
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01; int i10 = i01;
long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0]; long i1 = ((global long *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
@@ -105,8 +120,8 @@ kernel void kernel_set_rows_f32_i32(
ulong nb01, ulong nb01,
ulong nb02, ulong nb02,
ulong nb03, ulong nb03,
int ne11, uint4 ne11,
int ne12, uint4 ne12,
ulong nb10, ulong nb10,
ulong nb11, ulong nb11,
ulong nb12, ulong nb12,
@@ -127,8 +142,10 @@ kernel void kernel_set_rows_f32_i32(
return; return;
} }
int i12 = i03%ne12; //int i12 = i03%ne12;
int i11 = i02%ne11; //int i11 = i02%ne11;
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01; int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0]; int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
@@ -152,8 +169,8 @@ kernel void kernel_set_rows_f16_i32(
ulong nb01, ulong nb01,
ulong nb02, ulong nb02,
ulong nb03, ulong nb03,
int ne11, uint4 ne11,
int ne12, uint4 ne12,
ulong nb10, ulong nb10,
ulong nb11, ulong nb11,
ulong nb12, ulong nb12,
@@ -174,8 +191,10 @@ kernel void kernel_set_rows_f16_i32(
return; return;
} }
int i12 = i03%ne12; //int i12 = i03%ne12;
int i11 = i02%ne11; //int i11 = i02%ne11;
int i12 = fastmod(i03, ne12);
int i11 = fastmod(i02, ne11);
int i10 = i01; int i10 = i01;
int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0]; int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];