mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-17 11:37:10 +00:00
50 lines
1.7 KiB
Plaintext
50 lines
1.7 KiB
Plaintext
#include "common.cuh"
|
|
#include "fattn-tile.cuh"
|
|
#include "fattn-wmma-f16.cuh"
|
|
|
|
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
|
const ggml_tensor * K = dst->src[1];
|
|
const ggml_tensor * V = dst->src[2];
|
|
switch (K->ne[0]) {
|
|
case 40: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
|
|
} break;
|
|
case 64: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
|
|
} break;
|
|
case 72: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_cuda_flash_attn_ext_tile_case< 72, 72>(ctx, dst);
|
|
} break;
|
|
case 80: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
|
|
} break;
|
|
case 96: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
|
|
} break;
|
|
case 112: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst);
|
|
} break;
|
|
case 128: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
|
|
} break;
|
|
case 256: {
|
|
GGML_ASSERT(V->ne[0] == K->ne[0]);
|
|
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
|
|
} break;
|
|
case 576: {
|
|
GGML_ASSERT(V->ne[0] == 512);
|
|
ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
|
|
} break;
|
|
default: {
|
|
GGML_ABORT("Unsupported head size");
|
|
} break;
|
|
}
|
|
}
|