mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
kleidiai: generalize compute_forward_kv_cache to compute_forward_fp16 (#15817)
This commit is contained in:
@@ -154,7 +154,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||||
return compute_forward_q4_0(params, dst);
|
return compute_forward_q4_0(params, dst);
|
||||||
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
} else if (dst->src[0]->type == GGML_TYPE_F16) {
|
||||||
return compute_forward_kv_cache(params, dst);
|
return compute_forward_fp16(params, dst);
|
||||||
}
|
}
|
||||||
} else if (dst->op == GGML_OP_GET_ROWS) {
|
} else if (dst->op == GGML_OP_GET_ROWS) {
|
||||||
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
|
||||||
@@ -164,7 +164,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
|
bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
|
||||||
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
|
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
|
||||||
|
|
||||||
const ggml_tensor * src0 = dst->src[0];
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
@@ -534,13 +534,8 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
|
|||||||
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
|
||||||
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
|
||||||
}
|
}
|
||||||
else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
|
else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
|
||||||
op->src[0]->op == GGML_OP_VIEW &&
|
if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
|
||||||
(op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
|
|
||||||
op->src[1]->ne[1] > 1) {
|
|
||||||
if ((op->src[0]->nb[0] != 2) ||
|
|
||||||
(op->src[1]->nb[0] != 4) ||
|
|
||||||
(op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
|
|
||||||
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
|
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user