mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
CUDA: broadcasting for FlashAttention mask (#14500)
This commit is contained in:
committed by
Georgi Gerganov
parent
8875523eb3
commit
12a81af45f
@@ -32,7 +32,9 @@ typedef void (* fattn_kernel_t)(
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
@@ -851,7 +853,8 @@ void launch_fattn(
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
|
||||
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
|
||||
Reference in New Issue
Block a user