CUDA: faster tile FA, add oob checks, more HSs (#16492)

This commit is contained in:
Johannes Gäßler
2025-10-11 20:54:32 +02:00
committed by GitHub
parent a3cb04744f
commit 11f0af5504
18 changed files with 1358 additions and 784 deletions

View File

@@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
file(GLOB GGML_SOURCES_CUDA "*.cu")
file(GLOB SRCS "template-instances/fattn-tile*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmq*.cu")

View File

@@ -245,7 +245,8 @@ static bool fp16_available(const int cc) {
}
static bool fast_fp16_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc);
return GGML_CUDA_CC_IS_AMD(cc) ||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
}
// To be used for feature selection of external libraries, e.g. cuBLAS.
@@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v,
}
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
// Important: do not use this function if dst and src both point at registers.
// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types.
// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions.
// If dst and src point at different address spaces then they are guaranteed to not be aliased.
template <int nbytes, int alignment = 0>
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
if constexpr (alignment != 0) {

View File

@@ -793,8 +793,6 @@ void launch_fattn(
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int id = ggml_cuda_get_device();
@@ -878,7 +876,7 @@ void launch_fattn(
// Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped.
// Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or
// multiple sequences of possibly different lengths.
if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) {
const int s31 = mask->nb[1] / sizeof(half2);
const int s33 = mask->nb[3] / sizeof(half2);
@@ -916,8 +914,7 @@ void launch_fattn(
dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
} else {
GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0);
const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size.
// parallel_blocks must not be larger than what the tensor size allows:
parallel_blocks = std::min(parallel_blocks, ntiles_KQ);
@@ -946,7 +943,7 @@ void launch_fattn(
blocks_num.x = ntiles_x;
blocks_num.y = parallel_blocks;
blocks_num.z = Q->ne[2]*Q->ne[3];
blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3];
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));

View File

@@ -1,756 +1,45 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile.cuh"
#include "fattn-wmma-f16.cuh"
// kq_stride == number of KQ rows to process per iteration
// kq_nbatch == number of K columns to load in parallel for KQ calculation
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
if (GGML_CUDA_CC_IS_AMD(cc)) {
if (GGML_CUDA_CC_IS_RDNA(cc)) {
switch (D) {
case 64:
return 128;
case 128:
case 256:
return ncols <= 16 ? 128 : 64;
default:
GGML_ABORT("fatal error");
return -1;
}
}
switch (D) {
case 64:
return ncols == 32 ? 128 : 64;
case 128:
return ncols == 32 ? 64 : 32;
case 256:
return 32;
default:
GGML_ABORT("fatal error");
return -1;
}
}
if (fast_fp16_available(cc)) {
switch (D) {
case 64:
case 128:
case 256:
return ncols <= 16 ? 128 : 64;
default:
GGML_ABORT("fatal error");
return -1;
}
}
switch (D) {
case 64:
return ncols <= 16 ? 128 : 64;
case 128:
return ncols <= 16 ? 64 : 32;
case 256:
return 32;
default:
GGML_ABORT("fatal error");
return -1;
}
GGML_UNUSED(warp_size);
}
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
#ifdef GGML_USE_HIP
#ifdef RDNA
switch (D) {
case 64:
return 128;
case 128:
case 256:
return ncols <= 16 ? 128 : 64;
default:
return -1;
}
#else
switch (D) {
case 64:
return ncols == 32 ? 128 : 64;
case 128:
return ncols == 32 ? 64 : 32;
case 256:
return 32;
default:
return -1;
}
#endif // RDNA
#else
#ifdef FAST_FP16_AVAILABLE
switch (D) {
case 64:
case 128:
case 256:
return ncols <= 16 ? 128 : 64;
default:
return -1;
}
#else
switch (D) {
case 64:
return ncols <= 16 ? 128 : 64;
case 128:
return ncols <= 16 ? 64 : 32;
case 256:
return 32;
default:
return -1;
}
#endif // FAST_FP16_AVAILABLE
#endif // GGML_USE_HIP
GGML_UNUSED_VARS(ncols, warp_size);
}
static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) {
#ifdef GGML_USE_HIP
switch (D) {
case 64:
return 64;
case 128:
case 256:
return 128;
default:
return -1;
}
#else
#ifdef FAST_FP16_AVAILABLE
switch (D) {
case 64:
return 64;
case 128:
case 256:
return 128;
default:
return -1;
}
#else
switch (D) {
case 64:
return 64;
case 128:
return 128;
case 256:
return ncols <= 16 ? 128 : 64;
default:
return -1;
}
#endif // FAST_FP16_AVAILABLE
#endif // GGML_USE_HIP
GGML_UNUSED_VARS(ncols, warp_size);
}
static int fattn_tile_get_nthreads_host(const int cc, const int ncols) {
return 256;
GGML_UNUSED_VARS(cc, ncols);
}
static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) {
return 256;
GGML_UNUSED(ncols);
}
static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) {
#ifdef RDNA
return 3;
#else
return ncols <= 16 ? 3 : 2;
#endif // RDNA
GGML_UNUSED(ncols);
}
template<int D, int ncols, bool use_logit_softcap> // D == head size
__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols))
static __global__ void flash_attn_tile(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
#ifdef GGML_USE_WMMA_FATTN
NO_DEVICE_CODE;
return;
#endif // GGML_USE_WMMA_FATTN
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
return;
}
constexpr int warp_size = 32;
constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size;
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch");
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes();
constexpr int cpy_ne = cpy_nb / 4;
constexpr int cpw = ncols/nwarps; // cols per warp
// softmax_iter_j == number of KQ columns for which to calculate softmax in parallel.
// KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes.
#ifdef FAST_FP16_AVAILABLE
constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne;
__shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
__shared__ half2 Q_tmp[ncols][D/2];
__shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
#else
constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne;
__shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j];
__shared__ float Q_tmp[ncols][D];
__shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
#endif // FAST_FP16_AVAILABLE
static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j");
float KQ_max[cpw];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
KQ_max[j0/nwarps] = -FLT_MAX/2.0f;
}
float KQ_sum[cpw] = {0.0f};
// Load Q data, convert to FP16 if fast.
#pragma unroll
for (int j0 = 0; j0 < cpw; ++j0) {
const int j = j0 + threadIdx.y*cpw;
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
float tmp_f[cpy_ne_D] = {0.0f};
if (ic0 + j < ne01) {
ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]);
}
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
tmp_f[i1] *= scale;
}
#ifdef FAST_FP16_AVAILABLE
half2 tmp_h2[cpy_ne_D/2];
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) {
tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]);
}
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2);
#else
ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f);
#endif // FAST_FP16_AVAILABLE
}
}
__syncthreads();
// Main loop over KV cache:
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
// Calculate KQ tile and keep track of new maximum KQ values:
float KQ_max_new[cpw];
#pragma unroll
for (int j = 0; j < cpw; ++j) {
KQ_max_new[j] = KQ_max[j];
}
float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication.
// KQ = K @ Q matrix multiplication:
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
#ifdef FAST_FP16_AVAILABLE
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size);
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) {
ggml_cuda_memcpy_1<cpy_ne_kqnb*4>(
&KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb],
&K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]);
}
#else
constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size;
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) {
half2 tmp_h2[cpy_ne_kqnb/2];
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]);
float2 tmp_f2[cpy_ne_kqnb/2];
#pragma unroll
for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) {
tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]);
}
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
&KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2);
}
#endif // FAST_FP16_AVAILABLE
}
__syncthreads();
#ifdef FAST_FP16_AVAILABLE
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
half2 K_k[kq_stride/warp_size][cpy_ne];
half2 Q_k[cpw][cpy_ne];
#else
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
float K_k[kq_stride/warp_size][cpy_ne];
float Q_k[cpw][cpy_ne];
#endif // FAST_FP16_AVAILABLE
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#ifdef FAST_FP16_AVAILABLE
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
#else
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
#ifdef FAST_FP16_AVAILABLE
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
#else
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
#pragma unroll
for (int k = 0; k < cpy_ne; ++k) {
ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]);
}
}
}
}
if (k_KQ_0 + kq_nbatch < D) {
__syncthreads(); // Sync not needed on last iteration.
}
}
// Apply logit softcap, mask, update KQ_max:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) {
const int j_KQ = j_KQ_0 + threadIdx.y*cpw;
if (use_logit_softcap) {
KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
}
KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]);
}
}
__syncthreads();
// Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators:
#pragma unroll
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
#ifdef FAST_FP16_AVAILABLE
half tmp[kq_stride/warp_size][softmax_iter_j];
#else
float tmp[kq_stride/warp_size][softmax_iter_j];
#endif // FAST_FP16_AVAILABLE
#pragma unroll
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]);
const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]);
KQ_max[j0+j1] = KQ_max_new[j0+j1];
float KQ_sum_add = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]);
KQ_sum_add += val;
tmp[i0/warp_size][j1] = val;
}
KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add;
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale;
VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
const int i = i0 + threadIdx.x;
ggml_cuda_memcpy_1<sizeof(tmp[0])>(
KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]);
}
}
// VKQ = V @ KQ matrix multiplication:
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K.
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
#pragma unroll
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
#pragma unroll
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
const int k_tile = k1 + threadIdx.y;
#ifdef FAST_FP16_AVAILABLE
constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
ggml_cuda_memcpy_1<cpy_ne_D*4>(
&KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D],
&V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]);
}
#else
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
half2 tmp_h2[cpy_ne_D/2];
ggml_cuda_memcpy_1<sizeof(tmp_h2)>(
tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]);
float2 tmp_f2[cpy_ne_D/2];
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
tmp_f2[i1] = __half22float2(tmp_h2[i1]);
}
ggml_cuda_memcpy_1<sizeof(tmp_f2)>(
&KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2);
}
#endif // FAST_FP16_AVAILABLE
}
__syncthreads();
#ifdef FAST_FP16_AVAILABLE
#pragma unroll
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
half2 V_k[(D/2)/warp_size];
half2 KQ_k[cpw];
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]);
}
#pragma unroll
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
half tmp[softmax_iter_j];
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>(
&tmp, KQ[j][k0 + k1]);
#pragma unroll
for (int j1 = 0; j1 < softmax_iter_j; ++j1) {
KQ_k[j0+j1] = __half2half2(tmp[j1]);
}
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
#pragma unroll
for (int j0 = 0; j0 < cpw; ++j0) {
VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0];
}
}
}
#else
#pragma unroll
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
float2 V_k[(D/2)/warp_size];
float KQ_k[cpw];
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]);
}
#pragma unroll
for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) {
const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j);
ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>(
&KQ_k[j0], KQ[j][k0 + k1]);
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
#pragma unroll
for (int j0 = 0; j0 < cpw; ++j0) {
VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0];
VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0];
}
}
}
#endif // FAST_FP16_AVAILABLE
__syncthreads();
}
}
// Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const float sink = sinksf[head];
#pragma unroll
for (int j0 = 0; j0 < cpw; ++j0) {
float KQ_max_new_j = fmaxf(KQ_max[j0], sink);
KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j);
const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j);
KQ_max[j0] = KQ_max_new_j;
const float val = expf(sink - KQ_max[j0]);
KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale;
if (threadIdx.x == 0) {
KQ_sum[j0] += val;
}
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0][i0/warp_size] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0][i0/warp_size].x *= KQ_max_scale;
VKQ[j0][i0/warp_size].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
}
}
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]);
}
if (gridDim.y == 1) {
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]);
#pragma unroll
for (int i = 0; i < (D/2)/warp_size; ++i) {
VKQ[j_VKQ_0][i] *= KQ_sum_j_inv;
}
#else
const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0];
#pragma unroll
for (int i = 0; i < (D/2)/warp_size; ++i) {
VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv;
VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv;
}
#endif // FAST_FP16_AVAILABLE
}
}
// Write back results:
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) {
const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw;
if (ic0 + j_VKQ >= ne01) {
return;
}
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#ifdef FAST_FP16_AVAILABLE
constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) {
float2 tmp[cpy_ne_D];
#pragma unroll
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]);
}
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
}
#else
constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) {
ggml_cuda_memcpy_1<cpy_ne_D*4>(
&dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]);
}
#endif // FAST_FP16_AVAILABLE
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]);
}
}
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
template <int D, bool use_logit_softcap>
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = 32;
constexpr size_t nbytes_shared = 0;
#ifdef GGML_USE_HIP
if constexpr (D <= 128) {
if (Q->ne[1] > 32) {
constexpr int cols_per_block = 64;
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
return;
}
}
#endif // GGML_USE_HIP
if (Q->ne[1] > 16) {
constexpr int cols_per_block = 32;
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
return;
}
constexpr int cols_per_block = 16;
const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size;
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
}
template <bool use_logit_softcap>
static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
switch (K->ne[0]) {
case 40: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst);
} break;
case 64: {
launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst);
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst);
} break;
case 80: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst);
} break;
case 96: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst);
} break;
case 112: {
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst);
} break;
case 128: {
launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst);
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst);
} break;
case 256: {
launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst);
GGML_ASSERT(V->ne[0] == K->ne[0]);
ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst);
} break;
case 576: {
GGML_ASSERT(V->ne[0] == 512);
ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst);
} break;
default: {
GGML_ABORT("Unsupported head size");
} break;
}
}
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,3 +1,5 @@
#pragma once
#include "common.cuh"
#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA)

View File

@@ -198,6 +198,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_NONE;
#endif// FLASH_ATTN_AVAILABLE
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
@@ -206,37 +207,32 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
const int gqa_ratio = Q->ne[2] / K->ne[2];
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
// The effective batch size for the kernel can be increased by gqa_ratio.
// The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded,
const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0;
const int cc = ggml_cuda_info().devices[device].cc;
// TODO: temporary until support is extended
// https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206
if (K->ne[1] % FATTN_KQ_STRIDE != 0) {
return BEST_FATTN_KERNEL_NONE;
}
switch (K->ne[0]) {
case 40:
case 64:
case 128:
case 256:
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;
}
break;
case 80:
case 96:
case 128:
case 112:
case 256:
if (V->ne[0] != K->ne[0]) {
return BEST_FATTN_KERNEL_NONE;
}
if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) {
return BEST_FATTN_KERNEL_NONE;
}
break;
case 576:
if (V->ne[0] != 512) {
return BEST_FATTN_KERNEL_NONE;
}
if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) {
if (!gqa_opt_applies || gqa_ratio % 16 != 0) {
return BEST_FATTN_KERNEL_NONE;
}
break;
@@ -270,47 +266,57 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
return BEST_FATTN_KERNEL_NONE;
}
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0;
// If Turing tensor cores available, use them except for some cases with batch size 1:
if (turing_mma_available(cc)) {
best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
// For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes:
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0;
// If Turing tensor cores available, use them:
if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) {
if (can_use_vector_kernel) {
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
best = BEST_FATTN_KERNEL_VEC;
return BEST_FATTN_KERNEL_VEC;
}
} else {
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
if (Q->ne[1] <= 2) {
best = BEST_FATTN_KERNEL_VEC;
return BEST_FATTN_KERNEL_VEC;
}
} else {
if (Q->ne[1] == 1) {
best = BEST_FATTN_KERNEL_VEC;
return BEST_FATTN_KERNEL_VEC;
}
}
}
if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) {
best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
if (!gqa_opt_applies && Q->ne[1] == 1) {
return BEST_FATTN_KERNEL_VEC;
}
}
return best;
return BEST_FATTN_KERNEL_MMA_F16;
}
// Use kernels specialized for small batch sizes if possible:
if (Q->ne[1] <= 8 && can_use_vector_kernel) {
return BEST_FATTN_KERNEL_VEC;
}
// For large batch sizes, use the WMMA kernel if possible:
if (ggml_cuda_should_use_wmma_fattn(cc)) {
// Use the WMMA kernel if possible:
if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) {
if (can_use_vector_kernel && Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
return BEST_FATTN_KERNEL_WMMA_F16;
}
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
// If there are no tensor cores available, use the generic tile kernel:
if (can_use_vector_kernel) {
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
if (Q->ne[1] == 1) {
if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_VEC;
}
}
} else {
if (Q->ne[1] <= 2) {
return BEST_FATTN_KERNEL_VEC;
}
}
}
return BEST_FATTN_KERNEL_TILE;
}

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(112, 112);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(128, 128);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(256, 256);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(40, 40);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(576, 512);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(64, 64);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(80, 80);

View File

@@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(96, 96);

View File

@@ -3,8 +3,17 @@
from glob import glob
import os
HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576]
TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"]
SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v});
"""
SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-vec.cuh"
@@ -51,6 +60,11 @@ def get_short_name(long_quant_name):
for filename in glob("*.cu"):
os.remove(filename)
for head_size_kq in HEAD_SIZES_KQ:
head_size_v = head_size_kq if head_size_kq != 576 else 512
with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f:
f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v))
for type_k in TYPES_KV:
for type_v in TYPES_KV:
with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
@@ -64,7 +78,9 @@ for ncols in [8, 16, 32, 64]:
with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f:
f.write(SOURCE_FATTN_MMA_START)
for head_size_kq in [64, 80, 96, 112, 128, 256, 576]:
for head_size_kq in HEAD_SIZES_KQ:
if head_size_kq == 40:
continue
if head_size_kq != 576 and ncols2 == 16:
continue
if head_size_kq == 576 and ncols2 != 16:

View File

@@ -53,6 +53,8 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")

View File

@@ -30,6 +30,8 @@ if (MUSAToolkit_FOUND)
list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")