metal : reduce register pressure

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-03-25 21:55:15 +02:00
parent fe12e20a7f
commit e7d14ab26c

View File

@@ -1901,7 +1901,7 @@ void mul_vec_q_n_f32_impl(
float sumy[2] = { 0.f, 0.f };
#pragma unroll
for (int i = 0; i < 8; i += 2) {
for (short i = 0; i < 8; i += 2) {
sumy[0] += yb[i + 0] + yb[i + 1];
yl[i + 0] = yb[i + 0];
yl[i + 1] = yb[i + 1]/256.f;
@@ -1912,7 +1912,7 @@ void mul_vec_q_n_f32_impl(
}
#pragma unroll
for (int row = 0; row < nr0; row++) {
for (short row = 0; row < nr0; row++) {
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
}
@@ -2025,7 +2025,7 @@ void kernel_mul_mv_q8_0_f32_impl(
yl[i] = yb[i];
}
for (int row = 0; row < nr0; row++) {
for (short row = 0; row < nr0; row++) {
device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
float sumq = 0.f;
for (short iq = 0; iq < NB_Q8_0; ++iq) {
@@ -4449,18 +4449,17 @@ void kernel_mul_mv_q2_K_f32_impl(
float yl[32];
float sumf[nr0]={0.f};
const int ix = tiisg/8; // 0...3
const int it = tiisg%8; // 0...7
const int iq = it/4; // 0 or 1
const int ir = it%4; // 0...3
const int is = (8*ir)/16;// 0 or 1
const short ix = tiisg/8; // 0...3
const short it = tiisg%8; // 0...7
const short iq = it/4; // 0 or 1
const short ir = it%4; // 0...3
const short is = (8*ir)/16;// 0 or 1
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
for (int ib = ix; ib < nb; ib += 4) {
float4 sumy = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; ++i) {
for (short i = 0; i < 8; ++i) {
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
@@ -4471,7 +4470,7 @@ void kernel_mul_mv_q2_K_f32_impl(
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
device const half * dh = &x[ib].d;
for (int row = 0; row < nr0; row++) {
for (short row = 0; row < nr0; row++) {
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
@@ -4561,8 +4560,7 @@ void kernel_mul_mv_q3_K_f32_impl(
const int ip = tid/4; // 0 or 1
const int il = 2*((tid%4)/2); // 0 or 2
const int ir = tid%2;
const int n = 8;
const int l0 = n*ir;
const int l0 = 8*ir;
// One would think that the Metal compiler would figure out that ip and il can only have
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
@@ -4600,7 +4598,7 @@ void kernel_mul_mv_q3_K_f32_impl(
float sumf2[nr0] = {0.f};
for (int i = ix; i < nb; i += 4) {
for (int l = 0; l < 8; ++l) {
for (short l = 0; l < 8; ++l) {
yl[l+ 0] = y1[l+ 0];
yl[l+ 8] = y1[l+16];
yl[l+16] = y1[l+32];
@@ -4612,7 +4610,7 @@ void kernel_mul_mv_q3_K_f32_impl(
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
device const half * dh = &x[i].d;
for (int row = 0; row < nr0; ++row) {
for (short row = 0; row < nr0; ++row) {
const float d_all = (float)dh[0];
scales16[0] = a[4];
@@ -4623,7 +4621,7 @@ void kernel_mul_mv_q3_K_f32_impl(
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
for (int l = 0; l < n; l += 2) {
for (short l = 0; l < 8; l += 2) {
const int32_t qs = q[l/2];
s1 += yl[l+0] * (qs & qm[il/2][0]);
s2 += yl[l+1] * (qs & qm[il/2][1]);
@@ -4638,7 +4636,7 @@ void kernel_mul_mv_q3_K_f32_impl(
sumf2[row] += d2 * (scales[2] - 32);
s1 = s2 = s3 = s4 = s5 = s6 = 0;
for (int l = 0; l < n; l += 2) {
for (short l = 0; l < 8; l += 2) {
const int32_t qs = q[l/2+8];
s1 += yl[l+8] * (qs & qm[il/2][0]);
s2 += yl[l+9] * (qs & qm[il/2][1]);
@@ -4846,15 +4844,14 @@ void kernel_mul_mv_q5_K_f32_impl(
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int tid = tiisg/4;
const int ix = tiisg%4;
const int iq = tid/4;
const int ir = tid%4;
const int n = 8;
const short tid = tiisg/4;
const short ix = tiisg%4;
const short iq = tid/4;
const short ir = tid%4;
const int l0 = n*ir;
const int q_offset = 32*iq + l0;
const int y_offset = 64*iq + l0;
const short l0 = 8*ir;
const short q_offset = 32*iq + l0;
const short y_offset = 64*iq + l0;
const uint8_t hm1 = 1u << (2*iq);
const uint8_t hm2 = hm1 << 1;
@@ -4874,14 +4871,14 @@ void kernel_mul_mv_q5_K_f32_impl(
device const float * y2 = y1 + 128;
float4 sumy = {0.f, 0.f, 0.f, 0.f};
for (int l = 0; l < 8; ++l) {
for (short l = 0; l < 8; ++l) {
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
}
for (int row = 0; row < nr0; ++row) {
for (short row = 0; row < nr0; ++row) {
device const uint8_t * q2 = q1 + 64;
sc16[0] = a[0] & kmask1;
@@ -4891,7 +4888,7 @@ void kernel_mul_mv_q5_K_f32_impl(
float4 acc1 = {0.f};
float4 acc2 = {0.f};
for (int l = 0; l < n; ++l) {
for (short l = 0; l < 8; ++l) {
uint8_t h = qh[l];
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
@@ -5102,8 +5099,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
for (int i = 0; i < 32; ++i) {
for (short i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
@@ -5114,18 +5110,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
device const uint16_t * q2 = xr->qs + 4 * ib;
device const half * dh = &xr->d;
for (int row = 0; row < nr0; row++) {
for (short row = 0; row < nr0; row++) {
const float db = dh[0];
device const uint8_t * aux8 = (device const uint8_t *)q2;
const uint32_t aux32 = q2[2] | (q2[3] << 16);
const float d = db * (0.5f + (aux32 >> 28));
float sum = 0;
for (int l = 0; l < 4; ++l) {
for (short l = 0; l < 4; ++l) {
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
for (int j = 0; j < 8; ++j) {
for (short j = 0; j < 8; ++j) {
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
}
@@ -5210,8 +5205,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
for (int i = 0; i < 32; ++i) {
for (short i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
@@ -5223,8 +5217,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
device const uint8_t * sc = xr->scales + ib;
device const half * dh = &xr->d;
for (int row = 0; row < nr0; row++) {
for (short row = 0; row < nr0; row++) {
const float db = dh[0];
const uint8_t ls1 = sc[0] & 0xf;
const uint8_t ls2 = sc[0] >> 4;
@@ -5232,17 +5225,17 @@ void kernel_mul_mv_iq2_xs_f32_impl(
const float d2 = db * (0.5f + ls2);
float sum1 = 0, sum2 = 0;
for (int l = 0; l < 2; ++l) {
for (short l = 0; l < 2; ++l) {
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
const uint8_t signs = ssigns[(q2[l] >> 9)];
for (int j = 0; j < 8; ++j) {
for (short j = 0; j < 8; ++j) {
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
}
for (int l = 2; l < 4; ++l) {
for (short l = 2; l < 4; ++l) {
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
const uint8_t signs = ssigns[(q2[l] >> 9)];
for (int j = 0; j < 8; ++j) {
for (short j = 0; j < 8; ++j) {
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
}
@@ -5329,7 +5322,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
for (int i = 0; i < 32; ++i) {
for (short i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
@@ -5341,17 +5334,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
device const half * dh = &xr->d;
for (int row = 0; row < nr0; row++) {
for (short row = 0; row < nr0; row++) {
const float db = dh[0];
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float d = db * (0.5f + (aux32 >> 28));
float2 sum = {0};
for (int l = 0; l < 4; ++l) {
for (short l = 0; l < 4; ++l) {
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
for (int j = 0; j < 4; ++j) {
for (short j = 0; j < 4; ++j) {
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
@@ -5435,8 +5428,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
for (int i = 0; i < 32; ++i) {
for (short i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
@@ -5450,18 +5442,17 @@ void kernel_mul_mv_iq3_s_f32_impl(
device const uint8_t * signs = xr->signs + 4 * ib;
device const half * dh = &xr->d;
for (int row = 0; row < nr0; row++) {
for (short row = 0; row < nr0; row++) {
const float db = dh[0];
const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
float2 sum = {0};
for (int l = 0; l < 4; ++l) {
for (short l = 0; l < 4; ++l) {
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
for (int j = 0; j < 4; ++j) {
for (short j = 0; j < 4; ++j) {
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
}
@@ -5542,13 +5533,12 @@ void kernel_mul_mv_iq2_s_f32_impl(
// threadgroup_barrier(mem_flags::mem_threadgroup);
//}
const int ix = tiisg;
const short ix = tiisg;
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
for (int i = 0; i < 32; ++i) {
for (short i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
@@ -5562,19 +5552,18 @@ void kernel_mul_mv_iq2_s_f32_impl(
device const uint8_t * signs = qs + QK_K/8;
device const half * dh = &xr->d;
for (int row = 0; row < nr0; row++) {
for (short row = 0; row < nr0; row++) {
const float db = dh[0];
const float d1 = db * (0.5f + (sc[0] & 0xf));
const float d2 = db * (0.5f + (sc[0] >> 4));
float2 sum = {0};
for (int l = 0; l < 2; ++l) {
for (short l = 0; l < 2; ++l) {
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
for (int j = 0; j < 8; ++j) {
for (short j = 0; j < 8; ++j) {
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
}
@@ -5647,14 +5636,13 @@ void kernel_mul_mv_iq1_s_f32_impl(
const int nb32 = nb * (QK_K / 32);
const int ix = tiisg;
const short ix = tiisg;
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
float sumy = 0;
for (int i = 0; i < 32; ++i) {
for (short i = 0; i < 32; ++i) {
yl[i] = y4[i];
sumy += yl[i];
}
@@ -5667,15 +5655,14 @@ void kernel_mul_mv_iq1_s_f32_impl(
device const uint16_t * qh = xr->qh + ib;
device const half * dh = &xr->d;
for (int row = 0; row < nr0; row++) {
for (short row = 0; row < nr0; row++) {
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
float sum = 0;
for (int j = 0; j < 4; ++j) {
for (short j = 0; j < 4; ++j) {
sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
+ yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
@@ -5747,16 +5734,15 @@ void kernel_mul_mv_iq1_m_f32_impl(
const int nb32 = nb * (QK_K / 32);
const int ix = tiisg;
const short ix = tiisg;
device const float * y4 = y + 32 * ix;
iq1m_scale_t scale;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
float4 sumy = {0.f};
for (int i = 0; i < 8; ++i) {
for (short i = 0; i < 8; ++i) {
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
@@ -5771,7 +5757,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
device const uint8_t * qh = xr->qh + 2 * ib;
device const uint16_t * sc = (device const uint16_t *)xr->scales;
for (int row = 0; row < nr0; row++) {
for (short row = 0; row < nr0; row++) {
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
@@ -5780,7 +5766,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
float2 sum = {0.f};
for (int j = 0; j < 4; ++j) {
for (short j = 0; j < 4; ++j) {
sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
@@ -5852,8 +5838,8 @@ void kernel_mul_mv_iq4_nl_f32_impl(
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
const int ix = tiisg/2; // 0...15
const int it = tiisg%2; // 0 or 1
const short ix = tiisg/2; // 0...15
const short it = tiisg%2; // 0 or 1
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -5869,12 +5855,13 @@ void kernel_mul_mv_iq4_nl_f32_impl(
float4 qf1, qf2;
for (int ib = ix; ib < nb; ib += 16) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
for (int row = 0; row < nr0 && first_row + row < args.ne01; ++row) {
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];
for (short row = 0; row < nr0; row++) {
device const block_iq4_nl & xb = x[row*nb + ib];
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
@@ -5899,7 +5886,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
acc1 += acc2;
sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
}
yb += 16 * QK4_NL;
@@ -5956,10 +5942,10 @@ void kernel_mul_mv_iq4_xs_f32_impl(
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
const int ix = tiisg/16; // 0 or 1
const int it = tiisg%16; // 0...15
const int ib = it/2;
const int il = it%2;
const short ix = tiisg/16; // 0 or 1
const short it = tiisg%16; // 0...15
const short ib = it/2;
const short il = it%2;
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);
@@ -5976,9 +5962,12 @@ void kernel_mul_mv_iq4_xs_f32_impl(
for (int ibl = ix; ibl < nb; ibl += 2) {
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
yl[0] = y4[0];
yl[1] = y4[4];
yl[2] = y4[1];
yl[3] = y4[5];
for (int row = 0; row < nr0; ++row) {
for (short row = 0; row < nr0; ++row) {
device const block_iq4_xs & xb = x[row*nb + ibl];
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
@@ -6002,7 +5991,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
}
yb += 2 * QK_K;