mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-02 09:12:03 +00:00
ggml : support bcast ggml_soft_max_ext, ggml_flash_attn_ext (#14435)
ggml-ci
This commit is contained in:
@@ -3666,9 +3666,11 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||
if (mask) {
|
||||
GGML_ASSERT(mask->type == GGML_TYPE_F16 || mask->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(ggml_is_matrix(mask));
|
||||
GGML_ASSERT(ggml_is_3d(mask));
|
||||
GGML_ASSERT(mask->ne[0] == a->ne[0]);
|
||||
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
||||
GGML_ASSERT(a->ne[2]%mask->ne[2] == 0);
|
||||
GGML_ASSERT(a->ne[3]%mask->ne[3] == 0);
|
||||
}
|
||||
|
||||
if (max_bias > 0.0f) {
|
||||
@@ -4689,13 +4691,17 @@ struct ggml_tensor * ggml_flash_attn_ext(
|
||||
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||
// TODO: check if vT can be multiplied by (k*qT)
|
||||
|
||||
GGML_ASSERT(q->ne[3] == k->ne[3]);
|
||||
GGML_ASSERT(q->ne[3] == v->ne[3]);
|
||||
|
||||
if (mask) {
|
||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||
GGML_ASSERT(mask->ne[2] == 1);
|
||||
GGML_ASSERT(mask->ne[3] == 1);
|
||||
GGML_ASSERT(mask->ne[2] == q->ne[3]);
|
||||
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
||||
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
||||
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
||||
|
||||
GGML_ASSERT(q->ne[3] % mask->ne[2] == 0);
|
||||
}
|
||||
|
||||
if (max_bias > 0.0f) {
|
||||
|
||||
Reference in New Issue
Block a user