diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 869f042664..1d7c518661 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1901,7 +1901,7 @@ void mul_vec_q_n_f32_impl( float sumy[2] = { 0.f, 0.f }; #pragma unroll - for (int i = 0; i < 8; i += 2) { + for (short i = 0; i < 8; i += 2) { sumy[0] += yb[i + 0] + yb[i + 1]; yl[i + 0] = yb[i + 0]; yl[i + 1] = yb[i + 1]/256.f; @@ -1912,7 +1912,7 @@ void mul_vec_q_n_f32_impl( } #pragma unroll - for (int row = 0; row < nr0; row++) { + for (short row = 0; row < nr0; row++) { sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); } @@ -2025,7 +2025,7 @@ void kernel_mul_mv_q8_0_f32_impl( yl[i] = yb[i]; } - for (int row = 0; row < nr0; row++) { + for (short row = 0; row < nr0; row++) { device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; for (short iq = 0; iq < NB_Q8_0; ++iq) { @@ -4449,18 +4449,17 @@ void kernel_mul_mv_q2_K_f32_impl( float yl[32]; float sumf[nr0]={0.f}; - const int ix = tiisg/8; // 0...3 - const int it = tiisg%8; // 0...7 - const int iq = it/4; // 0 or 1 - const int ir = it%4; // 0...3 - const int is = (8*ir)/16;// 0 or 1 + const short ix = tiisg/8; // 0...3 + const short it = tiisg%8; // 0...7 + const short iq = it/4; // 0 or 1 + const short ir = it%4; // 0...3 + const short is = (8*ir)/16;// 0 or 1 device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; for (int ib = ix; ib < nb; ib += 4) { - float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int i = 0; i < 8; ++i) { + for (short i = 0; i < 8; ++i) { yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; @@ -4471,7 +4470,7 @@ void kernel_mul_mv_q2_K_f32_impl( device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; device const half * dh = &x[ib].d; - for (int row = 0; row < nr0; row++) { + for (short row = 0; row < nr0; row++) { float4 acc1 = {0.f, 0.f, 0.f, 0.f}; float4 acc2 = {0.f, 0.f, 0.f, 0.f}; for (int i = 0; i < 8; i += 2) { @@ -4561,8 +4560,7 @@ void kernel_mul_mv_q3_K_f32_impl( const int ip = tid/4; // 0 or 1 const int il = 2*((tid%4)/2); // 0 or 2 const int ir = tid%2; - const int n = 8; - const int l0 = n*ir; + const int l0 = 8*ir; // One would think that the Metal compiler would figure out that ip and il can only have // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it @@ -4600,7 +4598,7 @@ void kernel_mul_mv_q3_K_f32_impl( float sumf2[nr0] = {0.f}; for (int i = ix; i < nb; i += 4) { - for (int l = 0; l < 8; ++l) { + for (short l = 0; l < 8; ++l) { yl[l+ 0] = y1[l+ 0]; yl[l+ 8] = y1[l+16]; yl[l+16] = y1[l+32]; @@ -4612,7 +4610,7 @@ void kernel_mul_mv_q3_K_f32_impl( device const uint16_t * a = (device const uint16_t *)(x[i].scales); device const half * dh = &x[i].d; - for (int row = 0; row < nr0; ++row) { + for (short row = 0; row < nr0; ++row) { const float d_all = (float)dh[0]; scales16[0] = a[4]; @@ -4623,7 +4621,7 @@ void kernel_mul_mv_q3_K_f32_impl( scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; - for (int l = 0; l < n; l += 2) { + for (short l = 0; l < 8; l += 2) { const int32_t qs = q[l/2]; s1 += yl[l+0] * (qs & qm[il/2][0]); s2 += yl[l+1] * (qs & qm[il/2][1]); @@ -4638,7 +4636,7 @@ void kernel_mul_mv_q3_K_f32_impl( sumf2[row] += d2 * (scales[2] - 32); s1 = s2 = s3 = s4 = s5 = s6 = 0; - for (int l = 0; l < n; l += 2) { + for (short l = 0; l < 8; l += 2) { const int32_t qs = q[l/2+8]; s1 += yl[l+8] * (qs & qm[il/2][0]); s2 += yl[l+9] * (qs & qm[il/2][1]); @@ -4846,15 +4844,14 @@ void kernel_mul_mv_q5_K_f32_impl( const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = tiisg/4; - const int ix = tiisg%4; - const int iq = tid/4; - const int ir = tid%4; - const int n = 8; + const short tid = tiisg/4; + const short ix = tiisg%4; + const short iq = tid/4; + const short ir = tid%4; - const int l0 = n*ir; - const int q_offset = 32*iq + l0; - const int y_offset = 64*iq + l0; + const short l0 = 8*ir; + const short q_offset = 32*iq + l0; + const short y_offset = 64*iq + l0; const uint8_t hm1 = 1u << (2*iq); const uint8_t hm2 = hm1 << 1; @@ -4874,14 +4871,14 @@ void kernel_mul_mv_q5_K_f32_impl( device const float * y2 = y1 + 128; float4 sumy = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < 8; ++l) { + for (short l = 0; l < 8; ++l) { yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; } - for (int row = 0; row < nr0; ++row) { + for (short row = 0; row < nr0; ++row) { device const uint8_t * q2 = q1 + 64; sc16[0] = a[0] & kmask1; @@ -4891,7 +4888,7 @@ void kernel_mul_mv_q5_K_f32_impl( float4 acc1 = {0.f}; float4 acc2 = {0.f}; - for (int l = 0; l < n; ++l) { + for (short l = 0; l < 8; ++l) { uint8_t h = qh[l]; acc1[0] += yl[l+0] * (q1[l] & 0x0F); acc1[1] += yl[l+8] * (q1[l] & 0xF0); @@ -5102,8 +5099,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5114,18 +5110,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const uint16_t * q2 = xr->qs + 4 * ib; device const half * dh = &xr->d; - for (int row = 0; row < nr0; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; device const uint8_t * aux8 = (device const uint8_t *)q2; const uint32_t aux32 = q2[2] | (q2[3] << 16); const float d = db * (0.5f + (aux32 >> 28)); float sum = 0; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } @@ -5210,8 +5205,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5223,8 +5217,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const uint8_t * sc = xr->scales + ib; device const half * dh = &xr->d; - for (int row = 0; row < nr0; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const uint8_t ls1 = sc[0] & 0xf; const uint8_t ls2 = sc[0] >> 4; @@ -5232,17 +5225,17 @@ void kernel_mul_mv_iq2_xs_f32_impl( const float d2 = db * (0.5f + ls2); float sum1 = 0, sum2 = 0; - for (int l = 0; l < 2; ++l) { + for (short l = 0; l < 2; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); const uint8_t signs = ssigns[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } - for (int l = 2; l < 4; ++l) { + for (short l = 2; l < 4; ++l) { const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); const uint8_t signs = ssigns[(q2[l] >> 9)]; - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } @@ -5329,7 +5322,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5341,17 +5334,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; device const half * dh = &xr->d; - for (int row = 0; row < nr0; row++) { + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const uint32_t aux32 = gas[0] | (gas[1] << 16); const float d = db * (0.5f + (aux32 >> 28)); float2 sum = {0}; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); } @@ -5435,8 +5428,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5450,18 +5442,17 @@ void kernel_mul_mv_iq3_s_f32_impl( device const uint8_t * signs = xr->signs + 4 * ib; device const half * dh = &xr->d; - for (int row = 0; row < nr0; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); float2 sum = {0}; - for (int l = 0; l < 4; ++l) { + for (short l = 0; l < 4; ++l) { const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); } @@ -5542,13 +5533,12 @@ void kernel_mul_mv_iq2_s_f32_impl( // threadgroup_barrier(mem_flags::mem_threadgroup); //} - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5562,19 +5552,18 @@ void kernel_mul_mv_iq2_s_f32_impl( device const uint8_t * signs = qs + QK_K/8; device const half * dh = &xr->d; - for (int row = 0; row < nr0; row++) { - + for (short row = 0; row < nr0; row++) { const float db = dh[0]; const float d1 = db * (0.5f + (sc[0] & 0xf)); const float d2 = db * (0.5f + (sc[0] >> 4)); float2 sum = {0}; - for (int l = 0; l < 2; ++l) { + for (short l = 0; l < 2; ++l) { //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); - for (int j = 0; j < 8; ++j) { + for (short j = 0; j < 8; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); } @@ -5647,14 +5636,13 @@ void kernel_mul_mv_iq1_s_f32_impl( const int nb32 = nb * (QK_K / 32); - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - float sumy = 0; - for (int i = 0; i < 32; ++i) { + for (short i = 0; i < 32; ++i) { yl[i] = y4[i]; sumy += yl[i]; } @@ -5667,15 +5655,14 @@ void kernel_mul_mv_iq1_s_f32_impl( device const uint16_t * qh = xr->qh + ib; device const half * dh = &xr->d; - for (int row = 0; row < nr0; row++) { - + for (short row = 0; row < nr0; row++) { constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); float sum = 0; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) @@ -5747,16 +5734,15 @@ void kernel_mul_mv_iq1_m_f32_impl( const int nb32 = nb * (QK_K / 32); - const int ix = tiisg; + const short ix = tiisg; device const float * y4 = y + 32 * ix; iq1m_scale_t scale; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - float4 sumy = {0.f}; - for (int i = 0; i < 8; ++i) { + for (short i = 0; i < 8; ++i) { yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; @@ -5771,7 +5757,7 @@ void kernel_mul_mv_iq1_m_f32_impl( device const uint8_t * qh = xr->qh + 2 * ib; device const uint16_t * sc = (device const uint16_t *)xr->scales; - for (int row = 0; row < nr0; row++) { + for (short row = 0; row < nr0; row++) { scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); @@ -5780,7 +5766,7 @@ void kernel_mul_mv_iq1_m_f32_impl( constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); float2 sum = {0.f}; - for (int j = 0; j < 4; ++j) { + for (short j = 0; j < 4; ++j) { sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) @@ -5852,8 +5838,8 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); - const int ix = tiisg/2; // 0...15 - const int it = tiisg%2; // 0 or 1 + const short ix = tiisg/2; // 0...15 + const short it = tiisg%2; // 0 or 1 shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); @@ -5869,12 +5855,13 @@ void kernel_mul_mv_iq4_nl_f32_impl( float4 qf1, qf2; for (int ib = ix; ib < nb; ib += 16) { - device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - - for (int row = 0; row < nr0 && first_row + row < args.ne01; ++row) { + yl[0] = y4[0]; + yl[1] = y4[4]; + yl[2] = y4[1]; + yl[3] = y4[5]; + for (short row = 0; row < nr0; row++) { device const block_iq4_nl & xb = x[row*nb + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); @@ -5899,7 +5886,6 @@ void kernel_mul_mv_iq4_nl_f32_impl( acc1 += acc2; sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - } yb += 16 * QK4_NL; @@ -5956,10 +5942,10 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); - const int ix = tiisg/16; // 0 or 1 - const int it = tiisg%16; // 0...15 - const int ib = it/2; - const int il = it%2; + const short ix = tiisg/16; // 0 or 1 + const short it = tiisg%16; // 0...15 + const short ib = it/2; + const short il = it%2; shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); @@ -5976,9 +5962,12 @@ void kernel_mul_mv_iq4_xs_f32_impl( for (int ibl = ix; ibl < nb; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; - yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; + yl[0] = y4[0]; + yl[1] = y4[4]; + yl[2] = y4[1]; + yl[3] = y4[5]; - for (int row = 0; row < nr0; ++row) { + for (short row = 0; row < nr0; ++row) { device const block_iq4_xs & xb = x[row*nb + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); @@ -6002,7 +5991,6 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); - } yb += 2 * QK_K;