mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-31 08:51:55 +00:00
metal : reduce register pressure
ggml-ci
This commit is contained in:
@@ -1901,7 +1901,7 @@ void mul_vec_q_n_f32_impl(
|
||||
float sumy[2] = { 0.f, 0.f };
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
for (short i = 0; i < 8; i += 2) {
|
||||
sumy[0] += yb[i + 0] + yb[i + 1];
|
||||
yl[i + 0] = yb[i + 0];
|
||||
yl[i + 1] = yb[i + 1]/256.f;
|
||||
@@ -1912,7 +1912,7 @@ void mul_vec_q_n_f32_impl(
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int row = 0; row < nr0; row++) {
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
|
||||
}
|
||||
|
||||
@@ -2025,7 +2025,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
||||
yl[i] = yb[i];
|
||||
}
|
||||
|
||||
for (int row = 0; row < nr0; row++) {
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
|
||||
float sumq = 0.f;
|
||||
for (short iq = 0; iq < NB_Q8_0; ++iq) {
|
||||
@@ -4449,18 +4449,17 @@ void kernel_mul_mv_q2_K_f32_impl(
|
||||
float yl[32];
|
||||
float sumf[nr0]={0.f};
|
||||
|
||||
const int ix = tiisg/8; // 0...3
|
||||
const int it = tiisg%8; // 0...7
|
||||
const int iq = it/4; // 0 or 1
|
||||
const int ir = it%4; // 0...3
|
||||
const int is = (8*ir)/16;// 0 or 1
|
||||
const short ix = tiisg/8; // 0...3
|
||||
const short it = tiisg%8; // 0...7
|
||||
const short iq = it/4; // 0 or 1
|
||||
const short ir = it%4; // 0...3
|
||||
const short is = (8*ir)/16;// 0 or 1
|
||||
|
||||
device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 4) {
|
||||
|
||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
||||
yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8];
|
||||
yl[i+16] = y4[i+64]; sumy[2] += yl[i+16];
|
||||
@@ -4471,7 +4470,7 @@ void kernel_mul_mv_q2_K_f32_impl(
|
||||
device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir;
|
||||
device const half * dh = &x[ib].d;
|
||||
|
||||
for (int row = 0; row < nr0; row++) {
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
||||
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
@@ -4561,8 +4560,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
||||
const int ip = tid/4; // 0 or 1
|
||||
const int il = 2*((tid%4)/2); // 0 or 2
|
||||
const int ir = tid%2;
|
||||
const int n = 8;
|
||||
const int l0 = n*ir;
|
||||
const int l0 = 8*ir;
|
||||
|
||||
// One would think that the Metal compiler would figure out that ip and il can only have
|
||||
// 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
|
||||
@@ -4600,7 +4598,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
||||
float sumf2[nr0] = {0.f};
|
||||
|
||||
for (int i = ix; i < nb; i += 4) {
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
for (short l = 0; l < 8; ++l) {
|
||||
yl[l+ 0] = y1[l+ 0];
|
||||
yl[l+ 8] = y1[l+16];
|
||||
yl[l+16] = y1[l+32];
|
||||
@@ -4612,7 +4610,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
||||
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
|
||||
device const half * dh = &x[i].d;
|
||||
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
for (short row = 0; row < nr0; ++row) {
|
||||
const float d_all = (float)dh[0];
|
||||
|
||||
scales16[0] = a[4];
|
||||
@@ -4623,7 +4621,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
||||
scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
|
||||
|
||||
float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
|
||||
for (int l = 0; l < n; l += 2) {
|
||||
for (short l = 0; l < 8; l += 2) {
|
||||
const int32_t qs = q[l/2];
|
||||
s1 += yl[l+0] * (qs & qm[il/2][0]);
|
||||
s2 += yl[l+1] * (qs & qm[il/2][1]);
|
||||
@@ -4638,7 +4636,7 @@ void kernel_mul_mv_q3_K_f32_impl(
|
||||
sumf2[row] += d2 * (scales[2] - 32);
|
||||
|
||||
s1 = s2 = s3 = s4 = s5 = s6 = 0;
|
||||
for (int l = 0; l < n; l += 2) {
|
||||
for (short l = 0; l < 8; l += 2) {
|
||||
const int32_t qs = q[l/2+8];
|
||||
s1 += yl[l+8] * (qs & qm[il/2][0]);
|
||||
s2 += yl[l+9] * (qs & qm[il/2][1]);
|
||||
@@ -4846,15 +4844,14 @@ void kernel_mul_mv_q5_K_f32_impl(
|
||||
const uint16_t kmask2 = 0x0f0f;
|
||||
const uint16_t kmask3 = 0xc0c0;
|
||||
|
||||
const int tid = tiisg/4;
|
||||
const int ix = tiisg%4;
|
||||
const int iq = tid/4;
|
||||
const int ir = tid%4;
|
||||
const int n = 8;
|
||||
const short tid = tiisg/4;
|
||||
const short ix = tiisg%4;
|
||||
const short iq = tid/4;
|
||||
const short ir = tid%4;
|
||||
|
||||
const int l0 = n*ir;
|
||||
const int q_offset = 32*iq + l0;
|
||||
const int y_offset = 64*iq + l0;
|
||||
const short l0 = 8*ir;
|
||||
const short q_offset = 32*iq + l0;
|
||||
const short y_offset = 64*iq + l0;
|
||||
|
||||
const uint8_t hm1 = 1u << (2*iq);
|
||||
const uint8_t hm2 = hm1 << 1;
|
||||
@@ -4874,14 +4871,14 @@ void kernel_mul_mv_q5_K_f32_impl(
|
||||
|
||||
device const float * y2 = y1 + 128;
|
||||
float4 sumy = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int l = 0; l < 8; ++l) {
|
||||
for (short l = 0; l < 8; ++l) {
|
||||
yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0];
|
||||
yl[l+8] = y1[l+32]; sumy[1] += yl[l+8];
|
||||
yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0];
|
||||
yh[l+8] = y2[l+32]; sumy[3] += yh[l+8];
|
||||
}
|
||||
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
for (short row = 0; row < nr0; ++row) {
|
||||
device const uint8_t * q2 = q1 + 64;
|
||||
|
||||
sc16[0] = a[0] & kmask1;
|
||||
@@ -4891,7 +4888,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
||||
|
||||
float4 acc1 = {0.f};
|
||||
float4 acc2 = {0.f};
|
||||
for (int l = 0; l < n; ++l) {
|
||||
for (short l = 0; l < 8; ++l) {
|
||||
uint8_t h = qh[l];
|
||||
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
||||
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
||||
@@ -5102,8 +5099,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||
device const float * y4 = y + 32 * ix;
|
||||
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
for (short i = 0; i < 32; ++i) {
|
||||
yl[i] = y4[i];
|
||||
}
|
||||
|
||||
@@ -5114,18 +5110,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||
device const uint16_t * q2 = xr->qs + 4 * ib;
|
||||
device const half * dh = &xr->d;
|
||||
|
||||
for (int row = 0; row < nr0; row++) {
|
||||
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
const float db = dh[0];
|
||||
device const uint8_t * aux8 = (device const uint8_t *)q2;
|
||||
const uint32_t aux32 = q2[2] | (q2[3] << 16);
|
||||
const float d = db * (0.5f + (aux32 >> 28));
|
||||
|
||||
float sum = 0;
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
for (short l = 0; l < 4; ++l) {
|
||||
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
|
||||
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (short j = 0; j < 8; ++j) {
|
||||
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
}
|
||||
}
|
||||
@@ -5210,8 +5205,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||
device const float * y4 = y + 32 * ix;
|
||||
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
for (short i = 0; i < 32; ++i) {
|
||||
yl[i] = y4[i];
|
||||
}
|
||||
|
||||
@@ -5223,8 +5217,7 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||
device const uint8_t * sc = xr->scales + ib;
|
||||
device const half * dh = &xr->d;
|
||||
|
||||
for (int row = 0; row < nr0; row++) {
|
||||
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
const float db = dh[0];
|
||||
const uint8_t ls1 = sc[0] & 0xf;
|
||||
const uint8_t ls2 = sc[0] >> 4;
|
||||
@@ -5232,17 +5225,17 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||
const float d2 = db * (0.5f + ls2);
|
||||
|
||||
float sum1 = 0, sum2 = 0;
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
for (short l = 0; l < 2; ++l) {
|
||||
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
|
||||
const uint8_t signs = ssigns[(q2[l] >> 9)];
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (short j = 0; j < 8; ++j) {
|
||||
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
}
|
||||
}
|
||||
for (int l = 2; l < 4; ++l) {
|
||||
for (short l = 2; l < 4; ++l) {
|
||||
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
|
||||
const uint8_t signs = ssigns[(q2[l] >> 9)];
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (short j = 0; j < 8; ++j) {
|
||||
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||
}
|
||||
}
|
||||
@@ -5329,7 +5322,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||
device const float * y4 = y + 32 * ix;
|
||||
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
for (short i = 0; i < 32; ++i) {
|
||||
yl[i] = y4[i];
|
||||
}
|
||||
|
||||
@@ -5341,17 +5334,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||
device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
|
||||
device const half * dh = &xr->d;
|
||||
|
||||
for (int row = 0; row < nr0; row++) {
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
const float db = dh[0];
|
||||
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
||||
const float d = db * (0.5f + (aux32 >> 28));
|
||||
|
||||
float2 sum = {0};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
for (short l = 0; l < 4; ++l) {
|
||||
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
|
||||
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
|
||||
const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
for (short j = 0; j < 4; ++j) {
|
||||
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
||||
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||||
}
|
||||
@@ -5435,8 +5428,7 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
||||
device const float * y4 = y + 32 * ix;
|
||||
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
for (short i = 0; i < 32; ++i) {
|
||||
yl[i] = y4[i];
|
||||
}
|
||||
|
||||
@@ -5450,18 +5442,17 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
||||
device const uint8_t * signs = xr->signs + 4 * ib;
|
||||
device const half * dh = &xr->d;
|
||||
|
||||
for (int row = 0; row < nr0; row++) {
|
||||
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
const float db = dh[0];
|
||||
const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf));
|
||||
|
||||
float2 sum = {0};
|
||||
for (int l = 0; l < 4; ++l) {
|
||||
for (short l = 0; l < 4; ++l) {
|
||||
const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
|
||||
const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
|
||||
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
|
||||
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
for (short j = 0; j < 4; ++j) {
|
||||
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]);
|
||||
sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]);
|
||||
}
|
||||
@@ -5542,13 +5533,12 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
||||
// threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
//}
|
||||
|
||||
const int ix = tiisg;
|
||||
const short ix = tiisg;
|
||||
|
||||
device const float * y4 = y + 32 * ix;
|
||||
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
for (short i = 0; i < 32; ++i) {
|
||||
yl[i] = y4[i];
|
||||
}
|
||||
|
||||
@@ -5562,19 +5552,18 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
||||
device const uint8_t * signs = qs + QK_K/8;
|
||||
device const half * dh = &xr->d;
|
||||
|
||||
for (int row = 0; row < nr0; row++) {
|
||||
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
const float db = dh[0];
|
||||
const float d1 = db * (0.5f + (sc[0] & 0xf));
|
||||
const float d2 = db * (0.5f + (sc[0] >> 4));
|
||||
|
||||
float2 sum = {0};
|
||||
for (int l = 0; l < 2; ++l) {
|
||||
for (short l = 0; l < 2; ++l) {
|
||||
//const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
||||
//const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
for (short j = 0; j < 8; ++j) {
|
||||
sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]);
|
||||
sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]);
|
||||
}
|
||||
@@ -5647,14 +5636,13 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
||||
|
||||
const int nb32 = nb * (QK_K / 32);
|
||||
|
||||
const int ix = tiisg;
|
||||
const short ix = tiisg;
|
||||
|
||||
device const float * y4 = y + 32 * ix;
|
||||
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||
|
||||
float sumy = 0;
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
for (short i = 0; i < 32; ++i) {
|
||||
yl[i] = y4[i];
|
||||
sumy += yl[i];
|
||||
}
|
||||
@@ -5667,15 +5655,14 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
||||
device const uint16_t * qh = xr->qh + ib;
|
||||
device const half * dh = &xr->d;
|
||||
|
||||
for (int row = 0; row < nr0; row++) {
|
||||
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700)));
|
||||
constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700)));
|
||||
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700)));
|
||||
|
||||
float sum = 0;
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
for (short j = 0; j < 4; ++j) {
|
||||
sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
||||
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4)
|
||||
+ yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
||||
@@ -5747,16 +5734,15 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
||||
|
||||
const int nb32 = nb * (QK_K / 32);
|
||||
|
||||
const int ix = tiisg;
|
||||
const short ix = tiisg;
|
||||
|
||||
device const float * y4 = y + 32 * ix;
|
||||
|
||||
iq1m_scale_t scale;
|
||||
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||
|
||||
float4 sumy = {0.f};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
for (short i = 0; i < 8; ++i) {
|
||||
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
||||
yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
|
||||
yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
|
||||
@@ -5771,7 +5757,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
||||
device const uint8_t * qh = xr->qh + 2 * ib;
|
||||
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
||||
|
||||
for (int row = 0; row < nr0; row++) {
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
||||
@@ -5780,7 +5766,7 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
||||
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
|
||||
|
||||
float2 sum = {0.f};
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
for (short j = 0; j < 4; ++j) {
|
||||
sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
||||
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
|
||||
sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
||||
@@ -5852,8 +5838,8 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
const int ix = tiisg/2; // 0...15
|
||||
const int it = tiisg%2; // 0 or 1
|
||||
const short ix = tiisg/2; // 0...15
|
||||
const short it = tiisg%2; // 0 or 1
|
||||
|
||||
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@@ -5869,12 +5855,13 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
float4 qf1, qf2;
|
||||
|
||||
for (int ib = ix; ib < nb; ib += 16) {
|
||||
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
||||
|
||||
for (int row = 0; row < nr0 && first_row + row < args.ne01; ++row) {
|
||||
yl[0] = y4[0];
|
||||
yl[1] = y4[4];
|
||||
yl[2] = y4[1];
|
||||
yl[3] = y4[5];
|
||||
|
||||
for (short row = 0; row < nr0; row++) {
|
||||
device const block_iq4_nl & xb = x[row*nb + ib];
|
||||
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
|
||||
|
||||
@@ -5899,7 +5886,6 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
acc1 += acc2;
|
||||
|
||||
sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
||||
|
||||
}
|
||||
|
||||
yb += 16 * QK4_NL;
|
||||
@@ -5956,10 +5942,10 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
const int ix = tiisg/16; // 0 or 1
|
||||
const int it = tiisg%16; // 0...15
|
||||
const int ib = it/2;
|
||||
const int il = it%2;
|
||||
const short ix = tiisg/16; // 0 or 1
|
||||
const short it = tiisg%16; // 0...15
|
||||
const short ib = it/2;
|
||||
const short il = it%2;
|
||||
|
||||
shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
@@ -5976,9 +5962,12 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
|
||||
for (int ibl = ix; ibl < nb; ibl += 2) {
|
||||
device const float4 * y4 = (device const float4 *)yb;
|
||||
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
|
||||
yl[0] = y4[0];
|
||||
yl[1] = y4[4];
|
||||
yl[2] = y4[1];
|
||||
yl[3] = y4[5];
|
||||
|
||||
for (int row = 0; row < nr0; ++row) {
|
||||
for (short row = 0; row < nr0; ++row) {
|
||||
device const block_iq4_xs & xb = x[row*nb + ibl];
|
||||
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
|
||||
|
||||
@@ -6002,7 +5991,6 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
|
||||
const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32;
|
||||
sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]);
|
||||
|
||||
}
|
||||
|
||||
yb += 2 * QK_K;
|
||||
|
||||
Reference in New Issue
Block a user