kleidiai: add optimized per-channel kernels for Q8_0 (#16993)

This commit is contained in:
Charles Xu
2025-11-11 12:20:31 +01:00
committed by GitHub
parent 4a5b8aff40
commit 8c583242ad
4 changed files with 538 additions and 41 deletions

View File

@@ -590,6 +590,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
${KLEIDIAI_SRC}/kai/ukernels/
${KLEIDIAI_SRC}/kai/ukernels/matmul/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
@@ -608,23 +609,34 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c)
if (NOT DOTPROD_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c)
endif()
if (NOT I8MM_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c)
endif()
if (NOT SME_ENABLED MATCHES -1)
list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c

View File

@@ -4,6 +4,7 @@
// KleidiAI micro-kernels
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
#include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
@@ -11,20 +12,31 @@
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
#include "kai_lhs_quant_pack_qai8dxp_f32.h"
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
#include "kai_common.h"
#include "simd-mappings.h"
#define GGML_COMMON_DECL_CPP
#include "ggml-common.h"
#include "kernels.h"
#define NELEMS(x) sizeof(x) / sizeof(*x)
@@ -55,6 +67,14 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
}
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
const void* lhs, const void* rhs, void* dst,
size_t dst_stride_row, size_t dst_stride_col,
float clamp_min, float clamp_max) {
Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
}
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
return Fn(m, k, bl, mr, kr, sr);
@@ -93,6 +113,12 @@ static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t m
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {
Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
}
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
return Fn(n, k, nr, kr, bl);
@@ -124,6 +150,18 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
void* rhs_packed, size_t extra_bytes, const void* params) {
Fn(num_groups, n, k, nr, kr, sr,
static_cast<const int8_t*>(rhs),
static_cast<const float*>(bias),
static_cast<const float*>(scale),
rhs_packed, extra_bytes,
static_cast<const kai_rhs_pack_qsi8cx_params*>(params));
}
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
@@ -213,6 +251,57 @@ static void dequantize_row_qsi4c32ps1s0scalef16(
GGML_UNUSED(kr);
}
static void dequantize_row_qsi8cxp(
const void *packed_data,
int32_t row_idx,
int64_t k,
float *out,
size_t nr,
size_t packed_row_stride,
size_t kr,
size_t bl,
size_t num_bytes_multiplier
) {
GGML_UNUSED(bl);
GGML_UNUSED(num_bytes_multiplier);
const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;
const size_t group_idx = row_idx / nr;
const size_t row_in_group = row_idx % nr;
const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;
const int8_t * data_base = reinterpret_cast<const int8_t *>(group_ptr);
const size_t num_blocks = k_internal / kr;
for (size_t block = 0; block < num_blocks; ++block) {
const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;
for (size_t i = 0; i < kr; ++i) {
const size_t k_idx = block * kr + i;
if (k_idx < (size_t) k) {
out[k_idx] = static_cast<float>(block_ptr[i]);
}
}
}
const uint8_t * sums_ptr = group_ptr + nr * k_internal;
GGML_UNUSED(sums_ptr);
const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));
const float scale = scale_ptr[row_in_group];
if (scale == 0.0f) {
for (size_t i = 0; i < (size_t) k; ++i) {
out[i] = 0.0f;
}
return;
}
for (size_t i = 0; i < (size_t) k; ++i) {
out[i] *= scale;
}
}
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
#if defined(__ARM_FEATURE_SME)
{
@@ -548,6 +637,174 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
#endif
};
static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
#if defined(__ARM_FEATURE_SME)
{
/* SME GEMM */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* SME GEMV */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* .rhs_info = */ {
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
/* .to_float = */ dequantize_row_qsi8cxp,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
},
/* .required_cpu = */ CPU_FEATURE_SME,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q8_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#if defined(__ARM_FEATURE_MATMUL_INT8)
{
/* I8MM GEMM */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* I8MM GEMV (dotprod fallback) */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* .rhs_info = */ {
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
/* .to_float = */ dequantize_row_qsi8cxp,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q8_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
#if defined(__ARM_FEATURE_DOTPROD)
{
/* DOTPROD GEMM */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
},
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* DOTPROD GEMV */
{
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
},
/* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
},
/* .rhs_info = */ {
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
/* .to_float = */ dequantize_row_qsi8cxp,
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,
/* .rhs_type = */ GGML_TYPE_Q8_0,
/* .op_type = */ GGML_TYPE_F32,
},
#endif
};
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
ggml_kleidiai_kernels * kernel = nullptr;
@@ -562,6 +819,17 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
break;
}
}
if (!kernel) {
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
gemm_gemv_kernels_q8[i].op_type == tensor->type) {
kernel = &gemm_gemv_kernels_q8[i];
break;
}
}
}
#endif
}
@@ -582,3 +850,18 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
return kernels;
}
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {
ggml_kleidiai_kernels * kernels = nullptr;
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
kernels = &gemm_gemv_kernels_q8[i];
break;
}
}
#endif
return kernels;
}

View File

@@ -87,3 +87,4 @@ struct ggml_kleidiai_kernels {
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features);

View File

@@ -5,10 +5,13 @@
#include <assert.h>
#include <atomic>
#include <cfloat>
#include <cmath>
#include <algorithm>
#include <stdexcept>
#include <stdint.h>
#include <string.h>
#include <string>
#include <vector>
#if defined(__linux__)
#include <asm/hwcap.h>
#include <sys/auxv.h>
@@ -38,8 +41,9 @@
struct ggml_kleidiai_context {
cpu_feature features;
ggml_kleidiai_kernels * kernels;
} static ctx = { CPU_FEATURE_NONE, NULL };
ggml_kleidiai_kernels * kernels_q4;
ggml_kleidiai_kernels * kernels_q8;
} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
static const char* cpu_feature_to_string(cpu_feature f) {
switch (f) {
@@ -73,10 +77,14 @@ static void init_kleidiai_context(void) {
if (sme_enabled != 0) {
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
}
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
#ifndef NDEBUG
if (ctx.kernels) {
GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
if (ctx.kernels_q4) {
GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
}
if (ctx.kernels_q8) {
GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
}
#endif
}
@@ -130,6 +138,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
if (!lhs_info->packed_size_ex) return false;
size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
if (!lhs_info->packed_size_ex) return false;
size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_F16) {
if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
@@ -149,11 +160,13 @@ class tensor_traits : public ggml::cpu::tensor_traits {
if (dst->op == GGML_OP_MUL_MAT) {
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
return compute_forward_q4_0(params, dst);
} else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
return compute_forward_q8_0(params, dst);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
return compute_forward_fp16(params, dst);
}
} else if (dst->op == GGML_OP_GET_ROWS) {
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
return compute_forward_get_rows(params, dst);
}
}
@@ -400,19 +413,120 @@ class tensor_traits : public ggml::cpu::tensor_traits {
return true;
}
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
if (!ctx.kernels) {
return false;
}
bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
kernel_info * kernel = &ctx.kernels->gemm;
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
if (!kernels) {
return false;
}
bool is_gemv = src1->ne[1] == 1;
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
!kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
return false;
}
const int ith = params->ith;
const int nth_raw = params->nth;
const int nth = nth_raw > 0 ? nth_raw : 1;
const size_t k = ne00;
const size_t m = ne11;
const size_t n = ne01;
size_t mr = kernel->get_mr();
size_t kr = kernel->get_kr();
size_t sr = kernel->get_sr();
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
const size_t n_step = kernel->get_n_step();
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
const size_t n_start = ith * num_n_per_thread;
size_t n_to_process = 0;
if (n_start < n) {
n_to_process = num_n_per_thread;
if ((n_start + n_to_process) > n) {
n_to_process = n - n_start;
}
}
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
const size_t m_start = ith * num_m_per_thread;
size_t m_to_process = num_m_per_thread;
if ((m_start + m_to_process) > m) {
m_to_process = m - m_start;
}
if (m_start < m) {
const size_t src_stride = src1->nb[1];
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
}
ggml_barrier(params->threadpool);
const size_t dst_stride = dst->nb[1];
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset);
float * dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
if (n_to_process > 0) {
kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
sizeof(float), -FLT_MAX, FLT_MAX);
}
return true;
}
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
ggml_kleidiai_kernels * kernels = nullptr;
size_t block_len = 0;
size_t num_bytes_multiplier = 0;
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
if (!ctx.kernels_q4) {
return false;
}
kernels = ctx.kernels_q4;
block_len = QK4_0;
num_bytes_multiplier = sizeof(uint16_t);
} else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
if (!ctx.kernels_q8) {
return false;
}
kernels = ctx.kernels_q8;
block_len = QK8_0;
num_bytes_multiplier = sizeof(float);
} else {
return false;
}
rhs_packing_info * rhs_info = &kernels->rhs_info;
kernel_info * kernel = &kernels->gemm;
if (!rhs_info->to_float || !kernel->get_nr) {
return false;
}
@@ -423,8 +537,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const size_t block_rows = kernel->get_nr();
const size_t kr = kernel->get_kr();
const size_t num_bytes_multiplier = sizeof(uint16_t);
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);
const int ith = params->ith;
const int nth = params->nth;
@@ -439,7 +552,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
float *out = (float *)((char *)dst->data + i * nb1);
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
}
return true;
@@ -447,21 +560,91 @@ class tensor_traits : public ggml::cpu::tensor_traits {
public:
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
GGML_ASSERT(ctx.kernels);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
size_t nr = ctx.kernels->gemm.get_nr();
size_t kr = ctx.kernels->gemm.get_kr();
size_t sr = ctx.kernels->gemm.get_sr();
struct kai_rhs_pack_qs4cxs1s0_param params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, &params);
if (tensor->type == GGML_TYPE_Q4_0) {
if (!ctx.kernels_q4) {
return -1;
}
size_t nr = ctx.kernels_q4->gemm.get_nr();
size_t kr = ctx.kernels_q4->gemm.get_kr();
size_t sr = ctx.kernels_q4->gemm.get_sr();
struct kai_rhs_pack_qs4cxs1s0_param params;
params.lhs_zero_point = 1;
params.rhs_zero_point = 8;
ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
static_cast<const uint8_t *>(data),
nullptr, nullptr, tensor->data, 0, &params);
GGML_UNUSED(data_size);
return 0;
} else if (tensor->type == GGML_TYPE_Q8_0) {
if (!ctx.kernels_q8) {
return -1;
}
const size_t row_stride = tensor->nb[1];
const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
std::vector<int8_t> qdata(n * k, 0);
std::vector<float> scales(n, 0.0f);
for (size_t row = 0; row < n; ++row) {
const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
static_cast<const uint8_t *>(data) + row * row_stride);
float max_abs = 0.0f;
for (size_t block = 0; block < k_blocks; ++block) {
const block_q8_0 & blk = row_blocks[block];
const float d = GGML_FP16_TO_FP32(blk.d);
for (size_t l = 0; l < QK8_0; ++l) {
const size_t linear_idx = block * QK8_0 + l;
if (linear_idx >= k) {
break;
}
const float value = d * blk.qs[l];
max_abs = std::max(max_abs, std::fabs(value));
}
}
float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
scales[row] = scale;
const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;
for (size_t block = 0; block < k_blocks; ++block) {
const block_q8_0 & blk = row_blocks[block];
const float d = GGML_FP16_TO_FP32(blk.d);
for (size_t l = 0; l < QK8_0; ++l) {
const size_t linear_idx = block * QK8_0 + l;
if (linear_idx >= k) {
break;
}
const float value = d * blk.qs[l];
int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
q = std::clamp(q, -127, 127);
qdata[row * k + linear_idx] = static_cast<int8_t>(q);
}
}
}
size_t nr = ctx.kernels_q8->gemm.get_nr();
size_t kr = ctx.kernels_q8->gemm.get_kr();
size_t sr = ctx.kernels_q8->gemm.get_sr();
struct kai_rhs_pack_qsi8cx_params params;
params.lhs_zero_point = 1;
params.scale_multiplier = 1.0f;
ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
qdata.data(), nullptr, scales.data(),
tensor->data, 0, &params);
GGML_UNUSED(data_size);
return 0;
}
return 0;
GGML_UNUSED(data_size);
return -1;
}
};
@@ -518,27 +701,45 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
}
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
GGML_ASSERT(ctx.kernels);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
const size_t nr = ctx.kernels->gemm.get_nr();
const size_t kr = ctx.kernels->gemm.get_kr();
return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0);
GGML_UNUSED(buft);
const size_t n = tensor->ne[1];
const size_t k = tensor->ne[0];
ggml_kleidiai_kernels * kernels = nullptr;
size_t block_len = 0;
if (tensor->type == GGML_TYPE_Q4_0) {
GGML_ASSERT(ctx.kernels_q4);
kernels = ctx.kernels_q4;
block_len = QK4_0;
} else if (tensor->type == GGML_TYPE_Q8_0) {
GGML_ASSERT(ctx.kernels_q8);
kernels = ctx.kernels_q8;
block_len = QK8_0;
} else {
return 0;
}
const size_t nr = kernels->gemm.get_nr();
const size_t kr = kernels->gemm.get_kr();
const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
const size_t raw = ggml_nbytes(tensor);
return packed > raw ? packed : raw;
}
namespace ggml::cpu::kleidiai {
class extra_buffer_type : ggml::cpu::extra_buffer_type {
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
op->src[0]->type == GGML_TYPE_Q4_0 &&
(op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
op->src[0]->buffer &&
(ggml_n_dims(op->src[0]) == 2) &&
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
return false;
}
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false;
}