mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-27 08:21:30 +00:00 
			
		
		
		
	kleidiai: kernel interface refactoring (#16460)
This commit is contained in:
		| @@ -29,6 +29,108 @@ | ||||
|  | ||||
| #define NELEMS(x) sizeof(x) / sizeof(*x) | ||||
|  | ||||
| template<size_t(*Fn)(size_t,size_t,size_t)> | ||||
| static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) { | ||||
|     return Fn(a, b, c); | ||||
| } | ||||
|  | ||||
| template<size_t(*Fn)(size_t,size_t)> | ||||
| static inline size_t kernel_offs_fn2(size_t a, size_t b, size_t) { | ||||
|     return Fn(a, b); | ||||
| } | ||||
|  | ||||
| template<void(*Fn)(size_t,size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)> | ||||
| static inline void kernel_run_fn11(size_t m, size_t n, size_t k, size_t bl, | ||||
|                                      const void* lhs, const void* rhs, void* dst, | ||||
|                                      size_t dst_stride_row, size_t dst_stride_col, | ||||
|                                      float clamp_min, float clamp_max) { | ||||
|     Fn(m, n, k, bl, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max); | ||||
| } | ||||
|  | ||||
| template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,void*,size_t,size_t,float,float)> | ||||
| static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/, | ||||
|                                    const void* lhs, const void* rhs, void* dst, | ||||
|                                    size_t dst_stride_row, size_t dst_stride_col, | ||||
|                                    float clamp_min, float clamp_max) { | ||||
|     Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max); | ||||
| } | ||||
|  | ||||
| template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)> | ||||
| static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||||
|     return Fn(m, k, bl, mr, kr, sr); | ||||
| } | ||||
|  | ||||
| template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)> | ||||
| static inline size_t lhs_ps_fn5(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) { | ||||
|     return Fn(m, k, mr, kr, sr); | ||||
| } | ||||
|  | ||||
| template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)> | ||||
| static inline size_t lhs_offs_fn6(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { | ||||
|     return Fn(m_idx, k, bl, mr, kr, sr); | ||||
| } | ||||
|  | ||||
| template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)> | ||||
| static inline size_t lhs_offs_fn5(size_t m_idx, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) { | ||||
|     return Fn(m_idx, k, mr, kr, sr); | ||||
| } | ||||
|  | ||||
| template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)> | ||||
| static inline void lhs_pack_float_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, | ||||
|                                             size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { | ||||
|     Fn(m, k, bl, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed); | ||||
| } | ||||
|  | ||||
| template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)> | ||||
| static inline void lhs_pack_void_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, | ||||
|                                            size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { | ||||
|     Fn(m, k, bl, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed); | ||||
| } | ||||
|  | ||||
| template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const void*,size_t,void*)> | ||||
| static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr, | ||||
|                                              size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { | ||||
|     Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed); | ||||
| } | ||||
|  | ||||
| template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)> | ||||
| static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) { | ||||
|     return Fn(n, k, nr, kr, bl); | ||||
| } | ||||
|  | ||||
| template<size_t(*Fn)(size_t,size_t)> | ||||
| static inline size_t rhs_ps_fn2(size_t n, size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) { | ||||
|     return Fn(n, k); | ||||
| } | ||||
|  | ||||
| template<size_t(*Fn)(size_t,size_t,size_t,size_t)> | ||||
| static inline size_t rhs_stride_fn4(size_t k, size_t nr, size_t kr, size_t bl) { | ||||
|     return Fn(k, nr, kr, bl); | ||||
| } | ||||
|  | ||||
| template<size_t(*Fn)(size_t)> | ||||
| static inline size_t rhs_stride_fn1(size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) { | ||||
|     return Fn(k); | ||||
| } | ||||
|  | ||||
| template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const uint8_t*,const float*,void*,size_t,const struct kai_rhs_pack_qs4cxs1s0_param*)> | ||||
| static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, | ||||
|                                       size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* /*scale*/, | ||||
|                                       void* rhs_packed, size_t extra_bytes, const void* params) { | ||||
|     Fn(num_groups, n, k, nr, kr, sr, bl, | ||||
|        static_cast<const uint8_t*>(rhs), | ||||
|        static_cast<const float*>(bias), | ||||
|        rhs_packed, extra_bytes, | ||||
|        static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params)); | ||||
| } | ||||
|  | ||||
| template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)> | ||||
| static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/, | ||||
|                                                size_t rhs_stride, const void* rhs, const void* bias, const void* scale, | ||||
|                                                void* rhs_packed, size_t extra_bytes, const void* params) { | ||||
|     Fn(num_groups, n, k, nr, kr, sr, rhs_stride, rhs, bias, scale, rhs_packed, extra_bytes, params); | ||||
| } | ||||
|  | ||||
| static const size_t INT4_PER_BYTE = 2; | ||||
| static const size_t INT4_BITS     = 4; | ||||
| static const int Q4_0_ZERO_POINT  = 8; | ||||
| @@ -122,17 +224,18 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, | ||||
|             /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>, | ||||
|             /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>, | ||||
|             /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa>, | ||||
|         }, | ||||
|  | ||||
|         /* .gemm_lhs_info = */ { | ||||
|             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, | ||||
|             /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon, | ||||
|             /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, | ||||
|             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, | ||||
|             /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>, | ||||
|             /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>, | ||||
|             /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>, | ||||
|         }, | ||||
|         /* SME GEMV */ | ||||
|         /* .kern_info = */ { | ||||
| @@ -142,23 +245,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, | ||||
|             /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>, | ||||
|             /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>, | ||||
|             /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot>, | ||||
|         }, | ||||
|         /* .gemv_lhs_info = */ { | ||||
|             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, | ||||
|             /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon, | ||||
|             /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, | ||||
|             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, | ||||
|             /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon>, | ||||
|             /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon>, | ||||
|             /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32_neon>, | ||||
|         }, | ||||
|         /* .rhs_info = */ { | ||||
|             /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, | ||||
|             /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, | ||||
|             /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, | ||||
|             /* .to_float              = */ dequantize_row_qsi4c32ps1s0scalef16, | ||||
|             /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>, | ||||
|             /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>, | ||||
|             /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon>, | ||||
|         }, | ||||
|         /* .required_cpu       = */ CPU_FEATURE_SME, | ||||
|         /* .lhs_type           = */ GGML_TYPE_F32, | ||||
| @@ -174,17 +278,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_lhs_offset_ex     = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>, | ||||
|             /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>, | ||||
|             /* .run_kernel_ex         = */ &kernel_run_fn10<kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa>, | ||||
|         }, | ||||
|         /* .gemm_lhs_info = */ { | ||||
|             /* .get_offset            = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, | ||||
|             /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, | ||||
|             /* .packed_size           = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, | ||||
|             /* .pack_func             = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, | ||||
|             /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>, | ||||
|             /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>, | ||||
|             /* .pack_func_ex          = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>, | ||||
|         }, | ||||
|         /* SME GEMV */ | ||||
|         /* .kern_info = */ { | ||||
| @@ -194,23 +298,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, | ||||
|             /* .get_lhs_offset_ex     = */ nullptr, | ||||
|             /* .get_rhs_packed_offset_ex = */ nullptr, | ||||
|             /* .run_kernel_ex         = */ nullptr, | ||||
|         }, | ||||
|         /* .gemv_lhs_info = */ { | ||||
|             /* .get_offset            = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, | ||||
|             /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, | ||||
|             /* .packed_size           = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, | ||||
|             /* .pack_func             = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, | ||||
|             /* .get_packed_offset_ex  = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme>, | ||||
|             /* .packed_size_ex        = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme>, | ||||
|             /* .pack_func_ex          = */ &lhs_pack_void_fn9<kai_run_lhs_pack_bf16p2vlx2_f32_sme>, | ||||
|         }, | ||||
|         /* .rhs_info = */ { | ||||
|             /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, | ||||
|             /* .packed_stride = */ NULL, | ||||
|             /* .pack_func     = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, | ||||
|             /* .to_float      = */ NULL, | ||||
|             /* .packed_stride         = */ nullptr, | ||||
|             /* .to_float              = */ nullptr, | ||||
|             /* .packed_size_ex        = */ &rhs_ps_fn2<kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>, | ||||
|             /* .packed_stride_ex      = */ &rhs_stride_fn1<kai_get_rhs_packed_stride_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>, | ||||
|             /* .pack_func_ex          = */ &rhs_pack_fn13<kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme>, | ||||
|         }, | ||||
|         /* .required_cpu       = */ CPU_FEATURE_SME, | ||||
|         /* .lhs_type           = */ GGML_TYPE_F32, | ||||
| @@ -229,17 +334,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>, | ||||
|             /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>, | ||||
|             /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>, | ||||
|         }, | ||||
|         /* .gemm_lhs_info = */ { | ||||
|             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>, | ||||
|             /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>, | ||||
|             /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>, | ||||
|         }, | ||||
|         /* DOTPROD GEMV */ | ||||
|         /* .kern_info = */ { | ||||
| @@ -249,23 +354,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>, | ||||
|             /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>, | ||||
|             /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>, | ||||
|         }, | ||||
|         /* .gemv_lhs_info = */ { | ||||
|             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>, | ||||
|             /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>, | ||||
|             /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>, | ||||
|         }, | ||||
|         /* .rhs_info = */ { | ||||
|             /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .to_float              = */ dequantize_row_qsi4c32pscalef16, | ||||
|             /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, | ||||
|             /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, | ||||
|             /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, | ||||
|         }, | ||||
|         /* .required_cpu       = */ CPU_FEATURE_DOTPROD, | ||||
|         /* .lhs_type           = */ GGML_TYPE_F32, | ||||
| @@ -283,17 +389,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>, | ||||
|             /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>, | ||||
|             /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>, | ||||
|         }, | ||||
|         /* .gemm_lhs_info = */ { | ||||
|             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, | ||||
|             /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, | ||||
|             /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, | ||||
|             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, | ||||
|             /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>, | ||||
|             /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>, | ||||
|             /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>, | ||||
|         }, | ||||
|         /* i8mm GEMV */ | ||||
|         /* .kern_info = */ { | ||||
| @@ -303,23 +409,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>, | ||||
|             /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>, | ||||
|             /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>, | ||||
|         }, | ||||
|         /* .gemv_lhs_info = */ { | ||||
|             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>, | ||||
|             /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>, | ||||
|             /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>, | ||||
|         }, | ||||
|         /* .rhs_info = */ { | ||||
|             /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .to_float              = */ dequantize_row_qsi4c32pscalef16, | ||||
|             /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, | ||||
|             /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, | ||||
|             /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, | ||||
|         }, | ||||
|         /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, | ||||
|         /* .lhs_type           = */ GGML_TYPE_F32, | ||||
| @@ -338,17 +445,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, | ||||
|             /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>, | ||||
|             /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>, | ||||
|             /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm>, | ||||
|         }, | ||||
|         /* .gemm_lhs_info = */ { | ||||
|             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, | ||||
|             /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, | ||||
|             /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, | ||||
|             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, | ||||
|             /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>, | ||||
|             /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>, | ||||
|             /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon>, | ||||
|         }, | ||||
|         /* i8mm GEMV */ | ||||
|         /* .kern_info = */ { | ||||
| @@ -358,23 +465,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, | ||||
|             /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>, | ||||
|             /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>, | ||||
|             /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod>, | ||||
|         }, | ||||
|         /* .gemv_lhs_info = */ { | ||||
|             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>, | ||||
|             /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>, | ||||
|             /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>, | ||||
|         }, | ||||
|         /* .rhs_info = */ { | ||||
|             /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .to_float              = */ dequantize_row_qsi4c32pscalef16, | ||||
|             /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, | ||||
|             /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, | ||||
|             /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, | ||||
|         }, | ||||
|         /* .required_cpu       = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, | ||||
|         /* .lhs_type           = */ GGML_TYPE_F32, | ||||
| @@ -392,17 +500,17 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, | ||||
|             /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>, | ||||
|             /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>, | ||||
|             /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod>, | ||||
|         }, | ||||
|         /* .gemm_lhs_info = */ { | ||||
|             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>, | ||||
|             /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>, | ||||
|             /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>, | ||||
|         }, | ||||
|         /* DOTPROD GEMV */ | ||||
|         /* .kern_info = */ { | ||||
| @@ -412,23 +520,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { | ||||
|             /* .get_nr                = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_kr                = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_sr                = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_lhs_offset        = */ kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_rhs_packed_offset = */ kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_dst_offset        = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_dst_size          = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .run_kernel            = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, | ||||
|             /* .get_lhs_offset_ex     = */ &kernel_offs_fn3<kai_get_lhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>, | ||||
|             /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3<kai_get_rhs_packed_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>, | ||||
|             /* .run_kernel_ex         = */ &kernel_run_fn11<kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod>, | ||||
|         }, | ||||
|         /* .gemv_lhs_info = */ { | ||||
|             /* .get_offset            = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset     = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .packed_size           = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .pack_func             = */ kai_run_lhs_quant_pack_qsi8d32p_f32, | ||||
|             /* .get_packed_offset_ex  = */ &lhs_offs_fn6<kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32>, | ||||
|             /* .packed_size_ex        = */ &lhs_ps_fn6<kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32>, | ||||
|             /* .pack_func_ex          = */ &lhs_pack_float_fn10<kai_run_lhs_quant_pack_qsi8d32p_f32>, | ||||
|         }, | ||||
|         /* .rhs_info = */ { | ||||
|             /* .packed_size   = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .packed_stride         = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .pack_func     = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, | ||||
|             /* .to_float              = */ dequantize_row_qsi4c32pscalef16, | ||||
|             /* .packed_size_ex        = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, | ||||
|             /* .packed_stride_ex      = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, | ||||
|             /* .pack_func_ex          = */ &rhs_pack_fn12<kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0>, | ||||
|         }, | ||||
|         /* .required_cpu       = */ CPU_FEATURE_DOTPROD, | ||||
|         /* .lhs_type           = */ GGML_TYPE_F32, | ||||
| @@ -443,6 +552,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c | ||||
|     ggml_kleidiai_kernels * kernel = nullptr; | ||||
|  | ||||
|     if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { | ||||
| #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) | ||||
|         for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { | ||||
|             if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && | ||||
|                 gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && | ||||
| @@ -452,6 +562,7 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c | ||||
|                 break; | ||||
|             } | ||||
|         } | ||||
| #endif | ||||
|     } | ||||
|  | ||||
|     return kernel; | ||||
| @@ -460,12 +571,14 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c | ||||
| ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { | ||||
|     ggml_kleidiai_kernels * kernels = nullptr; | ||||
|  | ||||
| #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) | ||||
|     for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { | ||||
|         if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { | ||||
|             kernels = &gemm_gemv_kernels[i]; | ||||
|             break; | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
|  | ||||
|     return kernels; | ||||
| } | ||||
|   | ||||
| @@ -4,8 +4,6 @@ | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| #include <functional> | ||||
| #include <variant> | ||||
| #include "ggml.h" | ||||
|  | ||||
| enum cpu_feature { | ||||
| @@ -15,6 +13,7 @@ enum cpu_feature { | ||||
|     CPU_FEATURE_SVE     = 4, | ||||
|     CPU_FEATURE_SME     = 8 | ||||
| }; | ||||
|  | ||||
| inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) { | ||||
|     lhs = static_cast<cpu_feature>(lhs | rhs); | ||||
|     return lhs; | ||||
| @@ -30,56 +29,45 @@ struct kernel_info { | ||||
|     size_t (*get_nr)(void); | ||||
|     size_t (*get_kr)(void); | ||||
|     size_t (*get_sr)(void); | ||||
|     std::variant< | ||||
|         std::function<size_t(size_t n_idx, size_t k, size_t bl)>, | ||||
|         std::function<size_t(size_t m_idx, size_t k)> | ||||
|     > get_lhs_offset; | ||||
|     std::variant< | ||||
|         std::function<size_t(size_t n_idx, size_t k, size_t bl)>, | ||||
|         std::function<size_t(size_t n_idx, size_t k)> | ||||
|     > get_rhs_packed_offset; | ||||
|  | ||||
|     size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride); | ||||
|     size_t (*get_dst_size)(size_t m, size_t n); | ||||
|     std::variant< | ||||
|         std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed, | ||||
|             float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>, | ||||
|         std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row, | ||||
|             size_t dst_stride_col, float clamp_min, float clamp_max)> | ||||
|     > run_kernel; | ||||
|  | ||||
|     size_t (*get_lhs_offset_ex)(size_t m_idx, size_t k, size_t bl); | ||||
|  | ||||
|     size_t (*get_rhs_packed_offset_ex)(size_t n_idx, size_t k, size_t bl); | ||||
|  | ||||
|     void (*run_kernel_ex)( | ||||
|         size_t m, size_t n, size_t k, size_t bl, | ||||
|         const void* lhs_packed, const void* rhs_packed, | ||||
|         void* dst, size_t dst_stride_row, size_t dst_stride_col, | ||||
|         float clamp_min, float clamp_max); | ||||
| }; | ||||
|  | ||||
| struct lhs_packing_info { | ||||
|     size_t (*get_offset)(size_t m_idx, size_t lhs_stride); | ||||
|     std::variant< | ||||
|         std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>, | ||||
|         std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)> | ||||
|     > get_packed_offset; | ||||
|     std::variant< | ||||
|         std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>, | ||||
|         std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)> | ||||
|     > packed_size; | ||||
|     std::variant< | ||||
|         std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, | ||||
|             size_t lhs_stride, void* lhs_packed)>, | ||||
|         std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, | ||||
|         void* lhs_packed)> | ||||
|     > pack_func; | ||||
|  | ||||
|     size_t (*get_packed_offset_ex)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); | ||||
|  | ||||
|     size_t (*packed_size_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr); | ||||
|  | ||||
|     void (*pack_func_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, | ||||
|         size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed); | ||||
| }; | ||||
|  | ||||
| struct rhs_packing_info { | ||||
|     std::variant< | ||||
|         std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>, | ||||
|         std::function<size_t(size_t n, size_t k)> | ||||
|     > packed_size; | ||||
|     size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl); | ||||
|     std::variant< | ||||
|         std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs, | ||||
|             const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>, | ||||
|         std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs, | ||||
|             const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)> | ||||
|     > pack_func; | ||||
|     void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride, | ||||
|           size_t kr, size_t bl, size_t num_bytes_multiplier); | ||||
|  | ||||
|     void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, | ||||
|                      size_t nr_pack, size_t packed_row_stride, size_t kr, size_t bl, | ||||
|                      size_t num_bytes_multiplier); | ||||
|  | ||||
|     size_t (*packed_size_ex)(size_t n, size_t k, size_t nr, size_t kr, size_t bl); | ||||
|  | ||||
|     size_t (*packed_stride_ex)(size_t k, size_t nr, size_t kr, size_t bl); | ||||
|  | ||||
|     void (*pack_func_ex)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, | ||||
|         size_t rhs_stride, const void * rhs, const void * bias, const void * scale, void * rhs_packed, size_t extra_bytes, const void * params); | ||||
| }; | ||||
|  | ||||
| struct ggml_kleidiai_kernels { | ||||
|   | ||||
| @@ -8,6 +8,7 @@ | ||||
| #include <stdexcept> | ||||
| #include <stdint.h> | ||||
| #include <string.h> | ||||
| #include <string> | ||||
| #if defined(__linux__) | ||||
| #include <asm/hwcap.h> | ||||
| #include <sys/auxv.h> | ||||
| @@ -87,40 +88,6 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { | ||||
|     return tensor->ne[dim]; | ||||
| } | ||||
|  | ||||
| template <typename Variant, typename Ret, typename... Args, std::size_t... Is> | ||||
| constexpr bool variant_any_invocable_impl(std::index_sequence<Is...>) { | ||||
|     using V = std::remove_reference_t<Variant>; | ||||
|     return (std::is_invocable_r_v< | ||||
|                 Ret, | ||||
|                 std::variant_alternative_t<Is, V>, | ||||
|                 Args...> || ...); | ||||
| } | ||||
|  | ||||
| template <typename Variant, typename Ret, typename... Args> | ||||
| constexpr bool variant_any_invocable_v = | ||||
|     variant_any_invocable_impl<Variant, Ret, Args...>( | ||||
|         std::make_index_sequence< | ||||
|             std::variant_size_v<std::remove_reference_t<Variant>>>{}); | ||||
|  | ||||
| template<typename Ret, typename Variant, typename... Args> | ||||
| static inline Ret variant_call(Variant && var, Args&&... args) { | ||||
|     static_assert(variant_any_invocable_v<std::remove_reference_t<Variant>, Ret, Args...>, | ||||
|                   "No alternative in Variant is invocable with the provided arguments and return type."); | ||||
|  | ||||
|     return std::visit( | ||||
|         [&](auto && f) -> Ret { | ||||
|             using F = std::decay_t<decltype(f)>; | ||||
|             if constexpr (std::is_invocable_r_v<Ret, F, Args...>) { | ||||
|                 return std::invoke(std::forward<decltype(f)>(f), std::forward<Args>(args)...); | ||||
|             } else { | ||||
|                 GGML_ABORT("Invalid function type in variant_call"); | ||||
|                 GGML_UNREACHABLE(); | ||||
|             } | ||||
|         }, | ||||
|         std::forward<Variant>(var) | ||||
|     ); | ||||
| } | ||||
|  | ||||
| namespace ggml::cpu::kleidiai { | ||||
|  | ||||
| static size_t round_down(size_t x, size_t y) { | ||||
| @@ -145,7 +112,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { | ||||
|             return false; | ||||
|         } | ||||
|         ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); | ||||
|         GGML_ASSERT(kernels); | ||||
|         if (!kernels) { | ||||
|             return false; | ||||
|         } | ||||
|         bool is_gemv = op->src[1]->ne[1] == 1; | ||||
|         kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; | ||||
|         lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; | ||||
| @@ -159,16 +128,18 @@ class tensor_traits : public ggml::cpu::tensor_traits { | ||||
|         size_t sr = kernel->get_sr(); | ||||
|  | ||||
|         if (kernels->rhs_type == GGML_TYPE_Q4_0) { | ||||
|             size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr); | ||||
|             if (!lhs_info->packed_size_ex) return false; | ||||
|             size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr); | ||||
|         } else if (kernels->rhs_type == GGML_TYPE_F16) { | ||||
|             if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false; | ||||
|             const int64_t lhs_batch_size0 = op->src[1]->ne[2]; | ||||
|             const int64_t rhs_batch_size0 = op->src[0]->ne[2]; | ||||
|             const int64_t r = lhs_batch_size0 / rhs_batch_size0; | ||||
|             size = variant_call<size_t>(lhs_info->packed_size, m * r, k, mr, kr, sr) + | ||||
|                    variant_call<size_t>(kernels->rhs_info.packed_size, n, k) + | ||||
|             size = lhs_info->packed_size_ex(m * r, k, 0, mr, kr, sr) + | ||||
|                    kernels->rhs_info.packed_size_ex(n, k, kernel->get_nr(), kernel->get_kr(), 0) + | ||||
|                    k * n * sizeof(float) + n * sizeof(float); | ||||
|         } else { | ||||
|             GGML_ASSERT(false); | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         return true; | ||||
| @@ -196,12 +167,18 @@ class tensor_traits : public ggml::cpu::tensor_traits { | ||||
|         GGML_TENSOR_BINARY_OP_LOCALS | ||||
|  | ||||
|         ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); | ||||
|         GGML_ASSERT(kernels); | ||||
|         if (!kernels) { | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         const bool is_gemv = src1->ne[1] == 1; | ||||
|         kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; | ||||
|         lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; | ||||
|         GGML_ASSERT(kernel); | ||||
|         if (!kernels->rhs_info.pack_func_ex || | ||||
|             !kernel->get_lhs_offset_ex || !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex) { | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         const int nth = params->nth; | ||||
|         const int ith = params->ith; | ||||
| @@ -228,10 +205,10 @@ class tensor_traits : public ggml::cpu::tensor_traits { | ||||
|         const int64_t kr = (int64_t) kernel->get_kr(); | ||||
|         const int64_t sr = (int64_t) kernel->get_sr(); | ||||
|  | ||||
|         const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); | ||||
|         const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, (size_t)n, (size_t)k); | ||||
|         const size_t kxn_size        = (size_t)k * (size_t)n * sizeof(float); | ||||
|         const size_t bias_size       = (size_t)n * sizeof(float); | ||||
|         const size_t lhs_packed_size = lhs_info->packed_size_ex(m, k, 0, mr, kr, sr); | ||||
|         const size_t rhs_packed_size = kernels->rhs_info.packed_size_ex(n, k, nr, kr, 0); | ||||
|         const size_t kxn_size        = k * n * sizeof(float); | ||||
|         const size_t bias_size       = n * sizeof(float); | ||||
|  | ||||
|         const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size; | ||||
|         GGML_ASSERT(wsize_required <= params->wsize); | ||||
| @@ -259,10 +236,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { | ||||
|                     const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; | ||||
|  | ||||
|                     // Base packed offset (aligned) and per-row stride in bytes | ||||
|                     const size_t base_packed_off = variant_call<size_t>( | ||||
|                         lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); | ||||
|                     const size_t next_block_off = variant_call<size_t>( | ||||
|                         lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); | ||||
|                     const size_t base_packed_off  = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr); | ||||
|                     const size_t next_block_off   = lhs_info->get_packed_offset_ex(m_start + mr, k, 0, mr, kr, sr); | ||||
|                     const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr; | ||||
|  | ||||
|                     int64_t remaining = m_count; | ||||
| @@ -278,9 +253,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { | ||||
|                         const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes; | ||||
|                         void * dst_ptr       = lhs_packed + dst_off; | ||||
|  | ||||
|                         variant_call<void>(lhs_info->pack_func, | ||||
|                                         (size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr, | ||||
|                                         /*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr); | ||||
|                         lhs_info->pack_func_ex(take, k, 0, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); | ||||
|  | ||||
|                         cur       += take; | ||||
|                         remaining -= take; | ||||
| @@ -296,10 +269,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { | ||||
|                                         reinterpret_cast<const uint16_t *>(rhs_batch_base), | ||||
|                                         rhs_stride); | ||||
|  | ||||
|                 variant_call<void>(kernels->rhs_info.pack_func, | ||||
|                                    /*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr, | ||||
|                                    /*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)), | ||||
|                                    rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr); | ||||
|                 kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, n * sizeof(float), | ||||
|                              rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr); | ||||
|             } | ||||
|  | ||||
|             ggml_barrier(params->threadpool); | ||||
| @@ -320,20 +291,15 @@ class tensor_traits : public ggml::cpu::tensor_traits { | ||||
|                     const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0; | ||||
|  | ||||
|                     // LHS packed base at row 0 (consistent with packing above) | ||||
|                     const size_t lhs_packed_offset0 = variant_call<size_t>( | ||||
|                         lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); | ||||
|                     const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k); | ||||
|                     const size_t lhs_packed_offset0 = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr); | ||||
|                     const size_t rhs_packed_offset  = kernel->get_rhs_packed_offset_ex(n_start, k, 0); | ||||
|                     const size_t dst_offset         = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride); | ||||
|  | ||||
|                     const void * lhs_ptr = lhs_packed + lhs_packed_offset0; | ||||
|                     const void * rhs_ptr = rhs_packed + rhs_packed_offset; | ||||
|                     float * dst_ptr      = reinterpret_cast<float *>(dst_batch_base + dst_offset); | ||||
|  | ||||
|                     variant_call<void>(kernel->run_kernel, | ||||
|                                        (size_t)m, (size_t)n_to_process, (size_t)k, | ||||
|                                        lhs_ptr, rhs_ptr, | ||||
|                                        dst_ptr, dst_stride, sizeof(float), | ||||
|                                        -FLT_MAX, FLT_MAX); | ||||
|                     kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
| @@ -354,13 +320,19 @@ class tensor_traits : public ggml::cpu::tensor_traits { | ||||
|         GGML_TENSOR_BINARY_OP_LOCALS | ||||
|  | ||||
|         ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); | ||||
|         GGML_ASSERT(kernels); | ||||
|         if (!kernels) { | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         bool is_gemv = src1->ne[1] == 1; | ||||
|         kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; | ||||
|         lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; | ||||
|  | ||||
|         GGML_ASSERT(kernel); | ||||
|         if (!lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex || | ||||
|             !kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) { | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         const int ith = params->ith; | ||||
|         const int nth_raw = params->nth; | ||||
| @@ -402,25 +374,26 @@ class tensor_traits : public ggml::cpu::tensor_traits { | ||||
|             // Transform LHS | ||||
|             const size_t src_stride        = src1->nb[1]; | ||||
|             const float * src_ptr          = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1])); | ||||
|             const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, QK4_0, mr, kr, sr); | ||||
|             const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, QK4_0, mr, kr, sr); | ||||
|             void * lhs_packed_ptr          = static_cast<void *>(lhs_packed + lhs_packed_offset); | ||||
|  | ||||
|             variant_call<void>(lhs_info->pack_func, m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); | ||||
|             // Pack this thread's chunk with m_idx_start = 0 and per-thread output pointer | ||||
|             lhs_info->pack_func_ex(m_to_process, k, QK4_0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr); | ||||
|         } | ||||
|  | ||||
|         ggml_barrier(params->threadpool); | ||||
|  | ||||
|         // Perform the operation | ||||
|         const size_t dst_stride        = dst->nb[1]; | ||||
|         const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, 0, k, QK4_0, mr, kr, sr); | ||||
|         const size_t rhs_packed_offset = variant_call<size_t>(kernel->get_rhs_packed_offset, n_start, k, QK4_0); | ||||
|         const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, QK4_0, mr, kr, sr); | ||||
|         const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, QK4_0); | ||||
|         const size_t dst_offset        = kernel->get_dst_offset(0, n_start, dst_stride); | ||||
|         const void * rhs_ptr           = static_cast<const void *>(rhs_packed + rhs_packed_offset); | ||||
|         const void* lhs_ptr            = (const void*)((const char *)lhs_packed + lhs_packed_offset); | ||||
|         float *dst_ptr                 = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset); | ||||
|  | ||||
|         if (n_to_process > 0) { | ||||
|             variant_call<void>(kernel->run_kernel, m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, | ||||
|             kernel->run_kernel_ex(m, n_to_process, k, QK4_0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, | ||||
|                                sizeof(float), -FLT_MAX, FLT_MAX); | ||||
|         } | ||||
|  | ||||
| @@ -429,7 +402,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { | ||||
|  | ||||
|     bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) { | ||||
|         GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0); | ||||
|         GGML_ASSERT(ctx.kernels); | ||||
|         if (!ctx.kernels) { | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         const ggml_tensor * src0 = dst->src[0]; | ||||
|         const ggml_tensor * src1 = dst->src[1]; | ||||
| @@ -438,6 +413,9 @@ class tensor_traits : public ggml::cpu::tensor_traits { | ||||
|  | ||||
|         rhs_packing_info * rhs_info = &ctx.kernels->rhs_info; | ||||
|         kernel_info * kernel        = &ctx.kernels->gemm; | ||||
|         if (!rhs_info->to_float || !kernel->get_nr) { | ||||
|             return false; | ||||
|         } | ||||
|  | ||||
|         const int64_t nc     = ne00; | ||||
|         const int64_t nr     = ggml_nelements(src1); | ||||
| @@ -480,7 +458,7 @@ public: | ||||
|         struct kai_rhs_pack_qs4cxs1s0_param params; | ||||
|         params.lhs_zero_point = 1; | ||||
|         params.rhs_zero_point = 8; | ||||
|         variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms); | ||||
|         ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, ¶ms); | ||||
|  | ||||
|         return 0; | ||||
|         GGML_UNUSED(data_size); | ||||
| @@ -548,7 +526,7 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_ | ||||
|     const size_t nr = ctx.kernels->gemm.get_nr(); | ||||
|     const size_t kr = ctx.kernels->gemm.get_kr(); | ||||
|  | ||||
|     return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0); | ||||
|     return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0); | ||||
|  | ||||
|     GGML_UNUSED(buft); | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Charles Xu
					Charles Xu