mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Reduce mmq register use
This commit is contained in:
		| @@ -10,12 +10,6 @@ | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require | ||||
| #endif | ||||
|  | ||||
| #ifdef COOPMAT | ||||
| #extension GL_KHR_cooperative_matrix : enable | ||||
| #extension GL_KHR_memory_scope_semantics : enable | ||||
| #extension GL_KHR_shader_subgroup_basic : enable | ||||
| #endif | ||||
|  | ||||
| #ifdef MUL_MAT_ID | ||||
| #extension GL_EXT_shader_explicit_arithmetic_types_int16 : require | ||||
| #endif | ||||
| @@ -79,10 +73,6 @@ layout (constant_id = 10) const uint WARP = 32; | ||||
|  | ||||
| #define BK 32 | ||||
|  | ||||
| #ifdef COOPMAT | ||||
| #define SHMEM_STRIDE (BK / 4 + 4) | ||||
| #endif | ||||
|  | ||||
| #define MMQ_SHMEM | ||||
|  | ||||
| #include "mul_mmq_shmem_types.glsl" | ||||
| @@ -92,7 +82,7 @@ shared block_a_cache buf_a[BM]; | ||||
| shared block_b_cache buf_b[BN]; | ||||
| // Register cache | ||||
| block_a_cache cache_a[WMITER * TM]; | ||||
| block_b_cache cache_b[TN]; | ||||
| block_b_cache cache_b; | ||||
|  | ||||
| #define LOAD_VEC_A (4 * QUANT_R_MMQ) | ||||
| #define LOAD_VEC_B 16 | ||||
| @@ -104,10 +94,6 @@ shared u16vec2 row_ids[4096]; | ||||
|  | ||||
| #define NUM_WARPS (BLOCK_SIZE / WARP) | ||||
|  | ||||
| #ifdef COOPMAT | ||||
| shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; | ||||
| #endif | ||||
|  | ||||
| #include "mul_mmq_funcs.glsl" | ||||
|  | ||||
| void main() { | ||||
| @@ -137,26 +123,12 @@ void main() { | ||||
|     const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); | ||||
|     const uint WSUBM = WM / WMITER; | ||||
|     const uint WSUBN = WN / WNITER; | ||||
|  | ||||
| #ifdef COOPMAT | ||||
|     const uint warp_i = gl_SubgroupID; | ||||
|  | ||||
|     const uint tiw = gl_SubgroupInvocationID; | ||||
|  | ||||
|     const uint cms_per_row = WM / TM; | ||||
|     const uint cms_per_col = WN / TN; | ||||
|  | ||||
|     const uint storestride = WARP / TM; | ||||
|     const uint store_r = tiw % TM; | ||||
|     const uint store_c = tiw / TM; | ||||
| #else | ||||
|     const uint warp_i = gl_LocalInvocationID.x / WARP; | ||||
|  | ||||
|     const uint tiw = gl_LocalInvocationID.x % WARP; | ||||
|  | ||||
|     const uint tiwr = tiw % (WSUBM / TM); | ||||
|     const uint tiwc = tiw / (WSUBM / TM); | ||||
| #endif | ||||
|  | ||||
|     const uint warp_r = warp_i % (BM / WM); | ||||
|     const uint warp_c = warp_i / (BM / WM); | ||||
| @@ -207,26 +179,11 @@ void main() { | ||||
|     uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; | ||||
| #endif | ||||
|  | ||||
| #ifdef COOPMAT | ||||
|     coopmat<int8_t, gl_ScopeSubgroup, TM, TK, gl_MatrixUseA> cache_a; | ||||
|     coopmat<int8_t, gl_ScopeSubgroup, TK, TN, gl_MatrixUseB> cache_b; | ||||
|     coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_result; | ||||
|     ACC_TYPE_VEC2 sums[WMITER * TM * WNITER * TN / 2]; | ||||
|  | ||||
|     coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> factors[cms_per_row * cms_per_col]; | ||||
|  | ||||
|     coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> sums[cms_per_row * cms_per_col]; | ||||
|  | ||||
|     [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { | ||||
|         sums[i] = coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0.0f); | ||||
|     [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN/2; i++) { | ||||
|         sums[i] = ACC_TYPE_VEC2(0.0f); | ||||
|     } | ||||
| #else | ||||
|  | ||||
|     ACC_TYPE sums[WMITER * TM * WNITER * TN]; | ||||
|  | ||||
|     [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { | ||||
|         sums[i] = ACC_TYPE(0.0f); | ||||
|     } | ||||
| #endif | ||||
|  | ||||
|     for (uint block = start_k; block < end_k; block += BK) { | ||||
|         [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { | ||||
| @@ -267,38 +224,6 @@ void main() { | ||||
|         pos_a_ib += 1; | ||||
|         pos_b_ib += 1; | ||||
|  | ||||
| #ifdef COOPMAT | ||||
|         [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { | ||||
|             const uint ib_a = warp_r * WM + cm_row * TM; | ||||
|             // Load from shared into cache | ||||
|             coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); | ||||
|  | ||||
|             // TODO: only cache values that are actually needed | ||||
|             [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) { | ||||
|                 cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx]; | ||||
|             } | ||||
|  | ||||
|             [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { | ||||
|                 const uint ib_b = warp_c * WN + cm_col * TN; | ||||
|                 coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); | ||||
|  | ||||
|                 // TODO: only cache values that are actually needed | ||||
|                 [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) { | ||||
|                     cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx]; | ||||
|                 } | ||||
|  | ||||
|                 cm_result = coopmat<int32_t, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(0); | ||||
|                 cm_result = coopMatMulAdd(cache_a, cache_b, cm_result); | ||||
|  | ||||
|                 [[unroll]] for (uint col = 0; col < TN; col += storestride) { | ||||
|                     coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col])); | ||||
|                 } | ||||
|  | ||||
|                 coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); | ||||
|                 sums[cm_col * cms_per_row + cm_row] += factors * coopmat<ACC_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(cm_result); | ||||
|             } | ||||
|         } | ||||
| #else | ||||
|         // Load from shared into cache | ||||
|         [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { | ||||
|             [[unroll]] for (uint cr = 0; cr < TM; cr++) { | ||||
| @@ -312,24 +237,22 @@ void main() { | ||||
|         [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { | ||||
|             [[unroll]] for (uint cc = 0; cc < TN; cc++) { | ||||
|                 const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; | ||||
|                 cache_b[cc].ds = buf_b[ib].ds; | ||||
|                 cache_b.ds = buf_b[ib].ds; | ||||
|                 [[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) { | ||||
|                     cache_b[cc].qs[iqs] = buf_b[ib].qs[iqs]; | ||||
|                 } | ||||
|                     cache_b.qs[iqs] = buf_b[ib].qs[iqs]; | ||||
|                 } | ||||
|  | ||||
|                 [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { | ||||
|                 [[unroll]] for (uint cc = 0; cc < TN; cc++) { | ||||
|                     [[unroll]] for (uint cr = 0; cr < TM; cr++) { | ||||
|                         const uint cache_a_idx = wsir * TM + cr; | ||||
|                         const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; | ||||
|                     [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { | ||||
|                         const uint cache_a_idx = wsir * TM + cr * 2; | ||||
|                         const uint sums_idx = (wsic * TN + cc) * (WMITER * TM / 2) + wsir * TM / 2 + cr; | ||||
|  | ||||
|                         sums[sums_idx] += mmq_dot_product(cache_a_idx, cc); | ||||
|                         sums[sums_idx].x += mmq_dot_product(cache_a_idx); | ||||
|                         sums[sums_idx].y += mmq_dot_product(cache_a_idx + 1); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| #endif | ||||
|  | ||||
|         barrier(); | ||||
|     } | ||||
| @@ -341,54 +264,6 @@ void main() { | ||||
|     const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; | ||||
| #endif | ||||
|  | ||||
| #ifdef COOPMAT | ||||
| #ifdef MUL_MAT_ID | ||||
|     [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { | ||||
|         [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { | ||||
|             coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); | ||||
|  | ||||
|             [[unroll]] for (uint col = 0; col < BN; col += storestride) { | ||||
|                 const uint row_i = dc + cm_col * TN + col + store_c; | ||||
|                 if (row_i >= _ne1) break; | ||||
|  | ||||
|                 const u16vec2 row_idx = row_ids[row_i]; | ||||
|  | ||||
|                 data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| #else | ||||
|     const bool is_aligned = p.stride_d % 4 == 0;  // Assumption: D_TYPE == float | ||||
|  | ||||
|     [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { | ||||
|         [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { | ||||
|             const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; | ||||
|  | ||||
|             if (is_aligned && is_in_bounds) { | ||||
|                 // Full coopMat is within bounds and stride_d is aligned with 16B | ||||
|                 coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator> cm_dtype = coopmat<D_TYPE, gl_ScopeSubgroup, TM, TN, gl_MatrixUseAccumulator>(sums[cm_col * cms_per_row + cm_row]); | ||||
|                 coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); | ||||
|             } else if (is_in_bounds) { | ||||
|                 // Full coopMat is within bounds, but stride_d is not aligned | ||||
|                 coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); | ||||
|  | ||||
|                 [[unroll]] for (uint col = 0; col < TN; col += storestride) { | ||||
|                     data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); | ||||
|                 } | ||||
|             } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { | ||||
|                 // Partial coopMat is within bounds | ||||
|                 coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); | ||||
|  | ||||
|                 [[unroll]] for (uint col = 0; col < TN; col += storestride) { | ||||
|                     if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { | ||||
|                         data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| #endif // MUL_MAT_ID | ||||
| #else | ||||
|     [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { | ||||
|         [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { | ||||
|  | ||||
| @@ -399,19 +274,27 @@ void main() { | ||||
|                 const uint row_i = dc_warp + cc; | ||||
|                 if (row_i >= _ne1) break; | ||||
|  | ||||
|                 const u16vec2 row_idx = row_ids[row_i]; | ||||
|                 const u16vec2 row_idx = row_ids[row_i - ic * BN]; | ||||
| #endif // MUL_MAT_ID | ||||
|                 [[unroll]] for (uint cr = 0; cr < TM; cr++) { | ||||
|                 [[unroll]] for (uint cr = 0; cr < TM / 2; cr++) { | ||||
|                     const uint sums_idx = (wsic * TN + cc) * WMITER * (TM / 2) + wsir * (TM / 2) + cr; | ||||
| #ifdef MUL_MAT_ID | ||||
|                     data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); | ||||
|                     if (dr_warp + 2 * cr < p.M) { | ||||
|                         data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); | ||||
|                     } | ||||
|                     if (dr_warp + 2 * cr + 1 < p.M) { | ||||
|                         data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); | ||||
|                     } | ||||
| #else | ||||
|                     if (dr_warp + cr < p.M && dc_warp + cc < p.N) { | ||||
|                         data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); | ||||
|                     if (dr_warp + 2 * cr < p.M && dc_warp + cc < p.N) { | ||||
|                         data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr] = D_TYPE(sums[sums_idx].x); | ||||
|                     } | ||||
|                     if (dr_warp + 2 * cr + 1 < p.M && dc_warp + cc < p.N) { | ||||
|                         data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + 2 * cr + 1] = D_TYPE(sums[sums_idx].y); | ||||
|                     } | ||||
| #endif // MUL_MAT_ID | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| #endif // COOPMAT | ||||
| } | ||||
|   | ||||
| @@ -62,21 +62,21 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) { | ||||
| void block_a_to_registers(const uint reg_ib, const uint buf_ib, const uint iqs) { | ||||
| } | ||||
|  | ||||
| ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { | ||||
| ACC_TYPE mmq_dot_product(const uint ib_a) { | ||||
|     int32_t q_sum = 0; | ||||
|     [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { | ||||
|         const uint32_t vui = cache_a[ib_a].qs[iqs]; | ||||
|         const i32vec2 qs_a = i32vec2( vui       & 0x0F0F0F0F, | ||||
|                                      (vui >> 4) & 0x0F0F0F0F); | ||||
|  | ||||
|         const int32_t qs_b0 = cache_b[ib_b].qs[iqs]; | ||||
|         const int32_t qs_b1 = cache_b[ib_b].qs[iqs + 4]; | ||||
|         const int32_t qs_b0 = cache_b.qs[iqs]; | ||||
|         const int32_t qs_b1 = cache_b.qs[iqs + 4]; | ||||
|  | ||||
|         q_sum += dotPacked4x8EXT(qs_a.x, qs_b0); | ||||
|         q_sum += dotPacked4x8EXT(qs_a.y, qs_b1); | ||||
|     } | ||||
|  | ||||
|     return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b[ib_b].ds, 1); | ||||
|     return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); | ||||
| } | ||||
| #endif // MMQ_SHMEM | ||||
|  | ||||
| @@ -140,7 +140,7 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) { | ||||
|     } | ||||
| } | ||||
|  | ||||
| ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { | ||||
| ACC_TYPE mmq_dot_product(const uint ib_a) { | ||||
|     int32_t q_sum = 0; | ||||
|     [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { | ||||
|         const uint32_t vui = cache_a[ib_a].qs[iqs]; | ||||
| @@ -150,14 +150,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { | ||||
|         const int32_t qs_a1 = int32_t((vui >> 4) & 0x0F0F0F0F) | ||||
|                          | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) | ||||
|  | ||||
|         const int32_t qs_b0 = cache_b[ib_b].qs[iqs]; | ||||
|         const int32_t qs_b1 = cache_b[ib_b].qs[iqs + 4]; | ||||
|         const int32_t qs_b0 = cache_b.qs[iqs]; | ||||
|         const int32_t qs_b1 = cache_b.qs[iqs + 4]; | ||||
|  | ||||
|         q_sum += dotPacked4x8EXT(qs_a0, qs_b0); | ||||
|         q_sum += dotPacked4x8EXT(qs_a1, qs_b1); | ||||
|     } | ||||
|  | ||||
|     return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b[ib_b].ds, 1); | ||||
|     return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); | ||||
| } | ||||
| #endif // MMQ_SHMEM | ||||
| #endif | ||||
| @@ -191,16 +191,16 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) { | ||||
|     } | ||||
| } | ||||
|  | ||||
| ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { | ||||
| ACC_TYPE mmq_dot_product(const uint ib_a) { | ||||
|     int32_t q_sum = 0; | ||||
|     [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { | ||||
|         const int32_t qs_a = cache_a[ib_a].qs[iqs]; | ||||
|         const int32_t qs_b = cache_b[ib_b].qs[iqs]; | ||||
|         const int32_t qs_b = cache_b.qs[iqs]; | ||||
|  | ||||
|         q_sum += dotPacked4x8EXT(qs_a, qs_b); | ||||
|     } | ||||
|  | ||||
|     return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b[ib_b].ds, 1); | ||||
|     return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); | ||||
| } | ||||
| #endif // MMQ_SHMEM | ||||
| #endif | ||||
| @@ -247,7 +247,7 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { | ||||
|     buf_a[buf_ib].qs[iqs] = vals0 | (vals1 << 2) | (vals2 << 4) | (vals3 << 6); | ||||
|  | ||||
|     if (iqs == 0) { | ||||
|         buf_a[buf_ib].scales = u8vec2(data_a[ib_k].scales[iqs_k / 4], data_a[ib_k].scales[iqs_k / 4 + 1]); | ||||
|         buf_a[buf_ib].scales = unpack8(data_a_packed16[ib_k].scales[iqs_k / 8]); | ||||
|         buf_a[buf_ib].dm = FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); | ||||
|     } | ||||
| } | ||||
| @@ -261,26 +261,39 @@ void block_a_to_registers(const uint reg_ib, const uint buf_ib) { | ||||
|     } | ||||
| } | ||||
|  | ||||
| ACC_TYPE mmq_dot_product(const uint ib_a, const uint ib_b) { | ||||
| ACC_TYPE mmq_dot_product(const uint ib_a) { | ||||
|     int32_t sum_d = 0; | ||||
|     int32_t sum_m = 0; | ||||
|  | ||||
|     const i32vec2 scales = i32vec2(cache_a[ib_a].scales); | ||||
|     i32vec2 scale_m = scales >> 4; | ||||
|     uint8_t scale = cache_a[ib_a].scales[0]; | ||||
|     int32_t scale_m = int32_t(scale >> 4); | ||||
|     scale_m |= scale_m << 8; | ||||
|     scale_m |= scale_m << 16; | ||||
|  | ||||
|     [[unroll]] for (uint iqs = 0; iqs < 8; iqs++) { | ||||
|         const uint idx_half = iqs / 4; | ||||
|         const uint qs_shift = (iqs % 4) * 2; | ||||
|     [[unroll]] for (uint iqs = 0; iqs < 4; iqs++) { | ||||
|         const uint qs_shift = iqs * 2; | ||||
|  | ||||
|         const int32_t qs_a = int32_t((cache_a[ib_a].qs[idx_half] >> qs_shift) & 0x03030303); | ||||
|         const int32_t qs_a = int32_t((cache_a[ib_a].qs[0] >> qs_shift) & 0x03030303); | ||||
|  | ||||
|         sum_d += dotPacked4x8EXT(qs_a, cache_b[ib_b].qs[iqs]) * (scales[idx_half] & 0xF); | ||||
|         sum_m += dotPacked4x8EXT(scale_m[idx_half], cache_b[ib_b].qs[iqs]); | ||||
|         sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF); | ||||
|         sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); | ||||
|     } | ||||
|  | ||||
|     return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b[ib_b].ds, 1); | ||||
|     scale = cache_a[ib_a].scales[1]; | ||||
|     scale_m = int32_t(scale >> 4); | ||||
|     scale_m |= scale_m << 8; | ||||
|     scale_m |= scale_m << 16; | ||||
|  | ||||
|     [[unroll]] for (uint iqs = 4; iqs < 8; iqs++) { | ||||
|         const uint qs_shift = (iqs - 4) * 2; | ||||
|  | ||||
|         const int32_t qs_a = int32_t((cache_a[ib_a].qs[1] >> qs_shift) & 0x03030303); | ||||
|  | ||||
|         sum_d += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]) * (scale & 0xF); | ||||
|         sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); | ||||
|     } | ||||
|  | ||||
|     return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); | ||||
| } | ||||
| #endif // MMQ_SHMEM | ||||
| #endif | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 0cc4m
					0cc4m