mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	vulkan: scale caching for k quants + misc fixes (#11081)
* q6_k scale caching * 16 bit unpack * q4_k test (slow) * revert it * q3_k * q2_k * little stuff * try precalculating products of a and q2_k scales * Revert "try precalculating products of a and q2_k scales" This reverts commit 65110b81f23f66331a50c6e889a7c1ab9470a86b. * unpack should be u16, add vim swap to gitignore (about time) * better q4_k scales * q5_k * better q6_k with separate paths for all threads and partial threads in use, plus some more optimizations * q2_k better dequant * q3_k optimizations * q3_k use hmask simd from cpu avx version * make the caches happy * q3_k separate out calculation * q2_k separate out * little stuff * use calc_superblock everywhere * q2_k optimize scale calculation * more barriers
This commit is contained in:
		
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @@ -18,6 +18,7 @@ | |||||||
| *.metallib | *.metallib | ||||||
| *.o | *.o | ||||||
| *.so | *.so | ||||||
|  | *.swp | ||||||
| *.tmp | *.tmp | ||||||
|  |  | ||||||
| # IDE / OS | # IDE / OS | ||||||
|   | |||||||
| @@ -5,6 +5,80 @@ | |||||||
|  |  | ||||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||||
|  |  | ||||||
|  | shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16]; | ||||||
|  | shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16]; | ||||||
|  |  | ||||||
|  | FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; | ||||||
|  |  | ||||||
|  | void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { | ||||||
|  |     const uint y_idx = i * QUANT_K + y_offset; | ||||||
|  |  | ||||||
|  |     [[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||||||
|  |         const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; | ||||||
|  |  | ||||||
|  |         barrier(); | ||||||
|  |         if (!all_threads) { // when we don't have enough blocks to use all threads | ||||||
|  |             if (i < num_blocks_per_row) { | ||||||
|  |                 const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); | ||||||
|  |                 sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); | ||||||
|  |                 sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); | ||||||
|  |             } | ||||||
|  |             barrier(); | ||||||
|  |  | ||||||
|  |             if (i >= num_blocks_per_row) | ||||||
|  |                 continue; | ||||||
|  |         } else { | ||||||
|  |             const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); | ||||||
|  |             sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF); | ||||||
|  |             sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); | ||||||
|  |             barrier(); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); | ||||||
|  |         const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); | ||||||
|  |         const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); | ||||||
|  |         const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); | ||||||
|  |         const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); | ||||||
|  |  | ||||||
|  |         vec2 d = vec2(data_a[ib0 + i].d); | ||||||
|  |         const FLOAT_TYPE dall = FLOAT_TYPE(d.x); | ||||||
|  |         const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); | ||||||
|  |  | ||||||
|  |         [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||||||
|  |             vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]); | ||||||
|  |             vec2 b16 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  8]); | ||||||
|  |             vec2 b32 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); | ||||||
|  |             vec2 b48 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); | ||||||
|  |             vec2 b64 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); | ||||||
|  |             vec2 b80 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); | ||||||
|  |             vec2 b96 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); | ||||||
|  |             vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); | ||||||
|  |  | ||||||
|  |             FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); | ||||||
|  |             FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); | ||||||
|  |             [[unroll]] for (int l = 0; l < 2; ++l) { | ||||||
|  |                 sum1 = fma(FLOAT_TYPE(b0[l]),   sccache1[ix][    8*v_im] * qs_u32_0[l  ], | ||||||
|  |                        fma(FLOAT_TYPE(b16[l]),  sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2], | ||||||
|  |                        fma(FLOAT_TYPE(b32[l]),  sccache1[ix][2 + 8*v_im] * qs_u32_2[l  ], | ||||||
|  |                        fma(FLOAT_TYPE(b48[l]),  sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2], | ||||||
|  |                        fma(FLOAT_TYPE(b64[l]),  sccache1[ix][4 + 8*v_im] * qs_u32_4[l  ], | ||||||
|  |                        fma(FLOAT_TYPE(b80[l]),  sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2], | ||||||
|  |                        fma(FLOAT_TYPE(b96[l]),  sccache1[ix][6 + 8*v_im] * qs_u32_6[l  ], | ||||||
|  |                        fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); | ||||||
|  |                 sum2 = fma(FLOAT_TYPE(b0[l]),   sccache2[ix][    8*v_im], | ||||||
|  |                        fma(FLOAT_TYPE(b16[l]),  sccache2[ix][1 + 8*v_im], | ||||||
|  |                        fma(FLOAT_TYPE(b32[l]),  sccache2[ix][2 + 8*v_im], | ||||||
|  |                        fma(FLOAT_TYPE(b48[l]),  sccache2[ix][3 + 8*v_im], | ||||||
|  |                        fma(FLOAT_TYPE(b64[l]),  sccache2[ix][4 + 8*v_im], | ||||||
|  |                        fma(FLOAT_TYPE(b80[l]),  sccache2[ix][5 + 8*v_im], | ||||||
|  |                        fma(FLOAT_TYPE(b96[l]),  sccache2[ix][6 + 8*v_im], | ||||||
|  |                        fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2)))))))); | ||||||
|  |             } | ||||||
|  |             temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | ||||||
|     uint a_offset, b_offset, d_offset; |     uint a_offset, b_offset, d_offset; | ||||||
|     get_offsets(a_offset, b_offset, d_offset); |     get_offsets(a_offset, b_offset, d_offset); | ||||||
| @@ -14,88 +88,28 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | |||||||
|     // 16 threads are used to process each block |     // 16 threads are used to process each block | ||||||
|     const uint it_size = gl_WorkGroupSize.x/16; |     const uint it_size = gl_WorkGroupSize.x/16; | ||||||
|     const uint tid = gl_LocalInvocationID.x; |     const uint tid = gl_LocalInvocationID.x; | ||||||
|     const uint itid = tid%16;  // 0...16 |     const uint itid = tid%16;  // 0...15 | ||||||
|     const uint ix  = tid/16; |     const uint ix = tid/16; | ||||||
|  |  | ||||||
|     const uint step = 8; |     const uint v_im = itid/8;                                // 0 or 1. 0 computes 0..., 1 computes 128... | ||||||
|  |     const uint v_in = itid - 8*v_im;                         // 0...7 | ||||||
|     const uint v_im = itid/step;                             // 0 or 1. 0 computes 0..., 1 computes 128... |  | ||||||
|     const uint v_in = itid - step*v_im;                      // 0...15 or 0...7 |  | ||||||
|  |  | ||||||
|     const uint l0 = 2*v_in;                                  // 0...15 |     const uint l0 = 2*v_in;                                  // 0...15 | ||||||
|     const uint q_offset = 32*v_im + l0; |     const uint q_offset = 32*v_im + l0; | ||||||
|     const uint s_offset = 8*v_im; |  | ||||||
|     const uint y_offset = 128*v_im + l0; |     const uint y_offset = 128*v_im + l0; | ||||||
|  |  | ||||||
|     FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; |  | ||||||
|  |  | ||||||
|     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { |     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||||||
|         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { |         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { | ||||||
|             temp[j][i] = FLOAT_TYPE(0); |             temp[j][i] = FLOAT_TYPE(0); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { |     const uint nbr_par_th = num_blocks_per_row%it_size; | ||||||
|         const uint y_idx = i * QUANT_K + y_offset; |     const uint nbr_all_th = num_blocks_per_row - nbr_par_th; | ||||||
|  |     uint i0 = 0; | ||||||
|         [[unroll]] for (uint n = 0; n < num_rows; ++n) { |     [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) | ||||||
|             const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; |         calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); | ||||||
|             vec2 d = vec2(data_a[ib0 + i].d); |     calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); | ||||||
|             const FLOAT_TYPE dall = FLOAT_TYPE(d.x); |  | ||||||
|             const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); |  | ||||||
|  |  | ||||||
|             uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0]; |  | ||||||
|             uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1]; |  | ||||||
|  |  | ||||||
|             uint32_t s0_lo4_u32 = s0_u32 & 0x0F0F0F0F; |  | ||||||
|             uint32_t s0_hi4_u32 = (s0_u32 >> 4) & 0x0F0F0F0F; |  | ||||||
|             uint32_t s4_lo4_u32 = s4_u32 & 0x0F0F0F0F; |  | ||||||
|             uint32_t s4_hi4_u32 = (s4_u32 >> 4) & 0x0F0F0F0F; |  | ||||||
|  |  | ||||||
|             uvec4 s0_lo4 = uvec4(unpack8(s0_lo4_u32)); |  | ||||||
|             uvec4 s4_lo4 = uvec4(unpack8(s4_lo4_u32)); |  | ||||||
|             uvec4 s0_hi4 = uvec4(unpack8(s0_hi4_u32)); |  | ||||||
|             uvec4 s4_hi4 = uvec4(unpack8(s4_hi4_u32)); |  | ||||||
|  |  | ||||||
|             uint16_t qs0_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 0]; |  | ||||||
|             uint16_t qs16_u16 = data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]; |  | ||||||
|             uvec2 qs0 =  uvec2(unpack8(qs0_u16)); |  | ||||||
|             uvec2 qs16 = uvec2(unpack8(qs16_u16)); |  | ||||||
|  |  | ||||||
|             [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { |  | ||||||
|                 vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]); |  | ||||||
|                 vec2 b16 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  8]); |  | ||||||
|                 vec2 b32 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); |  | ||||||
|                 vec2 b48 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); |  | ||||||
|                 vec2 b64 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); |  | ||||||
|                 vec2 b80 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); |  | ||||||
|                 vec2 b96 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); |  | ||||||
|                 vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); |  | ||||||
|  |  | ||||||
|                 FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); |  | ||||||
|                 FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); |  | ||||||
|                 [[unroll]] for (int l = 0; l < 2; ++l) { |  | ||||||
|                     sum1 = fma(FLOAT_TYPE(b0[l]),   FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l]  >> 0) & 3), |  | ||||||
|                            fma(FLOAT_TYPE(b16[l]),  FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3), |  | ||||||
|                            fma(FLOAT_TYPE(b32[l]),  FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l]  >> 2) & 3), |  | ||||||
|                            fma(FLOAT_TYPE(b48[l]),  FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3), |  | ||||||
|                            fma(FLOAT_TYPE(b64[l]),  FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l]  >> 4) & 3), |  | ||||||
|                            fma(FLOAT_TYPE(b80[l]),  FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3), |  | ||||||
|                            fma(FLOAT_TYPE(b96[l]),  FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l]  >> 6) & 3), |  | ||||||
|                            fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1)))))))); |  | ||||||
|                     sum2 = fma(FLOAT_TYPE(b0[l]),   FLOAT_TYPE(s0_hi4[0]), |  | ||||||
|                            fma(FLOAT_TYPE(b16[l]),  FLOAT_TYPE(s0_hi4[1]), |  | ||||||
|                            fma(FLOAT_TYPE(b32[l]),  FLOAT_TYPE(s0_hi4[2]), |  | ||||||
|                            fma(FLOAT_TYPE(b48[l]),  FLOAT_TYPE(s0_hi4[3]), |  | ||||||
|                            fma(FLOAT_TYPE(b64[l]),  FLOAT_TYPE(s4_hi4[0]), |  | ||||||
|                            fma(FLOAT_TYPE(b80[l]),  FLOAT_TYPE(s4_hi4[1]), |  | ||||||
|                            fma(FLOAT_TYPE(b96[l]),  FLOAT_TYPE(s4_hi4[2]), |  | ||||||
|                            fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2)))))))); |  | ||||||
|                 } |  | ||||||
|                 temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     reduce_result(temp, d_offset, first_row, num_rows, tid); |     reduce_result(temp, d_offset, first_row, num_rows, tid); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -5,6 +5,74 @@ | |||||||
|  |  | ||||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||||
|  |  | ||||||
|  | shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8]; | ||||||
|  |  | ||||||
|  | FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; | ||||||
|  |  | ||||||
|  | void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { | ||||||
|  |     const uint y_idx = i * QUANT_K + y_offset; | ||||||
|  |  | ||||||
|  |     [[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||||||
|  |         const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; | ||||||
|  |  | ||||||
|  |         if (!all_threads) { // when we don't have enough blocks to use all threads | ||||||
|  |             barrier(); | ||||||
|  |             if (i < num_blocks_per_row) | ||||||
|  |                 sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); | ||||||
|  |             barrier(); | ||||||
|  |  | ||||||
|  |             if (i >= num_blocks_per_row) | ||||||
|  |                 continue; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); | ||||||
|  |         const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> (    v_im4)) << 2)); | ||||||
|  |         const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); | ||||||
|  |         const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); | ||||||
|  |         const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); | ||||||
|  |  | ||||||
|  |         // 0, 1, 16, 17 | ||||||
|  |         uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); | ||||||
|  |         qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16; | ||||||
|  |         const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); | ||||||
|  |         const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); | ||||||
|  |         const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); | ||||||
|  |         const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); | ||||||
|  |  | ||||||
|  |         if (all_threads) { | ||||||
|  |             barrier(); | ||||||
|  |             sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); | ||||||
|  |             barrier(); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); | ||||||
|  |  | ||||||
|  |         [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||||||
|  |             vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]); | ||||||
|  |             vec2 b16 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  8]); | ||||||
|  |             vec2 b32 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); | ||||||
|  |             vec2 b48 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); | ||||||
|  |             vec2 b64 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); | ||||||
|  |             vec2 b80 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); | ||||||
|  |             vec2 b96 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); | ||||||
|  |             vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); | ||||||
|  |  | ||||||
|  |             FLOAT_TYPE sum = FLOAT_TYPE(0.0); | ||||||
|  |             [[unroll]] for (int l = 0; l < 2; ++l) { | ||||||
|  |                 sum = fma(FLOAT_TYPE(  b0[l]) * sccache[ix][v_im][0], qs_u32_0[l  ] - hmk_0[l  ], | ||||||
|  |                       fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], | ||||||
|  |                       fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l  ] - hmk_1[l  ], | ||||||
|  |                       fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], | ||||||
|  |                       fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l  ] - hmk_2[l  ], | ||||||
|  |                       fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], | ||||||
|  |                       fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l  ] - hmk_3[l  ], | ||||||
|  |                       fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); | ||||||
|  |             } | ||||||
|  |             temp[j][n] = fma(d, sum, temp[j][n]); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | ||||||
|     uint a_offset, b_offset, d_offset; |     uint a_offset, b_offset, d_offset; | ||||||
|     get_offsets(a_offset, b_offset, d_offset); |     get_offsets(a_offset, b_offset, d_offset); | ||||||
| @@ -14,76 +82,37 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | |||||||
|     // 16 threads are used to process each block |     // 16 threads are used to process each block | ||||||
|     const uint it_size = gl_WorkGroupSize.x/16; |     const uint it_size = gl_WorkGroupSize.x/16; | ||||||
|     const uint tid = gl_LocalInvocationID.x; |     const uint tid = gl_LocalInvocationID.x; | ||||||
|     const uint itid = tid%16;  // 0...16 |     const uint itid = tid%16;  // 0...15 | ||||||
|     const uint ix  = tid/16; |     const uint ix = tid/16; | ||||||
|  |     const uint itid8 = itid%8; | ||||||
|  |  | ||||||
|     const uint step = 8; |     const uint v_im = itid/8;                               // 0 or 1. 0 computes 0..., 1 computes 128... | ||||||
|  |     const uint v_im4 = v_im*4; | ||||||
|  |     const uint v_in = itid - 8*v_im;                        // 0...7 | ||||||
|  |  | ||||||
|     const uint v_im = itid/step;                            // 0 or 1. 0 computes 0..., 1 computes 128... |     const uint32_t m = 0x01010101 << (4 * v_im); | ||||||
|     const uint v_in = itid - step*v_im;                     // 0...15 or 0...7 |     uint32_t hm_m[4]; | ||||||
|  |     [[unroll]] for (uint j = 0; j < 4; ++j) | ||||||
|     const uint8_t m = uint8_t(1 << (4 * v_im)); |         hm_m[j] = m << j; | ||||||
|  |  | ||||||
|     const uint l0 = 2*v_in;                                 // 0...15 |     const uint l0 = 2*v_in;                                 // 0...15 | ||||||
|     const uint q_offset = 32*v_im + l0; |     const uint q_offset = 32*v_im + l0; | ||||||
|     const uint y_offset = 128*v_im + l0; |     const uint y_offset = 128*v_im + l0; | ||||||
|  |  | ||||||
|     FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; |  | ||||||
|  |  | ||||||
|     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { |     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||||||
|         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { |         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { | ||||||
|             temp[j][i] = FLOAT_TYPE(0); |             temp[j][i] = FLOAT_TYPE(0); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const uint s_shift = 4 * v_im; |     const uint s_shift = v_im4 + 2*(itid8/4); | ||||||
|  |  | ||||||
|     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { |     const uint nbr_par_th = num_blocks_per_row%it_size; | ||||||
|         const uint y_idx = i * QUANT_K + y_offset; |     const uint nbr_all_th = num_blocks_per_row - nbr_par_th; | ||||||
|  |     uint i0 = 0; | ||||||
|         [[unroll]] for (uint n = 0; n < num_rows; ++n) { |     [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) | ||||||
|             const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; |         calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true); | ||||||
|             const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); |     calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false); | ||||||
|  |  | ||||||
|             uint16_t s0_16 = data_a_packed16[ib0 + i].scales[0]; |  | ||||||
|             uint16_t s2_16 = data_a_packed16[ib0 + i].scales[1]; |  | ||||||
|             uint16_t s4_16 = data_a_packed16[ib0 + i].scales[2]; |  | ||||||
|             uint16_t s6_16 = data_a_packed16[ib0 + i].scales[3]; |  | ||||||
|             uint16_t s8_16 = data_a_packed16[ib0 + i].scales[4]; |  | ||||||
|             uint16_t s10_16 = data_a_packed16[ib0 + i].scales[5]; |  | ||||||
|             u8vec2 s0 = unpack8(s0_16); |  | ||||||
|             u8vec2 s2 = unpack8(s2_16); |  | ||||||
|             u8vec2 s4 = unpack8(s4_16); |  | ||||||
|             u8vec2 s6 = unpack8(s6_16); |  | ||||||
|             u8vec2 s8 = unpack8(s8_16); |  | ||||||
|             u8vec2 s10 = unpack8(s10_16); |  | ||||||
|  |  | ||||||
|             [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { |  | ||||||
|  |  | ||||||
|                 vec2 b0 =   vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  0]); |  | ||||||
|                 vec2 b16 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 +  8]); |  | ||||||
|                 vec2 b32 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); |  | ||||||
|                 vec2 b48 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); |  | ||||||
|                 vec2 b64 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); |  | ||||||
|                 vec2 b80 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); |  | ||||||
|                 vec2 b96 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); |  | ||||||
|                 vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); |  | ||||||
|  |  | ||||||
|                 FLOAT_TYPE sum = FLOAT_TYPE(0.0); |  | ||||||
|                 [[unroll]] for (int l = 0; l < 2; ++l) { |  | ||||||
|                     sum = fma(FLOAT_TYPE(b0[l])   * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 0)) != 0) ? 0 : 4)), |  | ||||||
|                           fma(FLOAT_TYPE(b32[l])  * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 1)) != 0) ? 0 : 4)), |  | ||||||
|                           fma(FLOAT_TYPE(b64[l])  * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0]  >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 2)) != 0) ? 0 : 4)), |  | ||||||
|                           fma(FLOAT_TYPE(b96[l])  * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l   ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l   ] & (m << 3)) != 0) ? 0 : 4)), |  | ||||||
|                           fma(FLOAT_TYPE(b16[l])  * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1]  >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16]     ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)), |  | ||||||
|                           fma(FLOAT_TYPE(b48[l])  * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)), |  | ||||||
|                           fma(FLOAT_TYPE(b80[l])  * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1]  >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)), |  | ||||||
|                           fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum)))))))); |  | ||||||
|                 } |  | ||||||
|                 temp[j][n] = fma(d, sum, temp[j][n]); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     reduce_result(temp, d_offset, first_row, num_rows, tid); |     reduce_result(temp, d_offset, first_row, num_rows, tid); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -6,6 +6,86 @@ | |||||||
|  |  | ||||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||||
|  |  | ||||||
|  | FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; | ||||||
|  |  | ||||||
|  | void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { | ||||||
|  |     const uint y1_idx = i * QUANT_K + y_offset; | ||||||
|  |     const uint y2_idx = y1_idx + 128; | ||||||
|  |  | ||||||
|  |     [[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||||||
|  |         const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; | ||||||
|  |         vec2 d = vec2(data_a[ib0 + i].d); | ||||||
|  |         const FLOAT_TYPE dall = FLOAT_TYPE(d.x); | ||||||
|  |         const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); | ||||||
|  |  | ||||||
|  |         const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ]; | ||||||
|  |         const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; | ||||||
|  |         const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; | ||||||
|  |  | ||||||
|  |         const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; | ||||||
|  |         const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; | ||||||
|  |         const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); | ||||||
|  |         const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); | ||||||
|  |  | ||||||
|  |         const FLOAT_TYPE sc0 = scale_0_4_l_f.x; | ||||||
|  |         const FLOAT_TYPE sc1 = scale_0_4_l_f.y; | ||||||
|  |         const FLOAT_TYPE sc2 = scale_0_4_l_f.z; | ||||||
|  |         const FLOAT_TYPE sc3 = scale_0_4_l_f.w; | ||||||
|  |         const FLOAT_TYPE sc4 = scale8_f.x; | ||||||
|  |         const FLOAT_TYPE sc5 = scale8_f.y; | ||||||
|  |         const FLOAT_TYPE sc6 = scale8_f.z; | ||||||
|  |         const FLOAT_TYPE sc7 = scale8_f.w; | ||||||
|  |  | ||||||
|  |         const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; | ||||||
|  |         const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; | ||||||
|  |  | ||||||
|  |         const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; | ||||||
|  |         const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; | ||||||
|  |         const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; | ||||||
|  |         const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; | ||||||
|  |  | ||||||
|  |         const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4)); | ||||||
|  |         const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4)); | ||||||
|  |         const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4)); | ||||||
|  |         const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4)); | ||||||
|  |  | ||||||
|  |         const FLOAT_TYPE q4_0  = qs0_lo4.x; | ||||||
|  |         const FLOAT_TYPE q4_1  = qs0_lo4.y; | ||||||
|  |         const FLOAT_TYPE q4_2  = qs0_lo4.z; | ||||||
|  |         const FLOAT_TYPE q4_3  = qs0_lo4.w; | ||||||
|  |         const FLOAT_TYPE q4_4  = qs0_hi4.x; | ||||||
|  |         const FLOAT_TYPE q4_5  = qs0_hi4.y; | ||||||
|  |         const FLOAT_TYPE q4_6  = qs0_hi4.z; | ||||||
|  |         const FLOAT_TYPE q4_7  = qs0_hi4.w; | ||||||
|  |         const FLOAT_TYPE q4_8  = qs64_lo4.x; | ||||||
|  |         const FLOAT_TYPE q4_9  = qs64_lo4.y; | ||||||
|  |         const FLOAT_TYPE q4_10 = qs64_lo4.z; | ||||||
|  |         const FLOAT_TYPE q4_11 = qs64_lo4.w; | ||||||
|  |         const FLOAT_TYPE q4_12 = qs64_hi4.x; | ||||||
|  |         const FLOAT_TYPE q4_13 = qs64_hi4.y; | ||||||
|  |         const FLOAT_TYPE q4_14 = qs64_hi4.z; | ||||||
|  |         const FLOAT_TYPE q4_15 = qs64_hi4.w; | ||||||
|  |  | ||||||
|  |         [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||||||
|  |             vec4 by10 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4    ]); | ||||||
|  |             vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]); | ||||||
|  |             vec4 by20 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4    ]); | ||||||
|  |             vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]); | ||||||
|  |  | ||||||
|  |             const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x),      q4_0,  fma(FLOAT_TYPE(by10.y),  q4_1,  fma(FLOAT_TYPE(by10.z),  q4_2,  FLOAT_TYPE(by10.w) *  q4_3))); | ||||||
|  |             const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x),     q4_4,  fma(FLOAT_TYPE(by132.y), q4_5,  fma(FLOAT_TYPE(by132.z), q4_6,  FLOAT_TYPE(by132.w) * q4_7))); | ||||||
|  |             const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x),      q4_8,  fma(FLOAT_TYPE(by20.y),  q4_9,  fma(FLOAT_TYPE(by20.z),  q4_10, FLOAT_TYPE(by20.w) *  q4_11))); | ||||||
|  |             const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x),     q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); | ||||||
|  |             const FLOAT_TYPE smin = | ||||||
|  |                 fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, | ||||||
|  |                 fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, | ||||||
|  |                 fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, | ||||||
|  |                 fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6,     FLOAT_TYPE(by232.w) * sc7))))))))))))))); | ||||||
|  |             temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | ||||||
|     uint a_offset, b_offset, d_offset; |     uint a_offset, b_offset, d_offset; | ||||||
|     get_offsets(a_offset, b_offset, d_offset); |     get_offsets(a_offset, b_offset, d_offset); | ||||||
| @@ -15,13 +95,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | |||||||
|     // 16 threads are used to process each block |     // 16 threads are used to process each block | ||||||
|     const uint it_size = gl_WorkGroupSize.x/16; |     const uint it_size = gl_WorkGroupSize.x/16; | ||||||
|     const uint tid = gl_LocalInvocationID.x; |     const uint tid = gl_LocalInvocationID.x; | ||||||
|     const uint itid = tid%16;  // 0...16 |     const uint itid = tid%16;  // 0...15 | ||||||
|     const uint ix  = tid/16; |     const uint ix = tid/16; | ||||||
|  |  | ||||||
|     const uint step = 4; |     const uint il = itid/4;                         // 0...3 | ||||||
|  |     const uint ir = itid - 4*il;                    // 0...3 | ||||||
|     const uint il = itid/step;                      // 0...3 |  | ||||||
|     const uint ir = itid - step*il;                 // 0...7 or 0...3 |  | ||||||
|     const uint n =  4; |     const uint n =  4; | ||||||
|  |  | ||||||
|     const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 |     const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 | ||||||
| @@ -31,89 +109,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | |||||||
|     const uint q_offset = 32*v_im + l0; |     const uint q_offset = 32*v_im + l0; | ||||||
|     const uint y_offset = 64*v_im + l0; |     const uint y_offset = 64*v_im + l0; | ||||||
|  |  | ||||||
|     FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; |  | ||||||
|  |  | ||||||
|     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { |     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||||||
|         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { |         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { | ||||||
|             temp[j][i] = FLOAT_TYPE(0); |             temp[j][i] = FLOAT_TYPE(0); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { |     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) | ||||||
|         const uint y1_idx = i * QUANT_K + y_offset; |         calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); | ||||||
|         const uint y2_idx = y1_idx + 128; |  | ||||||
|  |  | ||||||
|         [[unroll]] for (uint n = 0; n < num_rows; ++n) { |  | ||||||
|             const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; |  | ||||||
|             vec2 d = vec2(data_a[ib0 + i].d); |  | ||||||
|             const FLOAT_TYPE dall = FLOAT_TYPE(d.x); |  | ||||||
|             const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); |  | ||||||
|  |  | ||||||
|             uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ]; |  | ||||||
|             uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; |  | ||||||
|             uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; |  | ||||||
|             uvec4 scale0 = uvec4(unpack8(scale0_u32)); |  | ||||||
|             uvec4 scale4 = uvec4(unpack8(scale4_u32)); |  | ||||||
|             uvec4 scale8 = uvec4(unpack8(scale8_u32)); |  | ||||||
|  |  | ||||||
|             const uint32_t sc0 = (  scale0.x       & 0x3f); |  | ||||||
|             const uint32_t sc1 = (  scale0.y       & 0x3f); |  | ||||||
|             const uint32_t sc2 = (  scale4.x       & 0x3f); |  | ||||||
|             const uint32_t sc3 = (  scale4.y       & 0x3f); |  | ||||||
|             const uint32_t sc4 = (( scale8.x       & 0x0f) | ((scale0.x & 0xc0) >> 2)); |  | ||||||
|             const uint32_t sc5 = (( scale8.y       & 0x0f) | ((scale0.y & 0xc0) >> 2)); |  | ||||||
|             const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); |  | ||||||
|             const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); |  | ||||||
|  |  | ||||||
|             uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; |  | ||||||
|             uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; |  | ||||||
|  |  | ||||||
|             uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; |  | ||||||
|             uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; |  | ||||||
|             uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; |  | ||||||
|             uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; |  | ||||||
|  |  | ||||||
|             uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4)); |  | ||||||
|             uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4)); |  | ||||||
|             uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4)); |  | ||||||
|             uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4)); |  | ||||||
|  |  | ||||||
|             const uint32_t q4_0  = qs0_lo4.x; |  | ||||||
|             const uint32_t q4_1  = qs0_lo4.y; |  | ||||||
|             const uint32_t q4_2  = qs0_lo4.z; |  | ||||||
|             const uint32_t q4_3  = qs0_lo4.w; |  | ||||||
|             const uint32_t q4_4  = qs0_hi4.x; |  | ||||||
|             const uint32_t q4_5  = qs0_hi4.y; |  | ||||||
|             const uint32_t q4_6  = qs0_hi4.z; |  | ||||||
|             const uint32_t q4_7  = qs0_hi4.w; |  | ||||||
|             const uint32_t q4_8  = qs64_lo4.x; |  | ||||||
|             const uint32_t q4_9  = qs64_lo4.y; |  | ||||||
|             const uint32_t q4_10 = qs64_lo4.z; |  | ||||||
|             const uint32_t q4_11 = qs64_lo4.w; |  | ||||||
|             const uint32_t q4_12 = qs64_hi4.x; |  | ||||||
|             const uint32_t q4_13 = qs64_hi4.y; |  | ||||||
|             const uint32_t q4_14 = qs64_hi4.z; |  | ||||||
|             const uint32_t q4_15 = qs64_hi4.w; |  | ||||||
|  |  | ||||||
|             [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { |  | ||||||
|                 vec4 by10 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4    ]); |  | ||||||
|                 vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]); |  | ||||||
|                 vec4 by20 =  vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4    ]); |  | ||||||
|                 vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]); |  | ||||||
|  |  | ||||||
|                 const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x),      q4_0,  fma(FLOAT_TYPE(by10.y),  q4_1,  fma(FLOAT_TYPE(by10.z),  q4_2,  FLOAT_TYPE(by10.w) *  q4_3))); |  | ||||||
|                 const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x),     q4_4,  fma(FLOAT_TYPE(by132.y), q4_5,  fma(FLOAT_TYPE(by132.z), q4_6,  FLOAT_TYPE(by132.w) * q4_7))); |  | ||||||
|                 const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x),      q4_8,  fma(FLOAT_TYPE(by20.y),  q4_9,  fma(FLOAT_TYPE(by20.z),  q4_10, FLOAT_TYPE(by20.w) *  q4_11))); |  | ||||||
|                 const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x),     q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); |  | ||||||
|                 const FLOAT_TYPE smin = |  | ||||||
|                     fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, |  | ||||||
|                     fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, |  | ||||||
|                     fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, |  | ||||||
|                     fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6,     FLOAT_TYPE(by232.w) * sc7))))))))))))))); |  | ||||||
|                 temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     reduce_result(temp, d_offset, first_row, num_rows, tid); |     reduce_result(temp, d_offset, first_row, num_rows, tid); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -6,6 +6,118 @@ | |||||||
|  |  | ||||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||||
|  |  | ||||||
|  | FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; | ||||||
|  |  | ||||||
|  | void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { | ||||||
|  |     const uint y1_idx = i * QUANT_K + y_offset; | ||||||
|  |     const uint y2_idx = y1_idx + 128; | ||||||
|  |  | ||||||
|  |     [[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||||||
|  |         const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; | ||||||
|  |         vec2 d = vec2(data_a[ib0 + i].d); | ||||||
|  |         const FLOAT_TYPE dall = FLOAT_TYPE(d.x); | ||||||
|  |         const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); | ||||||
|  |  | ||||||
|  |         const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ]; | ||||||
|  |         const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; | ||||||
|  |         const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; | ||||||
|  |  | ||||||
|  |         const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; | ||||||
|  |         const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; | ||||||
|  |         const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); | ||||||
|  |         const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); | ||||||
|  |  | ||||||
|  |         const FLOAT_TYPE sc0 = scale_0_4_l_f.x; | ||||||
|  |         const FLOAT_TYPE sc1 = scale_0_4_l_f.y; | ||||||
|  |         const FLOAT_TYPE sc2 = scale_0_4_l_f.z; | ||||||
|  |         const FLOAT_TYPE sc3 = scale_0_4_l_f.w; | ||||||
|  |         const FLOAT_TYPE sc4 = scale8_f.x; | ||||||
|  |         const FLOAT_TYPE sc5 = scale8_f.y; | ||||||
|  |         const FLOAT_TYPE sc6 = scale8_f.z; | ||||||
|  |         const FLOAT_TYPE sc7 = scale8_f.w; | ||||||
|  |  | ||||||
|  |         const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); | ||||||
|  |         const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); | ||||||
|  |  | ||||||
|  |         uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; | ||||||
|  |         uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; | ||||||
|  |         uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; | ||||||
|  |         uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; | ||||||
|  |  | ||||||
|  |         const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); | ||||||
|  |  | ||||||
|  |         const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; | ||||||
|  |         const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; | ||||||
|  |         const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); | ||||||
|  |         const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; | ||||||
|  |  | ||||||
|  |         qs0_16_u32_lo4 += qs0_16_lo4_offset16; | ||||||
|  |         qs0_16_u32_hi4 += qs0_16_hi4_offset16; | ||||||
|  |         qs64_80_u32_lo4 += qs64_80_lo4_offset16; | ||||||
|  |         qs64_80_u32_hi4 += qs64_80_hi4_offset16; | ||||||
|  |  | ||||||
|  |         const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4)); | ||||||
|  |         const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4)); | ||||||
|  |         const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4)); | ||||||
|  |         const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4)); | ||||||
|  |  | ||||||
|  |         const FLOAT_TYPE q4_0  = qs0_16_lo4.x; | ||||||
|  |         const FLOAT_TYPE q4_1  = qs0_16_lo4.y; | ||||||
|  |         const FLOAT_TYPE q4_2  = qs0_16_lo4.z; | ||||||
|  |         const FLOAT_TYPE q4_3  = qs0_16_lo4.w; | ||||||
|  |         const FLOAT_TYPE q4_4  = qs0_16_hi4.x; | ||||||
|  |         const FLOAT_TYPE q4_5  = qs0_16_hi4.y; | ||||||
|  |         const FLOAT_TYPE q4_6  = qs0_16_hi4.z; | ||||||
|  |         const FLOAT_TYPE q4_7  = qs0_16_hi4.w; | ||||||
|  |         const FLOAT_TYPE q4_8  = qs64_80_lo4.x; | ||||||
|  |         const FLOAT_TYPE q4_9  = qs64_80_lo4.y; | ||||||
|  |         const FLOAT_TYPE q4_10 = qs64_80_lo4.z; | ||||||
|  |         const FLOAT_TYPE q4_11 = qs64_80_lo4.w; | ||||||
|  |         const FLOAT_TYPE q4_12 = qs64_80_hi4.x; | ||||||
|  |         const FLOAT_TYPE q4_13 = qs64_80_hi4.y; | ||||||
|  |         const FLOAT_TYPE q4_14 = qs64_80_hi4.z; | ||||||
|  |         const FLOAT_TYPE q4_15 = qs64_80_hi4.w; | ||||||
|  |  | ||||||
|  |         [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||||||
|  |             vec2 by10 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2     ]); | ||||||
|  |             vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 +  8]); | ||||||
|  |             vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]); | ||||||
|  |             vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]); | ||||||
|  |             vec2 by20 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2     ]); | ||||||
|  |             vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 +  8]); | ||||||
|  |             vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]); | ||||||
|  |             vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]); | ||||||
|  |  | ||||||
|  |             const FLOAT_TYPE sx = | ||||||
|  |               fma(FLOAT_TYPE(by10.x), q4_0, | ||||||
|  |               fma(FLOAT_TYPE(by10.y), q4_1, | ||||||
|  |               fma(FLOAT_TYPE(by116.x), q4_2, | ||||||
|  |                  FLOAT_TYPE(by116.y) * q4_3))); | ||||||
|  |             const FLOAT_TYPE sy = | ||||||
|  |               fma(FLOAT_TYPE(by132.x), q4_4, | ||||||
|  |               fma(FLOAT_TYPE(by132.y), q4_5, | ||||||
|  |               fma(FLOAT_TYPE(by148.x), q4_6, | ||||||
|  |                  FLOAT_TYPE(by148.y) * q4_7))); | ||||||
|  |             const FLOAT_TYPE sz = | ||||||
|  |               fma(FLOAT_TYPE(by20.x), q4_8, | ||||||
|  |               fma(FLOAT_TYPE(by20.y), q4_9, | ||||||
|  |               fma(FLOAT_TYPE(by216.x), q4_10, | ||||||
|  |                  FLOAT_TYPE(by216.y) * q4_11))); | ||||||
|  |             const FLOAT_TYPE sw = | ||||||
|  |               fma(FLOAT_TYPE(by232.x), q4_12, | ||||||
|  |               fma(FLOAT_TYPE(by232.y), q4_13, | ||||||
|  |               fma(FLOAT_TYPE(by248.x), q4_14, | ||||||
|  |                  FLOAT_TYPE(by248.y) * q4_15))); | ||||||
|  |             const FLOAT_TYPE smin = | ||||||
|  |               fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, | ||||||
|  |               fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, | ||||||
|  |               fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, | ||||||
|  |                   (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); | ||||||
|  |             temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | ||||||
|     uint a_offset, b_offset, d_offset; |     uint a_offset, b_offset, d_offset; | ||||||
|     get_offsets(a_offset, b_offset, d_offset); |     get_offsets(a_offset, b_offset, d_offset); | ||||||
| @@ -15,11 +127,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | |||||||
|     // 16 threads are used to process each block |     // 16 threads are used to process each block | ||||||
|     const uint it_size = gl_WorkGroupSize.x/16; |     const uint it_size = gl_WorkGroupSize.x/16; | ||||||
|     const uint tid = gl_LocalInvocationID.x; |     const uint tid = gl_LocalInvocationID.x; | ||||||
|     const uint itid = tid%16;  // 0...16 |     const uint itid = tid%16;  // 0...15 | ||||||
|     const uint ix  = tid/16; |     const uint ix = tid/16; | ||||||
|  |  | ||||||
|     const uint il = itid/4;                          // 0...3 |     const uint il = itid/4;                          // 0...3 | ||||||
|     const uint ir = itid - 4*il;                     // 0...7 or 0...3 |     const uint ir = itid - 4*il;                     // 0...3 | ||||||
|  |  | ||||||
|     const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 |     const uint v_im = il / 2;  // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 | ||||||
|     const uint v_in = il % 2; |     const uint v_in = il % 2; | ||||||
| @@ -28,121 +140,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | |||||||
|     const uint q_offset = 32*v_im + l0; |     const uint q_offset = 32*v_im + l0; | ||||||
|     const uint y_offset = 64*v_im + l0; |     const uint y_offset = 64*v_im + l0; | ||||||
|  |  | ||||||
|     FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; |  | ||||||
|  |  | ||||||
|     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { |     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||||||
|         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { |         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { | ||||||
|             temp[j][i] = FLOAT_TYPE(0); |             temp[j][i] = FLOAT_TYPE(0); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { |     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) | ||||||
|         const uint y1_idx = i * QUANT_K + y_offset; |         calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); | ||||||
|         const uint y2_idx = y1_idx + 128; |  | ||||||
|  |  | ||||||
|         [[unroll]] for (uint n = 0; n < num_rows; ++n) { |  | ||||||
|             const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; |  | ||||||
|             vec2 d = vec2(data_a[ib0 + i].d); |  | ||||||
|             const FLOAT_TYPE dall = FLOAT_TYPE(d.x); |  | ||||||
|             const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); |  | ||||||
|  |  | ||||||
|             uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im    ]; |  | ||||||
|             uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; |  | ||||||
|             uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; |  | ||||||
|             uvec4 scale0 = uvec4(unpack8(scale0_u32)); |  | ||||||
|             uvec4 scale4 = uvec4(unpack8(scale4_u32)); |  | ||||||
|             uvec4 scale8 = uvec4(unpack8(scale8_u32)); |  | ||||||
|  |  | ||||||
|             const uint32_t sc0 = (  scale0.x       & 0x3f); |  | ||||||
|             const uint32_t sc1 = (  scale0.y       & 0x3f); |  | ||||||
|             const uint32_t sc2 = (  scale4.x       & 0x3f); |  | ||||||
|             const uint32_t sc3 = (  scale4.y       & 0x3f); |  | ||||||
|             const uint32_t sc4 = (( scale8.x       & 0x0f) | ((scale0.x & 0xc0) >> 2)); |  | ||||||
|             const uint32_t sc5 = (( scale8.y       & 0x0f) | ((scale0.y & 0xc0) >> 2)); |  | ||||||
|             const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2)); |  | ||||||
|             const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2)); |  | ||||||
|  |  | ||||||
|             uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); |  | ||||||
|             uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); |  | ||||||
|  |  | ||||||
|             uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; |  | ||||||
|             uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; |  | ||||||
|             uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; |  | ||||||
|             uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; |  | ||||||
|  |  | ||||||
|             uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); |  | ||||||
|  |  | ||||||
|             uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; |  | ||||||
|             uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; |  | ||||||
|             uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010) << 0; |  | ||||||
|             uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; |  | ||||||
|  |  | ||||||
|             qs0_16_u32_lo4 += qs0_16_lo4_offset16; |  | ||||||
|             qs0_16_u32_hi4 += qs0_16_hi4_offset16; |  | ||||||
|             qs64_80_u32_lo4 += qs64_80_lo4_offset16; |  | ||||||
|             qs64_80_u32_hi4 += qs64_80_hi4_offset16; |  | ||||||
|  |  | ||||||
|             uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4)); |  | ||||||
|             uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4)); |  | ||||||
|             uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4)); |  | ||||||
|             uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4)); |  | ||||||
|  |  | ||||||
|             const uint32_t q4_0  = qs0_16_lo4.x; |  | ||||||
|             const uint32_t q4_1  = qs0_16_lo4.y; |  | ||||||
|             const uint32_t q4_2  = qs0_16_lo4.z; |  | ||||||
|             const uint32_t q4_3  = qs0_16_lo4.w; |  | ||||||
|             const uint32_t q4_4  = qs0_16_hi4.x; |  | ||||||
|             const uint32_t q4_5  = qs0_16_hi4.y; |  | ||||||
|             const uint32_t q4_6  = qs0_16_hi4.z; |  | ||||||
|             const uint32_t q4_7  = qs0_16_hi4.w; |  | ||||||
|             const uint32_t q4_8  = qs64_80_lo4.x; |  | ||||||
|             const uint32_t q4_9  = qs64_80_lo4.y; |  | ||||||
|             const uint32_t q4_10 = qs64_80_lo4.z; |  | ||||||
|             const uint32_t q4_11 = qs64_80_lo4.w; |  | ||||||
|             const uint32_t q4_12 = qs64_80_hi4.x; |  | ||||||
|             const uint32_t q4_13 = qs64_80_hi4.y; |  | ||||||
|             const uint32_t q4_14 = qs64_80_hi4.z; |  | ||||||
|             const uint32_t q4_15 = qs64_80_hi4.w; |  | ||||||
|  |  | ||||||
|             [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { |  | ||||||
|                 vec2 by10 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2     ]); |  | ||||||
|                 vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 +  8]); |  | ||||||
|                 vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]); |  | ||||||
|                 vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]); |  | ||||||
|                 vec2 by20 =  vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2     ]); |  | ||||||
|                 vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 +  8]); |  | ||||||
|                 vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]); |  | ||||||
|                 vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]); |  | ||||||
|  |  | ||||||
|                 const FLOAT_TYPE sx = |  | ||||||
|                   fma(FLOAT_TYPE(by10.x), q4_0, |  | ||||||
|                   fma(FLOAT_TYPE(by10.y), q4_1, |  | ||||||
|                   fma(FLOAT_TYPE(by116.x), q4_2, |  | ||||||
|                      FLOAT_TYPE(by116.y) * q4_3))); |  | ||||||
|                 const FLOAT_TYPE sy = |  | ||||||
|                   fma(FLOAT_TYPE(by132.x), q4_4, |  | ||||||
|                   fma(FLOAT_TYPE(by132.y), q4_5, |  | ||||||
|                   fma(FLOAT_TYPE(by148.x), q4_6, |  | ||||||
|                      FLOAT_TYPE(by148.y) * q4_7))); |  | ||||||
|                 const FLOAT_TYPE sz = |  | ||||||
|                   fma(FLOAT_TYPE(by20.x), q4_8, |  | ||||||
|                   fma(FLOAT_TYPE(by20.y), q4_9, |  | ||||||
|                   fma(FLOAT_TYPE(by216.x), q4_10, |  | ||||||
|                      FLOAT_TYPE(by216.y) * q4_11))); |  | ||||||
|                 const FLOAT_TYPE sw = |  | ||||||
|                   fma(FLOAT_TYPE(by232.x), q4_12, |  | ||||||
|                   fma(FLOAT_TYPE(by232.y), q4_13, |  | ||||||
|                   fma(FLOAT_TYPE(by248.x), q4_14, |  | ||||||
|                      FLOAT_TYPE(by248.y) * q4_15))); |  | ||||||
|                 const FLOAT_TYPE smin = |  | ||||||
|                   fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, |  | ||||||
|                   fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, |  | ||||||
|                   fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, |  | ||||||
|                       (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); |  | ||||||
|                 temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     reduce_result(temp, d_offset, first_row, num_rows, tid); |     reduce_result(temp, d_offset, first_row, num_rows, tid); | ||||||
| } | } | ||||||
|   | |||||||
| @@ -6,7 +6,77 @@ | |||||||
|  |  | ||||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||||
|  |  | ||||||
| void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16]; | ||||||
|  |  | ||||||
|  | FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; | ||||||
|  |  | ||||||
|  | void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { | ||||||
|  |     const uint y_idx = i * QUANT_K + y_offset; | ||||||
|  |  | ||||||
|  |     [[unroll]] for (uint n = 0; n < num_rows; ++n) { | ||||||
|  |         const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; | ||||||
|  |  | ||||||
|  |         if (!all_threads) { // when we don't have enough blocks to use all threads | ||||||
|  |             barrier(); | ||||||
|  |             if (i < num_blocks_per_row) | ||||||
|  |                 sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); | ||||||
|  |             barrier(); | ||||||
|  |  | ||||||
|  |             if (i >= num_blocks_per_row) | ||||||
|  |                 continue; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const uint32_t ql0_u32 =  uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); | ||||||
|  |         const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); | ||||||
|  |  | ||||||
|  |         const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; | ||||||
|  |         const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; | ||||||
|  |         const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; | ||||||
|  |         const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; | ||||||
|  |  | ||||||
|  |         const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); | ||||||
|  |         const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; | ||||||
|  |         const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; | ||||||
|  |         const uint32_t qh4_u32 = (qh_u32 & 0x30303030); | ||||||
|  |         const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; | ||||||
|  |  | ||||||
|  |         const uint32_t q0_u32 = ql0_u32_lo4  | qh0_u32; | ||||||
|  |         const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; | ||||||
|  |         const uint32_t q2_u32 = ql0_u32_hi4  | qh4_u32; | ||||||
|  |         const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; | ||||||
|  |  | ||||||
|  |         const vec4 q0 = vec4(unpack8(q0_u32)) - 32; | ||||||
|  |         const vec4 q1 = vec4(unpack8(q1_u32)) - 32; | ||||||
|  |         const vec4 q2 = vec4(unpack8(q2_u32)) - 32; | ||||||
|  |         const vec4 q3 = vec4(unpack8(q3_u32)) - 32; | ||||||
|  |  | ||||||
|  |         if (all_threads) { | ||||||
|  |             barrier(); | ||||||
|  |             sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); | ||||||
|  |             barrier(); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); | ||||||
|  |  | ||||||
|  |         [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||||||
|  |             vec4 by0  = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4     ]); | ||||||
|  |             vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 +  8]); | ||||||
|  |             vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]); | ||||||
|  |             vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]); | ||||||
|  |  | ||||||
|  |             FLOAT_TYPE sum[4] = {0, 0, 0, 0}; | ||||||
|  |             [[unroll]] for (uint l = 0; l < 4; ++l) { | ||||||
|  |                 sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]); | ||||||
|  |                 sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]); | ||||||
|  |                 sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]); | ||||||
|  |                 sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]); | ||||||
|  |             } | ||||||
|  |             temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void compute_outputs(const uint first_row, const uint num_rows) { | ||||||
|     uint a_offset, b_offset, d_offset; |     uint a_offset, b_offset, d_offset; | ||||||
|     get_offsets(a_offset, b_offset, d_offset); |     get_offsets(a_offset, b_offset, d_offset); | ||||||
|  |  | ||||||
| @@ -15,13 +85,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | |||||||
|     // 16 threads are used to process each block |     // 16 threads are used to process each block | ||||||
|     const uint it_size = gl_WorkGroupSize.x/16; |     const uint it_size = gl_WorkGroupSize.x/16; | ||||||
|     const uint tid = gl_LocalInvocationID.x; |     const uint tid = gl_LocalInvocationID.x; | ||||||
|     const uint itid = tid%16;  // 0...16 |     const uint itid = tid%16;  // 0...15 | ||||||
|     const uint ix  = tid/16; |     const uint ix = tid/16; | ||||||
|  |  | ||||||
|     const uint step = 8; |     const uint v_im = itid/8;                               // 0 or 1. 0 computes 0..., 1 computes 128... | ||||||
|  |     const uint v_in = itid - 8*v_im;                        // 0...7 | ||||||
|     const uint v_im = itid/step;                            // 0 or 1. 0 computes 0..., 1 computes 128... |  | ||||||
|     const uint v_in = itid - step*v_im;                     // 0...15 or 0...7 |  | ||||||
|  |  | ||||||
|     const uint l0 = 4 * v_in;                               // 0, 4, 8, ..., 28 |     const uint l0 = 4 * v_in;                               // 0, 4, 8, ..., 28 | ||||||
|     const uint is = v_in / 4; |     const uint is = v_in / 4; | ||||||
| @@ -31,68 +99,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { | |||||||
|     const uint s_offset  =  8*v_im + is; |     const uint s_offset  =  8*v_im + is; | ||||||
|     const uint y_offset = 128*v_im + l0; |     const uint y_offset = 128*v_im + l0; | ||||||
|  |  | ||||||
|     FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; |  | ||||||
|  |  | ||||||
|     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { |     [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { | ||||||
|         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { |         [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { | ||||||
|             temp[j][i] = FLOAT_TYPE(0); |             temp[j][i] = FLOAT_TYPE(0); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) { |     const uint nbr_par_th = num_blocks_per_row%it_size; | ||||||
|         const uint y_idx = i * QUANT_K + y_offset; |     const uint nbr_all_th = num_blocks_per_row - nbr_par_th; | ||||||
|  |     uint i0 = 0; | ||||||
|         [[unroll]] for (uint n = 0; n < num_rows; ++n) { |     [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) | ||||||
|             const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; |         calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); | ||||||
|             const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); |     calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); | ||||||
|  |  | ||||||
|             FLOAT_TYPE scales[4]; |  | ||||||
|             scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]); |  | ||||||
|             scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]); |  | ||||||
|             scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]); |  | ||||||
|             scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]); |  | ||||||
|  |  | ||||||
|             uint32_t ql0_u32 =  uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); |  | ||||||
|             uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); |  | ||||||
|  |  | ||||||
|             uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; |  | ||||||
|             uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; |  | ||||||
|             uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; |  | ||||||
|             uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; |  | ||||||
|  |  | ||||||
|             uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); |  | ||||||
|             uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; |  | ||||||
|             uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; |  | ||||||
|             uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0; |  | ||||||
|             uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; |  | ||||||
|  |  | ||||||
|             uint32_t q0_u32 = ql0_u32_lo4  | qh0_u32; |  | ||||||
|             uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; |  | ||||||
|             uint32_t q2_u32 = ql0_u32_hi4  | qh4_u32; |  | ||||||
|             uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; |  | ||||||
|  |  | ||||||
|             uvec4 q0 = uvec4(unpack8(q0_u32)); |  | ||||||
|             uvec4 q1 = uvec4(unpack8(q1_u32)); |  | ||||||
|             uvec4 q2 = uvec4(unpack8(q2_u32)); |  | ||||||
|             uvec4 q3 = uvec4(unpack8(q3_u32)); |  | ||||||
|  |  | ||||||
|             [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { |  | ||||||
|                 vec4 by0  = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4     ]); |  | ||||||
|                 vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 +  8]); |  | ||||||
|                 vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]); |  | ||||||
|                 vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]); |  | ||||||
|  |  | ||||||
|                 FLOAT_TYPE sum = FLOAT_TYPE(0.0); |  | ||||||
|                 [[unroll]] for (int l = 0; l < 4; ++l) { |  | ||||||
|                     sum = fma(FLOAT_TYPE(by0[l])  * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32), |  | ||||||
|                           fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32), |  | ||||||
|                           fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32), |  | ||||||
|                           fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum)))); |  | ||||||
|                 } |  | ||||||
|                 temp[j][n] += sum * d; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     reduce_result(temp, d_offset, first_row, num_rows, tid); |     reduce_result(temp, d_offset, first_row, num_rows, tid); | ||||||
| } | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Eve
					Eve