mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	mat vec double buffer (#12188)
This commit is contained in:
		| @@ -5,23 +5,24 @@ | ||||
|  | ||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||
|  | ||||
| shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16]; | ||||
| shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16]; | ||||
| shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16]; | ||||
| shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16]; | ||||
|  | ||||
| FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; | ||||
| uint csel = 0; | ||||
|  | ||||
| void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { | ||||
|     const uint y_idx = i * QUANT_K + y_offset; | ||||
|  | ||||
|     [[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||||
|         const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; | ||||
|         csel ^= 1; | ||||
|  | ||||
|         barrier(); | ||||
|         if (!all_threads) { // when we don't have enough blocks to use all threads | ||||
|             if (i < num_blocks_per_row) { | ||||
|                 const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); | ||||
|                 sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); | ||||
|                 sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); | ||||
|                 sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF); | ||||
|                 sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); | ||||
|             } | ||||
|             barrier(); | ||||
|  | ||||
| @@ -29,8 +30,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, | ||||
|                 continue; | ||||
|         } else { | ||||
|             const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); | ||||
|             sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); | ||||
|             sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); | ||||
|             sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF); | ||||
|             sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); | ||||
|             barrier(); | ||||
|         } | ||||
|  | ||||
| @@ -57,22 +58,22 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, | ||||
|             FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); | ||||
|             FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); | ||||
|             [[unroll]] for (int l = 0; l < 2; ++l) { | ||||
|                 sum1 = fma(FLOAT_TYPE(b0[l]),   sccache1[ix][    8*v_im] * qs_u32_0[l  ], | ||||
|                        fma(FLOAT_TYPE(b16[l]),  sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2], | ||||
|                        fma(FLOAT_TYPE(b32[l]),  sccache1[ix][2 + 8*v_im] * qs_u32_2[l  ], | ||||
|                        fma(FLOAT_TYPE(b48[l]),  sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2], | ||||
|                        fma(FLOAT_TYPE(b64[l]),  sccache1[ix][4 + 8*v_im] * qs_u32_4[l  ], | ||||
|                        fma(FLOAT_TYPE(b80[l]),  sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2], | ||||
|                        fma(FLOAT_TYPE(b96[l]),  sccache1[ix][6 + 8*v_im] * qs_u32_6[l  ], | ||||
|                        fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); | ||||
|                 sum2 = fma(FLOAT_TYPE(b0[l]),   sccache2[ix][    8*v_im], | ||||
|                        fma(FLOAT_TYPE(b16[l]),  sccache2[ix][1 + 8*v_im], | ||||
|                        fma(FLOAT_TYPE(b32[l]),  sccache2[ix][2 + 8*v_im], | ||||
|                        fma(FLOAT_TYPE(b48[l]),  sccache2[ix][3 + 8*v_im], | ||||
|                        fma(FLOAT_TYPE(b64[l]),  sccache2[ix][4 + 8*v_im], | ||||
|                        fma(FLOAT_TYPE(b80[l]),  sccache2[ix][5 + 8*v_im], | ||||
|                        fma(FLOAT_TYPE(b96[l]),  sccache2[ix][6 + 8*v_im], | ||||
|                        fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2)))))))); | ||||
|                 sum1 = fma(FLOAT_TYPE(b0[l]),   sccache1[csel][ix][    8*v_im] * qs_u32_0[l  ], | ||||
|                        fma(FLOAT_TYPE(b16[l]),  sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2], | ||||
|                        fma(FLOAT_TYPE(b32[l]),  sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l  ], | ||||
|                        fma(FLOAT_TYPE(b48[l]),  sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2], | ||||
|                        fma(FLOAT_TYPE(b64[l]),  sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l  ], | ||||
|                        fma(FLOAT_TYPE(b80[l]),  sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2], | ||||
|                        fma(FLOAT_TYPE(b96[l]),  sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l  ], | ||||
|                        fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); | ||||
|                 sum2 = fma(FLOAT_TYPE(b0[l]),   sccache2[csel][ix][    8*v_im], | ||||
|                        fma(FLOAT_TYPE(b16[l]),  sccache2[csel][ix][1 + 8*v_im], | ||||
|                        fma(FLOAT_TYPE(b32[l]),  sccache2[csel][ix][2 + 8*v_im], | ||||
|                        fma(FLOAT_TYPE(b48[l]),  sccache2[csel][ix][3 + 8*v_im], | ||||
|                        fma(FLOAT_TYPE(b64[l]),  sccache2[csel][ix][4 + 8*v_im], | ||||
|                        fma(FLOAT_TYPE(b80[l]),  sccache2[csel][ix][5 + 8*v_im], | ||||
|                        fma(FLOAT_TYPE(b96[l]),  sccache2[csel][ix][6 + 8*v_im], | ||||
|                        fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2)))))))); | ||||
|             } | ||||
|             temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); | ||||
|         } | ||||
|   | ||||
| @@ -5,20 +5,21 @@ | ||||
|  | ||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||
|  | ||||
| shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8]; | ||||
| shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8]; | ||||
|  | ||||
| FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; | ||||
| uint csel = 0; | ||||
|  | ||||
| void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { | ||||
|     const uint y_idx = i * QUANT_K + y_offset; | ||||
|  | ||||
|     [[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||||
|         const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; | ||||
|         csel ^= 1; | ||||
|  | ||||
|         if (!all_threads) { // when we don't have enough blocks to use all threads | ||||
|             barrier(); | ||||
|             if (i < num_blocks_per_row) | ||||
|                 sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); | ||||
|                 sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); | ||||
|             barrier(); | ||||
|  | ||||
|             if (i >= num_blocks_per_row) | ||||
| @@ -40,8 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co | ||||
|         const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); | ||||
|  | ||||
|         if (all_threads) { | ||||
|             barrier(); | ||||
|             sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); | ||||
|             sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); | ||||
|             barrier(); | ||||
|         } | ||||
|  | ||||
| @@ -59,14 +59,14 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co | ||||
|  | ||||
|             FLOAT_TYPE sum = FLOAT_TYPE(0.0); | ||||
|             [[unroll]] for (int l = 0; l < 2; ++l) { | ||||
|                 sum = fma(FLOAT_TYPE(  b0[l]) * sccache[ix][v_im][0], qs_u32_0[l  ] - hmk_0[l  ], | ||||
|                       fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], | ||||
|                       fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l  ] - hmk_1[l  ], | ||||
|                       fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], | ||||
|                       fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l  ] - hmk_2[l  ], | ||||
|                       fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], | ||||
|                       fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l  ] - hmk_3[l  ], | ||||
|                       fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); | ||||
|                 sum = fma(FLOAT_TYPE(  b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l  ] - hmk_0[l  ], | ||||
|                       fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], | ||||
|                       fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l  ] - hmk_1[l  ], | ||||
|                       fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], | ||||
|                       fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l  ] - hmk_2[l  ], | ||||
|                       fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], | ||||
|                       fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l  ] - hmk_3[l  ], | ||||
|                       fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); | ||||
|             } | ||||
|             temp[j][n] = fma(d, sum, temp[j][n]); | ||||
|         } | ||||
|   | ||||
| @@ -6,20 +6,21 @@ | ||||
|  | ||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||
|  | ||||
| shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16]; | ||||
| shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16]; | ||||
|  | ||||
| FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; | ||||
| uint csel = 0; | ||||
|  | ||||
| void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { | ||||
|     const uint y_idx = i * QUANT_K + y_offset; | ||||
|  | ||||
|     [[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||||
|         const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; | ||||
|         csel ^= 1; | ||||
|  | ||||
|         if (!all_threads) { // when we don't have enough blocks to use all threads | ||||
|             barrier(); | ||||
|             if (i < num_blocks_per_row) | ||||
|                 sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); | ||||
|                 sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); | ||||
|             barrier(); | ||||
|  | ||||
|             if (i >= num_blocks_per_row) | ||||
| @@ -51,8 +52,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, | ||||
|         const vec4 q3 = vec4(unpack8(q3_u32)) - 32; | ||||
|  | ||||
|         if (all_threads) { | ||||
|             barrier(); | ||||
|             sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); | ||||
|             sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); | ||||
|             barrier(); | ||||
|         } | ||||
|  | ||||
| @@ -71,7 +71,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, | ||||
|                 sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]); | ||||
|                 sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]); | ||||
|             } | ||||
|             temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]); | ||||
|             temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]); | ||||
|         } | ||||
|     } | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Eve
					Eve