mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Load 4 quant blocks into shared memory in one step
This commit is contained in:
		| @@ -77,9 +77,13 @@ layout (constant_id = 10) const uint WARP = 32; | |||||||
|  |  | ||||||
| #include "mul_mmq_shmem_types.glsl" | #include "mul_mmq_shmem_types.glsl" | ||||||
|  |  | ||||||
|  | #ifndef BK_STEP | ||||||
|  | #define BK_STEP 4 | ||||||
|  | #endif | ||||||
|  |  | ||||||
| // Shared memory cache | // Shared memory cache | ||||||
| shared block_a_cache buf_a[BM]; | shared block_a_cache buf_a[BM * BK_STEP / QUANT_BLOCK_FACTOR]; | ||||||
| shared block_b_cache buf_b[BN]; | shared block_b_cache buf_b[BN * BK_STEP / QUANT_BLOCK_FACTOR]; | ||||||
| // Register cache | // Register cache | ||||||
| block_a_cache cache_a[WMITER * TM]; | block_a_cache cache_a[WMITER * TM]; | ||||||
| block_b_cache cache_b; | block_b_cache cache_b; | ||||||
| @@ -185,70 +189,64 @@ void main() { | |||||||
|         sums[i] = ACC_TYPE_VEC2(0.0f); |         sums[i] = ACC_TYPE_VEC2(0.0f); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     for (uint block = start_k; block < end_k; block += BK) { |     for (uint block = start_k; block < end_k; block += BK * BK_STEP) { | ||||||
|         [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { |         [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { | ||||||
|             const uint buf_ib = loadc_a + l; |             const uint buf_ib = loadc_a + l; | ||||||
|             const uint ib = pos_a_ib + buf_ib * p.stride_a / BK; |             const uint ib = pos_a_ib + buf_ib * p.stride_a / BK; | ||||||
|             const uint iqs = loadr_a; |             const uint iqs = loadr_a; | ||||||
|  |  | ||||||
|             block_a_to_shmem(buf_ib, ib, iqs); |             [[unroll]] for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) { | ||||||
|  |                 block_a_to_shmem(k_step * BM + buf_ib, ib + k_step, iqs); | ||||||
|  |             } | ||||||
|         } |         } | ||||||
|         [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { |         [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { | ||||||
|  |             const uint buf_ib = loadc_b + l; | ||||||
|  |  | ||||||
| #ifdef MUL_MAT_ID | #ifdef MUL_MAT_ID | ||||||
|             const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; |             const u16vec2 row_idx = row_ids[ic * BN + buf_ib]; | ||||||
|             const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; |             const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; | ||||||
|             const uint ib = idx / 8; |             const uint ib = idx / 8; | ||||||
|             const uint iqs = idx & 0x7; |             const uint iqs = idx & 0x7; | ||||||
| #else | #else | ||||||
|             const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; |             const uint ib = pos_b_ib + buf_ib * p.stride_b / BK; | ||||||
|             const uint ib_outer = ib / 4; |  | ||||||
|             const uint ib_inner = ib % 4; |  | ||||||
|  |  | ||||||
|             const uint iqs = loadr_b; |             const uint iqs = loadr_b; | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|             const uint buf_ib = loadc_b + l; |             [[unroll]] for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) { | ||||||
|  |                 block_b_to_shmem(k_step * BN + buf_ib, ib + k_step, iqs); | ||||||
|             if (iqs == 0) { |  | ||||||
|                 buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); |  | ||||||
|             } |             } | ||||||
|             const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; |  | ||||||
|             buf_b[buf_ib].qs[iqs * 4    ] = values.x; |  | ||||||
|             buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; |  | ||||||
|             buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; |  | ||||||
|             buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; |  | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         barrier(); |         barrier(); | ||||||
|  |  | ||||||
|         pos_a_ib += 1; |         pos_a_ib += BK_STEP; | ||||||
|         pos_b_ib += 1; |         pos_b_ib += BK_STEP; | ||||||
|  |  | ||||||
|         // Load from shared into cache |         for (uint k_step = 0; k_step < BK_STEP / QUANT_BLOCK_FACTOR; k_step++) { | ||||||
|         [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { |             // Load from shared into cache | ||||||
|             [[unroll]] for (uint cr = 0; cr < TM; cr++) { |             [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { | ||||||
|                 const uint reg_ib = wsir * TM + cr; |                 [[unroll]] for (uint cr = 0; cr < TM; cr++) { | ||||||
|                 const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; |                     const uint reg_ib = wsir * TM + cr; | ||||||
|  |                     const uint buf_ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; | ||||||
|  |  | ||||||
|                 block_a_to_registers(reg_ib, buf_ib); |                     block_a_to_registers(reg_ib, k_step * BM + buf_ib); | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { |  | ||||||
|             [[unroll]] for (uint cc = 0; cc < TN; cc++) { |  | ||||||
|                 const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; |  | ||||||
|                 cache_b.ds = buf_b[ib].ds; |  | ||||||
|                 [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { |  | ||||||
|                     cache_b.qs[iqs] = buf_b[ib].qs[iqs]; |  | ||||||
|                 } |                 } | ||||||
|  |             } | ||||||
|  |  | ||||||
|                 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { |             [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { | ||||||
|                     [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { |                 [[unroll]] for (uint cc = 0; cc < TN; cc++) { | ||||||
|                         const uint cache_a_idx = wsir * TM + cr * 2; |                     const uint ib = k_step * BN + warp_c * WN + wsic * WSUBN + tiwc * TN + cc; | ||||||
|                         const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr; |                     block_b_to_registers(ib); | ||||||
|  |  | ||||||
|                         sums[sums_idx].x += mmq_dot_product(cache_a_idx); |                     [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { | ||||||
|                         sums[sums_idx].y += mmq_dot_product(cache_a_idx + 1); |                         [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { | ||||||
|  |                             const uint cache_a_idx = wsir * TM + cr * 2; | ||||||
|  |                             const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr; | ||||||
|  |  | ||||||
|  |                             sums[sums_idx].x += mmq_dot_product(cache_a_idx); | ||||||
|  |                             sums[sums_idx].y += mmq_dot_product(cache_a_idx + 1); | ||||||
|  |                         } | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|   | |||||||
| @@ -233,71 +233,113 @@ ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, cons | |||||||
| #ifdef MMQ_SHMEM | #ifdef MMQ_SHMEM | ||||||
| void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { | void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { | ||||||
|     const uint ib_k = ib / 8; |     const uint ib_k = ib / 8; | ||||||
|     const uint iqs_k = (ib % 8) * 8 + iqs * 4; |     const uint iqs_k = (ib % 8) * 8 + iqs; | ||||||
|  |  | ||||||
|     const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); |     const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8); | ||||||
|     const uint qs_shift = ((iqs_k % 32) / 8) * 2; |     // const uint qs_shift = ((iqs_k % 32) / 8) * 2; | ||||||
|  |  | ||||||
|     // Repack 4x4 quants into one int |     // Repack 4x4 quants into one int | ||||||
|     const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x03030303; |     // const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx    ] >> qs_shift) & 0x03030303; | ||||||
|     const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303; |     // const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303; | ||||||
|     const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303; |     // const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303; | ||||||
|     const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303; |     // const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303; | ||||||
|  |  | ||||||
|     buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); |     buf_a[buf_ib].qs[iqs] = data_a_packed32[ib_k].qs[qs_idx]; // vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); | ||||||
|  |  | ||||||
|     if (iqs == 0) { |     if (iqs == 0) { | ||||||
|         buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]); |  | ||||||
|         buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); |         buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); | ||||||
|  |         buf_a[buf_ib].scales[0] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16]); | ||||||
|  |     } | ||||||
|  |     if (iqs == 1) { | ||||||
|  |         buf_a[buf_ib].scales[1] = unpack8(data_a_packed32[ib_k].scales[iqs_k / 16 + 1]); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| void block_a_to_registers(const uint reg_ib, const uint buf_ib) { | void block_a_to_registers(const uint reg_ib, const uint buf_ib) { | ||||||
|     cache_a[reg_ib].dm = buf_a[buf_ib].dm; |     cache_a[reg_ib].dm = buf_a[buf_ib].dm; | ||||||
|     cache_a[reg_ib].scales = buf_a[buf_ib].scales; |  | ||||||
|  |  | ||||||
|     [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) { |     [[unroll]] for (uint iqs = 0; iqs < 2; iqs++) { | ||||||
|  |         cache_a[reg_ib].scales[iqs] = buf_a[buf_ib].scales[iqs]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { | ||||||
|         cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; |         cache_a[reg_ib].qs[iqs] = buf_a[buf_ib].qs[iqs]; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| ACC_TYPE mmq_dot_product(const uint ib_a) { | ACC_TYPE mmq_dot_product(const uint ib_a) { | ||||||
|     int32_t sum_d = 0; |     float sum_d = 0; | ||||||
|     int32_t sum_m = 0; |     float sum_m = 0; | ||||||
|  |  | ||||||
|     uint8_t scale = cache_a[ib_a].scales[0]; |     [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { | ||||||
|     int32_t scale_m = int32_t(scale >> 4); |         const uint32_t qs_a_packed = cache_a[ib_a].qs[iqs]; | ||||||
|     scale_m |= scale_m << 8; |         [[unroll]] for (uint ib_b = 0; ib_b < 4; ib_b++) { | ||||||
|     scale_m |= scale_m << 16; |             const uint8_t scale = cache_a[ib_a].scales[ib_b / 2][(ib_b % 2) * 2 + (iqs / 4)]; | ||||||
|  |             const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits. | ||||||
|  |             const int32_t qs_a = int32_t((qs_a_packed >> (ib_b * 2)) & 0x03030303); | ||||||
|  |  | ||||||
|     [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { |             sum_d += cache_b.ds[ib_b].x * float(dotPacked4x8EXT(qs_a, cache_b.qs[ib_b * 8 + iqs]) * (scale & 0xF)); | ||||||
|         const uint qs_shift = iqs * 2; |             sum_m += cache_b.ds[ib_b].x * float(dotPacked4x8EXT(scale_m, cache_b.qs[ib_b * 8 + iqs])); | ||||||
|  |         } | ||||||
|         const int32_t qs_a = int32_t((cache_a[ib_a].qs[0] >> qs_shift) & 0x03030303); |  | ||||||
|  |  | ||||||
|         sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF); |  | ||||||
|         sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     scale = cache_a[ib_a].scales[1]; |     return ACC_TYPE(cache_a[ib_a].dm.x * sum_d - cache_a[ib_a].dm.y * sum_m); | ||||||
|     scale_m = int32_t(scale >> 4); |  | ||||||
|     scale_m |= scale_m << 8; |  | ||||||
|     scale_m |= scale_m << 16; |  | ||||||
|  |  | ||||||
|     [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) { |  | ||||||
|         const uint qs_shift = (iqs - 4) * 2; |  | ||||||
|  |  | ||||||
|         const int32_t qs_a = int32_t((cache_a[ib_a].qs[1] >> qs_shift) & 0x03030303); |  | ||||||
|  |  | ||||||
|         sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF); |  | ||||||
|         sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); |  | ||||||
| } | } | ||||||
| #endif // MMQ_SHMEM | #endif // MMQ_SHMEM | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  | #ifdef MMQ_SHMEM | ||||||
|  | #if defined(DATA_A_QUANT_LEGACY) | ||||||
|  | void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { | ||||||
|  |     const uint ib_outer = ib / 4; | ||||||
|  |     const uint ib_inner = ib % 4; | ||||||
|  |  | ||||||
|  |     if (iqs == 0) { | ||||||
|  |         buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; | ||||||
|  |     buf_b[buf_ib].qs[iqs * 4    ] = values.x; | ||||||
|  |     buf_b[buf_ib].qs[iqs * 4 + 1] = values.y; | ||||||
|  |     buf_b[buf_ib].qs[iqs * 4 + 2] = values.z; | ||||||
|  |     buf_b[buf_ib].qs[iqs * 4 + 3] = values.w; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void block_b_to_registers(const uint ib) { | ||||||
|  |     cache_b.ds = buf_b[ib].ds; | ||||||
|  |     [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { | ||||||
|  |         cache_b.qs[iqs] = buf_b[ib].qs[iqs]; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | #elif defined(DATA_A_QUANT_K) | ||||||
|  | void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { | ||||||
|  |     const uint ib_outer = ib / 4; | ||||||
|  |  | ||||||
|  |     buf_b[buf_ib].ds[iqs * 2    ] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2    ]); | ||||||
|  |     buf_b[buf_ib].ds[iqs * 2 + 1] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[iqs * 2 + 1]); | ||||||
|  |  | ||||||
|  |     [[unroll]] for (uint ib_inner = 0; ib_inner < 4; ib_inner++) { | ||||||
|  |         const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; | ||||||
|  |         buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4    ] = values.x; | ||||||
|  |         buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 1] = values.y; | ||||||
|  |         buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 2] = values.z; | ||||||
|  |         buf_b[buf_ib].qs[ib_inner * 8 + iqs * 4 + 3] = values.w; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void block_b_to_registers(const uint ib) { | ||||||
|  |     [[unroll]] for (uint i = 0; i < 4; i++) { | ||||||
|  |         cache_b.ds[i] = buf_b[ib].ds[i]; | ||||||
|  |     } | ||||||
|  |     [[unroll]] for (uint iqs = 0; iqs < 32; iqs++) { | ||||||
|  |         cache_b.qs[iqs] = buf_b[ib].qs[iqs]; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | #else | ||||||
|  | #error unimplemented | ||||||
|  | #endif | ||||||
|  | #endif | ||||||
|  |  | ||||||
| #if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) | #if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) | ||||||
| FLOAT_TYPE get_d(uint ib) { | FLOAT_TYPE get_d(uint ib) { | ||||||
|     return FLOAT_TYPE(data_a[ib].d); |     return FLOAT_TYPE(data_a[ib].d); | ||||||
|   | |||||||
| @@ -31,17 +31,31 @@ struct block_a_cache { | |||||||
|     FLOAT_TYPE dm; |     FLOAT_TYPE dm; | ||||||
| }; | }; | ||||||
| #elif defined(DATA_A_Q2_K) | #elif defined(DATA_A_Q2_K) | ||||||
| #define QUANT_R_MMQ 4 | #define QUANT_R_MMQ 1 | ||||||
| struct block_a_cache | struct block_a_cache | ||||||
| { | { | ||||||
|     uint32_t qs[2]; |     uint32_t qs[8]; | ||||||
|     u8vec2 scales; |     u8vec4 scales[2]; | ||||||
|     FLOAT_TYPE_VEC2 dm; |     FLOAT_TYPE_VEC2 dm; | ||||||
| }; | }; | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  | #if defined(DATA_A_QUANT_LEGACY) | ||||||
|  | #define QUANT_BLOCK_FACTOR 1 | ||||||
|  |  | ||||||
| struct block_b_cache | struct block_b_cache | ||||||
| { | { | ||||||
|     int32_t qs[8]; |     int32_t qs[8]; | ||||||
|     FLOAT_TYPE_VEC2 ds; |     FLOAT_TYPE_VEC2 ds; | ||||||
| }; | }; | ||||||
|  | #elif defined(DATA_A_QUANT_K) | ||||||
|  | #define QUANT_BLOCK_FACTOR 4 | ||||||
|  |  | ||||||
|  | struct block_b_cache | ||||||
|  | { | ||||||
|  |     int32_t qs[32]; | ||||||
|  |     FLOAT_TYPE_VEC2 ds[4]; | ||||||
|  | }; | ||||||
|  | #else | ||||||
|  | #error unimplemented | ||||||
|  | #endif | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 0cc4m
					0cc4m