// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates // SPDX-License-Identifier: MIT // // KleidiAI micro-kernels #include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h" #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" #include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" #include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" #include "kai_lhs_quant_pack_qsi8d32p_f32.h" #include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h" #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" #include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" #include "kai_common.h" #include "simd-mappings.h" #include "kernels.h" #define NELEMS(x) sizeof(x) / sizeof(*x) template static inline size_t kernel_offs_fn3(size_t a, size_t b, size_t c) { return Fn(a, b, c); } template static inline size_t kernel_offs_fn2(size_t a, size_t b, size_t) { return Fn(a, b); } template static inline void kernel_run_fn11(size_t m, size_t n, size_t k, size_t bl, const void* lhs, const void* rhs, void* dst, size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max) { Fn(m, n, k, bl, lhs, rhs, static_cast(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max); } template static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/, const void* lhs, const void* rhs, void* dst, size_t dst_stride_row, size_t dst_stride_col, float clamp_min, float clamp_max) { Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max); } template static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { return Fn(m, k, bl, mr, kr, sr); } template static inline size_t lhs_ps_fn5(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) { return Fn(m, k, mr, kr, sr); } template static inline size_t lhs_offs_fn6(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) { return Fn(m_idx, k, bl, mr, kr, sr); } template static inline size_t lhs_offs_fn5(size_t m_idx, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr) { return Fn(m_idx, k, mr, kr, sr); } template static inline void lhs_pack_float_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { Fn(m, k, bl, mr, kr, sr, m_idx_start, static_cast(lhs), lhs_stride, lhs_packed); } template static inline void lhs_pack_void_fn10(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { Fn(m, k, bl, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed); } template static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride, void* lhs_packed) { Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed); } template static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) { return Fn(n, k, nr, kr, bl); } template static inline size_t rhs_ps_fn2(size_t n, size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) { return Fn(n, k); } template static inline size_t rhs_stride_fn4(size_t k, size_t nr, size_t kr, size_t bl) { return Fn(k, nr, kr, bl); } template static inline size_t rhs_stride_fn1(size_t k, size_t /*nr*/, size_t /*kr*/, size_t /*bl*/) { return Fn(k); } template static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* /*scale*/, void* rhs_packed, size_t extra_bytes, const void* params) { Fn(num_groups, n, k, nr, kr, sr, bl, static_cast(rhs), static_cast(bias), rhs_packed, extra_bytes, static_cast(params)); } template static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/, size_t rhs_stride, const void* rhs, const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params) { Fn(num_groups, n, k, nr, kr, sr, rhs_stride, rhs, bias, scale, rhs_packed, extra_bytes, params); } static const size_t INT4_PER_BYTE = 2; static const size_t INT4_BITS = 4; static const int Q4_0_ZERO_POINT = 8; const size_t INT4_PER_UINT16 = 4; static void dequantize_row_qsi4c32pscalef16( const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride, size_t kr, size_t bl, size_t num_bytes_multiplier ) { size_t group_idx = row_idx / nr_pack; size_t row_in_group = row_idx % nr_pack; const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride; size_t num_blocks = nc / bl; const uint8_t *block_ptr = packed_group; for (size_t b = 0; b < num_blocks; ++b) { uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier)); float scale = GGML_CPU_FP16_TO_FP32(scale_f16); const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier; size_t num_segments = bl / kr; size_t num_bytes_per_segment = kr / INT4_PER_BYTE; for (size_t s = 0; s < num_segments; ++s) { const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment; const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment; for (size_t k = 0; k < num_bytes_per_segment; ++k) { uint8_t byte = qbytes[k] ^ 0x88; int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT; int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT; out[b * bl + s * num_bytes_per_segment + k] = x0 * scale; out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale; } } block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment; } } static void dequantize_row_qsi4c32ps1s0scalef16( const void *packed_data, int32_t row_idx, int64_t k, float *out, size_t nr, size_t packed_row_stride, size_t kr, size_t bl, size_t num_bytes_multiplier ) { const size_t num_blocks = k / bl; const size_t bl4 = bl / INT4_PER_UINT16; size_t group_idx = row_idx / nr; size_t row_in_group = row_idx % nr; const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride; const uint16_t *qdata = (const uint16_t *)packed_group; const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier)); for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) { uint16_t scale_f16 = scales[row_in_group + block_idx * nr]; float scale = GGML_CPU_FP16_TO_FP32(scale_f16); for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) { uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group]; for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) { int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT; out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale; } } } GGML_UNUSED(kr); } static ggml_kleidiai_kernels gemm_gemv_kernels[] = { #if defined(__ARM_FEATURE_SME) { /* SME GEMM */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, /* .get_packed_offset_ex = */ &lhs_offs_fn6, /* .packed_size_ex = */ &lhs_ps_fn6, /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* SME GEMV */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, /* .get_packed_offset_ex = */ &lhs_offs_fn6, /* .packed_size_ex = */ &lhs_ps_fn6, /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16, /* .packed_size_ex = */ &rhs_ps_fn5, /* .packed_stride_ex = */ &rhs_stride_fn4, /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_SME, /* .lhs_type = */ GGML_TYPE_F32, /* .rhs_type = */ GGML_TYPE_Q4_0, /* .op_type = */ GGML_TYPE_F32, }, { /* SME GEMM */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_lhs_offset_ex = */ &kernel_offs_fn2, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2, /* .run_kernel_ex = */ &kernel_run_fn10, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, /* .get_packed_offset_ex = */ &lhs_offs_fn5, /* .packed_size_ex = */ &lhs_ps_fn5, /* .pack_func_ex = */ &lhs_pack_void_fn9, }, /* SME GEMV */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_mr = */ kai_get_mr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_nr = */ kai_get_nr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_lhs_offset_ex = */ nullptr, /* .get_rhs_packed_offset_ex = */ nullptr, /* .run_kernel_ex = */ nullptr, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, /* .get_packed_offset_ex = */ &lhs_offs_fn5, /* .packed_size_ex = */ &lhs_ps_fn5, /* .pack_func_ex = */ &lhs_pack_void_fn9, }, /* .rhs_info = */ { /* .packed_stride = */ nullptr, /* .to_float = */ nullptr, /* .packed_size_ex = */ &rhs_ps_fn2, /* .packed_stride_ex = */ &rhs_stride_fn1, /* .pack_func_ex = */ &rhs_pack_fn13, }, /* .required_cpu = */ CPU_FEATURE_SME, /* .lhs_type = */ GGML_TYPE_F32, /* .rhs_type = */ GGML_TYPE_F16, /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__APPLE__) #if defined(__ARM_FEATURE_DOTPROD) { /* DOTPROD GEMM */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset_ex = */ &lhs_offs_fn6, /* .packed_size_ex = */ &lhs_ps_fn6, /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* DOTPROD GEMV */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset_ex = */ &lhs_offs_fn6, /* .packed_size_ex = */ &lhs_ps_fn6, /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, /* .to_float = */ dequantize_row_qsi4c32pscalef16, /* .packed_size_ex = */ &rhs_ps_fn5, /* .packed_stride_ex = */ &rhs_stride_fn4, /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, /* .rhs_type = */ GGML_TYPE_Q4_0, /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__ARM_FEATURE_MATMUL_INT8) { /* i8mm GEMM */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, /* .get_packed_offset_ex = */ &lhs_offs_fn6, /* .packed_size_ex = */ &lhs_ps_fn6, /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* i8mm GEMV */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset_ex = */ &lhs_offs_fn6, /* .packed_size_ex = */ &lhs_ps_fn6, /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, /* .to_float = */ dequantize_row_qsi4c32pscalef16, /* .packed_size_ex = */ &rhs_ps_fn5, /* .packed_stride_ex = */ &rhs_stride_fn4, /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, /* .rhs_type = */ GGML_TYPE_Q4_0, /* .op_type = */ GGML_TYPE_F32, }, #endif #else #if defined(__ARM_FEATURE_MATMUL_INT8) { /* i8mm GEMM */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon, /* .get_packed_offset_ex = */ &lhs_offs_fn6, /* .packed_size_ex = */ &lhs_ps_fn6, /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* i8mm GEMV */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset_ex = */ &lhs_offs_fn6, /* .packed_size_ex = */ &lhs_ps_fn6, /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, /* .to_float = */ dequantize_row_qsi4c32pscalef16, /* .packed_size_ex = */ &rhs_ps_fn5, /* .packed_stride_ex = */ &rhs_stride_fn4, /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, /* .rhs_type = */ GGML_TYPE_Q4_0, /* .op_type = */ GGML_TYPE_F32, }, #endif #if defined(__ARM_FEATURE_DOTPROD) { /* DOTPROD GEMM */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset_ex = */ &lhs_offs_fn6, /* .packed_size_ex = */ &lhs_ps_fn6, /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* DOTPROD GEMV */ /* .kern_info = */ { /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_lhs_offset_ex = */ &kernel_offs_fn3, /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemv_lhs_info = */ { /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset_ex = */ &lhs_offs_fn6, /* .packed_size_ex = */ &lhs_ps_fn6, /* .pack_func_ex = */ &lhs_pack_float_fn10, }, /* .rhs_info = */ { /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, /* .to_float = */ dequantize_row_qsi4c32pscalef16, /* .packed_size_ex = */ &rhs_ps_fn5, /* .packed_stride_ex = */ &rhs_stride_fn4, /* .pack_func_ex = */ &rhs_pack_fn12, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, /* .rhs_type = */ GGML_TYPE_Q4_0, /* .op_type = */ GGML_TYPE_F32, }, #endif #endif }; ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) { ggml_kleidiai_kernels * kernel = nullptr; if (tensor->op == GGML_OP_MUL_MAT && tensor->src[0] != nullptr && tensor->src[1] != nullptr) { #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { if ((cpu_features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu && gemm_gemv_kernels[i].lhs_type == tensor->src[1]->type && gemm_gemv_kernels[i].rhs_type == tensor->src[0]->type && gemm_gemv_kernels[i].op_type == tensor->type) { kernel = &gemm_gemv_kernels[i]; break; } } #endif } return kernel; } ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features) { ggml_kleidiai_kernels * kernels = nullptr; #if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8) for (size_t i = 0; i < NELEMS(gemm_gemv_kernels); ++i) { if ((features & gemm_gemv_kernels[i].required_cpu) == gemm_gemv_kernels[i].required_cpu) { kernels = &gemm_gemv_kernels[i]; break; } } #endif return kernels; }