mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	CUDA: faster tile FA, add oob checks, more HSs (#16492)
This commit is contained in:
		| @@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND) | |||||||
|     list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") |     list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") | ||||||
|  |  | ||||||
|     file(GLOB   GGML_SOURCES_CUDA "*.cu") |     file(GLOB   GGML_SOURCES_CUDA "*.cu") | ||||||
|  |     file(GLOB   SRCS "template-instances/fattn-tile*.cu") | ||||||
|  |     list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||||||
|     file(GLOB   SRCS "template-instances/fattn-mma*.cu") |     file(GLOB   SRCS "template-instances/fattn-mma*.cu") | ||||||
|     list(APPEND GGML_SOURCES_CUDA ${SRCS}) |     list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||||||
|     file(GLOB   SRCS "template-instances/mmq*.cu") |     file(GLOB   SRCS "template-instances/mmq*.cu") | ||||||
|   | |||||||
| @@ -245,7 +245,8 @@ static bool fp16_available(const int cc) { | |||||||
| } | } | ||||||
|  |  | ||||||
| static bool fast_fp16_available(const int cc) { | static bool fast_fp16_available(const int cc) { | ||||||
|     return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc); |     return GGML_CUDA_CC_IS_AMD(cc) || | ||||||
|  |         (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610); | ||||||
| } | } | ||||||
|  |  | ||||||
| // To be used for feature selection of external libraries, e.g. cuBLAS. | // To be used for feature selection of external libraries, e.g. cuBLAS. | ||||||
| @@ -571,6 +572,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, | |||||||
| } | } | ||||||
|  |  | ||||||
| // Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD. | // Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD. | ||||||
|  | // Important: do not use this function if dst and src both point at registers. | ||||||
|  | //     Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types. | ||||||
|  | //     The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions. | ||||||
|  | //     If dst and src point at different address spaces then they are guaranteed to not be aliased. | ||||||
| template <int nbytes, int alignment = 0> | template <int nbytes, int alignment = 0> | ||||||
| static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) { | static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) { | ||||||
|     if constexpr (alignment != 0) { |     if constexpr (alignment != 0) { | ||||||
|   | |||||||
| @@ -793,8 +793,6 @@ void launch_fattn( | |||||||
|     GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && |     GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && | ||||||
|         "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); |         "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); | ||||||
|  |  | ||||||
|     GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); |  | ||||||
|  |  | ||||||
|     ggml_cuda_pool & pool = ctx.pool(); |     ggml_cuda_pool & pool = ctx.pool(); | ||||||
|     cudaStream_t main_stream = ctx.stream(); |     cudaStream_t main_stream = ctx.stream(); | ||||||
|     const int id  = ggml_cuda_get_device(); |     const int id  = ggml_cuda_get_device(); | ||||||
| @@ -878,7 +876,7 @@ void launch_fattn( | |||||||
|     // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. |     // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. | ||||||
|     // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or |     // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or | ||||||
|     //     multiple sequences of possibly different lengths. |     //     multiple sequences of possibly different lengths. | ||||||
|     if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { |     if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { | ||||||
|         const int s31 = mask->nb[1] / sizeof(half2); |         const int s31 = mask->nb[1] / sizeof(half2); | ||||||
|         const int s33 = mask->nb[3] / sizeof(half2); |         const int s33 = mask->nb[3] / sizeof(half2); | ||||||
|  |  | ||||||
| @@ -916,8 +914,7 @@ void launch_fattn( | |||||||
|  |  | ||||||
|         dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); |         dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); | ||||||
|     } else { |     } else { | ||||||
|         GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); |         const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. | ||||||
|         const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. |  | ||||||
|  |  | ||||||
|         // parallel_blocks must not be larger than what the tensor size allows: |         // parallel_blocks must not be larger than what the tensor size allows: | ||||||
|         parallel_blocks = std::min(parallel_blocks, ntiles_KQ); |         parallel_blocks = std::min(parallel_blocks, ntiles_KQ); | ||||||
| @@ -946,7 +943,7 @@ void launch_fattn( | |||||||
|  |  | ||||||
|         blocks_num.x = ntiles_x; |         blocks_num.x = ntiles_x; | ||||||
|         blocks_num.y = parallel_blocks; |         blocks_num.y = parallel_blocks; | ||||||
|         blocks_num.z = Q->ne[2]*Q->ne[3]; |         blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3]; | ||||||
|  |  | ||||||
|         if (parallel_blocks > 1) { |         if (parallel_blocks > 1) { | ||||||
|             dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); |             dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); | ||||||
|   | |||||||
| @@ -1,756 +1,45 @@ | |||||||
| #include "common.cuh" | #include "common.cuh" | ||||||
| #include "fattn-common.cuh" |  | ||||||
| #include "fattn-tile.cuh" | #include "fattn-tile.cuh" | ||||||
| #include "fattn-wmma-f16.cuh" | #include "fattn-wmma-f16.cuh" | ||||||
|  |  | ||||||
| // kq_stride == number of KQ rows to process per iteration | void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||||||
| // kq_nbatch == number of K columns to load in parallel for KQ calculation |     const ggml_tensor * K = dst->src[1]; | ||||||
|  |     const ggml_tensor * V = dst->src[2]; | ||||||
| static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) { |     switch (K->ne[0]) { | ||||||
|     if (GGML_CUDA_CC_IS_AMD(cc)) { |         case  40: { | ||||||
|         if (GGML_CUDA_CC_IS_RDNA(cc)) { |             GGML_ASSERT(V->ne[0] == K->ne[0]); | ||||||
|             switch (D) { |             ggml_cuda_flash_attn_ext_tile_case< 40,  40>(ctx, dst); | ||||||
|                 case 64: |         } break; | ||||||
|                     return 128; |  | ||||||
|                 case 128: |  | ||||||
|                 case 256: |  | ||||||
|                     return ncols <= 16 ? 128 : 64; |  | ||||||
|                 default: |  | ||||||
|                     GGML_ABORT("fatal error"); |  | ||||||
|                     return -1; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         switch (D) { |  | ||||||
|             case 64: |  | ||||||
|                 return ncols == 32 ? 128 : 64; |  | ||||||
|             case 128: |  | ||||||
|                 return ncols == 32 ? 64 : 32; |  | ||||||
|             case 256: |  | ||||||
|                 return 32; |  | ||||||
|             default: |  | ||||||
|                 GGML_ABORT("fatal error"); |  | ||||||
|                 return -1; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     if (fast_fp16_available(cc)) { |  | ||||||
|         switch (D) { |  | ||||||
|             case 64: |  | ||||||
|             case 128: |  | ||||||
|             case 256: |  | ||||||
|                 return ncols <= 16 ? 128 : 64; |  | ||||||
|             default: |  | ||||||
|                 GGML_ABORT("fatal error"); |  | ||||||
|                 return -1; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|     switch (D) { |  | ||||||
|         case 64: |  | ||||||
|             return ncols <= 16 ? 128 : 64; |  | ||||||
|         case 128: |  | ||||||
|             return ncols <= 16 ? 64 : 32; |  | ||||||
|         case 256: |  | ||||||
|             return 32; |  | ||||||
|         default: |  | ||||||
|             GGML_ABORT("fatal error"); |  | ||||||
|             return -1; |  | ||||||
|     } |  | ||||||
|     GGML_UNUSED(warp_size); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) { |  | ||||||
| #ifdef GGML_USE_HIP |  | ||||||
| #ifdef RDNA |  | ||||||
|     switch (D) { |  | ||||||
|         case 64: |  | ||||||
|             return 128; |  | ||||||
|         case 128: |  | ||||||
|         case 256: |  | ||||||
|             return ncols <= 16 ? 128 : 64; |  | ||||||
|         default: |  | ||||||
|             return -1; |  | ||||||
|     } |  | ||||||
| #else |  | ||||||
|     switch (D) { |  | ||||||
|         case 64: |  | ||||||
|             return ncols == 32 ? 128 : 64; |  | ||||||
|         case 128: |  | ||||||
|             return ncols == 32 ? 64 : 32; |  | ||||||
|         case 256: |  | ||||||
|             return 32; |  | ||||||
|         default: |  | ||||||
|             return -1; |  | ||||||
|     } |  | ||||||
| #endif // RDNA |  | ||||||
| #else |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
|     switch (D) { |  | ||||||
|         case 64: |  | ||||||
|         case 128: |  | ||||||
|         case 256: |  | ||||||
|             return ncols <= 16 ? 128 : 64; |  | ||||||
|         default: |  | ||||||
|             return -1; |  | ||||||
|     } |  | ||||||
| #else |  | ||||||
|     switch (D) { |  | ||||||
|         case 64: |  | ||||||
|             return ncols <= 16 ? 128 : 64; |  | ||||||
|         case 128: |  | ||||||
|             return ncols <= 16 ? 64 : 32; |  | ||||||
|         case 256: |  | ||||||
|             return 32; |  | ||||||
|         default: |  | ||||||
|             return -1; |  | ||||||
|     } |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
| #endif // GGML_USE_HIP |  | ||||||
|     GGML_UNUSED_VARS(ncols, warp_size); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) { |  | ||||||
| #ifdef GGML_USE_HIP |  | ||||||
|     switch (D) { |  | ||||||
|         case 64: |  | ||||||
|             return 64; |  | ||||||
|         case 128: |  | ||||||
|         case 256: |  | ||||||
|             return 128; |  | ||||||
|         default: |  | ||||||
|             return -1; |  | ||||||
|     } |  | ||||||
| #else |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
|     switch (D) { |  | ||||||
|         case 64: |  | ||||||
|             return 64; |  | ||||||
|         case 128: |  | ||||||
|         case 256: |  | ||||||
|             return 128; |  | ||||||
|         default: |  | ||||||
|             return -1; |  | ||||||
|     } |  | ||||||
| #else |  | ||||||
|     switch (D) { |  | ||||||
|         case 64: |  | ||||||
|             return 64; |  | ||||||
|         case 128: |  | ||||||
|             return 128; |  | ||||||
|         case 256: |  | ||||||
|             return ncols <= 16 ? 128 : 64; |  | ||||||
|         default: |  | ||||||
|             return -1; |  | ||||||
|     } |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
| #endif // GGML_USE_HIP |  | ||||||
|     GGML_UNUSED_VARS(ncols, warp_size); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static int fattn_tile_get_nthreads_host(const int cc, const int ncols) { |  | ||||||
|     return 256; |  | ||||||
|     GGML_UNUSED_VARS(cc, ncols); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) { |  | ||||||
|     return 256; |  | ||||||
|     GGML_UNUSED(ncols); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) { |  | ||||||
| #ifdef RDNA |  | ||||||
|     return 3; |  | ||||||
| #else |  | ||||||
|     return ncols <= 16 ? 3 : 2; |  | ||||||
| #endif // RDNA |  | ||||||
|     GGML_UNUSED(ncols); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template<int D, int ncols, bool use_logit_softcap> // D == head size |  | ||||||
| __launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols)) |  | ||||||
| static __global__ void flash_attn_tile( |  | ||||||
|         const char * __restrict__ Q, |  | ||||||
|         const char * __restrict__ K, |  | ||||||
|         const char * __restrict__ V, |  | ||||||
|         const char * __restrict__ mask, |  | ||||||
|         const char * __restrict__ sinks, |  | ||||||
|         const int  * __restrict__ KV_max, |  | ||||||
|         float      * __restrict__ dst, |  | ||||||
|         float2     * __restrict__ dst_meta, |  | ||||||
|         const float scale, |  | ||||||
|         const float max_bias, |  | ||||||
|         const float m0, |  | ||||||
|         const float m1, |  | ||||||
|         const uint32_t n_head_log2, |  | ||||||
|         const float logit_softcap, |  | ||||||
|         const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, |  | ||||||
|                             const int32_t nb01, const int32_t nb02, const int32_t nb03, |  | ||||||
|         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, |  | ||||||
|                             const int32_t nb11, const int32_t nb12, const int64_t nb13, |  | ||||||
|                             const int32_t nb21, const int32_t nb22, const int64_t nb23, |  | ||||||
|                             const int32_t ne31, const int32_t ne32, const int32_t ne33, |  | ||||||
|                             const int32_t nb31, const int32_t nb32, const int64_t nb33) { |  | ||||||
| #ifdef FLASH_ATTN_AVAILABLE |  | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |  | ||||||
| #ifdef GGML_USE_WMMA_FATTN |  | ||||||
|     NO_DEVICE_CODE; |  | ||||||
|     return; |  | ||||||
| #endif // GGML_USE_WMMA_FATTN |  | ||||||
|  |  | ||||||
|     if (use_logit_softcap && !(D == 128 || D == 256)) { |  | ||||||
|         GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, |  | ||||||
|             max_bias, m0, m1, n_head_log2, logit_softcap, |  | ||||||
|             ne00, ne01, ne02, ne03, |  | ||||||
|                   nb01, nb02, nb03, |  | ||||||
|             ne10, ne11, ne12, ne13, |  | ||||||
|                   nb11, nb12, nb13, |  | ||||||
|                   nb21, nb22, nb23, |  | ||||||
|                   ne31, ne32, ne33, |  | ||||||
|                   nb31, nb32, nb33); |  | ||||||
|         NO_DEVICE_CODE; |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     constexpr int warp_size = 32; |  | ||||||
|     constexpr int nwarps    = fattn_tile_get_nthreads_device(ncols) / warp_size; |  | ||||||
|     constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size); |  | ||||||
|     static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size."); |  | ||||||
|     constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size); |  | ||||||
|     static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch"); |  | ||||||
|  |  | ||||||
|     // In this kernel Q, K, V are matrices while i, j, k are matrix indices. |  | ||||||
|  |  | ||||||
|     const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. |  | ||||||
|  |  | ||||||
|     const int sequence = blockIdx.z / ne02; |  | ||||||
|     const int head = blockIdx.z - sequence*ne02; |  | ||||||
|     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. |  | ||||||
|     const float * Q_f    = (const float *) (Q    + nb03* sequence         + nb02* head              + nb01*ic0); |  | ||||||
|     const half2 * K_h2   = (const half2 *) (K    + nb13* sequence         + nb12*(head / gqa_ratio)); |  | ||||||
|     const half2 * V_h2   = (const half2 *) (V    + nb13* sequence         + nb12*(head / gqa_ratio)); // K and V have same shape |  | ||||||
|     const half  * maskh  = (const half  *) (mask + nb33*(sequence % ne33)                           + nb31*ic0); |  | ||||||
|     const float * sinksf = (const float *) (sinks); |  | ||||||
|  |  | ||||||
|     const int stride_KV2 = nb11 / sizeof(half2); |  | ||||||
|  |  | ||||||
|     const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); |  | ||||||
|  |  | ||||||
|     constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); |  | ||||||
|     constexpr int cpy_ne = cpy_nb / 4; |  | ||||||
|  |  | ||||||
|     constexpr int cpw = ncols/nwarps; // cols per warp |  | ||||||
|  |  | ||||||
|     // softmax_iter_j == number of KQ columns for which to calculate softmax in parallel. |  | ||||||
|     // KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes. |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
|     constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; |  | ||||||
|  |  | ||||||
|     __shared__ half  KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; |  | ||||||
|     __shared__ half2 Q_tmp[ncols][D/2]; |  | ||||||
|     __shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts. |  | ||||||
|     half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; |  | ||||||
| #else |  | ||||||
|     constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; |  | ||||||
|  |  | ||||||
|     __shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; |  | ||||||
|     __shared__ float Q_tmp[ncols][D]; |  | ||||||
|     __shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts. |  | ||||||
|     float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
|     static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j"); |  | ||||||
|  |  | ||||||
|     float KQ_max[cpw]; |  | ||||||
| #pragma unroll |  | ||||||
|     for (int j0 = 0; j0 < ncols; j0 += nwarps) { |  | ||||||
|         KQ_max[j0/nwarps] = -FLT_MAX/2.0f; |  | ||||||
|     } |  | ||||||
|     float KQ_sum[cpw] = {0.0f}; |  | ||||||
|  |  | ||||||
|     // Load Q data, convert to FP16 if fast. |  | ||||||
| #pragma unroll |  | ||||||
|     for (int j0 = 0; j0 < cpw; ++j0) { |  | ||||||
|         const int j = j0 + threadIdx.y*cpw; |  | ||||||
|  |  | ||||||
|         constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|         for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { |  | ||||||
|             float tmp_f[cpy_ne_D] = {0.0f}; |  | ||||||
|             if (ic0 + j < ne01) { |  | ||||||
|                 ggml_cuda_memcpy_1<sizeof(tmp_f)>(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]); |  | ||||||
|             } |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|             for (int i1 = 0; i1 < cpy_ne_D; ++i1) { |  | ||||||
|                 tmp_f[i1] *= scale; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
|             half2 tmp_h2[cpy_ne_D/2]; |  | ||||||
| #pragma unroll |  | ||||||
|             for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { |  | ||||||
|                 tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); |  | ||||||
|             } |  | ||||||
|             ggml_cuda_memcpy_1<sizeof(tmp_h2)>(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2); |  | ||||||
| #else |  | ||||||
|             ggml_cuda_memcpy_1<sizeof(tmp_f)> (&Q_tmp[j][i0   + threadIdx.x* cpy_ne_D],    tmp_f); |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     __syncthreads(); |  | ||||||
|  |  | ||||||
|     // Main loop over KV cache: |  | ||||||
|     const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; |  | ||||||
|     for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) { |  | ||||||
|         // Calculate KQ tile and keep track of new maximum KQ values: |  | ||||||
|  |  | ||||||
|         float KQ_max_new[cpw]; |  | ||||||
| #pragma unroll |  | ||||||
|         for (int j = 0; j < cpw; ++j) { |  | ||||||
|             KQ_max_new[j] = KQ_max[j]; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication. |  | ||||||
|  |  | ||||||
|         // KQ = K @ Q matrix multiplication: |  | ||||||
| #pragma unroll |  | ||||||
|         for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) { |  | ||||||
| #pragma unroll |  | ||||||
|             for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) { |  | ||||||
|                 const int i_KQ = i_KQ_0 + threadIdx.y; |  | ||||||
|  |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
|                 constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size); |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) { |  | ||||||
|                     ggml_cuda_memcpy_1<cpy_ne_kqnb*4>( |  | ||||||
|                         &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], |  | ||||||
|                         &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]); |  | ||||||
|                 } |  | ||||||
| #else |  | ||||||
|                 constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size; |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) { |  | ||||||
|                     half2 tmp_h2[cpy_ne_kqnb/2]; |  | ||||||
|                     ggml_cuda_memcpy_1<sizeof(tmp_h2)>( |  | ||||||
|                         tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]); |  | ||||||
|  |  | ||||||
|                     float2 tmp_f2[cpy_ne_kqnb/2]; |  | ||||||
| #pragma unroll |  | ||||||
|                     for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) { |  | ||||||
|                         tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]); |  | ||||||
|                     } |  | ||||||
|                     ggml_cuda_memcpy_1<sizeof(tmp_f2)>( |  | ||||||
|                         &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2); |  | ||||||
|                 } |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             __syncthreads(); |  | ||||||
|  |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
| #pragma unroll |  | ||||||
|             for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) { |  | ||||||
|                 half2 K_k[kq_stride/warp_size][cpy_ne]; |  | ||||||
|                 half2 Q_k[cpw][cpy_ne]; |  | ||||||
| #else |  | ||||||
| #pragma unroll |  | ||||||
|             for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) { |  | ||||||
|                 float K_k[kq_stride/warp_size][cpy_ne]; |  | ||||||
|                 float Q_k[cpw][cpy_ne]; |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { |  | ||||||
|                     const int i_KQ = i_KQ_0 + threadIdx.x; |  | ||||||
|  |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
|                     ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]); |  | ||||||
| #else |  | ||||||
|                     ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch   + cpy_ne) + k_KQ_1]); |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
|                 } |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { |  | ||||||
|                     const int j_KQ = j_KQ_0 + threadIdx.y*cpw; |  | ||||||
|  |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
|                     ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]); |  | ||||||
| #else |  | ||||||
|                     ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0   + k_KQ_1]); |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { |  | ||||||
| #pragma unroll |  | ||||||
|                     for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { |  | ||||||
| #pragma unroll |  | ||||||
|                         for (int k = 0; k < cpy_ne; ++k) { |  | ||||||
|                             ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]); |  | ||||||
|                         } |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             if (k_KQ_0 + kq_nbatch < D) { |  | ||||||
|                 __syncthreads(); // Sync not needed on last iteration. |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         // Apply logit softcap, mask, update KQ_max: |  | ||||||
| #pragma unroll |  | ||||||
|         for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { |  | ||||||
|             const int i_KQ = i_KQ_0 + threadIdx.x; |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|             for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { |  | ||||||
|                 const int j_KQ = j_KQ_0 + threadIdx.y*cpw; |  | ||||||
|  |  | ||||||
|                 if (use_logit_softcap) { |  | ||||||
|                     KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]); |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; |  | ||||||
|  |  | ||||||
|                 KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         __syncthreads(); |  | ||||||
|  |  | ||||||
|         // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: |  | ||||||
| #pragma unroll |  | ||||||
|         for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
|             half  tmp[kq_stride/warp_size][softmax_iter_j]; |  | ||||||
| #else |  | ||||||
|             float tmp[kq_stride/warp_size][softmax_iter_j]; |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|             for (int j1 = 0; j1 < softmax_iter_j; ++j1) { |  | ||||||
|                 KQ_max_new[j0+j1] = warp_reduce_max<warp_size>(KQ_max_new[j0+j1]); |  | ||||||
|                 const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]); |  | ||||||
|                 KQ_max[j0+j1] = KQ_max_new[j0+j1]; |  | ||||||
|  |  | ||||||
|                 float KQ_sum_add = 0.0f; |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { |  | ||||||
|                     const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]); |  | ||||||
|                     KQ_sum_add += val; |  | ||||||
|                     tmp[i0/warp_size][j1] = val; |  | ||||||
|                 } |  | ||||||
|                 KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add; |  | ||||||
|  |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
|                 const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int i0 = 0; i0 < D/2; i0 += warp_size) { |  | ||||||
|                     VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2; |  | ||||||
|                 } |  | ||||||
| #else |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int i0 = 0; i0 < D/2; i0 += warp_size) { |  | ||||||
|                     VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale; |  | ||||||
|                     VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale; |  | ||||||
|                 } |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
|             } |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|             for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { |  | ||||||
|                 const int i = i0 + threadIdx.x; |  | ||||||
|  |  | ||||||
|                 ggml_cuda_memcpy_1<sizeof(tmp[0])>( |  | ||||||
|                     KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         // VKQ = V @ KQ matrix multiplication: |  | ||||||
|         constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K. |  | ||||||
|         static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter"); |  | ||||||
| #pragma unroll |  | ||||||
|         for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) { |  | ||||||
| #pragma unroll |  | ||||||
|             for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) { |  | ||||||
|                 const int k_tile = k1 + threadIdx.y; |  | ||||||
|  |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
|                 constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size); |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { |  | ||||||
|                     ggml_cuda_memcpy_1<cpy_ne_D*4>( |  | ||||||
|                         &KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D], |  | ||||||
|                         &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]); |  | ||||||
|                 } |  | ||||||
| #else |  | ||||||
|                 constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { |  | ||||||
|                     half2 tmp_h2[cpy_ne_D/2]; |  | ||||||
|                     ggml_cuda_memcpy_1<sizeof(tmp_h2)>( |  | ||||||
|                         tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]); |  | ||||||
|  |  | ||||||
|                     float2 tmp_f2[cpy_ne_D/2]; |  | ||||||
| #pragma unroll |  | ||||||
|                     for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { |  | ||||||
|                         tmp_f2[i1] = __half22float2(tmp_h2[i1]); |  | ||||||
|                     } |  | ||||||
|                     ggml_cuda_memcpy_1<sizeof(tmp_f2)>( |  | ||||||
|                         &KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2); |  | ||||||
|                 } |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             __syncthreads(); |  | ||||||
|  |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
| #pragma unroll |  | ||||||
|             for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { |  | ||||||
|                 half2 V_k[(D/2)/warp_size]; |  | ||||||
|                 half2 KQ_k[cpw]; |  | ||||||
|  |  | ||||||
|                 constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { |  | ||||||
|                     ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]); |  | ||||||
|                 } |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { |  | ||||||
|                     const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); |  | ||||||
|  |  | ||||||
|                     half tmp[softmax_iter_j]; |  | ||||||
|                     ggml_cuda_memcpy_1<softmax_iter_j*sizeof(half)>( |  | ||||||
|                         &tmp, KQ[j][k0 + k1]); |  | ||||||
| #pragma unroll |  | ||||||
|                     for (int j1 = 0; j1 < softmax_iter_j; ++j1) { |  | ||||||
|                         KQ_k[j0+j1] = __half2half2(tmp[j1]); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int i0 = 0; i0 < D/2; i0 += warp_size) { |  | ||||||
| #pragma unroll |  | ||||||
|                     for (int j0 = 0; j0 < cpw; ++j0) { |  | ||||||
|                         VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0]; |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| #else |  | ||||||
| #pragma unroll |  | ||||||
|             for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { |  | ||||||
|                 float2 V_k[(D/2)/warp_size]; |  | ||||||
|                 float  KQ_k[cpw]; |  | ||||||
|  |  | ||||||
|                 constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { |  | ||||||
|                     ggml_cuda_memcpy_1<cpy_ne_D*4>(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]); |  | ||||||
|                 } |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { |  | ||||||
|                     const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); |  | ||||||
|  |  | ||||||
|                     ggml_cuda_memcpy_1<softmax_iter_j*sizeof(float)>( |  | ||||||
|                         &KQ_k[j0], KQ[j][k0 + k1]); |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int i0 = 0; i0 < D/2; i0 += warp_size) { |  | ||||||
| #pragma unroll |  | ||||||
|                     for (int j0 = 0; j0 < cpw; ++j0) { |  | ||||||
|                         VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0]; |  | ||||||
|                         VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0]; |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
|  |  | ||||||
|             __syncthreads(); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|  |  | ||||||
|     // Attention sink: adjust running max and sum once per head |  | ||||||
|     if (sinksf && blockIdx.y == 0) { |  | ||||||
|         const float sink = sinksf[head]; |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|         for (int j0 = 0; j0 < cpw; ++j0) { |  | ||||||
|             float KQ_max_new_j = fmaxf(KQ_max[j0], sink); |  | ||||||
|             KQ_max_new_j = warp_reduce_max<warp_size>(KQ_max_new_j); |  | ||||||
|  |  | ||||||
|             const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j); |  | ||||||
|             KQ_max[j0] = KQ_max_new_j; |  | ||||||
|  |  | ||||||
|             const float val = expf(sink - KQ_max[j0]); |  | ||||||
|             KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale; |  | ||||||
|             if (threadIdx.x == 0) { |  | ||||||
|                 KQ_sum[j0] += val; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
|             const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); |  | ||||||
| #pragma unroll |  | ||||||
|             for (int i0 = 0; i0 < D/2; i0 += warp_size) { |  | ||||||
|                 VKQ[j0][i0/warp_size] *= KQ_max_scale_h2; |  | ||||||
|             } |  | ||||||
| #else |  | ||||||
| #pragma unroll |  | ||||||
|             for (int i0 = 0; i0 < D/2; i0 += warp_size) { |  | ||||||
|                 VKQ[j0][i0/warp_size].x *= KQ_max_scale; |  | ||||||
|                 VKQ[j0][i0/warp_size].y *= KQ_max_scale; |  | ||||||
|             } |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|     for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { |  | ||||||
|         KQ_sum[j_VKQ_0] = warp_reduce_sum<warp_size>(KQ_sum[j_VKQ_0]); |  | ||||||
|     } |  | ||||||
|     if (gridDim.y == 1) { |  | ||||||
| #pragma unroll |  | ||||||
|         for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
|             const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]); |  | ||||||
| #pragma unroll |  | ||||||
|             for (int i = 0; i < (D/2)/warp_size; ++i) { |  | ||||||
|                 VKQ[j_VKQ_0][i] *= KQ_sum_j_inv; |  | ||||||
|             } |  | ||||||
| #else |  | ||||||
|             const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0]; |  | ||||||
| #pragma unroll |  | ||||||
|             for (int i = 0; i < (D/2)/warp_size; ++i) { |  | ||||||
|                 VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv; |  | ||||||
|                 VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv; |  | ||||||
|             } |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // Write back results: |  | ||||||
| #pragma unroll |  | ||||||
|     for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { |  | ||||||
|         const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw; |  | ||||||
|  |  | ||||||
|         if (ic0 + j_VKQ >= ne01) { |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; |  | ||||||
|  |  | ||||||
| #ifdef FAST_FP16_AVAILABLE |  | ||||||
|         constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; |  | ||||||
| #pragma unroll |  | ||||||
|         for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { |  | ||||||
|             float2 tmp[cpy_ne_D]; |  | ||||||
| #pragma unroll |  | ||||||
|             for (int i1 = 0; i1 < cpy_ne_D; ++i1) { |  | ||||||
|                 tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]); |  | ||||||
|             } |  | ||||||
|             ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); |  | ||||||
|         } |  | ||||||
| #else |  | ||||||
|         constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; |  | ||||||
| #pragma unroll |  | ||||||
|         for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { |  | ||||||
|             ggml_cuda_memcpy_1<cpy_ne_D*4>( |  | ||||||
|                 &dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]); |  | ||||||
|         } |  | ||||||
| #endif // FAST_FP16_AVAILABLE |  | ||||||
|  |  | ||||||
|         if (gridDim.y != 1 && threadIdx.x == 0) { |  | ||||||
|             dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| #else |  | ||||||
|     GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, |  | ||||||
|         max_bias, m0, m1, n_head_log2, logit_softcap, |  | ||||||
|         ne00, ne01, ne02, ne03, |  | ||||||
|               nb01, nb02, nb03, |  | ||||||
|         ne10, ne11, ne12, ne13, |  | ||||||
|               nb11, nb12, nb13, |  | ||||||
|               nb21, nb22, nb23, |  | ||||||
|               ne31, ne32, ne33, |  | ||||||
|               nb31, nb32, nb33); |  | ||||||
|     NO_DEVICE_CODE; |  | ||||||
| #endif // FLASH_ATTN_AVAILABLE |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template <int D, bool use_logit_softcap> |  | ||||||
| static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |  | ||||||
|     const ggml_tensor * Q = dst->src[0]; |  | ||||||
|  |  | ||||||
|     const int id        = ggml_cuda_get_device(); |  | ||||||
|     const int cc        = ggml_cuda_info().devices[id].cc; |  | ||||||
|     const int warp_size = 32; |  | ||||||
|  |  | ||||||
|     constexpr size_t nbytes_shared = 0; |  | ||||||
|  |  | ||||||
| #ifdef GGML_USE_HIP |  | ||||||
|     if constexpr (D <= 128) { |  | ||||||
|         if (Q->ne[1] > 32) { |  | ||||||
|             constexpr int cols_per_block = 64; |  | ||||||
|             const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; |  | ||||||
|             fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>; |  | ||||||
|             const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); |  | ||||||
|             launch_fattn<D, cols_per_block, 1> |  | ||||||
|                 (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| #endif // GGML_USE_HIP |  | ||||||
|  |  | ||||||
|     if (Q->ne[1] > 16) { |  | ||||||
|         constexpr int cols_per_block = 32; |  | ||||||
|         const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; |  | ||||||
|         fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>; |  | ||||||
|         const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); |  | ||||||
|         launch_fattn<D, cols_per_block, 1> |  | ||||||
|             (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     constexpr int cols_per_block = 16; |  | ||||||
|     const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; |  | ||||||
|     fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>; |  | ||||||
|     const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); |  | ||||||
|     launch_fattn<D, cols_per_block, 1> |  | ||||||
|         (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template <bool use_logit_softcap> |  | ||||||
| static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |  | ||||||
|     const ggml_tensor * Q = dst->src[0]; |  | ||||||
|     switch (Q->ne[0]) { |  | ||||||
|         case  64: { |         case  64: { | ||||||
|             launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst); |             GGML_ASSERT(V->ne[0] == K->ne[0]); | ||||||
|  |             ggml_cuda_flash_attn_ext_tile_case< 64,  64>(ctx, dst); | ||||||
|  |         } break; | ||||||
|  |         case  80: { | ||||||
|  |             GGML_ASSERT(V->ne[0] == K->ne[0]); | ||||||
|  |             ggml_cuda_flash_attn_ext_tile_case< 80,  80>(ctx, dst); | ||||||
|  |         } break; | ||||||
|  |         case  96: { | ||||||
|  |             GGML_ASSERT(V->ne[0] == K->ne[0]); | ||||||
|  |             ggml_cuda_flash_attn_ext_tile_case< 96,  96>(ctx, dst); | ||||||
|  |         } break; | ||||||
|  |         case 112: { | ||||||
|  |             GGML_ASSERT(V->ne[0] == K->ne[0]); | ||||||
|  |             ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst); | ||||||
|         } break; |         } break; | ||||||
|         case 128: { |         case 128: { | ||||||
|             launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst); |             GGML_ASSERT(V->ne[0] == K->ne[0]); | ||||||
|  |             ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst); | ||||||
|         } break; |         } break; | ||||||
|         case 256: { |         case 256: { | ||||||
|             launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst); |             GGML_ASSERT(V->ne[0] == K->ne[0]); | ||||||
|  |             ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); | ||||||
|  |         } break; | ||||||
|  |         case 576: { | ||||||
|  |             GGML_ASSERT(V->ne[0] == 512); | ||||||
|  |             ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); | ||||||
|         } break; |         } break; | ||||||
|         default: { |         default: { | ||||||
|             GGML_ABORT("Unsupported head size"); |             GGML_ABORT("Unsupported head size"); | ||||||
|         } break; |         } break; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |  | ||||||
|     const ggml_tensor * KQV = dst; |  | ||||||
|  |  | ||||||
|     float logit_softcap; |  | ||||||
|     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); |  | ||||||
|  |  | ||||||
|     if (logit_softcap == 0.0f) { |  | ||||||
|         constexpr bool use_logit_softcap = false; |  | ||||||
|         launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst); |  | ||||||
|     } else { |  | ||||||
|         constexpr bool use_logit_softcap = true; |  | ||||||
|         launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|   | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -1,3 +1,5 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
| #include "common.cuh" | #include "common.cuh" | ||||||
|  |  | ||||||
| #if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) | #if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) | ||||||
|   | |||||||
| @@ -198,6 +198,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const | |||||||
|     return BEST_FATTN_KERNEL_NONE; |     return BEST_FATTN_KERNEL_NONE; | ||||||
| #endif// FLASH_ATTN_AVAILABLE | #endif// FLASH_ATTN_AVAILABLE | ||||||
|  |  | ||||||
|  |     const ggml_tensor * KQV   = dst; | ||||||
|     const ggml_tensor * Q     = dst->src[0]; |     const ggml_tensor * Q     = dst->src[0]; | ||||||
|     const ggml_tensor * K     = dst->src[1]; |     const ggml_tensor * K     = dst->src[1]; | ||||||
|     const ggml_tensor * V     = dst->src[2]; |     const ggml_tensor * V     = dst->src[2]; | ||||||
| @@ -206,37 +207,32 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const | |||||||
|     const int gqa_ratio = Q->ne[2] / K->ne[2]; |     const int gqa_ratio = Q->ne[2] / K->ne[2]; | ||||||
|     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); |     GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); | ||||||
|  |  | ||||||
|  |     float max_bias = 0.0f; | ||||||
|  |     memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); | ||||||
|  |  | ||||||
|  |     // The effective batch size for the kernel can be increased by gqa_ratio. | ||||||
|  |     // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, | ||||||
|  |     const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; | ||||||
|  |  | ||||||
|     const int cc = ggml_cuda_info().devices[device].cc; |     const int cc = ggml_cuda_info().devices[device].cc; | ||||||
|  |  | ||||||
|     // TODO: temporary until support is extended |  | ||||||
|     //       https://github.com/ggml-org/llama.cpp/pull/16148#issuecomment-3343525206 |  | ||||||
|     if (K->ne[1] % FATTN_KQ_STRIDE != 0) { |  | ||||||
|         return BEST_FATTN_KERNEL_NONE; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     switch (K->ne[0]) { |     switch (K->ne[0]) { | ||||||
|  |         case  40: | ||||||
|         case  64: |         case  64: | ||||||
|         case 128: |  | ||||||
|         case 256: |  | ||||||
|             if (V->ne[0] != K->ne[0]) { |  | ||||||
|                 return BEST_FATTN_KERNEL_NONE; |  | ||||||
|             } |  | ||||||
|             break; |  | ||||||
|         case  80: |         case  80: | ||||||
|         case  96: |         case  96: | ||||||
|  |         case 128: | ||||||
|         case 112: |         case 112: | ||||||
|  |         case 256: | ||||||
|             if (V->ne[0] != K->ne[0]) { |             if (V->ne[0] != K->ne[0]) { | ||||||
|                 return BEST_FATTN_KERNEL_NONE; |                 return BEST_FATTN_KERNEL_NONE; | ||||||
|             } |             } | ||||||
|             if (!ggml_cuda_should_use_wmma_fattn(cc) && !turing_mma_available(cc)) { |  | ||||||
|                 return BEST_FATTN_KERNEL_NONE; |  | ||||||
|             } |  | ||||||
|             break; |             break; | ||||||
|         case 576: |         case 576: | ||||||
|             if (V->ne[0] != 512) { |             if (V->ne[0] != 512) { | ||||||
|                 return BEST_FATTN_KERNEL_NONE; |                 return BEST_FATTN_KERNEL_NONE; | ||||||
|             } |             } | ||||||
|             if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) { |             if (!gqa_opt_applies || gqa_ratio % 16 != 0) { | ||||||
|                 return BEST_FATTN_KERNEL_NONE; |                 return BEST_FATTN_KERNEL_NONE; | ||||||
|             } |             } | ||||||
|             break; |             break; | ||||||
| @@ -270,47 +266,57 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const | |||||||
|         return BEST_FATTN_KERNEL_NONE; |         return BEST_FATTN_KERNEL_NONE; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0; |     // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: | ||||||
|  |     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; | ||||||
|     // If Turing tensor cores available, use them except for some cases with batch size 1: |  | ||||||
|     if (turing_mma_available(cc)) { |  | ||||||
|         best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16; |  | ||||||
|  |  | ||||||
|  |     // If Turing tensor cores available, use them: | ||||||
|  |     if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) { | ||||||
|         if (can_use_vector_kernel) { |         if (can_use_vector_kernel) { | ||||||
|             if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { |             if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { | ||||||
|                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { |                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { | ||||||
|                     best = BEST_FATTN_KERNEL_VEC; |                     return BEST_FATTN_KERNEL_VEC; | ||||||
|                 } |                 } | ||||||
|             } else { |             } else { | ||||||
|                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { |                 if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { | ||||||
|                     if (Q->ne[1] <= 2) { |                     if (Q->ne[1] <= 2) { | ||||||
|                         best = BEST_FATTN_KERNEL_VEC; |                         return BEST_FATTN_KERNEL_VEC; | ||||||
|                     } |                     } | ||||||
|                 } else { |                 } else { | ||||||
|                     if (Q->ne[1] == 1) { |                     if (Q->ne[1] == 1) { | ||||||
|                         best = BEST_FATTN_KERNEL_VEC; |                         return BEST_FATTN_KERNEL_VEC; | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|             if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) { |             if (!gqa_opt_applies && Q->ne[1] == 1) { | ||||||
|                 best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply. |                 return BEST_FATTN_KERNEL_VEC; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         return best; |         return BEST_FATTN_KERNEL_MMA_F16; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // Use kernels specialized for small batch sizes if possible: |     // Use the WMMA kernel if possible: | ||||||
|     if (Q->ne[1] <= 8 && can_use_vector_kernel) { |     if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) { | ||||||
|         return BEST_FATTN_KERNEL_VEC; |         if (can_use_vector_kernel && Q->ne[1] <= 2) { | ||||||
|     } |             return BEST_FATTN_KERNEL_VEC; | ||||||
|  |         } | ||||||
|     // For large batch sizes, use the WMMA kernel if possible: |  | ||||||
|     if (ggml_cuda_should_use_wmma_fattn(cc)) { |  | ||||||
|         return BEST_FATTN_KERNEL_WMMA_F16; |         return BEST_FATTN_KERNEL_WMMA_F16; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes: |     // If there are no tensor cores available, use the generic tile kernel: | ||||||
|  |     if (can_use_vector_kernel) { | ||||||
|  |         if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { | ||||||
|  |             if (Q->ne[1] == 1) { | ||||||
|  |                 if (!gqa_opt_applies) { | ||||||
|  |                     return BEST_FATTN_KERNEL_VEC; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } else { | ||||||
|  |             if (Q->ne[1] <= 2) { | ||||||
|  |                 return BEST_FATTN_KERNEL_VEC; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|     return BEST_FATTN_KERNEL_TILE; |     return BEST_FATTN_KERNEL_TILE; | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -0,0 +1,5 @@ | |||||||
|  | // This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
|  | #include "../fattn-tile.cuh" | ||||||
|  |  | ||||||
|  | DECL_FATTN_TILE_CASE(112, 112); | ||||||
| @@ -0,0 +1,5 @@ | |||||||
|  | // This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
|  | #include "../fattn-tile.cuh" | ||||||
|  |  | ||||||
|  | DECL_FATTN_TILE_CASE(128, 128); | ||||||
| @@ -0,0 +1,5 @@ | |||||||
|  | // This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
|  | #include "../fattn-tile.cuh" | ||||||
|  |  | ||||||
|  | DECL_FATTN_TILE_CASE(256, 256); | ||||||
| @@ -0,0 +1,5 @@ | |||||||
|  | // This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
|  | #include "../fattn-tile.cuh" | ||||||
|  |  | ||||||
|  | DECL_FATTN_TILE_CASE(40, 40); | ||||||
| @@ -0,0 +1,5 @@ | |||||||
|  | // This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
|  | #include "../fattn-tile.cuh" | ||||||
|  |  | ||||||
|  | DECL_FATTN_TILE_CASE(576, 512); | ||||||
| @@ -0,0 +1,5 @@ | |||||||
|  | // This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
|  | #include "../fattn-tile.cuh" | ||||||
|  |  | ||||||
|  | DECL_FATTN_TILE_CASE(64, 64); | ||||||
| @@ -0,0 +1,5 @@ | |||||||
|  | // This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
|  | #include "../fattn-tile.cuh" | ||||||
|  |  | ||||||
|  | DECL_FATTN_TILE_CASE(80, 80); | ||||||
| @@ -0,0 +1,5 @@ | |||||||
|  | // This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
|  | #include "../fattn-tile.cuh" | ||||||
|  |  | ||||||
|  | DECL_FATTN_TILE_CASE(96, 96); | ||||||
| @@ -3,8 +3,17 @@ | |||||||
| from glob import glob | from glob import glob | ||||||
| import os | import os | ||||||
|  |  | ||||||
|  | HEAD_SIZES_KQ = [40, 64, 80, 96, 112, 128, 256, 576] | ||||||
|  |  | ||||||
| TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] | TYPES_KV = ["GGML_TYPE_F16", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0"] | ||||||
|  |  | ||||||
|  | SOURCE_FATTN_TILE = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
|  | #include "../fattn-tile.cuh" | ||||||
|  |  | ||||||
|  | DECL_FATTN_TILE_CASE({head_size_kq}, {head_size_v}); | ||||||
|  | """ | ||||||
|  |  | ||||||
| SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. | SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
| #include "../fattn-vec.cuh" | #include "../fattn-vec.cuh" | ||||||
| @@ -51,6 +60,11 @@ def get_short_name(long_quant_name): | |||||||
| for filename in glob("*.cu"): | for filename in glob("*.cu"): | ||||||
|     os.remove(filename) |     os.remove(filename) | ||||||
|  |  | ||||||
|  | for head_size_kq in HEAD_SIZES_KQ: | ||||||
|  |     head_size_v = head_size_kq if head_size_kq != 576 else 512 | ||||||
|  |     with open(f"fattn-tile-instance-dkq{head_size_kq}-dv{head_size_v}.cu", "w") as f: | ||||||
|  |         f.write(SOURCE_FATTN_TILE.format(head_size_kq=head_size_kq, head_size_v=head_size_v)) | ||||||
|  |  | ||||||
| for type_k in TYPES_KV: | for type_k in TYPES_KV: | ||||||
|     for type_v in TYPES_KV: |     for type_v in TYPES_KV: | ||||||
|         with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: |         with open(f"fattn-vec-instance-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: | ||||||
| @@ -64,7 +78,9 @@ for ncols in [8, 16, 32, 64]: | |||||||
|         with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: |         with open(f"fattn-mma-f16-instance-ncols1_{ncols1}-ncols2_{ncols2}.cu", "w") as f: | ||||||
|             f.write(SOURCE_FATTN_MMA_START) |             f.write(SOURCE_FATTN_MMA_START) | ||||||
|  |  | ||||||
|             for head_size_kq in [64, 80, 96, 112, 128, 256, 576]: |             for head_size_kq in HEAD_SIZES_KQ: | ||||||
|  |                 if head_size_kq == 40: | ||||||
|  |                     continue | ||||||
|                 if head_size_kq != 576 and ncols2 == 16: |                 if head_size_kq != 576 and ncols2 == 16: | ||||||
|                     continue |                     continue | ||||||
|                 if head_size_kq == 576 and ncols2 != 16: |                 if head_size_kq == 576 and ncols2 != 16: | ||||||
|   | |||||||
| @@ -53,6 +53,8 @@ file(GLOB   GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") | |||||||
| list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") | list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") | ||||||
|  |  | ||||||
| file(GLOB   GGML_SOURCES_ROCM "../ggml-cuda/*.cu") | file(GLOB   GGML_SOURCES_ROCM "../ggml-cuda/*.cu") | ||||||
|  | file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-tile*.cu") | ||||||
|  | list(APPEND GGML_SOURCES_ROCM ${SRCS}) | ||||||
| file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") | file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") | ||||||
| list(APPEND GGML_SOURCES_ROCM ${SRCS}) | list(APPEND GGML_SOURCES_ROCM ${SRCS}) | ||||||
| file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu") | file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu") | ||||||
|   | |||||||
| @@ -30,6 +30,8 @@ if (MUSAToolkit_FOUND) | |||||||
|     list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh") |     list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh") | ||||||
|  |  | ||||||
|     file(GLOB   GGML_SOURCES_MUSA "../ggml-cuda/*.cu") |     file(GLOB   GGML_SOURCES_MUSA "../ggml-cuda/*.cu") | ||||||
|  |     file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-tile*.cu") | ||||||
|  |     list(APPEND GGML_SOURCES_MUSA ${SRCS}) | ||||||
|     file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") |     file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") | ||||||
|     list(APPEND GGML_SOURCES_MUSA ${SRCS}) |     list(APPEND GGML_SOURCES_MUSA ${SRCS}) | ||||||
|     file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu") |     file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu") | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler