metal : mv q6_K support nr0 > 1

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-03-25 17:48:43 +02:00
parent 51dea76888
commit fe12e20a7f

View File

@@ -4975,9 +4975,9 @@ void kernel_mul_mv_q6_K_f32_impl(
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
device const float * yy = (device const float *) (src1 + offset1);
// TODO: support nr0 > 1
static_assert(nr0 == 1, "nr0 > 1 not supported");
float sumf[1] = { 0.f };
float sumf[nr0] = { 0.f };
float yl[16];
const short tid = tiisg/2;
const short ix = tiisg%2;
@@ -4995,22 +4995,37 @@ void kernel_mul_mv_q6_K_f32_impl(
device const uint8_t * q2 = q1 + 32;
device const uint8_t * qh = x[i].qh + q_offset_h;
device const int8_t * sc = x[i].scales + is;
device const half * dh = &x[i].d;
device const float * y = yy + i * QK_K + y_offset;
const float dall = x[i].d;
float4 sums = {0.f, 0.f, 0.f, 0.f};
#pragma unroll(4)
for (short l = 0; l < 4; ++l) {
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
yl[4*l + 0] = y[l + 0];
yl[4*l + 1] = y[l + 32];
yl[4*l + 2] = y[l + 64];
yl[4*l + 3] = y[l + 96];
}
sumf[0] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
for (short row = 0; row < nr0; ++row) {
const float dall = dh[0];
float4 sums = {0.f, 0.f, 0.f, 0.f};
for (short l = 0; l < 4; ++l) {
sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
}
sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
q1 += args.nb01;
q2 += args.nb01;
qh += args.nb01;
sc += args.nb01;
dh += args.nb01/2;
}
}
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;