kleidiai: add support for get_rows (#14676)

* kleidiai: add support for get_rows

* apply fixes based on code review

* apply more fixes based on code review
This commit is contained in:
Charles Xu
2025-07-21 15:49:52 +02:00
committed by GitHub
parent 2ba1333b35
commit 922042601b
4 changed files with 202 additions and 24 deletions

View File

@@ -22,9 +22,94 @@
#include "kai_common.h"
#include "simd-mappings.h"
#include "kernels.h"
#define NELEMS(x) sizeof(x) / sizeof(*x)
static const size_t INT4_PER_BYTE = 2;
static const size_t INT4_BITS = 4;
static const int Q4_0_ZERO_POINT = 8;
const size_t INT4_PER_UINT16 = 4;
static void dequantize_row_qsi4c32pscalef16(
const void *packed_data,
int32_t row_idx,
int64_t nc,
float *out,
size_t nr_pack,
size_t packed_row_stride,
size_t kr,
size_t bl,
size_t num_bytes_multiplier
) {
size_t group_idx = row_idx / nr_pack;
size_t row_in_group = row_idx % nr_pack;
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
size_t num_blocks = nc / bl;
const uint8_t *block_ptr = packed_group;
for (size_t b = 0; b < num_blocks; ++b) {
uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier));
float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier;
size_t num_segments = bl / kr;
size_t num_bytes_per_segment = kr / INT4_PER_BYTE;
for (size_t s = 0; s < num_segments; ++s) {
const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment;
const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment;
for (size_t k = 0; k < num_bytes_per_segment; ++k) {
uint8_t byte = qbytes[k] ^ 0x88;
int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT;
int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT;
out[b * bl + s * num_bytes_per_segment + k] = x0 * scale;
out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale;
}
}
block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment;
}
}
static void dequantize_row_qsi4c32ps1s0scalef16(
const void *packed_data,
int32_t row_idx,
int64_t k,
float *out,
size_t nr,
size_t packed_row_stride,
size_t kr,
size_t bl,
size_t num_bytes_multiplier
) {
const size_t num_blocks = k / bl;
const size_t bl4 = bl / INT4_PER_UINT16;
size_t group_idx = row_idx / nr;
size_t row_in_group = row_idx % nr;
const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride;
const uint16_t *qdata = (const uint16_t *)packed_group;
const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier));
for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) {
uint16_t scale_f16 = scales[row_in_group + block_idx * nr];
float scale = GGML_CPU_FP16_TO_FP32(scale_f16);
for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) {
uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group];
for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) {
int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT;
out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale;
}
}
}
GGML_UNUSED(kr);
}
static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
#if defined(__ARM_FEATURE_SME)
{
@@ -63,8 +148,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon,
/* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16,
},
/* .required_cpu = */ CPU_FEATURE_SME,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -107,8 +194,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
/* .packed_stride = */ NULL,
/* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme,
/* .to_float = */ NULL,
},
/* .required_cpu = */ CPU_FEATURE_SME,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -154,8 +243,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -200,8 +291,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -247,8 +340,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM,
/* .lhs_type = */ GGML_TYPE_F32,
@@ -293,8 +388,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* .rhs_info = */ {
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0,
/* .to_float = */ dequantize_row_qsi4c32pscalef16,
},
/* .required_cpu = */ CPU_FEATURE_DOTPROD,
/* .lhs_type = */ GGML_TYPE_F32,