mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-14 11:07:10 +00:00
kleidiai: add optimized per-channel kernels for Q8_0 (#16993)
This commit is contained in:
@@ -590,6 +590,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||||||
${KLEIDIAI_SRC}/kai/ukernels/
|
${KLEIDIAI_SRC}/kai/ukernels/
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
|
||||||
|
|
||||||
@@ -608,23 +609,34 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
|
|||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.c)
|
||||||
|
|
||||||
if (NOT DOTPROD_ENABLED MATCHES -1)
|
if (NOT DOTPROD_ENABLED MATCHES -1)
|
||||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c)
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.c)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (NOT I8MM_ENABLED MATCHES -1)
|
if (NOT I8MM_ENABLED MATCHES -1)
|
||||||
list(APPEND GGML_KLEIDIAI_SOURCES ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c)
|
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.c)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (NOT SME_ENABLED MATCHES -1)
|
if (NOT SME_ENABLED MATCHES -1)
|
||||||
list(APPEND GGML_KLEIDIAI_SOURCES
|
list(APPEND GGML_KLEIDIAI_SOURCES
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.c
|
||||||
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S
|
||||||
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
|
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
|
||||||
|
|||||||
@@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
// KleidiAI micro-kernels
|
// KleidiAI micro-kernels
|
||||||
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
#include "kai_matmul_clamp_f32_qsi8d32p_qsi4c32p_interface.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp_qsi8cxp_interface.h"
|
||||||
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
#include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod.h"
|
||||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h"
|
||||||
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
#include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h"
|
||||||
@@ -11,20 +12,31 @@
|
|||||||
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h"
|
||||||
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
#include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h"
|
||||||
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
|
#include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod.h"
|
||||||
|
#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h"
|
||||||
|
|
||||||
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
|
||||||
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
#include "kai_lhs_quant_pack_qsi8d32p_f32.h"
|
||||||
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
|
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
|
||||||
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
|
||||||
|
#include "kai_lhs_quant_pack_qai8dxp_f32.h"
|
||||||
|
|
||||||
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
|
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
|
||||||
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
#include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h"
|
||||||
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
#include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h"
|
||||||
|
#include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h"
|
||||||
|
|
||||||
#include "kai_common.h"
|
#include "kai_common.h"
|
||||||
|
|
||||||
#include "simd-mappings.h"
|
#include "simd-mappings.h"
|
||||||
|
|
||||||
|
#define GGML_COMMON_DECL_CPP
|
||||||
|
#include "ggml-common.h"
|
||||||
|
|
||||||
#include "kernels.h"
|
#include "kernels.h"
|
||||||
|
|
||||||
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
#define NELEMS(x) sizeof(x) / sizeof(*x)
|
||||||
@@ -55,6 +67,14 @@ static inline void kernel_run_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
|||||||
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
Fn(m, n, k, lhs, rhs, dst, dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<void(*Fn)(size_t,size_t,size_t,const void*,const void*,float*,size_t,size_t,float,float)>
|
||||||
|
static inline void kernel_run_float_fn10(size_t m, size_t n, size_t k, size_t /*bl*/,
|
||||||
|
const void* lhs, const void* rhs, void* dst,
|
||||||
|
size_t dst_stride_row, size_t dst_stride_col,
|
||||||
|
float clamp_min, float clamp_max) {
|
||||||
|
Fn(m, n, k, lhs, rhs, static_cast<float*>(dst), dst_stride_row, dst_stride_col, clamp_min, clamp_max);
|
||||||
|
}
|
||||||
|
|
||||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t)>
|
||||||
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
static inline size_t lhs_ps_fn6(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr) {
|
||||||
return Fn(m, k, bl, mr, kr, sr);
|
return Fn(m, k, bl, mr, kr, sr);
|
||||||
@@ -93,6 +113,12 @@ static inline void lhs_pack_void_fn9(size_t m, size_t k, size_t /*bl*/, size_t m
|
|||||||
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
Fn(m, k, mr, kr, sr, m_idx_start, lhs, lhs_stride, lhs_packed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const float*,size_t,void*)>
|
||||||
|
static inline void lhs_pack_float_fn9_no_bl(size_t m, size_t k, size_t /*bl*/, size_t mr, size_t kr, size_t sr,
|
||||||
|
size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed) {
|
||||||
|
Fn(m, k, mr, kr, sr, m_idx_start, static_cast<const float*>(lhs), lhs_stride, lhs_packed);
|
||||||
|
}
|
||||||
|
|
||||||
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
template<size_t(*Fn)(size_t,size_t,size_t,size_t,size_t)>
|
||||||
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
|
static inline size_t rhs_ps_fn5(size_t n, size_t k, size_t nr, size_t kr, size_t bl) {
|
||||||
return Fn(n, k, nr, kr, bl);
|
return Fn(n, k, nr, kr, bl);
|
||||||
@@ -124,6 +150,18 @@ static inline void rhs_pack_fn12(size_t num_groups, size_t n, size_t k, size_t n
|
|||||||
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
|
static_cast<const kai_rhs_pack_qs4cxs1s0_param*>(params));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,const int8_t*,const float*,const float*,void*,size_t,const struct kai_rhs_pack_qsi8cx_params*)>
|
||||||
|
static inline void rhs_pack_scale_fn12(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
||||||
|
size_t /*rhs_stride*/, const void* rhs, const void* bias, const void* scale,
|
||||||
|
void* rhs_packed, size_t extra_bytes, const void* params) {
|
||||||
|
Fn(num_groups, n, k, nr, kr, sr,
|
||||||
|
static_cast<const int8_t*>(rhs),
|
||||||
|
static_cast<const float*>(bias),
|
||||||
|
static_cast<const float*>(scale),
|
||||||
|
rhs_packed, extra_bytes,
|
||||||
|
static_cast<const kai_rhs_pack_qsi8cx_params*>(params));
|
||||||
|
}
|
||||||
|
|
||||||
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
|
template<void(*Fn)(size_t,size_t,size_t,size_t,size_t,size_t,size_t,const void*,const void*,const void*,void*,size_t,const void*)>
|
||||||
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
static inline void rhs_pack_fn13(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t /*bl*/,
|
||||||
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
|
size_t rhs_stride, const void* rhs, const void* bias, const void* scale,
|
||||||
@@ -213,6 +251,57 @@ static void dequantize_row_qsi4c32ps1s0scalef16(
|
|||||||
GGML_UNUSED(kr);
|
GGML_UNUSED(kr);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void dequantize_row_qsi8cxp(
|
||||||
|
const void *packed_data,
|
||||||
|
int32_t row_idx,
|
||||||
|
int64_t k,
|
||||||
|
float *out,
|
||||||
|
size_t nr,
|
||||||
|
size_t packed_row_stride,
|
||||||
|
size_t kr,
|
||||||
|
size_t bl,
|
||||||
|
size_t num_bytes_multiplier
|
||||||
|
) {
|
||||||
|
GGML_UNUSED(bl);
|
||||||
|
GGML_UNUSED(num_bytes_multiplier);
|
||||||
|
|
||||||
|
const size_t k_internal = ((size_t) k + QK8_0 - 1) / QK8_0 * QK8_0;
|
||||||
|
const size_t group_idx = row_idx / nr;
|
||||||
|
const size_t row_in_group = row_idx % nr;
|
||||||
|
|
||||||
|
const uint8_t * group_ptr = static_cast<const uint8_t *>(packed_data) + group_idx * packed_row_stride;
|
||||||
|
const int8_t * data_base = reinterpret_cast<const int8_t *>(group_ptr);
|
||||||
|
|
||||||
|
const size_t num_blocks = k_internal / kr;
|
||||||
|
|
||||||
|
for (size_t block = 0; block < num_blocks; ++block) {
|
||||||
|
const int8_t * block_ptr = data_base + (block * nr + row_in_group) * kr;
|
||||||
|
for (size_t i = 0; i < kr; ++i) {
|
||||||
|
const size_t k_idx = block * kr + i;
|
||||||
|
if (k_idx < (size_t) k) {
|
||||||
|
out[k_idx] = static_cast<float>(block_ptr[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint8_t * sums_ptr = group_ptr + nr * k_internal;
|
||||||
|
GGML_UNUSED(sums_ptr);
|
||||||
|
|
||||||
|
const float * scale_ptr = reinterpret_cast<const float *>(sums_ptr + nr * sizeof(int32_t));
|
||||||
|
const float scale = scale_ptr[row_in_group];
|
||||||
|
|
||||||
|
if (scale == 0.0f) {
|
||||||
|
for (size_t i = 0; i < (size_t) k; ++i) {
|
||||||
|
out[i] = 0.0f;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < (size_t) k; ++i) {
|
||||||
|
out[i] *= scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
||||||
#if defined(__ARM_FEATURE_SME)
|
#if defined(__ARM_FEATURE_SME)
|
||||||
{
|
{
|
||||||
@@ -548,6 +637,174 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
|
|||||||
#endif
|
#endif
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static ggml_kleidiai_kernels gemm_gemv_kernels_q8[] = {
|
||||||
|
#if defined(__ARM_FEATURE_SME)
|
||||||
|
{
|
||||||
|
/* SME GEMM */
|
||||||
|
{
|
||||||
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa,
|
||||||
|
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
||||||
|
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
||||||
|
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa>,
|
||||||
|
},
|
||||||
|
/* .gemm_lhs_info = */ {
|
||||||
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||||
|
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
},
|
||||||
|
/* SME GEMV */
|
||||||
|
{
|
||||||
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot,
|
||||||
|
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
||||||
|
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
||||||
|
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot>,
|
||||||
|
},
|
||||||
|
/* .gemv_lhs_info = */ {
|
||||||
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||||
|
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
},
|
||||||
|
/* .rhs_info = */ {
|
||||||
|
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
||||||
|
/* .to_float = */ dequantize_row_qsi8cxp,
|
||||||
|
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
},
|
||||||
|
/* .required_cpu = */ CPU_FEATURE_SME,
|
||||||
|
/* .lhs_type = */ GGML_TYPE_F32,
|
||||||
|
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
||||||
|
/* .op_type = */ GGML_TYPE_F32,
|
||||||
|
},
|
||||||
|
#endif
|
||||||
|
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
{
|
||||||
|
/* I8MM GEMM */
|
||||||
|
{
|
||||||
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm,
|
||||||
|
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
||||||
|
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
||||||
|
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm>,
|
||||||
|
},
|
||||||
|
/* .gemm_lhs_info = */ {
|
||||||
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||||
|
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
},
|
||||||
|
/* I8MM GEMV (dotprod fallback) */
|
||||||
|
{
|
||||||
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod,
|
||||||
|
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
||||||
|
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
||||||
|
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod>,
|
||||||
|
},
|
||||||
|
/* .gemv_lhs_info = */ {
|
||||||
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||||
|
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
},
|
||||||
|
/* .rhs_info = */ {
|
||||||
|
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
||||||
|
/* .to_float = */ dequantize_row_qsi8cxp,
|
||||||
|
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
},
|
||||||
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
|
||||||
|
/* .lhs_type = */ GGML_TYPE_F32,
|
||||||
|
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
||||||
|
/* .op_type = */ GGML_TYPE_F32,
|
||||||
|
},
|
||||||
|
#endif
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
{
|
||||||
|
/* DOTPROD GEMM */
|
||||||
|
{
|
||||||
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod,
|
||||||
|
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
||||||
|
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
||||||
|
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod>,
|
||||||
|
},
|
||||||
|
/* .gemm_lhs_info = */ {
|
||||||
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||||
|
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
},
|
||||||
|
/* DOTPROD GEMV */
|
||||||
|
{
|
||||||
|
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_mr = */ kai_get_mr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_nr = */ kai_get_nr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_kr = */ kai_get_kr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_sr = */ kai_get_sr_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod,
|
||||||
|
/* .get_lhs_offset_ex = */ &kernel_offs_fn2<kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
||||||
|
/* .get_rhs_packed_offset_ex = */ &kernel_offs_fn2<kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
||||||
|
/* .run_kernel_ex = */ &kernel_run_float_fn10<kai_run_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod>,
|
||||||
|
},
|
||||||
|
/* .gemv_lhs_info = */ {
|
||||||
|
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32,
|
||||||
|
/* .get_packed_offset_ex = */ &lhs_offs_fn5<kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .packed_size_ex = */ &lhs_ps_fn5<kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
/* .pack_func_ex = */ &lhs_pack_float_fn9_no_bl<kai_run_lhs_quant_pack_qai8dxp_f32>,
|
||||||
|
},
|
||||||
|
/* .rhs_info = */ {
|
||||||
|
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon,
|
||||||
|
/* .to_float = */ dequantize_row_qsi8cxp,
|
||||||
|
/* .packed_size_ex = */ &rhs_ps_fn5<kai_get_rhs_packed_size_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
/* .packed_stride_ex = */ &rhs_stride_fn4<kai_get_rhs_packed_stride_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
/* .pack_func_ex = */ &rhs_pack_scale_fn12<kai_run_rhs_pack_nxk_qsi8cxp_qsi8cx_neon>,
|
||||||
|
},
|
||||||
|
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
|
||||||
|
/* .lhs_type = */ GGML_TYPE_F32,
|
||||||
|
/* .rhs_type = */ GGML_TYPE_Q8_0,
|
||||||
|
/* .op_type = */ GGML_TYPE_F32,
|
||||||
|
},
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
|
||||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor) {
|
||||||
ggml_kleidiai_kernels * kernel = nullptr;
|
ggml_kleidiai_kernels * kernel = nullptr;
|
||||||
|
|
||||||
@@ -562,6 +819,17 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, c
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (!kernel) {
|
||||||
|
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
||||||
|
if ((cpu_features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu &&
|
||||||
|
gemm_gemv_kernels_q8[i].lhs_type == tensor->src[1]->type &&
|
||||||
|
gemm_gemv_kernels_q8[i].rhs_type == tensor->src[0]->type &&
|
||||||
|
gemm_gemv_kernels_q8[i].op_type == tensor->type) {
|
||||||
|
kernel = &gemm_gemv_kernels_q8[i];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -582,3 +850,18 @@ ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features)
|
|||||||
|
|
||||||
return kernels;
|
return kernels;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features) {
|
||||||
|
ggml_kleidiai_kernels * kernels = nullptr;
|
||||||
|
|
||||||
|
#if defined(__ARM_FEATURE_SME) || defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
for (size_t i = 0; i < NELEMS(gemm_gemv_kernels_q8); ++i) {
|
||||||
|
if ((features & gemm_gemv_kernels_q8[i].required_cpu) == gemm_gemv_kernels_q8[i].required_cpu) {
|
||||||
|
kernels = &gemm_gemv_kernels_q8[i];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
return kernels;
|
||||||
|
}
|
||||||
|
|||||||
@@ -87,3 +87,4 @@ struct ggml_kleidiai_kernels {
|
|||||||
|
|
||||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels(cpu_feature cpu_features, const ggml_tensor * tensor);
|
||||||
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q4_0(cpu_feature features);
|
||||||
|
ggml_kleidiai_kernels * ggml_kleidiai_select_kernels_q8_0(cpu_feature features);
|
||||||
|
|||||||
@@ -5,10 +5,13 @@
|
|||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
#include <cfloat>
|
#include <cfloat>
|
||||||
|
#include <cmath>
|
||||||
|
#include <algorithm>
|
||||||
#include <stdexcept>
|
#include <stdexcept>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
#include <string.h>
|
#include <string.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
#if defined(__linux__)
|
#if defined(__linux__)
|
||||||
#include <asm/hwcap.h>
|
#include <asm/hwcap.h>
|
||||||
#include <sys/auxv.h>
|
#include <sys/auxv.h>
|
||||||
@@ -38,8 +41,9 @@
|
|||||||
|
|
||||||
struct ggml_kleidiai_context {
|
struct ggml_kleidiai_context {
|
||||||
cpu_feature features;
|
cpu_feature features;
|
||||||
ggml_kleidiai_kernels * kernels;
|
ggml_kleidiai_kernels * kernels_q4;
|
||||||
} static ctx = { CPU_FEATURE_NONE, NULL };
|
ggml_kleidiai_kernels * kernels_q8;
|
||||||
|
} static ctx = { CPU_FEATURE_NONE, NULL, NULL };
|
||||||
|
|
||||||
static const char* cpu_feature_to_string(cpu_feature f) {
|
static const char* cpu_feature_to_string(cpu_feature f) {
|
||||||
switch (f) {
|
switch (f) {
|
||||||
@@ -73,10 +77,14 @@ static void init_kleidiai_context(void) {
|
|||||||
if (sme_enabled != 0) {
|
if (sme_enabled != 0) {
|
||||||
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
||||||
}
|
}
|
||||||
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
ctx.kernels_q4 = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
||||||
|
ctx.kernels_q8 = ggml_kleidiai_select_kernels_q8_0(ctx.features);
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
if (ctx.kernels) {
|
if (ctx.kernels_q4) {
|
||||||
GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
|
GGML_LOG_DEBUG("kleidiai: using q4 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q4->required_cpu));
|
||||||
|
}
|
||||||
|
if (ctx.kernels_q8) {
|
||||||
|
GGML_LOG_DEBUG("kleidiai: using q8 kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels_q8->required_cpu));
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
@@ -130,6 +138,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||||||
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
if (kernels->rhs_type == GGML_TYPE_Q4_0) {
|
||||||
if (!lhs_info->packed_size_ex) return false;
|
if (!lhs_info->packed_size_ex) return false;
|
||||||
size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
|
size = lhs_info->packed_size_ex(m, k, QK4_0, mr, kr, sr);
|
||||||
|
} else if (kernels->rhs_type == GGML_TYPE_Q8_0) {
|
||||||
|
if (!lhs_info->packed_size_ex) return false;
|
||||||
|
size = lhs_info->packed_size_ex(m, k, QK8_0, mr, kr, sr);
|
||||||
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
} else if (kernels->rhs_type == GGML_TYPE_F16) {
|
||||||
if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
|
if (!lhs_info->packed_size_ex || !kernels->rhs_info.packed_size_ex) return false;
|
||||||
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
|
const int64_t lhs_batch_size0 = op->src[1]->ne[2];
|
||||||
@@ -149,11 +160,13 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||||||
if (dst->op == GGML_OP_MUL_MAT) {
|
if (dst->op == GGML_OP_MUL_MAT) {
|
||||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||||
return compute_forward_q4_0(params, dst);
|
return compute_forward_q4_0(params, dst);
|
||||||
|
} else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
|
||||||
|
return compute_forward_q8_0(params, dst);
|
||||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||||
return compute_forward_fp16(params, dst);
|
return compute_forward_fp16(params, dst);
|
||||||
}
|
}
|
||||||
} else if (dst->op == GGML_OP_GET_ROWS) {
|
} else if (dst->op == GGML_OP_GET_ROWS) {
|
||||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
if (dst->src[0]->type == GGML_TYPE_Q4_0 || dst->src[0]->type == GGML_TYPE_Q8_0) {
|
||||||
return compute_forward_get_rows(params, dst);
|
return compute_forward_get_rows(params, dst);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -400,19 +413,120 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
bool compute_forward_q8_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q8_0);
|
||||||
if (!ctx.kernels) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
const ggml_tensor * src1 = dst->src[1];
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
|
||||||
GGML_TENSOR_BINARY_OP_LOCALS
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
|
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
|
||||||
kernel_info * kernel = &ctx.kernels->gemm;
|
if (!kernels) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_gemv = src1->ne[1] == 1;
|
||||||
|
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
|
||||||
|
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
|
||||||
|
|
||||||
|
if (!kernel || !lhs_info->get_packed_offset_ex || !lhs_info->pack_func_ex ||
|
||||||
|
!kernel->get_rhs_packed_offset_ex || !kernel->run_kernel_ex || !kernel->get_dst_offset) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth_raw = params->nth;
|
||||||
|
const int nth = nth_raw > 0 ? nth_raw : 1;
|
||||||
|
|
||||||
|
const size_t k = ne00;
|
||||||
|
const size_t m = ne11;
|
||||||
|
const size_t n = ne01;
|
||||||
|
|
||||||
|
size_t mr = kernel->get_mr();
|
||||||
|
size_t kr = kernel->get_kr();
|
||||||
|
size_t sr = kernel->get_sr();
|
||||||
|
|
||||||
|
const uint8_t * lhs = static_cast<const uint8_t *>(src1->data);
|
||||||
|
uint8_t * lhs_packed = static_cast<uint8_t *>(params->wdata);
|
||||||
|
const uint8_t * rhs_packed = static_cast<const uint8_t *>(src0->data);
|
||||||
|
|
||||||
|
const size_t n_step = kernel->get_n_step();
|
||||||
|
const size_t num_n_per_thread = kai_roundup(kai_roundup(n, nth) / nth, n_step);
|
||||||
|
const size_t n_start = ith * num_n_per_thread;
|
||||||
|
|
||||||
|
size_t n_to_process = 0;
|
||||||
|
if (n_start < n) {
|
||||||
|
n_to_process = num_n_per_thread;
|
||||||
|
if ((n_start + n_to_process) > n) {
|
||||||
|
n_to_process = n - n_start;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t num_m_per_thread = kai_roundup(m, mr * nth) / nth;
|
||||||
|
const size_t m_start = ith * num_m_per_thread;
|
||||||
|
size_t m_to_process = num_m_per_thread;
|
||||||
|
if ((m_start + m_to_process) > m) {
|
||||||
|
m_to_process = m - m_start;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (m_start < m) {
|
||||||
|
const size_t src_stride = src1->nb[1];
|
||||||
|
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(m_start, dst->src[1]->nb[1]));
|
||||||
|
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(m_start, k, 0, mr, kr, sr);
|
||||||
|
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
|
||||||
|
|
||||||
|
lhs_info->pack_func_ex(m_to_process, k, 0, mr, kr, sr, 0, src_ptr, src_stride, lhs_packed_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_barrier(params->threadpool);
|
||||||
|
|
||||||
|
const size_t dst_stride = dst->nb[1];
|
||||||
|
const size_t lhs_packed_offset = lhs_info->get_packed_offset_ex(0, k, 0, mr, kr, sr);
|
||||||
|
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset_ex(n_start, k, 0);
|
||||||
|
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
|
||||||
|
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
|
||||||
|
const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset);
|
||||||
|
float * dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
|
||||||
|
|
||||||
|
if (n_to_process > 0) {
|
||||||
|
kernel->run_kernel_ex(m, n_to_process, k, 0, lhs_ptr, rhs_ptr, dst_ptr, dst_stride,
|
||||||
|
sizeof(float), -FLT_MAX, FLT_MAX);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
|
||||||
|
GGML_TENSOR_BINARY_OP_LOCALS
|
||||||
|
|
||||||
|
ggml_kleidiai_kernels * kernels = nullptr;
|
||||||
|
size_t block_len = 0;
|
||||||
|
size_t num_bytes_multiplier = 0;
|
||||||
|
|
||||||
|
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||||
|
if (!ctx.kernels_q4) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
kernels = ctx.kernels_q4;
|
||||||
|
block_len = QK4_0;
|
||||||
|
num_bytes_multiplier = sizeof(uint16_t);
|
||||||
|
} else if (dst->src[0]->type == GGML_TYPE_Q8_0) {
|
||||||
|
if (!ctx.kernels_q8) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
kernels = ctx.kernels_q8;
|
||||||
|
block_len = QK8_0;
|
||||||
|
num_bytes_multiplier = sizeof(float);
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
rhs_packing_info * rhs_info = &kernels->rhs_info;
|
||||||
|
kernel_info * kernel = &kernels->gemm;
|
||||||
if (!rhs_info->to_float || !kernel->get_nr) {
|
if (!rhs_info->to_float || !kernel->get_nr) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@@ -423,8 +537,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||||||
const size_t block_rows = kernel->get_nr();
|
const size_t block_rows = kernel->get_nr();
|
||||||
const size_t kr = kernel->get_kr();
|
const size_t kr = kernel->get_kr();
|
||||||
|
|
||||||
const size_t num_bytes_multiplier = sizeof(uint16_t);
|
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, block_len);
|
||||||
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
|
|
||||||
|
|
||||||
const int ith = params->ith;
|
const int ith = params->ith;
|
||||||
const int nth = params->nth;
|
const int nth = params->nth;
|
||||||
@@ -439,7 +552,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||||||
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
|
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
|
||||||
|
|
||||||
float *out = (float *)((char *)dst->data + i * nb1);
|
float *out = (float *)((char *)dst->data + i * nb1);
|
||||||
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
|
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, block_len, num_bytes_multiplier);
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
@@ -447,21 +560,91 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
||||||
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
|
||||||
GGML_ASSERT(ctx.kernels);
|
|
||||||
const size_t n = tensor->ne[1];
|
const size_t n = tensor->ne[1];
|
||||||
const size_t k = tensor->ne[0];
|
const size_t k = tensor->ne[0];
|
||||||
size_t nr = ctx.kernels->gemm.get_nr();
|
|
||||||
size_t kr = ctx.kernels->gemm.get_kr();
|
if (tensor->type == GGML_TYPE_Q4_0) {
|
||||||
size_t sr = ctx.kernels->gemm.get_sr();
|
if (!ctx.kernels_q4) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
size_t nr = ctx.kernels_q4->gemm.get_nr();
|
||||||
|
size_t kr = ctx.kernels_q4->gemm.get_kr();
|
||||||
|
size_t sr = ctx.kernels_q4->gemm.get_sr();
|
||||||
|
|
||||||
struct kai_rhs_pack_qs4cxs1s0_param params;
|
struct kai_rhs_pack_qs4cxs1s0_param params;
|
||||||
params.lhs_zero_point = 1;
|
params.lhs_zero_point = 1;
|
||||||
params.rhs_zero_point = 8;
|
params.rhs_zero_point = 8;
|
||||||
ctx.kernels->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0, (const uint8_t*)data, nullptr, nullptr, tensor->data, 0, ¶ms);
|
ctx.kernels_q4->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, QK4_0, 0,
|
||||||
|
static_cast<const uint8_t *>(data),
|
||||||
return 0;
|
nullptr, nullptr, tensor->data, 0, ¶ms);
|
||||||
GGML_UNUSED(data_size);
|
GGML_UNUSED(data_size);
|
||||||
|
return 0;
|
||||||
|
} else if (tensor->type == GGML_TYPE_Q8_0) {
|
||||||
|
if (!ctx.kernels_q8) {
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t row_stride = tensor->nb[1];
|
||||||
|
const size_t k_blocks = (k + QK8_0 - 1) / QK8_0;
|
||||||
|
|
||||||
|
std::vector<int8_t> qdata(n * k, 0);
|
||||||
|
std::vector<float> scales(n, 0.0f);
|
||||||
|
|
||||||
|
for (size_t row = 0; row < n; ++row) {
|
||||||
|
const auto * row_blocks = reinterpret_cast<const block_q8_0 *>(
|
||||||
|
static_cast<const uint8_t *>(data) + row * row_stride);
|
||||||
|
|
||||||
|
float max_abs = 0.0f;
|
||||||
|
for (size_t block = 0; block < k_blocks; ++block) {
|
||||||
|
const block_q8_0 & blk = row_blocks[block];
|
||||||
|
const float d = GGML_FP16_TO_FP32(blk.d);
|
||||||
|
for (size_t l = 0; l < QK8_0; ++l) {
|
||||||
|
const size_t linear_idx = block * QK8_0 + l;
|
||||||
|
if (linear_idx >= k) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const float value = d * blk.qs[l];
|
||||||
|
max_abs = std::max(max_abs, std::fabs(value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
float scale = max_abs > 0.0f ? max_abs / 127.0f : 0.0f;
|
||||||
|
scales[row] = scale;
|
||||||
|
const float inv_scale = scale > 0.0f ? 1.0f / scale : 0.0f;
|
||||||
|
|
||||||
|
for (size_t block = 0; block < k_blocks; ++block) {
|
||||||
|
const block_q8_0 & blk = row_blocks[block];
|
||||||
|
const float d = GGML_FP16_TO_FP32(blk.d);
|
||||||
|
for (size_t l = 0; l < QK8_0; ++l) {
|
||||||
|
const size_t linear_idx = block * QK8_0 + l;
|
||||||
|
if (linear_idx >= k) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const float value = d * blk.qs[l];
|
||||||
|
int32_t q = scale > 0.0f ? static_cast<int32_t>(std::lround(value * inv_scale)) : 0;
|
||||||
|
q = std::clamp(q, -127, 127);
|
||||||
|
qdata[row * k + linear_idx] = static_cast<int8_t>(q);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t nr = ctx.kernels_q8->gemm.get_nr();
|
||||||
|
size_t kr = ctx.kernels_q8->gemm.get_kr();
|
||||||
|
size_t sr = ctx.kernels_q8->gemm.get_sr();
|
||||||
|
|
||||||
|
struct kai_rhs_pack_qsi8cx_params params;
|
||||||
|
params.lhs_zero_point = 1;
|
||||||
|
params.scale_multiplier = 1.0f;
|
||||||
|
|
||||||
|
ctx.kernels_q8->rhs_info.pack_func_ex(1, n, k, nr, kr, sr, 0, 0,
|
||||||
|
qdata.data(), nullptr, scales.data(),
|
||||||
|
tensor->data, 0, ¶ms);
|
||||||
|
GGML_UNUSED(data_size);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_UNUSED(data_size);
|
||||||
|
return -1;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -518,27 +701,45 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
|
|||||||
}
|
}
|
||||||
|
|
||||||
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
||||||
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
GGML_UNUSED(buft);
|
||||||
GGML_ASSERT(ctx.kernels);
|
|
||||||
|
|
||||||
const size_t n = tensor->ne[1];
|
const size_t n = tensor->ne[1];
|
||||||
const size_t k = tensor->ne[0];
|
const size_t k = tensor->ne[0];
|
||||||
const size_t nr = ctx.kernels->gemm.get_nr();
|
|
||||||
const size_t kr = ctx.kernels->gemm.get_kr();
|
|
||||||
|
|
||||||
return ctx.kernels->rhs_info.packed_size_ex(n, k, nr, kr, QK4_0);
|
ggml_kleidiai_kernels * kernels = nullptr;
|
||||||
|
size_t block_len = 0;
|
||||||
|
|
||||||
GGML_UNUSED(buft);
|
if (tensor->type == GGML_TYPE_Q4_0) {
|
||||||
|
GGML_ASSERT(ctx.kernels_q4);
|
||||||
|
kernels = ctx.kernels_q4;
|
||||||
|
block_len = QK4_0;
|
||||||
|
} else if (tensor->type == GGML_TYPE_Q8_0) {
|
||||||
|
GGML_ASSERT(ctx.kernels_q8);
|
||||||
|
kernels = ctx.kernels_q8;
|
||||||
|
block_len = QK8_0;
|
||||||
|
} else {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t nr = kernels->gemm.get_nr();
|
||||||
|
const size_t kr = kernels->gemm.get_kr();
|
||||||
|
const size_t packed = kernels->rhs_info.packed_size_ex(n, k, nr, kr, block_len);
|
||||||
|
const size_t raw = ggml_nbytes(tensor);
|
||||||
|
|
||||||
|
return packed > raw ? packed : raw;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace ggml::cpu::kleidiai {
|
namespace ggml::cpu::kleidiai {
|
||||||
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||||
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
||||||
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
|
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
|
||||||
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
(op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q8_0) &&
|
||||||
op->src[0]->buffer &&
|
op->src[0]->buffer &&
|
||||||
(ggml_n_dims(op->src[0]) == 2) &&
|
(ggml_n_dims(op->src[0]) == 2) &&
|
||||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||||
|
if (((op->src[0]->type == GGML_TYPE_Q4_0) ? ctx.kernels_q4 : ctx.kernels_q8) == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user