mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-20 12:07:33 +00:00
CUDA: broadcasting for FlashAttention mask (#14500)
This commit is contained in:
committed by
Georgi Gerganov
parent
8875523eb3
commit
12a81af45f
@@ -27,7 +27,9 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const int ne12,
|
||||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
@@ -68,7 +70,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
K += nb12*(blockIdx.z / gqa_ratio);
|
||||
V += nb22*(blockIdx.z / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) mask + ne11*ic0;
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
@@ -342,8 +344,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
|
||||
Reference in New Issue
Block a user