#include "common.cuh" #include "fattn-tile.cuh" #include "fattn-wmma-f16.cuh" void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; switch (K->ne[0]) { case 40: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst); } break; case 64: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst); } break; case 72: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case< 72, 72>(ctx, dst); } break; case 80: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst); } break; case 96: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst); } break; case 112: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst); } break; case 128: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst); } break; case 256: { GGML_ASSERT(V->ne[0] == K->ne[0]); ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); } break; case 576: { GGML_ASSERT(V->ne[0] == 512); ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); } break; default: { GGML_ABORT("Unsupported head size"); } break; } }