mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	CUDA: skip masked KV slices for all FA kernels (#14924)
This commit is contained in:
		| @@ -432,6 +432,20 @@ static __global__ void reduce_rows_f32(const float * x, float * dst, const int n | |||||||
|     dst[row] = norm ? sum / ncols : sum; |     dst[row] = norm ? sum / ncols : sum; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template<int width = WARP_SIZE> | ||||||
|  | static __device__ __forceinline__ int warp_reduce_all(int x) { | ||||||
|  | #ifdef GGML_USE_HIP | ||||||
|  | #pragma unroll | ||||||
|  |     for (int offset = width/2; offset > 0; offset >>= 1) { | ||||||
|  |         x = x && __shfl_xor_sync(0xffffffff, x, offset, width); | ||||||
|  |     } | ||||||
|  |     return x; | ||||||
|  | #else | ||||||
|  |     static_assert(width == WARP_SIZE, "width != WARP_SIZE not implemented"); | ||||||
|  |     return __all_sync(0xffffffff, x); | ||||||
|  | #endif // GGML_USE_HIP | ||||||
|  | } | ||||||
|  |  | ||||||
| template<int width = WARP_SIZE> | template<int width = WARP_SIZE> | ||||||
| static __device__ __forceinline__ float warp_reduce_max(float x) { | static __device__ __forceinline__ float warp_reduce_max(float x) { | ||||||
| #pragma unroll | #pragma unroll | ||||||
|   | |||||||
| @@ -15,6 +15,7 @@ typedef void (* fattn_kernel_t)( | |||||||
|         const char * __restrict__ K, |         const char * __restrict__ K, | ||||||
|         const char * __restrict__ V, |         const char * __restrict__ V, | ||||||
|         const char * __restrict__ mask, |         const char * __restrict__ mask, | ||||||
|  |         const int  * __restrict__ KV_max, | ||||||
|         float      * __restrict__ dst, |         float      * __restrict__ dst, | ||||||
|         float2     * __restrict__ dst_meta, |         float2     * __restrict__ dst_meta, | ||||||
|         const float scale, |         const float scale, | ||||||
| @@ -500,6 +501,55 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { | |||||||
|         nullptr; |         nullptr; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template <int ncols1> | ||||||
|  | __launch_bounds__(FATTN_KQ_STRIDE/2, 1) | ||||||
|  | static __global__ void flash_attn_mask_to_KV_max( | ||||||
|  |         const half2 * __restrict__ mask, int * __restrict__ KV_max, const int ne30, const int s31, const int s33) { | ||||||
|  |     const int ne31     = gridDim.x; | ||||||
|  |     const int tid      = threadIdx.x; | ||||||
|  |     const int sequence = blockIdx.y; | ||||||
|  |     const int jt       = blockIdx.x; | ||||||
|  |  | ||||||
|  |     mask += sequence*s33 + jt*ncols1*s31; | ||||||
|  |  | ||||||
|  |     __shared__ int buf_iw[WARP_SIZE]; | ||||||
|  |     if (tid < WARP_SIZE) { | ||||||
|  |         buf_iw[tid] = 1; | ||||||
|  |     } | ||||||
|  |     __syncthreads(); | ||||||
|  |  | ||||||
|  |     int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; | ||||||
|  |     for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) { | ||||||
|  |         int all_inf = 1; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int j = 0; j < ncols1; ++j) { | ||||||
|  |             const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]); | ||||||
|  |             all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y)); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         all_inf = warp_reduce_all(all_inf); | ||||||
|  |         if (tid % WARP_SIZE == 0) { | ||||||
|  |             buf_iw[tid / WARP_SIZE] = all_inf; | ||||||
|  |         } | ||||||
|  |         __syncthreads(); | ||||||
|  |         all_inf = buf_iw[tid % WARP_SIZE]; | ||||||
|  |         __syncthreads(); | ||||||
|  |         all_inf = warp_reduce_all(all_inf); | ||||||
|  |  | ||||||
|  |         if (!all_inf) { | ||||||
|  |             KV_max_sj += FATTN_KQ_STRIDE; | ||||||
|  |             break; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (threadIdx.x != 0) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     KV_max[sequence*ne31 + jt] = KV_max_sj; | ||||||
|  | } | ||||||
|  |  | ||||||
| template<int D, int ncols1, int ncols2> // D == head size | template<int D, int ncols1, int ncols2> // D == head size | ||||||
| __launch_bounds__(D, 1) | __launch_bounds__(D, 1) | ||||||
| static __global__ void flash_attn_stream_k_fixup( | static __global__ void flash_attn_stream_k_fixup( | ||||||
| @@ -711,6 +761,7 @@ void launch_fattn( | |||||||
|  |  | ||||||
|     ggml_cuda_pool_alloc<half>   K_f16(pool); |     ggml_cuda_pool_alloc<half>   K_f16(pool); | ||||||
|     ggml_cuda_pool_alloc<half>   V_f16(pool); |     ggml_cuda_pool_alloc<half>   V_f16(pool); | ||||||
|  |     ggml_cuda_pool_alloc<int>    KV_max(pool); | ||||||
|     ggml_cuda_pool_alloc<float>  dst_tmp(pool); |     ggml_cuda_pool_alloc<float>  dst_tmp(pool); | ||||||
|     ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); |     ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); | ||||||
|  |  | ||||||
| @@ -779,11 +830,30 @@ void launch_fattn( | |||||||
|         V_data = (char *) V_f16.ptr; |         V_data = (char *) V_f16.ptr; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     int parallel_blocks = 1; |  | ||||||
|  |  | ||||||
|     const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); |     const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); | ||||||
|     const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; |     const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; | ||||||
|  |  | ||||||
|  |     // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. | ||||||
|  |     // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or | ||||||
|  |     //     multiple sequences of possibly different lengths. | ||||||
|  |     if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { | ||||||
|  |         const int s31 = mask->nb[1] / sizeof(half2); | ||||||
|  |         const int s33 = mask->nb[3] / sizeof(half2); | ||||||
|  |  | ||||||
|  |         const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1); | ||||||
|  |         const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1); | ||||||
|  |  | ||||||
|  |         const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y; | ||||||
|  |         const int iter_k = K->ne[1] / FATTN_KQ_STRIDE; | ||||||
|  |  | ||||||
|  |         KV_max.alloc(ne_KV_max); | ||||||
|  |         flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>> | ||||||
|  |             ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33); | ||||||
|  |         CUDA_CHECK(cudaGetLastError()); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     int parallel_blocks = 1; | ||||||
|  |  | ||||||
|     const dim3 block_dim(warp_size, nwarps, 1); |     const dim3 block_dim(warp_size, nwarps, 1); | ||||||
|     int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. |     int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. | ||||||
|     CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); |     CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); | ||||||
| @@ -870,6 +940,7 @@ void launch_fattn( | |||||||
|         K_data, |         K_data, | ||||||
|         V_data, |         V_data, | ||||||
|         mask ? ((const char *) mask->data) : nullptr, |         mask ? ((const char *) mask->data) : nullptr, | ||||||
|  |         KV_max.ptr, | ||||||
|         !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, |         !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, | ||||||
|         scale, max_bias, m0, m1, n_head_log2, logit_softcap, |         scale, max_bias, m0, m1, n_head_log2, logit_softcap, | ||||||
|         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], |         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], | ||||||
|   | |||||||
| @@ -392,7 +392,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter> | template<int DKQ, int DV, int ncols1, int ncols2, int nwarps, int ntiles, | ||||||
|  |     bool use_logit_softcap, bool mla, bool needs_fixup, bool is_fixup, bool last_iter> | ||||||
| static __device__ __forceinline__ void flash_attn_ext_f16_iter( | static __device__ __forceinline__ void flash_attn_ext_f16_iter( | ||||||
|         const float2 * const __restrict__ Q_f2, |         const float2 * const __restrict__ Q_f2, | ||||||
|         const half2  * const __restrict__ K_h2, |         const half2  * const __restrict__ K_h2, | ||||||
| @@ -922,7 +923,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // Iterate over ne11 == previous tokens: |     // Iterate over ne11 == previous tokens: | ||||||
|     for (int kb0 = kb0_start; kb0 < kb0_stop-1; ++kb0) { |     int kb0 = kb0_start; | ||||||
|  |     for (; kb0 < kb0_stop-1; ++kb0) { | ||||||
|         constexpr bool last_iter = false; |         constexpr bool last_iter = false; | ||||||
|         flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> |         flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> | ||||||
|             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, |             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, | ||||||
| @@ -932,7 +934,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( | |||||||
|         constexpr bool last_iter = true; |         constexpr bool last_iter = true; | ||||||
|         flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> |         flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> | ||||||
|             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, |             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, | ||||||
|              ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); |              ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // With multi-stage loading there is no __syncthreads at the end of the iter, |     // With multi-stage loading there is no __syncthreads at the end of the iter, | ||||||
| @@ -1204,6 +1206,7 @@ static __global__ void flash_attn_ext_f16( | |||||||
|         const char * __restrict__ K, |         const char * __restrict__ K, | ||||||
|         const char * __restrict__ V, |         const char * __restrict__ V, | ||||||
|         const char * __restrict__ mask, |         const char * __restrict__ mask, | ||||||
|  |         const int  * __restrict__ KV_max, | ||||||
|         float      * __restrict__ dst, |         float      * __restrict__ dst, | ||||||
|         float2     * __restrict__ dst_meta, |         float2     * __restrict__ dst_meta, | ||||||
|         const float scale, |         const float scale, | ||||||
| @@ -1280,7 +1283,11 @@ static __global__ void flash_attn_ext_f16( | |||||||
|         const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; |         const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; | ||||||
|  |  | ||||||
|         const int kb0_start_kernel = kb0_start * kb_niter; |         const int kb0_start_kernel = kb0_start * kb_niter; | ||||||
|         const int kb0_stop_kernel  = kb0_stop  * kb_niter; |         int       kb0_stop_kernel  = kb0_stop  * kb_niter; | ||||||
|  |  | ||||||
|  |         if (KV_max) { | ||||||
|  |             kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. |         constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. | ||||||
|         if (kb0_start == 0) { |         if (kb0_start == 0) { | ||||||
| @@ -1321,7 +1328,11 @@ static __global__ void flash_attn_ext_f16( | |||||||
|     const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; |     const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f; | ||||||
|  |  | ||||||
|     const int kb0_start_kernel = kb0_start * kb_niter; |     const int kb0_start_kernel = kb0_start * kb_niter; | ||||||
|     const int kb0_stop_kernel  = kb0_stop  * kb_niter; |     int       kb0_stop_kernel  = kb0_stop  * kb_niter; | ||||||
|  |  | ||||||
|  |     if (KV_max) { | ||||||
|  |         kb0_stop_kernel = min(kb0_stop_kernel, KV_max[sequence*iter_j + jt] / c::nbatch_fa); | ||||||
|  |     } | ||||||
|  |  | ||||||
|     constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. |     constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. | ||||||
|     constexpr bool needs_fixup = false; |     constexpr bool needs_fixup = false; | ||||||
|   | |||||||
| @@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f16( | |||||||
|         const char * __restrict__ K, |         const char * __restrict__ K, | ||||||
|         const char * __restrict__ V, |         const char * __restrict__ V, | ||||||
|         const char * __restrict__ mask, |         const char * __restrict__ mask, | ||||||
|  |         const int  * __restrict__ KV_max, | ||||||
|         float      * __restrict__ dst, |         float      * __restrict__ dst, | ||||||
|         float2     * __restrict__ dst_meta, |         float2     * __restrict__ dst_meta, | ||||||
|         const float scale, |         const float scale, | ||||||
| @@ -90,7 +91,8 @@ static __global__ void flash_attn_tile_ext_f16( | |||||||
|  |  | ||||||
|     __syncthreads(); |     __syncthreads(); | ||||||
|  |  | ||||||
|     for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) { |     const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; | ||||||
|  |     for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) { | ||||||
|         // Calculate KQ tile and keep track of new maximum KQ values: |         // Calculate KQ tile and keep track of new maximum KQ values: | ||||||
|  |  | ||||||
|         half kqmax_new[ncols/nwarps]; |         half kqmax_new[ncols/nwarps]; | ||||||
|   | |||||||
| @@ -13,6 +13,7 @@ static __global__ void flash_attn_tile_ext_f32( | |||||||
|         const char * __restrict__ K, |         const char * __restrict__ K, | ||||||
|         const char * __restrict__ V, |         const char * __restrict__ V, | ||||||
|         const char * __restrict__ mask, |         const char * __restrict__ mask, | ||||||
|  |         const int  * __restrict__ KV_max, | ||||||
|         float      * __restrict__ dst, |         float      * __restrict__ dst, | ||||||
|         float2     * __restrict__ dst_meta, |         float2     * __restrict__ dst_meta, | ||||||
|         const float scale, |         const float scale, | ||||||
| @@ -99,7 +100,8 @@ static __global__ void flash_attn_tile_ext_f32( | |||||||
|  |  | ||||||
|     __syncthreads(); |     __syncthreads(); | ||||||
|  |  | ||||||
|     for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) { |     const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; | ||||||
|  |     for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) { | ||||||
|         // Calculate KQ tile and keep track of new maximum KQ values: |         // Calculate KQ tile and keep track of new maximum KQ values: | ||||||
|  |  | ||||||
|         float kqmax_new[ncols/nwarps]; |         float kqmax_new[ncols/nwarps]; | ||||||
|   | |||||||
| @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f16( | |||||||
|         const char * __restrict__ K, |         const char * __restrict__ K, | ||||||
|         const char * __restrict__ V, |         const char * __restrict__ V, | ||||||
|         const char * __restrict__ mask, |         const char * __restrict__ mask, | ||||||
|  |         const int  * __restrict__ KV_max, | ||||||
|         float      * __restrict__ dst, |         float      * __restrict__ dst, | ||||||
|         float2     * __restrict__ dst_meta, |         float2     * __restrict__ dst_meta, | ||||||
|         const float scale, |         const float scale, | ||||||
| @@ -177,10 +178,11 @@ static __global__ void flash_attn_vec_ext_f16( | |||||||
|  |  | ||||||
|     half2 VKQ[ncols] = {{0.0f, 0.0f}}; |     half2 VKQ[ncols] = {{0.0f, 0.0f}}; | ||||||
|  |  | ||||||
|  |     const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; | ||||||
|     K     += blockIdx.y*D * nb11; |     K     += blockIdx.y*D * nb11; | ||||||
|     V     += blockIdx.y*D * nb21; |     V     += blockIdx.y*D * nb21; | ||||||
|     maskh += blockIdx.y*D; |     maskh += blockIdx.y*D; | ||||||
|     for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, |     for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D, | ||||||
|              // Increment pointers after each loop: |              // Increment pointers after each loop: | ||||||
|              K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { |              K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { | ||||||
|  |  | ||||||
| @@ -191,29 +193,7 @@ static __global__ void flash_attn_vec_ext_f16( | |||||||
|             for (int j = 0; j < ncols; ++j) { |             for (int j = 0; j < ncols; ++j) { | ||||||
|                 maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid]; |                 maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid]; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             __syncthreads(); |             __syncthreads(); | ||||||
|  |  | ||||||
|             // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. |  | ||||||
|             // In such cases, skip the KV slice. |  | ||||||
|             // On AMD __all_sync would not work correctly because it assumes a warp size of 64. |  | ||||||
| #ifndef GGML_USE_HIP |  | ||||||
|             bool skip = true; |  | ||||||
| #pragma unroll |  | ||||||
|             for (int j = 0; j < ncols; ++j) { |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { |  | ||||||
|                     const int i = i0 + threadIdx.x; |  | ||||||
|  |  | ||||||
|                     const float2 tmp = __half22float2(((const half2 *) maskh_shared)[j*(D/2) + i]); |  | ||||||
|                     skip = skip && isinf(tmp.x) && isinf(tmp.y); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             if (__all_sync(0xFFFFFFFF, skip)) { |  | ||||||
|                 __syncthreads(); |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
| #endif // GGML_USE_HIP |  | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, |         // For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression, | ||||||
|   | |||||||
| @@ -16,6 +16,7 @@ static __global__ void flash_attn_vec_ext_f32( | |||||||
|         const char * __restrict__ K, |         const char * __restrict__ K, | ||||||
|         const char * __restrict__ V, |         const char * __restrict__ V, | ||||||
|         const char * __restrict__ mask, |         const char * __restrict__ mask, | ||||||
|  |         const int  * __restrict__ KV_max, | ||||||
|         float      * __restrict__ dst, |         float      * __restrict__ dst, | ||||||
|         float2     * __restrict__ dst_meta, |         float2     * __restrict__ dst_meta, | ||||||
|         const float scale, |         const float scale, | ||||||
| @@ -183,10 +184,11 @@ static __global__ void flash_attn_vec_ext_f32( | |||||||
|  |  | ||||||
|     float VKQ[ncols] = {0.0f}; |     float VKQ[ncols] = {0.0f}; | ||||||
|  |  | ||||||
|  |     const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; | ||||||
|     K     += blockIdx.y*D * nb11; |     K     += blockIdx.y*D * nb11; | ||||||
|     V     += blockIdx.y*D * nb21; |     V     += blockIdx.y*D * nb21; | ||||||
|     maskh += blockIdx.y*D; |     maskh += blockIdx.y*D; | ||||||
|     for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D, |     for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*D, | ||||||
|              // Increment pointers after each loop: |              // Increment pointers after each loop: | ||||||
|              K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { |              K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) { | ||||||
|  |  | ||||||
| @@ -197,28 +199,7 @@ static __global__ void flash_attn_vec_ext_f32( | |||||||
|             for (int j = 0; j < ncols; ++j) { |             for (int j = 0; j < ncols; ++j) { | ||||||
|                 maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]); |                 maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             __syncthreads(); |             __syncthreads(); | ||||||
|  |  | ||||||
|             // When using multiple parallel sequences in llama.cpp, some KV slices can be fully masked out. |  | ||||||
|             // In such cases, skip the KV slice. |  | ||||||
|             // On AMD __all_sync would not work correctly because it assumes a warp size of 64. |  | ||||||
| #ifndef GGML_USE_HIP |  | ||||||
|             bool skip = true; |  | ||||||
| #pragma unroll |  | ||||||
|             for (int j = 0; j < ncols; ++j) { |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { |  | ||||||
|                     const int i = i0 + threadIdx.x; |  | ||||||
|  |  | ||||||
|                     skip = skip && isinf(maskf_shared[j*D + i]); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|             if (__all_sync(0xFFFFFFFF, skip)) { |  | ||||||
|                 __syncthreads(); |  | ||||||
|                 continue; |  | ||||||
|             } |  | ||||||
| #endif // GGML_USE_HIP |  | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         float kqmax_new_arr[ncols]; |         float kqmax_new_arr[ncols]; | ||||||
|   | |||||||
| @@ -29,6 +29,7 @@ static __global__ void flash_attn_ext_f16( | |||||||
|         const char * __restrict__ K, |         const char * __restrict__ K, | ||||||
|         const char * __restrict__ V, |         const char * __restrict__ V, | ||||||
|         const char * __restrict__ mask, |         const char * __restrict__ mask, | ||||||
|  |         const int  * __restrict__ KV_max, | ||||||
|         float      * __restrict__ dst, |         float      * __restrict__ dst, | ||||||
|         float2     * __restrict__ dst_meta, |         float2     * __restrict__ dst_meta, | ||||||
|         const float scale, |         const float scale, | ||||||
| @@ -165,7 +166,8 @@ static __global__ void flash_attn_ext_f16( | |||||||
|     __syncthreads(); |     __syncthreads(); | ||||||
|  |  | ||||||
|     // Iterate over ne11 == previous tokens: |     // Iterate over ne11 == previous tokens: | ||||||
|     for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) { |     const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; | ||||||
|  |     for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) { | ||||||
|         // Calculate tile of KQ: |         // Calculate tile of KQ: | ||||||
| #pragma unroll | #pragma unroll | ||||||
|         for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { |         for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { | ||||||
|   | |||||||
| @@ -315,7 +315,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst | |||||||
|  |  | ||||||
|     const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations |     const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations | ||||||
|     const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; |     const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16; | ||||||
|     const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion; |     const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && | ||||||
|  |         (Q->ne[3] > 1 || cc < GGML_CUDA_CC_ADA_LOVELACE) && !mma_needs_data_conversion; | ||||||
|     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; |     const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0; | ||||||
|     if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { |     if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) { | ||||||
|         if (prec == GGML_PREC_DEFAULT) { |         if (prec == GGML_PREC_DEFAULT) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler