mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-15 11:17:31 +00:00
kleidiai: add support for get_rows (#14676)
* kleidiai: add support for get_rows * apply fixes based on code review * apply more fixes based on code review
This commit is contained in:
@@ -40,6 +40,17 @@ struct ggml_kleidiai_context {
|
||||
ggml_kleidiai_kernels * kernels;
|
||||
} static ctx = { CPU_FEATURE_NONE, NULL };
|
||||
|
||||
static const char* cpu_feature_to_string(cpu_feature f) {
|
||||
switch (f) {
|
||||
case CPU_FEATURE_NONE: return "NONE";
|
||||
case CPU_FEATURE_DOTPROD: return "DOTPROD";
|
||||
case CPU_FEATURE_I8MM: return "I8MM";
|
||||
case CPU_FEATURE_SVE: return "SVE";
|
||||
case CPU_FEATURE_SME: return "SME";
|
||||
default: return "UNKNOWN";
|
||||
}
|
||||
}
|
||||
|
||||
static void init_kleidiai_context(void) {
|
||||
|
||||
ggml_critical_section_start();
|
||||
@@ -62,6 +73,11 @@ static void init_kleidiai_context(void) {
|
||||
ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE;
|
||||
}
|
||||
ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features);
|
||||
#ifndef NDEBUG
|
||||
if (ctx.kernels) {
|
||||
GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
ggml_critical_section_end();
|
||||
}
|
||||
@@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1
|
||||
|
||||
class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override {
|
||||
if (op->op != GGML_OP_MUL_MAT) {
|
||||
return false;
|
||||
}
|
||||
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
|
||||
GGML_ASSERT(kernels);
|
||||
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm;
|
||||
@@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||
return compute_forward_kv_cache(params, dst);
|
||||
}
|
||||
} else if (dst->op == GGML_OP_GET_ROWS) {
|
||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||
return compute_forward_get_rows(params, dst);
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
@@ -270,6 +293,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
}
|
||||
|
||||
bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
@@ -342,8 +367,49 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
rhs_packing_info * rhs_info = &ctx.kernels->rhs_info;
|
||||
kernel_info * kernel = &ctx.kernels->gemm;
|
||||
|
||||
const int64_t nc = ne00;
|
||||
const int64_t nr = ggml_nelements(src1);
|
||||
|
||||
const size_t block_rows = kernel->get_nr();
|
||||
const size_t kr = kernel->get_kr();
|
||||
|
||||
const size_t num_bytes_multiplier = sizeof(uint16_t);
|
||||
const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int dr = (nr + nth - 1) / nth;
|
||||
const int ir0 = dr * ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int64_t i = ir0; i < ir1; ++i) {
|
||||
GGML_ASSERT(src1->type == GGML_TYPE_I32);
|
||||
int64_t row_idx = ((const int32_t *)src1->data)[i];
|
||||
GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]);
|
||||
|
||||
float *out = (float *)((char *)dst->data + i * nb1);
|
||||
rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
public:
|
||||
int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) {
|
||||
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
const size_t n = tensor->ne[1];
|
||||
const size_t k = tensor->ne[0];
|
||||
@@ -351,17 +417,12 @@ public:
|
||||
size_t kr = ctx.kernels->gemm.get_kr();
|
||||
size_t sr = ctx.kernels->gemm.get_sr();
|
||||
|
||||
#ifndef NDEBUG
|
||||
const size_t repacked_size = variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
||||
GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!");
|
||||
#endif
|
||||
struct kai_rhs_pack_qs4cxs1s0_param params;
|
||||
params.lhs_zero_point = 1;
|
||||
params.rhs_zero_point = 8;
|
||||
variant_call<void>(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms);
|
||||
|
||||
return 0;
|
||||
|
||||
GGML_UNUSED(data_size);
|
||||
}
|
||||
};
|
||||
@@ -375,8 +436,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc
|
||||
static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
|
||||
tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor);
|
||||
|
||||
GGML_UNUSED(buffer);
|
||||
return GGML_STATUS_SUCCESS;
|
||||
GGML_UNUSED(buffer);
|
||||
}
|
||||
|
||||
static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor,
|
||||
@@ -418,18 +479,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0);
|
||||
GGML_ASSERT(ctx.kernels);
|
||||
|
||||
const size_t n = tensor->ne[1];
|
||||
const size_t k = tensor->ne[0];
|
||||
const size_t nr = ctx.kernels->gemm.get_nr();
|
||||
const size_t kr = ctx.kernels->gemm.get_kr();
|
||||
|
||||
return variant_call<size_t>(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0);
|
||||
|
||||
GGML_UNUSED(buft);
|
||||
}
|
||||
|
||||
namespace ggml::cpu::kleidiai {
|
||||
class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override {
|
||||
if (op->op == GGML_OP_MUL_MAT &&
|
||||
if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) &&
|
||||
op->src[0]->type == GGML_TYPE_Q4_0 &&
|
||||
op->src[0]->buffer &&
|
||||
(ggml_n_dims(op->src[0]) == 2) &&
|
||||
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
|
||||
if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type == GGML_TYPE_F32 &&
|
||||
if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) &&
|
||||
ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) {
|
||||
return true;
|
||||
}
|
||||
@@ -438,7 +516,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
||||
}
|
||||
|
||||
ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override {
|
||||
if (op->op == GGML_OP_MUL_MAT) {
|
||||
if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) {
|
||||
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
||||
}
|
||||
@@ -469,7 +547,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) {
|
||||
/* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer,
|
||||
/* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment,
|
||||
/* .get_max_size = */ nullptr, // defaults to SIZE_MAX
|
||||
/* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes
|
||||
/* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size,
|
||||
/* .is_host = */ nullptr,
|
||||
},
|
||||
/* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0),
|
||||
|
||||
Reference in New Issue
Block a user