mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
metal : mv q6_K support nr0 > 1
ggml-ci
This commit is contained in:
@@ -4975,9 +4975,9 @@ void kernel_mul_mv_q6_K_f32_impl(
|
||||
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
|
||||
device const float * yy = (device const float *) (src1 + offset1);
|
||||
|
||||
// TODO: support nr0 > 1
|
||||
static_assert(nr0 == 1, "nr0 > 1 not supported");
|
||||
float sumf[1] = { 0.f };
|
||||
float sumf[nr0] = { 0.f };
|
||||
|
||||
float yl[16];
|
||||
|
||||
const short tid = tiisg/2;
|
||||
const short ix = tiisg%2;
|
||||
@@ -4995,22 +4995,37 @@ void kernel_mul_mv_q6_K_f32_impl(
|
||||
device const uint8_t * q2 = q1 + 32;
|
||||
device const uint8_t * qh = x[i].qh + q_offset_h;
|
||||
device const int8_t * sc = x[i].scales + is;
|
||||
device const half * dh = &x[i].d;
|
||||
|
||||
device const float * y = yy + i * QK_K + y_offset;
|
||||
|
||||
const float dall = x[i].d;
|
||||
|
||||
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
||||
|
||||
#pragma unroll(4)
|
||||
for (short l = 0; l < 4; ++l) {
|
||||
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
||||
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
||||
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
||||
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
||||
yl[4*l + 0] = y[l + 0];
|
||||
yl[4*l + 1] = y[l + 32];
|
||||
yl[4*l + 2] = y[l + 64];
|
||||
yl[4*l + 3] = y[l + 96];
|
||||
}
|
||||
|
||||
sumf[0] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
||||
for (short row = 0; row < nr0; ++row) {
|
||||
const float dall = dh[0];
|
||||
|
||||
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
||||
|
||||
for (short l = 0; l < 4; ++l) {
|
||||
sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
||||
sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
||||
sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
||||
sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
||||
}
|
||||
|
||||
sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
||||
|
||||
q1 += args.nb01;
|
||||
q2 += args.nb01;
|
||||
qh += args.nb01;
|
||||
sc += args.nb01;
|
||||
dh += args.nb01/2;
|
||||
}
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
Reference in New Issue
Block a user