mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	CUDA: fix overflow in FA, tune performance (#14840)
This commit is contained in:
		| @@ -23,33 +23,13 @@ typedef void (* fattn_kernel_t)( | |||||||
|         const float m1, |         const float m1, | ||||||
|         const uint32_t n_head_log2, |         const uint32_t n_head_log2, | ||||||
|         const float logit_softcap, |         const float logit_softcap, | ||||||
|         const int ne00, |         const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, | ||||||
|         const int ne01, |                             const int32_t nb01, const int32_t nb02, const int32_t nb03, | ||||||
|         const int ne02, |         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, | ||||||
|         const int ne03, |                             const int32_t nb11, const int32_t nb12, const int64_t nb13, | ||||||
|         const int ne10, |                             const int32_t nb21, const int32_t nb22, const int64_t nb23, | ||||||
|         const int ne11, |                             const int32_t ne31, const int32_t ne32, const int32_t ne33, | ||||||
|         const int ne12, |                             const int32_t nb31, const int32_t nb32, const int64_t nb33); | ||||||
|         const int ne13, |  | ||||||
|         const int ne31, |  | ||||||
|         const int ne32, |  | ||||||
|         const int ne33, |  | ||||||
|         const int nb31, |  | ||||||
|         const int nb32, |  | ||||||
|         const int nb33, |  | ||||||
|         const int nb01, |  | ||||||
|         const int nb02, |  | ||||||
|         const int nb03, |  | ||||||
|         const int nb11, |  | ||||||
|         const int nb12, |  | ||||||
|         const int nb13, |  | ||||||
|         const int nb21, |  | ||||||
|         const int nb22, |  | ||||||
|         const int nb23, |  | ||||||
|         const int ne0, |  | ||||||
|         const int ne1, |  | ||||||
|         const int ne2, |  | ||||||
|         const int ne3); |  | ||||||
|  |  | ||||||
| typedef half (*vec_dot_KQ_f16_t)( | typedef half (*vec_dot_KQ_f16_t)( | ||||||
|     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); |     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); | ||||||
| @@ -892,14 +872,11 @@ void launch_fattn( | |||||||
|         mask ? ((const char *) mask->data) : nullptr, |         mask ? ((const char *) mask->data) : nullptr, | ||||||
|         !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, |         !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, | ||||||
|         scale, max_bias, m0, m1, n_head_log2, logit_softcap, |         scale, max_bias, m0, m1, n_head_log2, logit_softcap, | ||||||
|         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], |         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], | ||||||
|         K->ne[0], K->ne[1], K->ne[2], K->ne[3], |         K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, | ||||||
|         mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, |  | ||||||
|         mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0, |  | ||||||
|         Q->nb[1], Q->nb[2], Q->nb[3], |  | ||||||
|         nb11, nb12, nb13, |  | ||||||
|         nb21, nb22, nb23, |         nb21, nb22, nb23, | ||||||
|         KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] |         mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, | ||||||
|  |         mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0 | ||||||
|     ); |     ); | ||||||
|     CUDA_CHECK(cudaGetLastError()); |     CUDA_CHECK(cudaGetLastError()); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( | |||||||
|         const int stride_K, |         const int stride_K, | ||||||
|         const int stride_V, |         const int stride_V, | ||||||
|         const int stride_mask, |         const int stride_mask, | ||||||
|         const int jt, |  | ||||||
|         half2        * const __restrict__ tile_Q, |         half2        * const __restrict__ tile_Q, | ||||||
|         half2        * const __restrict__ tile_K, |         half2        * const __restrict__ tile_K, | ||||||
|         half2        * const __restrict__ tile_V, |         half2        * const __restrict__ tile_V, | ||||||
| @@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( | |||||||
|         cp_async_wait_all(); |         cp_async_wait_all(); | ||||||
|         __syncthreads(); |         __syncthreads(); | ||||||
|         flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async> |         flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async> | ||||||
|             (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); |             (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V); | ||||||
|     } else { |     } else { | ||||||
|         constexpr bool use_cp_async = nstages == 1; |         constexpr bool use_cp_async = nstages == 1; | ||||||
|         if (ncols2 > 1 || mask_h2) { |         if (ncols2 > 1 || mask_h2) { | ||||||
| @@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( | |||||||
|         if (nstages <= 1) { |         if (nstages <= 1) { | ||||||
|             constexpr bool use_cp_async = nstages == 1; |             constexpr bool use_cp_async = nstages == 1; | ||||||
|             flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> |             flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> | ||||||
|                 (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K); |                 (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K); | ||||||
|             if (use_cp_async) { |             if (use_cp_async) { | ||||||
|                 cp_async_wait_all(); |                 cp_async_wait_all(); | ||||||
|             } |             } | ||||||
| @@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( | |||||||
|                     (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); |                     (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); | ||||||
|             } |             } | ||||||
|             flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> |             flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> | ||||||
|                 (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); |                 (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( | |||||||
|         if (nstages <= 1 && i0_start < reusable_cutoff) { |         if (nstages <= 1 && i0_start < reusable_cutoff) { | ||||||
|             constexpr bool use_cp_async = nstages == 1; |             constexpr bool use_cp_async = nstages == 1; | ||||||
|             flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async> |             flash_attn_ext_f16_load_tile<stride_tile_V, nwarps, c::nbatch_fa, use_cp_async> | ||||||
|                 (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); |                 (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); | ||||||
|             if (use_cp_async) { |             if (use_cp_async) { | ||||||
|                 cp_async_wait_all(); |                 cp_async_wait_all(); | ||||||
|             } |             } | ||||||
| @@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( | |||||||
|     GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); |     GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); | ||||||
|     GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); |     GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); | ||||||
|     GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); |     GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); | ||||||
|     GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); |     GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K); | ||||||
|     GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); |  | ||||||
|     GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); |     GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); | ||||||
|     GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); |     GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); | ||||||
|     GGML_UNUSED(kb0); GGML_UNUSED(tile_Q); |     GGML_UNUSED(kb0); GGML_UNUSED(tile_Q); | ||||||
| @@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( | |||||||
|                 (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); |                 (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); | ||||||
|         } |         } | ||||||
|         flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> |         flash_attn_ext_f16_load_tile<stride_tile_K, nwarps, c::nbatch_fa, use_cp_async> | ||||||
|             (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); |             (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // Iterate over ne11 == previous tokens: |     // Iterate over ne11 == previous tokens: | ||||||
| @@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( | |||||||
|         constexpr bool last_iter = false; |         constexpr bool last_iter = false; | ||||||
|         flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> |         flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> | ||||||
|             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, |             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, | ||||||
|              ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); |              ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); | ||||||
|     } |     } | ||||||
|     { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. |     { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. | ||||||
|         constexpr bool last_iter = true; |         constexpr bool last_iter = true; | ||||||
|         flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> |         flash_attn_ext_f16_iter<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup, last_iter> | ||||||
|             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, |             (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, | ||||||
|              ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); |              ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // With multi-stage loading there is no __syncthreads at the end of the iter, |     // With multi-stage loading there is no __syncthreads at the end of the iter, | ||||||
| @@ -1214,33 +1212,13 @@ static __global__ void flash_attn_ext_f16( | |||||||
|         const float m1, |         const float m1, | ||||||
|         const uint32_t n_head_log2, |         const uint32_t n_head_log2, | ||||||
|         const float logit_softcap, |         const float logit_softcap, | ||||||
|         const int ne00, |         const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, | ||||||
|         const int ne01, |                             const int32_t nb01, const int32_t nb02, const int32_t nb03, | ||||||
|         const int ne02, |         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, | ||||||
|         const int ne03, |                             const int32_t nb11, const int32_t nb12, const int64_t nb13, | ||||||
|         const int ne10, |                             const int32_t nb21, const int32_t nb22, const int64_t nb23, | ||||||
|         const int ne11, |                             const int32_t ne31, const int32_t ne32, const int32_t ne33, | ||||||
|         const int ne12, |                             const int32_t nb31, const int32_t nb32, const int64_t nb33) { | ||||||
|         const int ne13, |  | ||||||
|         const int ne31, |  | ||||||
|         const int ne32, |  | ||||||
|         const int ne33, |  | ||||||
|         const int nb31, |  | ||||||
|         const int nb32, |  | ||||||
|         const int nb33, |  | ||||||
|         const int nb01, |  | ||||||
|         const int nb02, |  | ||||||
|         const int nb03, |  | ||||||
|         const int nb11, |  | ||||||
|         const int nb12, |  | ||||||
|         const int nb13, |  | ||||||
|         const int nb21, |  | ||||||
|         const int nb22, |  | ||||||
|         const int nb23, |  | ||||||
|         const int ne0, |  | ||||||
|         const int ne1, |  | ||||||
|         const int ne2, |  | ||||||
|         const int ne3) { |  | ||||||
| #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) | #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
| @@ -1359,8 +1337,7 @@ static __global__ void flash_attn_ext_f16( | |||||||
|     GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); |     GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); | ||||||
|     GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); |     GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); | ||||||
|     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); |     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); | ||||||
|     GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); |     GGML_UNUSED(nb22); GGML_UNUSED(nb23); | ||||||
|     GGML_UNUSED(ne2); GGML_UNUSED(ne3); |  | ||||||
|     NO_DEVICE_CODE; |     NO_DEVICE_CODE; | ||||||
| #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) | #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f16( | |||||||
|         const float m1, |         const float m1, | ||||||
|         const uint32_t n_head_log2, |         const uint32_t n_head_log2, | ||||||
|         const float logit_softcap, |         const float logit_softcap, | ||||||
|         const int ne00, |         const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, | ||||||
|         const int ne01, |                             const int32_t nb01, const int32_t nb02, const int32_t nb03, | ||||||
|         const int ne02, |         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, | ||||||
|         const int ne03, |                             const int32_t nb11, const int32_t nb12, const int64_t nb13, | ||||||
|         const int ne10, |                             const int32_t nb21, const int32_t nb22, const int64_t nb23, | ||||||
|         const int ne11, |                             const int32_t ne31, const int32_t ne32, const int32_t ne33, | ||||||
|         const int ne12, |                             const int32_t nb31, const int32_t nb32, const int64_t nb33) { | ||||||
|         const int ne13, |  | ||||||
|         const int ne31, |  | ||||||
|         const int ne32, |  | ||||||
|         const int ne33, |  | ||||||
|         const int nb31, |  | ||||||
|         const int nb32, |  | ||||||
|         const int nb33, |  | ||||||
|         const int nb01, |  | ||||||
|         const int nb02, |  | ||||||
|         const int nb03, |  | ||||||
|         const int nb11, |  | ||||||
|         const int nb12, |  | ||||||
|         const int nb13, |  | ||||||
|         const int nb21, |  | ||||||
|         const int nb22, |  | ||||||
|         const int nb23, |  | ||||||
|         const int ne0, |  | ||||||
|         const int ne1, |  | ||||||
|         const int ne2, |  | ||||||
|         const int ne3) { |  | ||||||
| #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) | #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
| @@ -127,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16( | |||||||
|             for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { |             for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { | ||||||
|                 const int k_KQ = k_KQ_0 + threadIdx.x; |                 const int k_KQ = k_KQ_0 + threadIdx.x; | ||||||
|  |  | ||||||
|                 KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; |                 KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -221,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16( | |||||||
|             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { |             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { | ||||||
|                 const int i = i0 + threadIdx.x; |                 const int i = i0 + threadIdx.x; | ||||||
|  |  | ||||||
|                 KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i]; |                 KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -300,8 +280,7 @@ static __global__ void flash_attn_tile_ext_f16( | |||||||
|     GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); |     GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); | ||||||
|     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); |     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); | ||||||
|     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); |     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); | ||||||
|     GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); |     GGML_UNUSED(nb23); | ||||||
|     GGML_UNUSED(ne2); GGML_UNUSED(ne3); |  | ||||||
|     NO_DEVICE_CODE; |     NO_DEVICE_CODE; | ||||||
| #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) | #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f32( | |||||||
|         const float m1, |         const float m1, | ||||||
|         const uint32_t n_head_log2, |         const uint32_t n_head_log2, | ||||||
|         const float logit_softcap, |         const float logit_softcap, | ||||||
|         const int ne00, |         const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, | ||||||
|         const int ne01, |                             const int32_t nb01, const int32_t nb02, const int32_t nb03, | ||||||
|         const int ne02, |         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, | ||||||
|         const int ne03, |                             const int32_t nb11, const int32_t nb12, const int64_t nb13, | ||||||
|         const int ne10, |                             const int32_t nb21, const int32_t nb22, const int64_t nb23, | ||||||
|         const int ne11, |                             const int32_t ne31, const int32_t ne32, const int32_t ne33, | ||||||
|         const int ne12, |                             const int32_t nb31, const int32_t nb32, const int64_t nb33) { | ||||||
|         const int ne13, |  | ||||||
|         const int ne31, |  | ||||||
|         const int ne32, |  | ||||||
|         const int ne33, |  | ||||||
|         const int nb31, |  | ||||||
|         const int nb32, |  | ||||||
|         const int nb33, |  | ||||||
|         const int nb01, |  | ||||||
|         const int nb02, |  | ||||||
|         const int nb03, |  | ||||||
|         const int nb11, |  | ||||||
|         const int nb12, |  | ||||||
|         const int nb13, |  | ||||||
|         const int nb21, |  | ||||||
|         const int nb22, |  | ||||||
|         const int nb23, |  | ||||||
|         const int ne0, |  | ||||||
|         const int ne1, |  | ||||||
|         const int ne2, |  | ||||||
|         const int ne3) { |  | ||||||
| #ifdef FLASH_ATTN_AVAILABLE | #ifdef FLASH_ATTN_AVAILABLE | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
| @@ -66,8 +46,7 @@ static __global__ void flash_attn_tile_ext_f32( | |||||||
|         GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); |         GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); | ||||||
|         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); |         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); | ||||||
|         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); |         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); | ||||||
|         GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); |         GGML_UNUSED(nb23); | ||||||
|         GGML_UNUSED(ne2); GGML_UNUSED(ne3); |  | ||||||
|         NO_DEVICE_CODE; |         NO_DEVICE_CODE; | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
| @@ -135,7 +114,7 @@ static __global__ void flash_attn_tile_ext_f32( | |||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { |             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { | ||||||
|                 const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; |                 const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; | ||||||
|                 KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] =  __low2float(tmp); |                 KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] =  __low2float(tmp); | ||||||
|                 KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); |                 KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); | ||||||
|             } |             } | ||||||
| @@ -231,8 +210,9 @@ static __global__ void flash_attn_tile_ext_f32( | |||||||
|             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { |             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { | ||||||
|                 const int i = i0 + threadIdx.x; |                 const int i = i0 + threadIdx.x; | ||||||
|  |  | ||||||
|                 KV_tmp2[k*(D/2) + i].x =  __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); |                 const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; | ||||||
|                 KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); |                 KV_tmp2[k*(D/2) + i].x =  __low2float(tmp); | ||||||
|  |                 KV_tmp2[k*(D/2) + i].y = __high2float(tmp); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -312,7 +292,6 @@ static __global__ void flash_attn_tile_ext_f32( | |||||||
|     GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); |     GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); | ||||||
|     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); |     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); | ||||||
|     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); |     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); | ||||||
|     GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); |  | ||||||
|     NO_DEVICE_CODE; |     NO_DEVICE_CODE; | ||||||
| #endif // FLASH_ATTN_AVAILABLE | #endif // FLASH_ATTN_AVAILABLE | ||||||
| } | } | ||||||
|   | |||||||
| @@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f16( | |||||||
|         const float m1, |         const float m1, | ||||||
|         const uint32_t n_head_log2, |         const uint32_t n_head_log2, | ||||||
|         const float logit_softcap, |         const float logit_softcap, | ||||||
|         const int ne00, |         const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, | ||||||
|         const int ne01, |                             const int32_t nb01, const int32_t nb02, const int32_t nb03, | ||||||
|         const int ne02, |         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, | ||||||
|         const int ne03, |                             const int32_t nb11, const int32_t nb12, const int64_t nb13, | ||||||
|         const int ne10, |                             const int32_t nb21, const int32_t nb22, const int64_t nb23, | ||||||
|         const int ne11, |                             const int32_t ne31, const int32_t ne32, const int32_t ne33, | ||||||
|         const int ne12, |                             const int32_t nb31, const int32_t nb32, const int64_t nb33) { | ||||||
|         const int ne13, |  | ||||||
|         const int ne31, |  | ||||||
|         const int ne32, |  | ||||||
|         const int ne33, |  | ||||||
|         const int nb31, |  | ||||||
|         const int nb32, |  | ||||||
|         const int nb33, |  | ||||||
|         const int nb01, |  | ||||||
|         const int nb02, |  | ||||||
|         const int nb03, |  | ||||||
|         const int nb11, |  | ||||||
|         const int nb12, |  | ||||||
|         const int nb13, |  | ||||||
|         const int nb21, |  | ||||||
|         const int nb22, |  | ||||||
|         const int nb23, |  | ||||||
|         const int ne0, |  | ||||||
|         const int ne1, |  | ||||||
|         const int ne2, |  | ||||||
|         const int ne3) { |  | ||||||
| #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) | #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
| @@ -191,13 +171,16 @@ static __global__ void flash_attn_vec_ext_f16( | |||||||
|  |  | ||||||
|     half2 VKQ[ncols] = {{0.0f, 0.0f}}; |     half2 VKQ[ncols] = {{0.0f, 0.0f}}; | ||||||
|  |  | ||||||
|  |     K     += blockIdx.y*D * nb11; | ||||||
|  |     V     += blockIdx.y*D * nb21; | ||||||
|  |     maskh += blockIdx.y*D; | ||||||
|     for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { |     for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { | ||||||
|         // Calculate KQ tile and keep track of new maximum KQ values: |         // Calculate KQ tile and keep track of new maximum KQ values: | ||||||
|  |  | ||||||
|         if (mask) { |         if (mask) { | ||||||
| #pragma unroll | #pragma unroll | ||||||
|             for (int j = 0; j < ncols; ++j) { |             for (int j = 0; j < ncols; ++j) { | ||||||
|                 maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid]; |                 maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid]; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             __syncthreads(); |             __syncthreads(); | ||||||
| @@ -244,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16( | |||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|             for (int j = 0; j < ncols; ++j) { |             for (int j = 0; j < ncols; ++j) { | ||||||
|                 half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); |                 half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); | ||||||
|                 sum = warp_reduce_sum((float)sum); |                 sum = warp_reduce_sum((float)sum); | ||||||
|  |  | ||||||
|                 if (use_logit_softcap) { |                 if (use_logit_softcap) { | ||||||
| @@ -300,14 +283,18 @@ static __global__ void flash_attn_vec_ext_f16( | |||||||
|             } |             } | ||||||
|  |  | ||||||
|             half2 V_k; |             half2 V_k; | ||||||
|             reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid); |             reinterpret_cast<half&>(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid); | ||||||
|             reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); |             reinterpret_cast<half&>(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid); | ||||||
| #pragma unroll | #pragma unroll | ||||||
|             for (int j = 0; j < ncols; ++j) { |             for (int j = 0; j < ncols; ++j) { | ||||||
|                 VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; |                 VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         K     += gridDim.y*D * nb11; | ||||||
|  |         V     += gridDim.y*D * nb21; | ||||||
|  |         maskh += gridDim.y*D; | ||||||
|  |  | ||||||
|         __syncthreads(); |         __syncthreads(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -351,8 +338,7 @@ static __global__ void flash_attn_vec_ext_f16( | |||||||
|     GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); |     GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); | ||||||
|     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); |     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); | ||||||
|     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); |     GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); | ||||||
|     GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); |     GGML_UNUSED(nb23); | ||||||
|     GGML_UNUSED(ne2); GGML_UNUSED(ne3); |  | ||||||
|     NO_DEVICE_CODE; |     NO_DEVICE_CODE; | ||||||
| #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) | #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f32( | |||||||
|         const float m1, |         const float m1, | ||||||
|         const uint32_t n_head_log2, |         const uint32_t n_head_log2, | ||||||
|         const float logit_softcap, |         const float logit_softcap, | ||||||
|         const int ne00, |         const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, | ||||||
|         const int ne01, |                             const int32_t nb01, const int32_t nb02, const int32_t nb03, | ||||||
|         const int ne02, |         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, | ||||||
|         const int ne03, |                             const int32_t nb11, const int32_t nb12, const int64_t nb13, | ||||||
|         const int ne10, |                             const int32_t nb21, const int32_t nb22, const int64_t nb23, | ||||||
|         const int ne11, |                             const int32_t ne31, const int32_t ne32, const int32_t ne33, | ||||||
|         const int ne12, |                             const int32_t nb31, const int32_t nb32, const int64_t nb33) { | ||||||
|         const int ne13, |  | ||||||
|         const int ne31, |  | ||||||
|         const int ne32, |  | ||||||
|         const int ne33, |  | ||||||
|         const int nb31, |  | ||||||
|         const int nb32, |  | ||||||
|         const int nb33, |  | ||||||
|         const int nb01, |  | ||||||
|         const int nb02, |  | ||||||
|         const int nb03, |  | ||||||
|         const int nb11, |  | ||||||
|         const int nb12, |  | ||||||
|         const int nb13, |  | ||||||
|         const int nb21, |  | ||||||
|         const int nb22, |  | ||||||
|         const int nb23, |  | ||||||
|         const int ne0, |  | ||||||
|         const int ne1, |  | ||||||
|         const int ne2, |  | ||||||
|         const int ne3) { |  | ||||||
| #ifdef FLASH_ATTN_AVAILABLE | #ifdef FLASH_ATTN_AVAILABLE | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
| @@ -59,8 +39,7 @@ static __global__ void flash_attn_vec_ext_f32( | |||||||
|         GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); |         GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); | ||||||
|         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); |         GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); | ||||||
|         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); |         GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); | ||||||
|         GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); |         GGML_UNUSED(nb23); | ||||||
|         GGML_UNUSED(ne2); GGML_UNUSED(ne3); |  | ||||||
|         NO_DEVICE_CODE; |         NO_DEVICE_CODE; | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
| @@ -198,13 +177,16 @@ static __global__ void flash_attn_vec_ext_f32( | |||||||
|  |  | ||||||
|     float VKQ[ncols] = {0.0f}; |     float VKQ[ncols] = {0.0f}; | ||||||
|  |  | ||||||
|  |     K     += blockIdx.y*D * nb11; | ||||||
|  |     V     += blockIdx.y*D * nb21; | ||||||
|  |     maskh += blockIdx.y*D; | ||||||
|     for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { |     for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { | ||||||
|         // Calculate KQ tile and keep track of new maximum KQ values: |         // Calculate KQ tile and keep track of new maximum KQ values: | ||||||
|  |  | ||||||
|         if (mask) { |         if (mask) { | ||||||
| #pragma unroll | #pragma unroll | ||||||
|             for (int j = 0; j < ncols; ++j) { |             for (int j = 0; j < ncols; ++j) { | ||||||
|                 maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]); |                 maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             __syncthreads(); |             __syncthreads(); | ||||||
| @@ -246,7 +228,7 @@ static __global__ void flash_attn_vec_ext_f32( | |||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|             for (int j = 0; j < ncols; ++j) { |             for (int j = 0; j < ncols; ++j) { | ||||||
|                 float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); |                 float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); | ||||||
|                 sum = warp_reduce_sum(sum); |                 sum = warp_reduce_sum(sum); | ||||||
|  |  | ||||||
|                 if (use_logit_softcap) { |                 if (use_logit_softcap) { | ||||||
| @@ -297,13 +279,17 @@ static __global__ void flash_attn_vec_ext_f32( | |||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid); |             const float V_ki = dequantize_1_v(V + k*nb21, tid); | ||||||
| #pragma unroll | #pragma unroll | ||||||
|             for (int j = 0; j < ncols; ++j) { |             for (int j = 0; j < ncols; ++j) { | ||||||
|                 VKQ[j] += V_ki*KQ[j*D + k]; |                 VKQ[j] += V_ki*KQ[j*D + k]; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         K     += gridDim.y*D * nb11; | ||||||
|  |         V     += gridDim.y*D * nb21; | ||||||
|  |         maskh += gridDim.y*D; | ||||||
|  |  | ||||||
|         __syncthreads(); |         __syncthreads(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -348,7 +334,6 @@ static __global__ void flash_attn_vec_ext_f32( | |||||||
|     GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); |     GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); | ||||||
|     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); |     GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); | ||||||
|     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); |     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); | ||||||
|     GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); |  | ||||||
|     NO_DEVICE_CODE; |     NO_DEVICE_CODE; | ||||||
| #endif // FLASH_ATTN_AVAILABLE | #endif // FLASH_ATTN_AVAILABLE | ||||||
| } | } | ||||||
|   | |||||||
| @@ -37,33 +37,13 @@ static __global__ void flash_attn_ext_f16( | |||||||
|         const float m1, |         const float m1, | ||||||
|         const uint32_t n_head_log2, |         const uint32_t n_head_log2, | ||||||
|         const float logit_softcap, |         const float logit_softcap, | ||||||
|         const int ne00, |         const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, | ||||||
|         const int ne01, |                             const int32_t nb01, const int32_t nb02, const int32_t nb03, | ||||||
|         const int ne02, |         const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, | ||||||
|         const int ne03, |                             const int32_t nb11, const int32_t nb12, const int64_t nb13, | ||||||
|         const int ne10, |                             const int32_t nb21, const int32_t nb22, const int64_t nb23, | ||||||
|         const int ne11, |                             const int32_t ne31, const int32_t ne32, const int32_t ne33, | ||||||
|         const int ne12, |                             const int32_t nb31, const int32_t nb32, const int64_t nb33) { | ||||||
|         const int ne13, |  | ||||||
|         const int ne31, |  | ||||||
|         const int ne32, |  | ||||||
|         const int ne33, |  | ||||||
|         const int nb31, |  | ||||||
|         const int nb32, |  | ||||||
|         const int nb33, |  | ||||||
|         const int nb01, |  | ||||||
|         const int nb02, |  | ||||||
|         const int nb03, |  | ||||||
|         const int nb11, |  | ||||||
|         const int nb12, |  | ||||||
|         const int nb13, |  | ||||||
|         const int nb21, |  | ||||||
|         const int nb22, |  | ||||||
|         const int nb23, |  | ||||||
|         const int ne0, |  | ||||||
|         const int ne1, |  | ||||||
|         const int ne2, |  | ||||||
|         const int ne3) { |  | ||||||
| #if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) | #if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
|     if (use_logit_softcap && !(D == 128 || D == 256)) { |     if (use_logit_softcap && !(D == 128 || D == 256)) { | ||||||
| @@ -197,7 +177,7 @@ static __global__ void flash_attn_ext_f16( | |||||||
| #pragma unroll | #pragma unroll | ||||||
|             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { |             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { | ||||||
|                 frag_a_K K_a; |                 frag_a_K K_a; | ||||||
|                 wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); |                 wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); | ||||||
| #pragma unroll | #pragma unroll | ||||||
|                 for (int j = 0; j < ncols/frag_n; ++j) { |                 for (int j = 0; j < ncols/frag_n; ++j) { | ||||||
|                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); |                     wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); | ||||||
| @@ -344,7 +324,7 @@ static __global__ void flash_attn_ext_f16( | |||||||
|                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16; |                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16; | ||||||
|  |  | ||||||
|                 frag_a_V v_a; |                 frag_a_V v_a; | ||||||
|                 wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); |                 wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); | ||||||
| #pragma unroll | #pragma unroll | ||||||
|                 for (int j = 0; j < ncols/frag_n; ++j) { |                 for (int j = 0; j < ncols/frag_n; ++j) { | ||||||
|                     wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); |                     wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); | ||||||
| @@ -451,7 +431,6 @@ static __global__ void flash_attn_ext_f16( | |||||||
|     GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); |     GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); | ||||||
|     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); |     GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); | ||||||
|     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); |     GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); | ||||||
|     GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); |  | ||||||
|     NO_DEVICE_CODE; |     NO_DEVICE_CODE; | ||||||
| #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) | #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) | ||||||
| } | } | ||||||
|   | |||||||
| @@ -280,22 +280,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst | |||||||
|     const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; |     const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; | ||||||
|     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); |     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); | ||||||
|  |  | ||||||
|     if (GGML_CUDA_CC_IS_AMD(cc)) { |  | ||||||
| #if defined(GGML_HIP_ROCWMMA_FATTN) | #if defined(GGML_HIP_ROCWMMA_FATTN) | ||||||
|         if (fp16_mma_available(cc)) { |     if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { | ||||||
|             ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); |         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); | ||||||
|             return; |  | ||||||
|         } |  | ||||||
| #endif // defined(GGML_HIP_ROCWMMA_FATTN) |  | ||||||
|  |  | ||||||
|         // On AMD the tile kernels perform poorly, use the vec kernel instead: |  | ||||||
|         if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { |  | ||||||
|             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); |  | ||||||
|         } else { |  | ||||||
|             ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); |  | ||||||
|         } |  | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  | #endif // defined(GGML_HIP_ROCWMMA_FATTN) | ||||||
|  |  | ||||||
|     if (!fast_fp16_available(cc)) { |     if (!fast_fp16_available(cc)) { | ||||||
|         if (Q->ne[1] <= 8 || Q->ne[0] == 256) { |         if (Q->ne[1] <= 8 || Q->ne[0] == 256) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler