mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-14 11:07:10 +00:00
cpu: introduce chunking for flash attention (#16829)
Factor out the core FA loop into flash_atten_f16_one_chunk and add an outter loop on top that handles the chunks.
This commit is contained in:
@@ -7909,10 +7909,10 @@ void ggml_compute_forward_argsort(
|
|||||||
|
|
||||||
// ggml_compute_forward_flash_attn_ext
|
// ggml_compute_forward_flash_attn_ext
|
||||||
|
|
||||||
static void ggml_compute_forward_flash_attn_ext_f16(
|
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
|
||||||
const ggml_compute_params * params,
|
const ggml_compute_params * params,
|
||||||
ggml_tensor * dst) {
|
ggml_tensor * dst,
|
||||||
|
int ir0, int ir1) {
|
||||||
const ggml_tensor * q = dst->src[0];
|
const ggml_tensor * q = dst->src[0];
|
||||||
const ggml_tensor * k = dst->src[1];
|
const ggml_tensor * k = dst->src[1];
|
||||||
const ggml_tensor * v = dst->src[2];
|
const ggml_tensor * v = dst->src[2];
|
||||||
@@ -7928,9 +7928,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||||
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||||
|
|
||||||
const int ith = params->ith;
|
|
||||||
const int nth = params->nth;
|
|
||||||
|
|
||||||
const int64_t DK = nek0;
|
const int64_t DK = nek0;
|
||||||
const int64_t DV = nev0;
|
const int64_t DV = nev0;
|
||||||
const int64_t N = neq1;
|
const int64_t N = neq1;
|
||||||
@@ -7964,16 +7961,6 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
|
|
||||||
// parallelize by q rows using ggml_vec_dot_f32
|
// parallelize by q rows using ggml_vec_dot_f32
|
||||||
|
|
||||||
// total rows in q
|
|
||||||
const int nr = neq1*neq2*neq3;
|
|
||||||
|
|
||||||
// rows per thread
|
|
||||||
const int dr = (nr + nth - 1)/nth;
|
|
||||||
|
|
||||||
// row range for this thread
|
|
||||||
const int ir0 = dr*ith;
|
|
||||||
const int ir1 = MIN(ir0 + dr, nr);
|
|
||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
float max_bias = 0.0f;
|
float max_bias = 0.0f;
|
||||||
float logit_softcap = 0.0f;
|
float logit_softcap = 0.0f;
|
||||||
@@ -8000,6 +7987,8 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
|
||||||
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
|
||||||
|
|
||||||
|
int ith = params->ith;
|
||||||
|
|
||||||
// loop over n_batch and n_head
|
// loop over n_batch and n_head
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
// q indices
|
// q indices
|
||||||
@@ -8147,6 +8136,91 @@ static void ggml_compute_forward_flash_attn_ext_f16(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * q = dst->src[0];
|
||||||
|
const ggml_tensor * k = dst->src[1];
|
||||||
|
const ggml_tensor * v = dst->src[2];
|
||||||
|
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||||
|
|
||||||
|
const int64_t DK = nek0;
|
||||||
|
const int64_t DV = nev0;
|
||||||
|
const int64_t N = neq1;
|
||||||
|
|
||||||
|
GGML_ASSERT(ne0 == DV);
|
||||||
|
GGML_ASSERT(ne2 == N);
|
||||||
|
|
||||||
|
// input tensor rows must be contiguous
|
||||||
|
GGML_ASSERT(nbq0 == ggml_type_size(q->type));
|
||||||
|
GGML_ASSERT(nbk0 == ggml_type_size(k->type));
|
||||||
|
GGML_ASSERT(nbv0 == ggml_type_size(v->type));
|
||||||
|
|
||||||
|
GGML_ASSERT(neq0 == DK);
|
||||||
|
GGML_ASSERT(nek0 == DK);
|
||||||
|
GGML_ASSERT(nev0 == DV);
|
||||||
|
|
||||||
|
GGML_ASSERT(neq1 == N);
|
||||||
|
|
||||||
|
// dst cannot be transposed or permuted
|
||||||
|
GGML_ASSERT(nb0 == sizeof(float));
|
||||||
|
GGML_ASSERT(nb0 <= nb1);
|
||||||
|
GGML_ASSERT(nb1 <= nb2);
|
||||||
|
GGML_ASSERT(nb2 <= nb3);
|
||||||
|
|
||||||
|
// parallelize by q rows using ggml_vec_dot_f32
|
||||||
|
|
||||||
|
// total rows in q
|
||||||
|
const int64_t nr = neq1*neq2*neq3;
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
// disable for NUMA
|
||||||
|
const bool disable_chunking = ggml_is_numa();
|
||||||
|
|
||||||
|
// 4x chunks per thread
|
||||||
|
int nth_scaled = nth * 4;
|
||||||
|
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
|
||||||
|
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
|
||||||
|
|
||||||
|
if (nth == 1 || nchunk < nth || disable_chunking) {
|
||||||
|
nchunk = nth;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ith == 0) {
|
||||||
|
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
|
||||||
|
ggml_threadpool_chunk_set(params->threadpool, nth);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_barrier(params->threadpool);
|
||||||
|
|
||||||
|
// The number of elements in each chunk
|
||||||
|
const int64_t dr = (nr + nchunk - 1) / nchunk;
|
||||||
|
|
||||||
|
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||||
|
int current_chunk = ith;
|
||||||
|
|
||||||
|
while (current_chunk < nchunk) {
|
||||||
|
const int64_t ir0 = dr * current_chunk;
|
||||||
|
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1);
|
||||||
|
|
||||||
|
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_compute_forward_flash_attn_ext(
|
void ggml_compute_forward_flash_attn_ext(
|
||||||
const ggml_compute_params * params,
|
const ggml_compute_params * params,
|
||||||
ggml_tensor * dst) {
|
ggml_tensor * dst) {
|
||||||
|
|||||||
Reference in New Issue
Block a user