kleidiai: kernel interface refactoring (#16460)

This commit is contained in:
Charles Xu
2025-10-09 09:29:17 +02:00
committed by GitHub
parent b260213755
commit d80d6d2400
3 changed files with 292 additions and 213 deletions

View File

@@ -4,8 +4,6 @@
#pragma once
#include <functional>
#include <variant>
#include "ggml.h"
enum cpu_feature {
@@ -15,6 +13,7 @@ enum cpu_feature {
CPU_FEATURE_SVE = 4,
CPU_FEATURE_SME = 8
};
inline cpu_feature& operator|=(cpu_feature& lhs, cpu_feature rhs) {
lhs = static_cast<cpu_feature>(lhs | rhs);
return lhs;
@@ -30,63 +29,52 @@ struct kernel_info {
size_t (*get_nr)(void);
size_t (*get_kr)(void);
size_t (*get_sr)(void);
std::variant<
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
std::function<size_t(size_t m_idx, size_t k)>
> get_lhs_offset;
std::variant<
std::function<size_t(size_t n_idx, size_t k, size_t bl)>,
std::function<size_t(size_t n_idx, size_t k)>
> get_rhs_packed_offset;
size_t (*get_dst_offset)(size_t m_idx, size_t n_idx, size_t stride);
size_t (*get_dst_size)(size_t m, size_t n);
std::variant<
std::function<void(size_t m, size_t n, size_t k, size_t bl, const void* lhs_packed, const void* rhs_packed,
float* dst, size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max)>,
std::function<void(size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, void* dst, size_t dst_stride_row,
size_t dst_stride_col, float clamp_min, float clamp_max)>
> run_kernel;
size_t (*get_lhs_offset_ex)(size_t m_idx, size_t k, size_t bl);
size_t (*get_rhs_packed_offset_ex)(size_t n_idx, size_t k, size_t bl);
void (*run_kernel_ex)(
size_t m, size_t n, size_t k, size_t bl,
const void* lhs_packed, const void* rhs_packed,
void* dst, size_t dst_stride_row, size_t dst_stride_col,
float clamp_min, float clamp_max);
};
struct lhs_packing_info {
size_t (*get_offset)(size_t m_idx, size_t lhs_stride);
std::variant<
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
std::function<size_t(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr)>
> get_packed_offset;
std::variant<
std::function<size_t(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr)>,
std::function<size_t(size_t m, size_t k, size_t mr, size_t kr, size_t sr)>
> packed_size;
std::variant<
std::function<void(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs,
size_t lhs_stride, void* lhs_packed)>,
std::function<void(size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const void* lhs, size_t lhs_stride,
void* lhs_packed)>
> pack_func;
size_t (*get_packed_offset_ex)(size_t m_idx, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
size_t (*packed_size_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr);
void (*pack_func_ex)(size_t m, size_t k, size_t bl, size_t mr, size_t kr, size_t sr,
size_t m_idx_start, const void * lhs, size_t lhs_stride, void * lhs_packed);
};
struct rhs_packing_info {
std::variant<
std::function<size_t(size_t n, size_t k, size_t nr, size_t kr, size_t bl)>,
std::function<size_t(size_t n, size_t k)>
> packed_size;
size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl);
std::variant<
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl, const uint8_t* rhs,
const float* bias, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_qs4cxs1s0_param* params)>,
std::function<void(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t rhs_stride, const void* rhs,
const void* bias, const void* scale, void* rhs_packed, size_t extra_bytes, const void* params)>
> pack_func;
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride,
size_t kr, size_t bl, size_t num_bytes_multiplier);
void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out,
size_t nr_pack, size_t packed_row_stride, size_t kr, size_t bl,
size_t num_bytes_multiplier);
size_t (*packed_size_ex)(size_t n, size_t k, size_t nr, size_t kr, size_t bl);
size_t (*packed_stride_ex)(size_t k, size_t nr, size_t kr, size_t bl);
void (*pack_func_ex)(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, size_t bl,
size_t rhs_stride, const void * rhs, const void * bias, const void * scale, void * rhs_packed, size_t extra_bytes, const void * params);
};
struct ggml_kleidiai_kernels {
kernel_info gemm;
kernel_info gemm;
lhs_packing_info gemm_lhs_info;
kernel_info gemv;
kernel_info gemv;
lhs_packing_info gemv_lhs_info;
rhs_packing_info rhs_info;