mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : reduce register pressure
ggml-ci
This commit is contained in:
		| @@ -1901,7 +1901,7 @@ void mul_vec_q_n_f32_impl( | ||||
|         float sumy[2] = { 0.f, 0.f }; | ||||
|  | ||||
| #pragma unroll | ||||
|         for (int i = 0; i < 8; i += 2) { | ||||
|         for (short i = 0; i < 8; i += 2) { | ||||
|             sumy[0]  += yb[i +  0] + yb[i +  1]; | ||||
|             yl[i + 0] = yb[i +  0]; | ||||
|             yl[i + 1] = yb[i +  1]/256.f; | ||||
| @@ -1912,7 +1912,7 @@ void mul_vec_q_n_f32_impl( | ||||
|         } | ||||
|  | ||||
| #pragma unroll | ||||
|         for (int row = 0; row < nr0; row++) { | ||||
|         for (short row = 0; row < nr0; row++) { | ||||
|             sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); | ||||
|         } | ||||
|  | ||||
| @@ -2025,7 +2025,7 @@ void kernel_mul_mv_q8_0_f32_impl( | ||||
|             yl[i] = yb[i]; | ||||
|         } | ||||
|  | ||||
|         for (int row = 0; row < nr0; row++) { | ||||
|         for (short row = 0; row < nr0; row++) { | ||||
|             device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; | ||||
|             float sumq = 0.f; | ||||
|             for (short iq = 0; iq < NB_Q8_0; ++iq) { | ||||
| @@ -4449,18 +4449,17 @@ void kernel_mul_mv_q2_K_f32_impl( | ||||
|     float yl[32]; | ||||
|     float sumf[nr0]={0.f}; | ||||
|  | ||||
|     const int ix = tiisg/8;  // 0...3 | ||||
|     const int it = tiisg%8;  // 0...7 | ||||
|     const int iq = it/4;     // 0 or 1 | ||||
|     const int ir = it%4;     // 0...3 | ||||
|     const int is = (8*ir)/16;// 0 or 1 | ||||
|     const short ix = tiisg/8;  // 0...3 | ||||
|     const short it = tiisg%8;  // 0...7 | ||||
|     const short iq = it/4;     // 0 or 1 | ||||
|     const short ir = it%4;     // 0...3 | ||||
|     const short is = (8*ir)/16;// 0 or 1 | ||||
|  | ||||
|     device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; | ||||
|  | ||||
|     for (int ib = ix; ib < nb; ib += 4) { | ||||
|  | ||||
|         float4 sumy = {0.f, 0.f, 0.f, 0.f}; | ||||
|         for (int i = 0; i < 8; ++i) { | ||||
|         for (short i = 0; i < 8; ++i) { | ||||
|             yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; | ||||
|             yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; | ||||
|             yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; | ||||
| @@ -4471,7 +4470,7 @@ void kernel_mul_mv_q2_K_f32_impl( | ||||
|         device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; | ||||
|         device const half     * dh = &x[ib].d; | ||||
|  | ||||
|         for (int row = 0; row < nr0; row++) { | ||||
|         for (short row = 0; row < nr0; row++) { | ||||
|             float4 acc1 = {0.f, 0.f, 0.f, 0.f}; | ||||
|             float4 acc2 = {0.f, 0.f, 0.f, 0.f}; | ||||
|             for (int i = 0; i < 8; i += 2) { | ||||
| @@ -4561,8 +4560,7 @@ void kernel_mul_mv_q3_K_f32_impl( | ||||
|     const int ip  = tid/4;          // 0 or 1 | ||||
|     const int il  = 2*((tid%4)/2);  // 0 or 2 | ||||
|     const int ir  = tid%2; | ||||
|     const int n   = 8; | ||||
|     const int l0  = n*ir; | ||||
|     const int l0  = 8*ir; | ||||
|  | ||||
|     // One would think that the Metal compiler would figure out that ip and il can only have | ||||
|     // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it | ||||
| @@ -4600,7 +4598,7 @@ void kernel_mul_mv_q3_K_f32_impl( | ||||
|     float sumf2[nr0] = {0.f}; | ||||
|  | ||||
|     for (int i = ix; i < nb; i += 4) { | ||||
|         for (int l = 0; l < 8; ++l) { | ||||
|         for (short l = 0; l < 8; ++l) { | ||||
|             yl[l+ 0] = y1[l+ 0]; | ||||
|             yl[l+ 8] = y1[l+16]; | ||||
|             yl[l+16] = y1[l+32]; | ||||
| @@ -4612,7 +4610,7 @@ void kernel_mul_mv_q3_K_f32_impl( | ||||
|         device const uint16_t * a = (device const uint16_t *)(x[i].scales); | ||||
|         device const half * dh = &x[i].d; | ||||
|  | ||||
|         for (int row = 0; row < nr0; ++row) { | ||||
|         for (short row = 0; row < nr0; ++row) { | ||||
|             const float d_all = (float)dh[0]; | ||||
|  | ||||
|             scales16[0] = a[4]; | ||||
| @@ -4623,7 +4621,7 @@ void kernel_mul_mv_q3_K_f32_impl( | ||||
|             scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32; | ||||
|  | ||||
|             float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0; | ||||
|             for (int l = 0; l < n; l += 2) { | ||||
|             for (short l = 0; l < 8; l += 2) { | ||||
|                 const int32_t qs = q[l/2]; | ||||
|                 s1 += yl[l+0] * (qs & qm[il/2][0]); | ||||
|                 s2 += yl[l+1] * (qs & qm[il/2][1]); | ||||
| @@ -4638,7 +4636,7 @@ void kernel_mul_mv_q3_K_f32_impl( | ||||
|             sumf2[row] += d2 * (scales[2] - 32); | ||||
|  | ||||
|             s1 = s2 = s3 = s4 = s5 = s6 = 0; | ||||
|             for (int l = 0; l < n; l += 2) { | ||||
|             for (short l = 0; l < 8; l += 2) { | ||||
|                 const int32_t qs = q[l/2+8]; | ||||
|                 s1 += yl[l+8] * (qs & qm[il/2][0]); | ||||
|                 s2 += yl[l+9] * (qs & qm[il/2][1]); | ||||
| @@ -4846,15 +4844,14 @@ void kernel_mul_mv_q5_K_f32_impl( | ||||
|     const uint16_t kmask2 = 0x0f0f; | ||||
|     const uint16_t kmask3 = 0xc0c0; | ||||
|  | ||||
|     const int tid = tiisg/4; | ||||
|     const int ix  = tiisg%4; | ||||
|     const int iq  = tid/4; | ||||
|     const int ir  = tid%4; | ||||
|     const int n   = 8; | ||||
|     const short tid = tiisg/4; | ||||
|     const short ix  = tiisg%4; | ||||
|     const short iq  = tid/4; | ||||
|     const short ir  = tid%4; | ||||
|  | ||||
|     const int l0 = n*ir; | ||||
|     const int q_offset = 32*iq + l0; | ||||
|     const int y_offset = 64*iq + l0; | ||||
|     const short l0 = 8*ir; | ||||
|     const short q_offset = 32*iq + l0; | ||||
|     const short y_offset = 64*iq + l0; | ||||
|  | ||||
|     const uint8_t hm1 = 1u << (2*iq); | ||||
|     const uint8_t hm2 = hm1 << 1; | ||||
| @@ -4874,14 +4871,14 @@ void kernel_mul_mv_q5_K_f32_impl( | ||||
|  | ||||
|         device const float * y2 = y1 + 128; | ||||
|         float4 sumy = {0.f, 0.f, 0.f, 0.f}; | ||||
|         for (int l = 0; l < 8; ++l) { | ||||
|         for (short l = 0; l < 8; ++l) { | ||||
|             yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; | ||||
|             yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; | ||||
|             yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; | ||||
|             yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; | ||||
|         } | ||||
|  | ||||
|         for (int row = 0; row < nr0; ++row) { | ||||
|         for (short row = 0; row < nr0; ++row) { | ||||
|             device const uint8_t * q2 = q1 + 64; | ||||
|  | ||||
|             sc16[0] = a[0] & kmask1; | ||||
| @@ -4891,7 +4888,7 @@ void kernel_mul_mv_q5_K_f32_impl( | ||||
|  | ||||
|             float4 acc1 = {0.f}; | ||||
|             float4 acc2 = {0.f}; | ||||
|             for (int l = 0; l < n; ++l) { | ||||
|             for (short l = 0; l < 8; ++l) { | ||||
|                 uint8_t h = qh[l]; | ||||
|                 acc1[0] += yl[l+0] * (q1[l] & 0x0F); | ||||
|                 acc1[1] += yl[l+8] * (q1[l] & 0xF0); | ||||
| @@ -5102,8 +5099,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( | ||||
|     device const float * y4 = y + 32 * ix; | ||||
|  | ||||
|     for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||||
|  | ||||
|         for (int i = 0; i < 32; ++i) { | ||||
|         for (short i = 0; i < 32; ++i) { | ||||
|             yl[i] = y4[i]; | ||||
|         } | ||||
|  | ||||
| @@ -5114,18 +5110,17 @@ void kernel_mul_mv_iq2_xxs_f32_impl( | ||||
|         device const uint16_t * q2 = xr->qs + 4 * ib; | ||||
|         device const half * dh = &xr->d; | ||||
|  | ||||
|         for (int row = 0; row < nr0; row++) { | ||||
|  | ||||
|         for (short row = 0; row < nr0; row++) { | ||||
|             const float db = dh[0]; | ||||
|             device const uint8_t * aux8 = (device const uint8_t *)q2; | ||||
|             const uint32_t aux32 = q2[2] | (q2[3] << 16); | ||||
|             const float d = db * (0.5f + (aux32 >> 28)); | ||||
|  | ||||
|             float sum = 0; | ||||
|             for (int l = 0; l < 4; ++l) { | ||||
|             for (short l = 0; l < 4; ++l) { | ||||
|                 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); | ||||
|                 const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; | ||||
|                 for (int j = 0; j < 8; ++j) { | ||||
|                 for (short j = 0; j < 8; ++j) { | ||||
|                     sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); | ||||
|                 } | ||||
|             } | ||||
| @@ -5210,8 +5205,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( | ||||
|     device const float * y4 = y + 32 * ix; | ||||
|  | ||||
|     for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||||
|  | ||||
|         for (int i = 0; i < 32; ++i) { | ||||
|         for (short i = 0; i < 32; ++i) { | ||||
|             yl[i] = y4[i]; | ||||
|         } | ||||
|  | ||||
| @@ -5223,8 +5217,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( | ||||
|         device const uint8_t  * sc = xr->scales + ib; | ||||
|         device const half * dh = &xr->d; | ||||
|  | ||||
|         for (int row = 0; row < nr0; row++) { | ||||
|  | ||||
|         for (short row = 0; row < nr0; row++) { | ||||
|             const float db = dh[0]; | ||||
|             const uint8_t ls1 = sc[0] & 0xf; | ||||
|             const uint8_t ls2 = sc[0] >>  4; | ||||
| @@ -5232,17 +5225,17 @@ void kernel_mul_mv_iq2_xs_f32_impl( | ||||
|             const float d2 = db * (0.5f + ls2); | ||||
|  | ||||
|             float sum1 = 0, sum2 = 0; | ||||
|             for (int l = 0; l < 2; ++l) { | ||||
|             for (short l = 0; l < 2; ++l) { | ||||
|                 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); | ||||
|                 const uint8_t signs = ssigns[(q2[l] >> 9)]; | ||||
|                 for (int j = 0; j < 8; ++j) { | ||||
|                 for (short j = 0; j < 8; ++j) { | ||||
|                     sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); | ||||
|                 } | ||||
|             } | ||||
|             for (int l = 2; l < 4; ++l) { | ||||
|             for (short l = 2; l < 4; ++l) { | ||||
|                 const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); | ||||
|                 const uint8_t signs = ssigns[(q2[l] >> 9)]; | ||||
|                 for (int j = 0; j < 8; ++j) { | ||||
|                 for (short j = 0; j < 8; ++j) { | ||||
|                     sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); | ||||
|                 } | ||||
|             } | ||||
| @@ -5329,7 +5322,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( | ||||
|     device const float * y4 = y + 32 * ix; | ||||
|  | ||||
|     for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||||
|         for (int i = 0; i < 32; ++i) { | ||||
|         for (short i = 0; i < 32; ++i) { | ||||
|             yl[i] = y4[i]; | ||||
|         } | ||||
|  | ||||
| @@ -5341,17 +5334,17 @@ void kernel_mul_mv_iq3_xxs_f32_impl( | ||||
|         device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib; | ||||
|         device const half * dh = &xr->d; | ||||
|  | ||||
|         for (int row = 0; row < nr0; row++) { | ||||
|         for (short row = 0; row < nr0; row++) { | ||||
|             const float db = dh[0]; | ||||
|             const uint32_t aux32 = gas[0] | (gas[1] << 16); | ||||
|             const float d = db * (0.5f + (aux32 >> 28)); | ||||
|  | ||||
|             float2 sum = {0}; | ||||
|             for (int l = 0; l < 4; ++l) { | ||||
|             for (short l = 0; l < 4; ++l) { | ||||
|                 const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); | ||||
|                 const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); | ||||
|                 const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; | ||||
|                 for (int j = 0; j < 4; ++j) { | ||||
|                 for (short j = 0; j < 4; ++j) { | ||||
|                     sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); | ||||
|                     sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); | ||||
|                 } | ||||
| @@ -5435,8 +5428,7 @@ void kernel_mul_mv_iq3_s_f32_impl( | ||||
|     device const float * y4 = y + 32 * ix; | ||||
|  | ||||
|     for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||||
|  | ||||
|         for (int i = 0; i < 32; ++i) { | ||||
|         for (short i = 0; i < 32; ++i) { | ||||
|             yl[i] = y4[i]; | ||||
|         } | ||||
|  | ||||
| @@ -5450,18 +5442,17 @@ void kernel_mul_mv_iq3_s_f32_impl( | ||||
|         device const uint8_t * signs = xr->signs + 4 * ib; | ||||
|         device const half * dh = &xr->d; | ||||
|  | ||||
|         for (int row = 0; row < nr0; row++) { | ||||
|  | ||||
|         for (short row = 0; row < nr0; row++) { | ||||
|             const float db = dh[0]; | ||||
|             const float d = db * (1 + 2*((sc[0] >> 4*(ib%2)) & 0xf)); | ||||
|  | ||||
|             float2 sum = {0}; | ||||
|             for (int l = 0; l < 4; ++l) { | ||||
|             for (short l = 0; l < 4; ++l) { | ||||
|                 const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; | ||||
|                 const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; | ||||
|                 const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); | ||||
|                 const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); | ||||
|                 for (int j = 0; j < 4; ++j) { | ||||
|                 for (short j = 0; j < 4; ++j) { | ||||
|                     sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); | ||||
|                     sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); | ||||
|                 } | ||||
| @@ -5542,13 +5533,12 @@ void kernel_mul_mv_iq2_s_f32_impl( | ||||
|     //    threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|     //} | ||||
|  | ||||
|     const int ix = tiisg; | ||||
|     const short ix = tiisg; | ||||
|  | ||||
|     device const float * y4 = y + 32 * ix; | ||||
|  | ||||
|     for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||||
|  | ||||
|         for (int i = 0; i < 32; ++i) { | ||||
|         for (short i = 0; i < 32; ++i) { | ||||
|             yl[i] = y4[i]; | ||||
|         } | ||||
|  | ||||
| @@ -5562,19 +5552,18 @@ void kernel_mul_mv_iq2_s_f32_impl( | ||||
|         device const uint8_t * signs = qs + QK_K/8; | ||||
|         device const half * dh = &xr->d; | ||||
|  | ||||
|         for (int row = 0; row < nr0; row++) { | ||||
|  | ||||
|         for (short row = 0; row < nr0; row++) { | ||||
|             const float db = dh[0]; | ||||
|             const float d1 = db * (0.5f + (sc[0] & 0xf)); | ||||
|             const float d2 = db * (0.5f + (sc[0] >>  4)); | ||||
|  | ||||
|             float2 sum = {0}; | ||||
|             for (int l = 0; l < 2; ++l) { | ||||
|             for (short l = 0; l < 2; ++l) { | ||||
|                 //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); | ||||
|                 //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); | ||||
|                 constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); | ||||
|                 constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); | ||||
|                 for (int j = 0; j < 8; ++j) { | ||||
|                 for (short j = 0; j < 8; ++j) { | ||||
|                     sum[0] += yl[8*l + j +  0] * grid1[j] * select(1, -1, signs[l+0] & kmask_iq2xs[j]); | ||||
|                     sum[1] += yl[8*l + j + 16] * grid2[j] * select(1, -1, signs[l+2] & kmask_iq2xs[j]); | ||||
|                 } | ||||
| @@ -5647,14 +5636,13 @@ void kernel_mul_mv_iq1_s_f32_impl( | ||||
|  | ||||
|     const int nb32 = nb * (QK_K / 32); | ||||
|  | ||||
|     const int ix = tiisg; | ||||
|     const short ix = tiisg; | ||||
|  | ||||
|     device const float * y4 = y + 32 * ix; | ||||
|  | ||||
|     for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||||
|  | ||||
|         float sumy = 0; | ||||
|         for (int i = 0; i < 32; ++i) { | ||||
|         for (short i = 0; i < 32; ++i) { | ||||
|             yl[i] = y4[i]; | ||||
|             sumy += yl[i]; | ||||
|         } | ||||
| @@ -5667,15 +5655,14 @@ void kernel_mul_mv_iq1_s_f32_impl( | ||||
|         device const uint16_t * qh = xr->qh + ib; | ||||
|         device const half     * dh = &xr->d; | ||||
|  | ||||
|         for (int row = 0; row < nr0; row++) { | ||||
|  | ||||
|         for (short row = 0; row < nr0; row++) { | ||||
|             constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); | ||||
|             constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 5) & 0x700))); | ||||
|             constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[0] << 2) & 0x700))); | ||||
|             constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[0] >> 1) & 0x700))); | ||||
|  | ||||
|             float sum = 0; | ||||
|             for (int j = 0; j < 4; ++j) { | ||||
|             for (short j = 0; j < 4; ++j) { | ||||
|                 sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) | ||||
|                      + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) | ||||
|                      + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) | ||||
| @@ -5747,16 +5734,15 @@ void kernel_mul_mv_iq1_m_f32_impl( | ||||
|  | ||||
|     const int nb32 = nb * (QK_K / 32); | ||||
|  | ||||
|     const int ix = tiisg; | ||||
|     const short ix = tiisg; | ||||
|  | ||||
|     device const float * y4 = y + 32 * ix; | ||||
|  | ||||
|     iq1m_scale_t scale; | ||||
|  | ||||
|     for (int ib32 = ix; ib32 < nb32; ib32 += 32) { | ||||
|  | ||||
|         float4 sumy = {0.f}; | ||||
|         for (int i = 0; i < 8; ++i) { | ||||
|         for (short i = 0; i < 8; ++i) { | ||||
|             yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; | ||||
|             yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8]; | ||||
|             yl[i+16] = y4[i+16]; sumy[2] += yl[i+16]; | ||||
| @@ -5771,7 +5757,7 @@ void kernel_mul_mv_iq1_m_f32_impl( | ||||
|         device const uint8_t  * qh = xr->qh + 2 * ib; | ||||
|         device const uint16_t * sc = (device const uint16_t *)xr->scales; | ||||
|  | ||||
|         for (int row = 0; row < nr0; row++) { | ||||
|         for (short row = 0; row < nr0; row++) { | ||||
|             scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000); | ||||
|  | ||||
|             constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700))); | ||||
| @@ -5780,7 +5766,7 @@ void kernel_mul_mv_iq1_m_f32_impl( | ||||
|             constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700))); | ||||
|  | ||||
|             float2 sum = {0.f}; | ||||
|             for (int j = 0; j < 4; ++j) { | ||||
|             for (short j = 0; j < 4; ++j) { | ||||
|                 sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) | ||||
|                         + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4); | ||||
|                 sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) | ||||
| @@ -5852,8 +5838,8 @@ void kernel_mul_mv_iq4_nl_f32_impl( | ||||
|     device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); | ||||
|     device const float        * y = (device const float        *) (src1 + offset1); | ||||
|  | ||||
|     const int ix = tiisg/2;  // 0...15 | ||||
|     const int it = tiisg%2;  // 0 or 1 | ||||
|     const short ix = tiisg/2;  // 0...15 | ||||
|     const short it = tiisg%2;  // 0 or 1 | ||||
|  | ||||
|     shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
| @@ -5869,12 +5855,13 @@ void kernel_mul_mv_iq4_nl_f32_impl( | ||||
|     float4 qf1, qf2; | ||||
|  | ||||
|     for (int ib = ix; ib < nb; ib += 16) { | ||||
|  | ||||
|         device const float4 * y4 = (device const float4 *)yb; | ||||
|         yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; | ||||
|  | ||||
|         for (int row = 0; row < nr0 && first_row + row < args.ne01; ++row) { | ||||
|         yl[0] = y4[0]; | ||||
|         yl[1] = y4[4]; | ||||
|         yl[2] = y4[1]; | ||||
|         yl[3] = y4[5]; | ||||
|  | ||||
|         for (short row = 0; row < nr0; row++) { | ||||
|             device const block_iq4_nl & xb = x[row*nb + ib]; | ||||
|             device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); | ||||
|  | ||||
| @@ -5899,7 +5886,6 @@ void kernel_mul_mv_iq4_nl_f32_impl( | ||||
|             acc1 += acc2; | ||||
|  | ||||
|             sumf[row] += (float)xb.d * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); | ||||
|  | ||||
|         } | ||||
|  | ||||
|         yb += 16 * QK4_NL; | ||||
| @@ -5956,10 +5942,10 @@ void kernel_mul_mv_iq4_xs_f32_impl( | ||||
|     device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); | ||||
|     device const float        * y = (device const float        *) (src1 + offset1); | ||||
|  | ||||
|     const int ix = tiisg/16;  // 0 or 1 | ||||
|     const int it = tiisg%16;  // 0...15 | ||||
|     const int ib = it/2; | ||||
|     const int il = it%2; | ||||
|     const short ix = tiisg/16;  // 0 or 1 | ||||
|     const short it = tiisg%16;  // 0...15 | ||||
|     const short ib = it/2; | ||||
|     const short il = it%2; | ||||
|  | ||||
|     shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; | ||||
|     threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
| @@ -5976,9 +5962,12 @@ void kernel_mul_mv_iq4_xs_f32_impl( | ||||
|  | ||||
|     for (int ibl = ix; ibl < nb; ibl += 2) { | ||||
|         device const float4 * y4 = (device const float4 *)yb; | ||||
|         yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; | ||||
|         yl[0] = y4[0]; | ||||
|         yl[1] = y4[4]; | ||||
|         yl[2] = y4[1]; | ||||
|         yl[3] = y4[5]; | ||||
|  | ||||
|         for (int row = 0; row < nr0; ++row) { | ||||
|         for (short row = 0; row < nr0; ++row) { | ||||
|             device const block_iq4_xs & xb = x[row*nb + ibl]; | ||||
|             device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); | ||||
|  | ||||
| @@ -6002,7 +5991,6 @@ void kernel_mul_mv_iq4_xs_f32_impl( | ||||
|  | ||||
|             const int ls = (((xb.scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((xb.scales_h >> 2*ib) & 3) << 4)) - 32; | ||||
|             sumf[row] += (float)xb.d * ls * (acc1[0] + acc1[1] + acc1[2] + acc1[3]); | ||||
|  | ||||
|         } | ||||
|  | ||||
|         yb += 2 * QK_K; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov