mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	CUDA: use mma PTX instructions for FlashAttention (#11583)
* CUDA: use mma PTX instructions for FlashAttention * __shfl_sync workaround for movmatrix * add __shfl_sync to HIP Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
		
							
								
								
									
										2
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								Makefile
									
									
									
									
									
								
							| @@ -596,7 +596,7 @@ ifdef GGML_RPC | |||||||
| 	OBJ_GGML_EXT += ggml/src/ggml-rpc.o | 	OBJ_GGML_EXT += ggml/src/ggml-rpc.o | ||||||
| endif # GGML_RPC | endif # GGML_RPC | ||||||
|  |  | ||||||
| OBJ_CUDA_TMPL      = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-wmma*.cu)) | OBJ_CUDA_TMPL      = $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/fattn-mma*.cu)) | ||||||
| OBJ_CUDA_TMPL     += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/mmq*.cu)) | OBJ_CUDA_TMPL     += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/template-instances/mmq*.cu)) | ||||||
|  |  | ||||||
| ifdef GGML_CUDA_FA_ALL_QUANTS | ifdef GGML_CUDA_FA_ALL_QUANTS | ||||||
|   | |||||||
| @@ -1775,7 +1775,7 @@ extern "C" { | |||||||
|             struct ggml_tensor  * a, |             struct ggml_tensor  * a, | ||||||
|             int                   k); |             int                   k); | ||||||
|  |  | ||||||
| #define GGML_KQ_MASK_PAD 32 | #define GGML_KQ_MASK_PAD 64 | ||||||
|  |  | ||||||
|     // q:    [n_embd, n_batch,     n_head,    1] |     // q:    [n_embd, n_batch,     n_head,    1] | ||||||
|     // k:    [n_embd, n_kv,        n_head_kv, 1] |     // k:    [n_embd, n_kv,        n_head_kv, 1] | ||||||
|   | |||||||
| @@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND) | |||||||
|     list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") |     list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") | ||||||
|  |  | ||||||
|     file(GLOB   GGML_SOURCES_CUDA "*.cu") |     file(GLOB   GGML_SOURCES_CUDA "*.cu") | ||||||
|     file(GLOB   SRCS "template-instances/fattn-wmma*.cu") |     file(GLOB   SRCS "template-instances/fattn-mma*.cu") | ||||||
|     list(APPEND GGML_SOURCES_CUDA ${SRCS}) |     list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||||||
|     file(GLOB   SRCS "template-instances/mmq*.cu") |     file(GLOB   SRCS "template-instances/mmq*.cu") | ||||||
|     list(APPEND GGML_SOURCES_CUDA ${SRCS}) |     list(APPEND GGML_SOURCES_CUDA ${SRCS}) | ||||||
|   | |||||||
| @@ -148,7 +148,7 @@ typedef float2 dfloat2; | |||||||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA | #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA | ||||||
|  |  | ||||||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING | #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING | ||||||
| #define INT8_MMA_AVAILABLE | #define NEW_MMA_AVAILABLE | ||||||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING | #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING | ||||||
|  |  | ||||||
| #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) | #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) | ||||||
| @@ -159,11 +159,13 @@ static constexpr bool fast_fp16_available(const int cc) { | |||||||
|     return cc >= GGML_CUDA_CC_PASCAL && cc != 610; |     return cc >= GGML_CUDA_CC_PASCAL && cc != 610; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Any FP16 tensor cores are available. | ||||||
| static constexpr bool fp16_mma_available(const int cc) { | static constexpr bool fp16_mma_available(const int cc) { | ||||||
|     return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA; |     return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA; | ||||||
| } | } | ||||||
|  |  | ||||||
| static constexpr bool int8_mma_available(const int cc) { | // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. | ||||||
|  | static constexpr bool new_mma_available(const int cc) { | ||||||
|     return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING; |     return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING; | ||||||
| } | } | ||||||
|  |  | ||||||
|   | |||||||
| @@ -516,6 +516,104 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) { | |||||||
|         nullptr; |         nullptr; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template<int D, int ncols, int KQ_stride> // D == head size | ||||||
|  | #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||||||
|  | __launch_bounds__(D, 1) | ||||||
|  | #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||||||
|  | static __global__ void flash_attn_stream_k_fixup( | ||||||
|  |         float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) { | ||||||
|  |     const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols); | ||||||
|  |  | ||||||
|  |     const int iter_k = ne11 / KQ_stride; | ||||||
|  |     const int iter_j = (ne01 + (ncols - 1)) / ncols; | ||||||
|  |  | ||||||
|  |     const int bidx0 = blockIdx.x; | ||||||
|  |  | ||||||
|  |     const int kbc0      = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x; | ||||||
|  |     const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x; | ||||||
|  |  | ||||||
|  |     const bool did_not_have_any_data   = kbc0 == kbc0_stop; | ||||||
|  |     const bool wrote_beginning_of_tile = kbc0 % iter_k == 0; | ||||||
|  |     const bool did_not_write_last      = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0; | ||||||
|  |     if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const int channel = kbc0 / (iter_k*iter_j); | ||||||
|  |     const int jt      = (kbc0 - channel*iter_k*iter_j) / iter_k; | ||||||
|  |  | ||||||
|  |     dst += jt*ncols*ne02*D + channel*D; | ||||||
|  |  | ||||||
|  |     // Load the partial result that needs a fixup: | ||||||
|  |     float dst_val[ncols] = {0.0f}; | ||||||
|  |     float max_val[ncols] = {0.0f}; | ||||||
|  |     float rowsum[ncols]  = {0.0f}; | ||||||
|  | #pragma unroll | ||||||
|  |     for (int j = 0; j < ncols; ++j) { | ||||||
|  |         if (jt*ncols + j >= ne01) { | ||||||
|  |             break; | ||||||
|  |         } | ||||||
|  |         dst_val[j] = dst[j*ne02*D + threadIdx.x]; | ||||||
|  |  | ||||||
|  |         const float2 tmp = dst_fixup[bidx0*ncols + j]; | ||||||
|  |         max_val[j] = tmp.x; | ||||||
|  |         rowsum[j]  = tmp.y; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Iterate over previous blocks and compute the combined results. | ||||||
|  |     // All CUDA blocks that get here must have a previous block that needs a fixup. | ||||||
|  |     int bidx = bidx0 - 1; | ||||||
|  |     int kbc_stop = kbc0; | ||||||
|  |     while(true) { | ||||||
|  |         const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x; | ||||||
|  |         if (kbc == kbc_stop) { // Did not have any data. | ||||||
|  |             bidx--; | ||||||
|  |             kbc_stop = kbc; | ||||||
|  |             continue; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int j = 0; j < ncols; ++j) { | ||||||
|  |             if (jt*ncols + j >= ne01) { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x]; | ||||||
|  |  | ||||||
|  |             const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j]; | ||||||
|  |  | ||||||
|  |             // Scale the current and new value accumulators depending on the max. values. | ||||||
|  |             const float max_val_new = fmaxf(max_val[j], tmp.x); | ||||||
|  |  | ||||||
|  |             const float diff_val = max_val[j] - max_val_new; | ||||||
|  |             const float diff_add = tmp.x      - max_val_new; | ||||||
|  |  | ||||||
|  |             const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f; | ||||||
|  |             const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f; | ||||||
|  |  | ||||||
|  |             dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add; | ||||||
|  |             rowsum[j]  = scale_val*rowsum[j]  + scale_add*tmp.y; | ||||||
|  |  | ||||||
|  |             max_val[j] = max_val_new; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // If this block started in a previous tile we are done and don't need to combine additional partial results. | ||||||
|  |         if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) { | ||||||
|  |             break; | ||||||
|  |         } | ||||||
|  |         bidx--; | ||||||
|  |         kbc_stop = kbc; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Write back final result: | ||||||
|  | #pragma unroll | ||||||
|  |     for (int j = 0; j < ncols; ++j) { | ||||||
|  |         if (jt*ncols + j >= ne01) { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |         dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j]; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| template<int D, int parallel_blocks> // D == head size | template<int D, int parallel_blocks> // D == head size | ||||||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||||||
| __launch_bounds__(D, 1) | __launch_bounds__(D, 1) | ||||||
| @@ -581,10 +679,11 @@ static void on_no_fattn_vec_case(const int D) { | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| template <int D, int parallel_blocks> | // parallel_blocks == 0 is stream-k decomposition | ||||||
|  | template <int D, int cols_per_block, int parallel_blocks, int KQ_stride> | ||||||
| void launch_fattn( | void launch_fattn( | ||||||
|     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, |     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, | ||||||
|     const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V |     const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V | ||||||
| ) { | ) { | ||||||
|     const ggml_tensor * Q = dst->src[0]; |     const ggml_tensor * Q = dst->src[0]; | ||||||
|     const ggml_tensor * K = dst->src[1]; |     const ggml_tensor * K = dst->src[1]; | ||||||
| @@ -603,20 +702,23 @@ void launch_fattn( | |||||||
|  |  | ||||||
|     GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); |     GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); | ||||||
|  |  | ||||||
|  |     GGML_ASSERT(Q->ne[3] == 1); | ||||||
|  |  | ||||||
|     ggml_cuda_pool & pool = ctx.pool(); |     ggml_cuda_pool & pool = ctx.pool(); | ||||||
|     cudaStream_t main_stream = ctx.stream(); |     cudaStream_t main_stream = ctx.stream(); | ||||||
|  |     const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; | ||||||
|  |  | ||||||
|     ggml_cuda_pool_alloc<half>   K_f16(pool); |     ggml_cuda_pool_alloc<half>   K_f16(pool); | ||||||
|     ggml_cuda_pool_alloc<half>   V_f16(pool); |     ggml_cuda_pool_alloc<half>   V_f16(pool); | ||||||
|     ggml_cuda_pool_alloc<float>  dst_tmp(pool); |     ggml_cuda_pool_alloc<float>  dst_tmp(pool); | ||||||
|     ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); |     ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); | ||||||
|  |  | ||||||
|     char * K_data = (char *) K->data; |     const char * K_data = (const char *) K->data; | ||||||
|     size_t nb11 = K->nb[1]; |     size_t nb11 = K->nb[1]; | ||||||
|     size_t nb12 = K->nb[2]; |     size_t nb12 = K->nb[2]; | ||||||
|     size_t nb13 = K->nb[3]; |     size_t nb13 = K->nb[3]; | ||||||
|  |  | ||||||
|     char * V_data = (char *) V->data; |     const char * V_data = (const char *) V->data; | ||||||
|     size_t nb21 = V->nb[1]; |     size_t nb21 = V->nb[1]; | ||||||
|     size_t nb22 = V->nb[2]; |     size_t nb22 = V->nb[2]; | ||||||
|     size_t nb23 = V->nb[3]; |     size_t nb23 = V->nb[3]; | ||||||
| @@ -649,39 +751,60 @@ void launch_fattn( | |||||||
|         nb23 = nb23*bs*sizeof(half)/ts; |         nb23 = nb23*bs*sizeof(half)/ts; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block); | ||||||
|  |     const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3]; | ||||||
|  |  | ||||||
|  |     const dim3 block_dim(WARP_SIZE, nwarps, 1); | ||||||
|  |     dim3 blocks_num; | ||||||
|  |     if (parallel_blocks == 0) { | ||||||
|  |         // For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup. | ||||||
|  |         const int tiles_nwaves  = (ntiles_total - nsm - 1) / nsm; | ||||||
|  |         const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total; | ||||||
|  |         const bool short_context = K->ne[1] < 4096; | ||||||
|  |  | ||||||
|  |         const int nblocks_stream_k = 2*nsm; | ||||||
|  |  | ||||||
|  |         blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k; | ||||||
|  |         blocks_num.y = 1; | ||||||
|  |         blocks_num.z = 1; | ||||||
|  |  | ||||||
|  |         dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float)); | ||||||
|  |     } else { | ||||||
|  |         blocks_num.x = parallel_blocks*ntiles_x; | ||||||
|  |         blocks_num.y = Q->ne[2]; | ||||||
|  |         blocks_num.z = Q->ne[3]; | ||||||
|  |  | ||||||
|         if (parallel_blocks > 1) { |         if (parallel_blocks > 1) { | ||||||
|             dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); |             dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); | ||||||
|             dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); |             dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV)); | ||||||
|         } |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|     const dim3 block_dim(WARP_SIZE, nwarps, 1); |  | ||||||
|     const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); |  | ||||||
|     const int  shmem = 0; |  | ||||||
|  |  | ||||||
|     float scale         = 1.0f; |     float scale         = 1.0f; | ||||||
|     float max_bias      = 0.0f; |     float max_bias      = 0.0f; | ||||||
|     float logit_softcap = 0.0f; |     float logit_softcap = 0.0f; | ||||||
|  |  | ||||||
|     memcpy(&scale,         (float *) KQV->op_params + 0, sizeof(float)); |     memcpy(&scale,         (const float *) KQV->op_params + 0, sizeof(float)); | ||||||
|     memcpy(&max_bias,      (float *) KQV->op_params + 1, sizeof(float)); |     memcpy(&max_bias,      (const float *) KQV->op_params + 1, sizeof(float)); | ||||||
|     memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float)); |     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); | ||||||
|  |  | ||||||
|     if (logit_softcap != 0.0f) { |     if (logit_softcap != 0.0f) { | ||||||
|         scale /= logit_softcap; |         scale /= logit_softcap; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const uint32_t n_head      = Q->ne[2]; |     const uint32_t n_head      = Q->ne[2]; | ||||||
|     const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); |     const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head)))); | ||||||
|  |  | ||||||
|     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2); |     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2); | ||||||
|     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); |     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | ||||||
|  |  | ||||||
|     fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>( |     fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>( | ||||||
|         (const char *) Q->data, |         (const char *) Q->data, | ||||||
|         K_data, |         K_data, | ||||||
|         V_data, |         V_data, | ||||||
|         mask ? ((const char *) mask->data) : nullptr, |         mask ? ((const char *) mask->data) : nullptr, | ||||||
|         (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, |         (parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, | ||||||
|         scale, max_bias, m0, m1, n_head_log2, logit_softcap, |         scale, max_bias, m0, m1, n_head_log2, logit_softcap, | ||||||
|         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], |         Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], | ||||||
|         K->ne[0], K->ne[1], K->ne[2], K->ne[3], |         K->ne[0], K->ne[1], K->ne[2], K->ne[3], | ||||||
| @@ -693,16 +816,22 @@ void launch_fattn( | |||||||
|     ); |     ); | ||||||
|     CUDA_CHECK(cudaGetLastError()); |     CUDA_CHECK(cudaGetLastError()); | ||||||
|  |  | ||||||
|     if ((parallel_blocks) == 1) { |     if constexpr (parallel_blocks == 0) { | ||||||
|         return; |         if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles. | ||||||
|     } |             const dim3 block_dim_combine(D, 1, 1); | ||||||
|  |             const dim3 blocks_num_combine = blocks_num; | ||||||
|  |  | ||||||
|  |             flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride> | ||||||
|  |                 <<<blocks_num_combine, block_dim_combine, 0, main_stream>>> | ||||||
|  |                 ((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]); | ||||||
|  |         } | ||||||
|  |     } else if constexpr (parallel_blocks > 1) { | ||||||
|         const dim3 block_dim_combine(D, 1, 1); |         const dim3 block_dim_combine(D, 1, 1); | ||||||
|         const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); |         const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z); | ||||||
|     const int  shmem_combine = 0; |  | ||||||
|  |  | ||||||
|         flash_attn_combine_results<D, parallel_blocks> |         flash_attn_combine_results<D, parallel_blocks> | ||||||
|         <<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>> |             <<<blocks_num_combine, block_dim_combine, 0, main_stream>>> | ||||||
|             (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); |             (dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data); | ||||||
|  |     } | ||||||
|     CUDA_CHECK(cudaGetLastError()); |     CUDA_CHECK(cudaGetLastError()); | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										637
									
								
								ggml/src/ggml-cuda/fattn-mma-f16.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										637
									
								
								ggml/src/ggml-cuda/fattn-mma-f16.cuh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,637 @@ | |||||||
|  | #include "common.cuh" | ||||||
|  | #include "mma.cuh" | ||||||
|  | #include "fattn-common.cuh" | ||||||
|  |  | ||||||
|  | template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup> | ||||||
|  | static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( | ||||||
|  |         const float2 * const __restrict__ Q_f2, | ||||||
|  |         const half2  * const __restrict__ K_h2, | ||||||
|  |         const half2  * const __restrict__ V_h2, | ||||||
|  |         const half   * const __restrict__ maskh, | ||||||
|  |         float2       * const __restrict__ dstk, | ||||||
|  |         float2       * const __restrict__ dstk_fixup, | ||||||
|  |         const float scale, | ||||||
|  |         const float slope, | ||||||
|  |         const float logit_softcap, | ||||||
|  |         const int ne00, | ||||||
|  |         const int ne01, | ||||||
|  |         const int ne02, | ||||||
|  |         const int ne03, | ||||||
|  |         const int ne10, | ||||||
|  |         const int ne11, | ||||||
|  |         const int ne12, | ||||||
|  |         const int ne13, | ||||||
|  |         const int ne31, | ||||||
|  |         const int nb31, | ||||||
|  |         const int nb01, | ||||||
|  |         const int nb02, | ||||||
|  |         const int nb03, | ||||||
|  |         const int nb11, | ||||||
|  |         const int nb12, | ||||||
|  |         const int nb13, | ||||||
|  |         const int nb21, | ||||||
|  |         const int nb22, | ||||||
|  |         const int nb23, | ||||||
|  |         const int ne0, | ||||||
|  |         const int ne1, | ||||||
|  |         const int ne2, | ||||||
|  |         const int ne3, | ||||||
|  |         const int jt, | ||||||
|  |         const int kb0_start, | ||||||
|  |         const int kb0_stop) { | ||||||
|  | #ifdef NEW_MMA_AVAILABLE | ||||||
|  |     //In this kernel Q, K, V are matrices while i, j, k are matrix indices. | ||||||
|  |  | ||||||
|  |     typedef mma_A_I16K8<half2> mma_A; | ||||||
|  |     typedef mma_B_J8K8<half2>  mma_B; | ||||||
|  |     typedef mma_C_I16J8<float> mma_C_KQ; | ||||||
|  |     typedef mma_C_I16J8<half2> mma_C_VKQ; | ||||||
|  |  | ||||||
|  |     static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps"); | ||||||
|  |     constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column. | ||||||
|  |  | ||||||
|  |     static_assert(D         % nwarps == 0, "bad D"); | ||||||
|  |     static_assert(KQ_stride % nwarps == 0, "bad KQ_stride"); | ||||||
|  |  | ||||||
|  |     constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts. | ||||||
|  |     extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements. | ||||||
|  |  | ||||||
|  |     const int stride_Q    = nb01 / sizeof(float2); | ||||||
|  |     const int stride_KV   = nb11 / sizeof(half2); | ||||||
|  |     const int stride_mask = nb31 / sizeof(half); | ||||||
|  |  | ||||||
|  |     mma_B Q_B[D/(2*mma_B::K)]; | ||||||
|  |     mma_C_VKQ VKQ_C[D/mma_C_VKQ::I]; | ||||||
|  |  | ||||||
|  |     float2    KQ_rowsum = {0.0f, 0.0f}; | ||||||
|  |     float2       KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f}; | ||||||
|  |     float2 KQ_max_scale = {0.0f, 0.0f}; | ||||||
|  |  | ||||||
|  |     // Temporarily load Q data into tile_KV, will be loaded into registers afterwards. | ||||||
|  |     // The loading is done with decreasing granularity for D for better memory bandwidth. | ||||||
|  |     const half2 scale_h2 = make_half2(scale, scale); | ||||||
|  | #pragma unroll | ||||||
|  |     for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { | ||||||
|  |         const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); | ||||||
|  |         const int k0_stop  =                             D/2 - (D/2) % (1*stride_k); | ||||||
|  |         const int stride_j = WARP_SIZE / stride_k; | ||||||
|  |  | ||||||
|  |         if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { | ||||||
|  |             break; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) { | ||||||
|  |             const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); | ||||||
|  |  | ||||||
|  |             if (jt*ncols + j < ne01) { | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { | ||||||
|  |                     const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); | ||||||
|  |  | ||||||
|  |                     const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k]; | ||||||
|  |                     tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y); | ||||||
|  |                 } | ||||||
|  |             } else { | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { | ||||||
|  |                     const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); | ||||||
|  |  | ||||||
|  |                     tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __syncthreads(); | ||||||
|  |  | ||||||
|  |     { | ||||||
|  |         const int j0 = (threadIdx.y / np) * mma_B::J; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { | ||||||
|  |             Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __syncthreads(); | ||||||
|  |  | ||||||
|  |     // Iterate over ne11 == previous tokens: | ||||||
|  |     for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) { | ||||||
|  |         const int k_VKQ_0 = kb0*KQ_stride; | ||||||
|  |         mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)]; | ||||||
|  |  | ||||||
|  |         // Load K data into tile with decreasing granularity for D for better memory bandwidth: | ||||||
|  |         static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); | ||||||
|  | #pragma unroll | ||||||
|  |         for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { | ||||||
|  |             const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); | ||||||
|  |             const int k0_stop  =                             D/2 - (D/2) % (1*stride_k); | ||||||
|  |             const int stride_i = WARP_SIZE / stride_k; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |             for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) { | ||||||
|  |                 const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) { | ||||||
|  |                     const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); | ||||||
|  |  | ||||||
|  |                     tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ]; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         __syncthreads(); | ||||||
|  |  | ||||||
|  |         // Calculate tile of KQ: | ||||||
|  | #pragma unroll | ||||||
|  |         for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) { | ||||||
|  |             const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I; | ||||||
|  | #pragma unroll | ||||||
|  |             for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) { | ||||||
|  |                 mma_A K_A; | ||||||
|  |                 K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded); | ||||||
|  |                 KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         __syncthreads(); | ||||||
|  |  | ||||||
|  |         if (use_logit_softcap) { | ||||||
|  |             static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); | ||||||
|  | #pragma unroll | ||||||
|  |             for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) { | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int l = 0; l < mma_C_KQ::ne; ++l) { | ||||||
|  |                     KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if (maskh) { | ||||||
|  |             static_assert(KQ_stride % (np       *mma_C_KQ::I) == 0, "bad loop size"); | ||||||
|  |             static_assert(ncols     % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size"); | ||||||
|  | #pragma unroll | ||||||
|  |             for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) { | ||||||
|  |                 const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I; | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int l = 0; l < mma_C_KQ::ne; ++l) { | ||||||
|  |                     const int i = i0 + mma_C_KQ::get_i(l); | ||||||
|  |                     const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l); | ||||||
|  |  | ||||||
|  |                     KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // Calculate softmax for each KQ column using the current max. value. | ||||||
|  |         // The divisor is stored in KQ_rowsum and will be applied at the end. | ||||||
|  |         float2 KQ_max_new = KQ_max; | ||||||
|  |         static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); | ||||||
|  | #pragma unroll | ||||||
|  |         for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { | ||||||
|  | #pragma unroll | ||||||
|  |             for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) { | ||||||
|  |                 KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]); | ||||||
|  |                 KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // Values per KQ column are spread across 8 threads, does not need full warp reduce: | ||||||
|  | #pragma unroll | ||||||
|  |         for (int offset = 16; offset > 2; offset >>= 1) { | ||||||
|  |             KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE)); | ||||||
|  |             KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE)); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         { | ||||||
|  |             const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y); | ||||||
|  |             KQ_max_scale = make_float2(expf(diff.x), expf(diff.y)); | ||||||
|  |             if (diff.x <= SOFTMAX_FTZ_THRESHOLD) { | ||||||
|  |                 KQ_max_scale.x = 0.0f; | ||||||
|  |             } | ||||||
|  |             if (diff.y <= SOFTMAX_FTZ_THRESHOLD) { | ||||||
|  |                 KQ_max_scale.y = 0.0f; | ||||||
|  |             } | ||||||
|  |             KQ_max = KQ_max_new; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         float2 KQ_rowsum_add = make_float2(0.0f, 0.0f); | ||||||
|  |         static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size"); | ||||||
|  | #pragma unroll | ||||||
|  |         for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) { | ||||||
|  | #pragma unroll | ||||||
|  |             for (int l = 0; l < mma_C_KQ::ne; ++l) { | ||||||
|  |                 const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y; | ||||||
|  |                 const float diff = KQ_C[k].x[l] - KQ_max_l; | ||||||
|  |                 KQ_C[k].x[l] = expf(diff); | ||||||
|  |                 if (diff <= SOFTMAX_FTZ_THRESHOLD) { | ||||||
|  |                     KQ_C[k].x[l] = 0.0f; | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 if (l % 2 == 0) { | ||||||
|  |                     KQ_rowsum_add.x += KQ_C[k].x[l]; | ||||||
|  |                 } else { | ||||||
|  |                     KQ_rowsum_add.y += KQ_C[k].x[l]; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // Scale previous KQ_rowsum to account for a potential increase in KQ_max: | ||||||
|  |         KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x; | ||||||
|  |         KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y; | ||||||
|  |  | ||||||
|  |         const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y); | ||||||
|  | #pragma unroll | ||||||
|  |         for (int i = 0; i < D/mma_C_VKQ::I; ++i) { | ||||||
|  | #pragma unroll | ||||||
|  |             for (int l = 0; l < mma_C_VKQ::ne; ++l) { | ||||||
|  |                 VKQ_C[i].x[l] *= KQ_max_scale_h2; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // Convert KQ C tiles into B tiles for VKQ calculation: | ||||||
|  |         mma_B B[KQ_stride/(np*2*mma_B::K)]; | ||||||
|  |         static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size"); | ||||||
|  | #pragma unroll | ||||||
|  |         for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) { | ||||||
|  |             B[k] = KQ_C[k].to_mma_B(); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // Load V data into tile with decreasing granularity for D for better memory bandwidth: | ||||||
|  |         static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds"); | ||||||
|  | #pragma unroll | ||||||
|  |         for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { | ||||||
|  |             const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i); | ||||||
|  |             const int i0_stop  =                             D/2 - (D/2) % (1*stride_i); | ||||||
|  |             const int stride_k = WARP_SIZE / stride_i; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |             for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) { | ||||||
|  |                 const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) { | ||||||
|  |                     const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i); | ||||||
|  |  | ||||||
|  |                     tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V]; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         __syncthreads(); | ||||||
|  |  | ||||||
|  |         // Calculate VKQ tile: | ||||||
|  | #pragma unroll | ||||||
|  |         for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) { | ||||||
|  |             static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size"); | ||||||
|  | #pragma unroll | ||||||
|  |             for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) { | ||||||
|  |                 const int k0 = k00 + (threadIdx.y % np)*mma_A::K; | ||||||
|  |  | ||||||
|  |                 mma_A A; | ||||||
|  |                 A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded); | ||||||
|  |                 VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         __syncthreads(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Finally, sum up partial KQ rowsums. | ||||||
|  |     // The partial sums are spread across 8 threads each, does not need full reduce. | ||||||
|  | #pragma unroll | ||||||
|  |     for (int offset = 16; offset > 2; offset >>= 1) { | ||||||
|  |         KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE); | ||||||
|  |         KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Write VKQ accumulators to shared memory in column-major format. | ||||||
|  |     // It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM. | ||||||
|  |     // Also for np > 1 the combination is done via these values in shared memory. | ||||||
|  |     const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data | ||||||
|  | #pragma unroll | ||||||
|  |     for (int k0 = 0; k0 < D/2; k0 += mma_B::K) { | ||||||
|  |         const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format. | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < mma_B::ne; ++l) { | ||||||
|  |             const int k = k0 + mma_B::get_k(l); | ||||||
|  |  | ||||||
|  |             tile_KV[j_cwd*D2_padded + k] = B.x[l]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset | ||||||
|  |     const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta | ||||||
|  |     const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum | ||||||
|  |  | ||||||
|  |     if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) { | ||||||
|  |         // Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale. | ||||||
|  |         ((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __syncthreads(); | ||||||
|  |  | ||||||
|  |     static_assert(np == 1 || np == 2 || np == 4, "bad np"); | ||||||
|  |     if (np == 1) { | ||||||
|  |         // No combination is needed, the meta data can be directly written from registers to VRAM. | ||||||
|  |         if (needs_fixup && threadIdx.x < mma_B::J) { | ||||||
|  |             float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; | ||||||
|  |             dstk_fixup_meta[j_cwm] = KQ_cmr; | ||||||
|  |         } | ||||||
|  |         if (is_fixup && threadIdx.x < mma_B::J) { | ||||||
|  |             float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; | ||||||
|  |             dstk_fixup_meta[j_cwm] = KQ_cmr; | ||||||
|  |         } | ||||||
|  |     } else if (threadIdx.y % np == 0) { | ||||||
|  |         // Combine the meta data for parallel warps via shared memory. | ||||||
|  |         // Warps with threadIdx.y % np != 0 must NOT return early. | ||||||
|  |         // All threads must return simultaneously to avoid race conditions with work on the next tile. | ||||||
|  |  | ||||||
|  |         float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2; | ||||||
|  |  | ||||||
|  |         float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp. | ||||||
|  |         if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { | ||||||
|  |             KQ_cm = meta_j[0]; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps. | ||||||
|  | #pragma unroll | ||||||
|  |         for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { | ||||||
|  |             KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp. | ||||||
|  |         float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps. | ||||||
|  |         if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { | ||||||
|  |             KQ_crs = KQ_cms*meta_j[1]; | ||||||
|  |         } | ||||||
|  | #pragma unroll | ||||||
|  |         for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) { | ||||||
|  |             KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         // Write back combined meta data: | ||||||
|  |         if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) { | ||||||
|  |             meta_j[0] = KQ_cmn; // Combined max. KQ values. | ||||||
|  |             meta_j[1] = KQ_crs; // Combined KQ rowsums. | ||||||
|  |             meta_j[2] = KQ_cms; // KQ max scales per parallel warp. | ||||||
|  |         } | ||||||
|  |         if (needs_fixup && threadIdx.x < mma_B::J) { | ||||||
|  |             float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; | ||||||
|  |             dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); | ||||||
|  |         } | ||||||
|  |         if (is_fixup && threadIdx.x < mma_B::J) { | ||||||
|  |             float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; | ||||||
|  |             dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (np > 1) { | ||||||
|  |         __syncthreads(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (np == 1 || threadIdx.y % np == 0) { | ||||||
|  |         // The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums. | ||||||
|  |         // The values after that are for the partial results of the individual blocks. | ||||||
|  |         float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2)); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { | ||||||
|  |             const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k); | ||||||
|  |             const int k0_stop  =                             D/2 - (D/2) % (1*stride_k); | ||||||
|  |             const int stride_j = WARP_SIZE / stride_k; | ||||||
|  |  | ||||||
|  |             if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |             for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) { | ||||||
|  |                 const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); | ||||||
|  |                 const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J; | ||||||
|  |  | ||||||
|  |                 if (!is_fixup && jt*ncols + j_dst >= ne01) { | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|  |                 const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2; | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { | ||||||
|  |                     const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); | ||||||
|  |  | ||||||
|  |                     float2 dstk_val = make_float2(0.0f, 0.0f); | ||||||
|  | #pragma unroll | ||||||
|  |                     for (int ip = 0; ip < np; ++ip) { | ||||||
|  |                         const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2]; | ||||||
|  |                         const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]); | ||||||
|  |                         dstk_val.x += dstk_val_add.x*KQ_crs; | ||||||
|  |                         dstk_val.y += dstk_val_add.y*KQ_crs; | ||||||
|  |                     } | ||||||
|  |  | ||||||
|  |                     if (!needs_fixup && !is_fixup) { | ||||||
|  |                         const float KQ_rowsum_j = meta_j[1]; | ||||||
|  |                         dstk_val.x /= KQ_rowsum_j; | ||||||
|  |                         dstk_val.y /= KQ_rowsum_j; | ||||||
|  |                     } | ||||||
|  |  | ||||||
|  |                     if (is_fixup) { | ||||||
|  |                         dstk_fixup_data[j_dst*(D/2) + k] = dstk_val; | ||||||
|  |                     } else { | ||||||
|  |                         dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (np > 1) { | ||||||
|  |         __syncthreads(); | ||||||
|  |     } | ||||||
|  | #else | ||||||
|  |    NO_DEVICE_CODE; | ||||||
|  | #endif // NEW_MMA_AVAILABLE | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap> | ||||||
|  | #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||||||
|  | __launch_bounds__(nwarps*WARP_SIZE, 2) | ||||||
|  | #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||||||
|  | static __global__ void flash_attn_ext_f16( | ||||||
|  |         const char * __restrict__ Q, | ||||||
|  |         const char * __restrict__ K, | ||||||
|  |         const char * __restrict__ V, | ||||||
|  |         const char * __restrict__ mask, | ||||||
|  |         float      * __restrict__ dst, | ||||||
|  |         float2     * __restrict__ dst_meta, | ||||||
|  |         const float scale, | ||||||
|  |         const float max_bias, | ||||||
|  |         const float m0, | ||||||
|  |         const float m1, | ||||||
|  |         const uint32_t n_head_log2, | ||||||
|  |         const float logit_softcap, | ||||||
|  |         const int ne00, | ||||||
|  |         const int ne01, | ||||||
|  |         const int ne02, | ||||||
|  |         const int ne03, | ||||||
|  |         const int ne10, | ||||||
|  |         const int ne11, | ||||||
|  |         const int ne12, | ||||||
|  |         const int ne13, | ||||||
|  |         const int ne31, | ||||||
|  |         const int nb31, | ||||||
|  |         const int nb01, | ||||||
|  |         const int nb02, | ||||||
|  |         const int nb03, | ||||||
|  |         const int nb11, | ||||||
|  |         const int nb12, | ||||||
|  |         const int nb13, | ||||||
|  |         const int nb21, | ||||||
|  |         const int nb22, | ||||||
|  |         const int nb23, | ||||||
|  |         const int ne0, | ||||||
|  |         const int ne1, | ||||||
|  |         const int ne2, | ||||||
|  |         const int ne3) { | ||||||
|  |     // Skip unused kernel variants for faster compilation: | ||||||
|  |     if (use_logit_softcap && !(D == 128 || D == 256)) { | ||||||
|  |         NO_DEVICE_CODE; | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride"); | ||||||
|  |  | ||||||
|  |     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. | ||||||
|  |  | ||||||
|  |     const int iter_k = ne11 / KQ_stride; | ||||||
|  |     const int iter_j = (ne01 + (ncols - 1)) / ncols; | ||||||
|  |  | ||||||
|  |     // kbc == k block continuous, current index in continuous ijk space. | ||||||
|  |     int       kbc      = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x; | ||||||
|  |     const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x; | ||||||
|  |  | ||||||
|  |     // If the seams of 2 CUDA blocks fall within an output tile their results need to be combined. | ||||||
|  |     // For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup). | ||||||
|  |     // In the most general case >2 seams can fall into the same tile. | ||||||
|  |  | ||||||
|  |     // kb0 == k start index when in the output tile. | ||||||
|  |     int kb0_start = kbc % iter_k; | ||||||
|  |     int kb0_stop  = min(iter_k, kb0_start + kbc_stop - kbc); | ||||||
|  |     while (kbc < kbc_stop && kb0_stop == iter_k) { | ||||||
|  |         const int channel = kbc / (iter_k*iter_j); | ||||||
|  |         const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. | ||||||
|  |  | ||||||
|  |         const float2 * Q_f2  = (const float2 *) (Q + nb02* channel); | ||||||
|  |         const half2  * K_h2  = (const half2  *) (K + nb12*(channel / gqa_ratio)); | ||||||
|  |         const half2  * V_h2  = (const half2  *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape | ||||||
|  |         const half   * maskh = mask ? (const half  *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; | ||||||
|  |         float2       * dstk  = ((float2 *) dst) + channel*(D/2); | ||||||
|  |  | ||||||
|  |         const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); | ||||||
|  |  | ||||||
|  |         constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. | ||||||
|  |         if (kb0_start == 0) { | ||||||
|  |             constexpr bool needs_fixup = false; // CUDA block is working on an entire tile. | ||||||
|  |             flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup> | ||||||
|  |                 (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, | ||||||
|  |                 ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, | ||||||
|  |                 jt, kb0_start, kb0_stop); | ||||||
|  |         } else { | ||||||
|  |             constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile. | ||||||
|  |             flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup> | ||||||
|  |                 (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, | ||||||
|  |                 ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, | ||||||
|  |                 jt, kb0_start, kb0_stop); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         kbc += iter_k; | ||||||
|  |         kbc -= kbc % iter_k; | ||||||
|  |  | ||||||
|  |         kb0_start = 0; | ||||||
|  |         kb0_stop  = min(iter_k, kbc_stop - kbc); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (kbc >= kbc_stop) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const int channel = kbc / (iter_k*iter_j); | ||||||
|  |     const int jt      = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile. | ||||||
|  |  | ||||||
|  |     const float2 * Q_f2  = (const float2 *) (Q + nb02* channel); | ||||||
|  |     const half2  * K_h2  = (const half2  *) (K + nb12*(channel / gqa_ratio)); | ||||||
|  |     const half2  * V_h2  = (const half2  *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape | ||||||
|  |     const half   * maskh = mask ? (const half  *) mask + (nb31/sizeof(half))*jt*ncols : nullptr; | ||||||
|  |     float2       * dstk  = ((float2 *) dst) + channel*(D/2); | ||||||
|  |  | ||||||
|  |     const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1); | ||||||
|  |  | ||||||
|  |     constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. | ||||||
|  |     constexpr bool needs_fixup = false; | ||||||
|  |     flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup> | ||||||
|  |         (Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap, | ||||||
|  |         ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3, | ||||||
|  |         jt, kb0_start, kb0_stop); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <int D, int cols_per_block> | ||||||
|  | void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||||||
|  |     typedef mma_A_I16K8<half2> mma_A; | ||||||
|  |     typedef mma_B_J8K8<half2>  mma_B; | ||||||
|  |  | ||||||
|  |     static_assert(D              % mma_B::K == 0, "bad D"); | ||||||
|  |     static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block"); | ||||||
|  |  | ||||||
|  |     const ggml_tensor * KQV = dst; | ||||||
|  |  | ||||||
|  |     constexpr int    KQ_stride     = D <= 128 ? 64 : 32; | ||||||
|  |     constexpr int    nwarps        = (KQ_stride == 32 && cols_per_block <= 16) ? | ||||||
|  |                                      cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8); | ||||||
|  |     constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half); | ||||||
|  |  | ||||||
|  |     float logit_softcap; | ||||||
|  |     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); | ||||||
|  |  | ||||||
|  |     fattn_kernel_t fattn_kernel; | ||||||
|  |     if (logit_softcap == 0.0f) { | ||||||
|  |         constexpr bool use_logit_softcap = false; | ||||||
|  |         fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>; | ||||||
|  |     } else { | ||||||
|  |         constexpr bool use_logit_softcap = true; | ||||||
|  |         fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>; | ||||||
|  |     } | ||||||
|  |     launch_fattn<D, cols_per_block, 0, KQ_stride>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define DECL_FATTN_MMA_F16_CASE(D, cols_per_block)                          \ | ||||||
|  |     template void ggml_cuda_flash_attn_ext_mma_f16_case                     \ | ||||||
|  |     <D, cols_per_block>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ | ||||||
|  |  | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE( 64,  8); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE( 80,  8); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE( 96,  8); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE(112,  8); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE(128,  8); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE(256,  8); | ||||||
|  |  | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE( 64, 16); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE( 80, 16); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE( 96, 16); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE(112, 16); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE(128, 16); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE(256, 16); | ||||||
|  |  | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE( 64, 32); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE( 80, 32); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE( 96, 32); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE(112, 32); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE(128, 32); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE(256, 32); | ||||||
|  |  | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE( 64, 64); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE( 80, 64); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE( 96, 64); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE(112, 64); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE(128, 64); | ||||||
|  | extern DECL_FATTN_MMA_F16_CASE(256, 64); | ||||||
| @@ -45,7 +45,17 @@ static __global__ void flash_attn_tile_ext_f16( | |||||||
|         const int ne2, |         const int ne2, | ||||||
|         const int ne3) { |         const int ne3) { | ||||||
| #ifdef FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|  |  | ||||||
|  | #ifndef FLASH_ATTN_AVAILABLE | ||||||
|  |     NO_DEVICE_CODE; | ||||||
|  |     return; | ||||||
|  | #endif // FLASH_ATTN_AVAILABLE | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
|  | #ifdef FP16_MMA_AVAILABLE | ||||||
|  |     NO_DEVICE_CODE; | ||||||
|  |     return; | ||||||
|  | #endif // FP16_MMA_AVAILABLE | ||||||
|     if (use_logit_softcap && !(D == 128 || D == 256)) { |     if (use_logit_softcap && !(D == 128 || D == 256)) { | ||||||
|         NO_DEVICE_CODE; |         NO_DEVICE_CODE; | ||||||
|         return; |         return; | ||||||
| @@ -290,14 +300,16 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * | |||||||
|         case  64: { |         case  64: { | ||||||
|             constexpr int    D             = 64; |             constexpr int    D             = 64; | ||||||
|             constexpr int    nwarps        = 8; |             constexpr int    nwarps        = 8; | ||||||
|  |             constexpr size_t nbytes_shared = 0; | ||||||
|             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>; |             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>; | ||||||
|             launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); |             launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); | ||||||
|         } break; |         } break; | ||||||
|         case 128: { |         case 128: { | ||||||
|             constexpr int    D             = 128; |             constexpr int    D             = 128; | ||||||
|             constexpr int    nwarps        = 8; |             constexpr int    nwarps        = 8; | ||||||
|  |             constexpr size_t nbytes_shared = 0; | ||||||
|             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>; |             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>; | ||||||
|             launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); |             launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); | ||||||
|         } break; |         } break; | ||||||
|         default: { |         default: { | ||||||
|             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); |             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); | ||||||
|   | |||||||
| @@ -48,7 +48,12 @@ static __global__ void flash_attn_tile_ext_f32( | |||||||
|     NO_DEVICE_CODE; |     NO_DEVICE_CODE; | ||||||
|     return; |     return; | ||||||
| #endif // FLASH_ATTN_AVAILABLE | #endif // FLASH_ATTN_AVAILABLE | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
|  | #ifdef FP16_MMA_AVAILABLE | ||||||
|  |     NO_DEVICE_CODE; | ||||||
|  |     return; | ||||||
|  | #endif // FP16_MMA_AVAILABLE | ||||||
|     if (use_logit_softcap && !(D == 128 || D == 256)) { |     if (use_logit_softcap && !(D == 128 || D == 256)) { | ||||||
|         NO_DEVICE_CODE; |         NO_DEVICE_CODE; | ||||||
|         return; |         return; | ||||||
| @@ -289,14 +294,16 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * | |||||||
|         case  64: { |         case  64: { | ||||||
|             constexpr int    D             = 64; |             constexpr int    D             = 64; | ||||||
|             constexpr int    nwarps        = 8; |             constexpr int    nwarps        = 8; | ||||||
|  |             constexpr size_t nbytes_shared = 0; | ||||||
|             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>; |             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>; | ||||||
|             launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); |             launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); | ||||||
|         } break; |         } break; | ||||||
|         case 128: { |         case 128: { | ||||||
|             constexpr int    D             = 128; |             constexpr int    D             = 128; | ||||||
|             constexpr int    nwarps        = 8; |             constexpr int    nwarps        = 8; | ||||||
|  |             constexpr size_t nbytes_shared = 0; | ||||||
|             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>; |             fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>; | ||||||
|             launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); |             launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true); | ||||||
|         } break; |         } break; | ||||||
|         default: { |         default: { | ||||||
|             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); |             GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128."); | ||||||
|   | |||||||
| @@ -42,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16( | |||||||
|         const int ne2, |         const int ne2, | ||||||
|         const int ne3) { |         const int ne3) { | ||||||
| #ifdef FP16_AVAILABLE | #ifdef FP16_AVAILABLE | ||||||
|  |  | ||||||
|  | #ifndef FLASH_ATTN_AVAILABLE | ||||||
|  |     NO_DEVICE_CODE; | ||||||
|  |     return; | ||||||
|  | #endif // FLASH_ATTN_AVAILABLE | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
|     if (use_logit_softcap && !(D == 128 || D == 256)) { |     if (use_logit_softcap && !(D == 128 || D == 256)) { | ||||||
|         NO_DEVICE_CODE; |         NO_DEVICE_CODE; | ||||||
| @@ -303,7 +309,8 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, | |||||||
|     fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>; |     fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>; | ||||||
|     constexpr bool need_f16_K = D != 128; |     constexpr bool need_f16_K = D != 128; | ||||||
|     constexpr bool need_f16_V = D != 128 && D != 64; |     constexpr bool need_f16_V = D != 128 && D != 64; | ||||||
|     launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); |     constexpr size_t nbytes_shared = 0; | ||||||
|  |     launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); | ||||||
| } | } | ||||||
|  |  | ||||||
| template <int D, ggml_type type_K, ggml_type type_V> | template <int D, ggml_type type_K, ggml_type type_V> | ||||||
|   | |||||||
| @@ -41,6 +41,11 @@ static __global__ void flash_attn_vec_ext_f32( | |||||||
|         const int ne1, |         const int ne1, | ||||||
|         const int ne2, |         const int ne2, | ||||||
|         const int ne3) { |         const int ne3) { | ||||||
|  | #ifndef FLASH_ATTN_AVAILABLE | ||||||
|  |     NO_DEVICE_CODE; | ||||||
|  |     return; | ||||||
|  | #endif // FLASH_ATTN_AVAILABLE | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
|     if (use_logit_softcap && !(D == 128 || D == 256)) { |     if (use_logit_softcap && !(D == 128 || D == 256)) { | ||||||
|         NO_DEVICE_CODE; |         NO_DEVICE_CODE; | ||||||
| @@ -284,7 +289,8 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, | |||||||
|     fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>; |     fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>; | ||||||
|     constexpr bool need_f16_K = D != 128; |     constexpr bool need_f16_K = D != 128; | ||||||
|     constexpr bool need_f16_V = D != 128 && D != 64; |     constexpr bool need_f16_V = D != 128 && D != 64; | ||||||
|     launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V); |     constexpr size_t nbytes_shared = 0; | ||||||
|  |     launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V); | ||||||
| } | } | ||||||
|  |  | ||||||
| template <int D, ggml_type type_K, ggml_type type_V> | template <int D, ggml_type type_K, ggml_type type_V> | ||||||
|   | |||||||
							
								
								
									
										648
									
								
								ggml/src/ggml-cuda/fattn-wmma-f16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										648
									
								
								ggml/src/ggml-cuda/fattn-wmma-f16.cu
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,648 @@ | |||||||
|  | // Old and deprecated WMMA FlashAttention implementation. | ||||||
|  | // It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing. | ||||||
|  | // Long-term the WMMA code should be replaced with a dedicated Volta implementation. | ||||||
|  |  | ||||||
|  | #include "common.cuh" | ||||||
|  | #include "fattn-common.cuh" | ||||||
|  | #include "fattn-wmma-f16.cuh" | ||||||
|  |  | ||||||
|  | #ifdef FP16_MMA_AVAILABLE | ||||||
|  | #include <mma.h> | ||||||
|  | #endif // FP16_MMA_AVAILABLE | ||||||
|  |  | ||||||
|  | // D == head size, VKQ_stride == num VKQ rows calculated in parallel: | ||||||
|  | template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap> | ||||||
|  | #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||||||
|  | __launch_bounds__(nwarps*WARP_SIZE, 1) | ||||||
|  | #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) | ||||||
|  | static __global__ void flash_attn_ext_f16( | ||||||
|  |         const char * __restrict__ Q, | ||||||
|  |         const char * __restrict__ K, | ||||||
|  |         const char * __restrict__ V, | ||||||
|  |         const char * __restrict__ mask, | ||||||
|  |         float      * __restrict__ dst, | ||||||
|  |         float2     * __restrict__ dst_meta, | ||||||
|  |         const float scale, | ||||||
|  |         const float max_bias, | ||||||
|  |         const float m0, | ||||||
|  |         const float m1, | ||||||
|  |         const uint32_t n_head_log2, | ||||||
|  |         const float logit_softcap, | ||||||
|  |         const int ne00, | ||||||
|  |         const int ne01, | ||||||
|  |         const int ne02, | ||||||
|  |         const int ne03, | ||||||
|  |         const int ne10, | ||||||
|  |         const int ne11, | ||||||
|  |         const int ne12, | ||||||
|  |         const int ne13, | ||||||
|  |         const int ne31, | ||||||
|  |         const int nb31, | ||||||
|  |         const int nb01, | ||||||
|  |         const int nb02, | ||||||
|  |         const int nb03, | ||||||
|  |         const int nb11, | ||||||
|  |         const int nb12, | ||||||
|  |         const int nb13, | ||||||
|  |         const int nb21, | ||||||
|  |         const int nb22, | ||||||
|  |         const int nb23, | ||||||
|  |         const int ne0, | ||||||
|  |         const int ne1, | ||||||
|  |         const int ne2, | ||||||
|  |         const int ne3) { | ||||||
|  | #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||||||
|  |     // Skip unused kernel variants for faster compilation: | ||||||
|  |     if (use_logit_softcap && !(D == 128 || D == 256)) { | ||||||
|  |         NO_DEVICE_CODE; | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     //In this kernel Q, K, V are matrices while i, j, k are matrix indices. | ||||||
|  |  | ||||||
|  |     const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. | ||||||
|  |     const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel. | ||||||
|  |  | ||||||
|  |     static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); | ||||||
|  |     static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); | ||||||
|  |     constexpr int frag_m = ncols == 8 ? 32 : 16; | ||||||
|  |     constexpr int frag_n = ncols == 8 ?  8 : 16; | ||||||
|  |     static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); | ||||||
|  |     typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K; | ||||||
|  |     typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V; | ||||||
|  |     typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b; | ||||||
|  |     typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ; | ||||||
|  |     typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ; | ||||||
|  |  | ||||||
|  |     constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel. | ||||||
|  |     constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. | ||||||
|  |     static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); | ||||||
|  |  | ||||||
|  |     // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: | ||||||
|  |     constexpr int D_padded = D + 8; | ||||||
|  |     constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; | ||||||
|  |     constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); | ||||||
|  |  | ||||||
|  |     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. | ||||||
|  |     const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0); | ||||||
|  |     const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio)); | ||||||
|  |     const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape | ||||||
|  |     const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0; | ||||||
|  |     const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2); | ||||||
|  |  | ||||||
|  |     const int stride_Q  = nb01 / sizeof(float); | ||||||
|  |     const int stride_KV = nb11 / sizeof(half); | ||||||
|  |  | ||||||
|  |     const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); | ||||||
|  |     const half  slopeh = __float2half(slopef); | ||||||
|  |     const half2 slope2 = make_half2(slopef, slopef); | ||||||
|  |  | ||||||
|  |     const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap); | ||||||
|  |  | ||||||
|  |     frag_b Q_b[D/16][ncols/frag_n]; | ||||||
|  |  | ||||||
|  |     // A single buffer for temporarily holding tiles of KQ and VKQ parts: | ||||||
|  |     constexpr int mem_KQ = ncols*kqs_padded*kqar; | ||||||
|  |     constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; | ||||||
|  |     __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; | ||||||
|  |     float * KQ_f = (float *) KQ; | ||||||
|  |     half2 * KQ2 = (half2 *) KQ; | ||||||
|  |  | ||||||
|  |     float    KQ_rowsum_f[ncols/nwarps] = {0.0f}; | ||||||
|  |     float       KQ_max_f[ncols/nwarps]; | ||||||
|  |     float KQ_max_scale_f[ncols/nwarps] = {0.0f}; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |     for (int j = 0; j < ncols/nwarps; ++j) { | ||||||
|  |         KQ_max_f[j] = -FLT_MAX/2.0f; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; | ||||||
|  |     half2       KQ_max_h2[ncols/nwarps]; | ||||||
|  |     half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |     for (int j = 0; j < ncols/nwarps; ++j) { | ||||||
|  |         KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. | ||||||
|  |     half2 * VKQ2 = (half2 *) VKQ; | ||||||
|  | #pragma unroll | ||||||
|  |     for (int j0 = 0; j0 < ncols; j0 += nwarps) { | ||||||
|  |         const int j = j0 + threadIdx.y; | ||||||
|  | #pragma unroll | ||||||
|  |         for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { | ||||||
|  |             const int i = i0 + threadIdx.x; | ||||||
|  |             if (i0 + WARP_SIZE > D/2 && i >= D/2) { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     // Convert Q to half and apply scale, temporarily store in KQ: | ||||||
|  | #pragma unroll | ||||||
|  |     for (int j0 = 0; j0 < ncols; j0 += nwarps) { | ||||||
|  |         const int j = j0 + threadIdx.y; | ||||||
|  | #pragma unroll | ||||||
|  |         for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { | ||||||
|  |             const int i = i0 + threadIdx.x; | ||||||
|  |             if (i0 + WARP_SIZE > D && i >= D) { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __syncthreads(); | ||||||
|  |  | ||||||
|  |     // Load Q into tensor core fragments/registers since it will be used frequently: | ||||||
|  | #pragma unroll | ||||||
|  |     for (int i0 = 0; i0 < D; i0 += 16) { | ||||||
|  | #pragma unroll | ||||||
|  |         for (int j0 = 0; j0 < ncols; j0 += frag_n) { | ||||||
|  |             nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __syncthreads(); | ||||||
|  |  | ||||||
|  |     // Iterate over ne11 == previous tokens: | ||||||
|  |     for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { | ||||||
|  |         // Calculate tile of KQ: | ||||||
|  | #pragma unroll | ||||||
|  |         for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { | ||||||
|  |             frag_c_KQ KQ_c[ncols/frag_n]; | ||||||
|  | #pragma unroll | ||||||
|  |             for (int j = 0; j < ncols/frag_n; ++j) { | ||||||
|  |                 nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); | ||||||
|  |             } | ||||||
|  | #pragma unroll | ||||||
|  |             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { | ||||||
|  |                 frag_a_K K_a; | ||||||
|  |                 nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int j = 0; j < ncols/frag_n; ++j) { | ||||||
|  |                     nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  | #pragma unroll | ||||||
|  |             for (int j0 = 0; j0 < ncols; j0 += frag_n) { | ||||||
|  |                 nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         __syncthreads(); | ||||||
|  |  | ||||||
|  |         // Calculate softmax for each KQ column using the current max. value. | ||||||
|  |         // The divisor is stored in KQ_rowsum and will be applied at the end. | ||||||
|  | #pragma unroll | ||||||
|  |         for (int j0 = 0; j0 < ncols; j0 += nwarps) { | ||||||
|  |             const int j = j0 + threadIdx.y; | ||||||
|  |  | ||||||
|  |             if (std::is_same<KQ_acc_t, float>::value) { | ||||||
|  |                 float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { | ||||||
|  |                     const int k = k0 + threadIdx.x; | ||||||
|  |  | ||||||
|  |                     KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; | ||||||
|  |  | ||||||
|  |                     if (use_logit_softcap) { | ||||||
|  |                         KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 float KQ_max_new = KQ_max_f[j0/nwarps]; | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { | ||||||
|  |                     const int k = k0 + threadIdx.x; | ||||||
|  |  | ||||||
|  |                     KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; | ||||||
|  |                     KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); | ||||||
|  |                 } | ||||||
|  |                 KQ_max_new = warp_reduce_max(KQ_max_new); | ||||||
|  |  | ||||||
|  |                 const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; | ||||||
|  |                 KQ_max_scale_f[j0/nwarps] = expf(diff); | ||||||
|  |                 if (diff <= SOFTMAX_FTZ_THRESHOLD) { | ||||||
|  |                     KQ_max_scale_f[j0/nwarps] = 0.0f; | ||||||
|  |                 } | ||||||
|  |                 KQ_max_f[j0/nwarps] = KQ_max_new; | ||||||
|  |  | ||||||
|  |                 float KQ_rowsum_add = 0.0f; | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { | ||||||
|  |                     const int k = k0 + threadIdx.x; | ||||||
|  |  | ||||||
|  |                     const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; | ||||||
|  |                     KQ_f_tmp[k0/WARP_SIZE] = expf(diff); | ||||||
|  |                     if (diff <= SOFTMAX_FTZ_THRESHOLD) { | ||||||
|  |                         KQ_f_tmp[k0/WARP_SIZE] = 0.0f; | ||||||
|  |                     } | ||||||
|  |                     KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; | ||||||
|  |                     KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; | ||||||
|  |                 } | ||||||
|  |                 KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); | ||||||
|  |  | ||||||
|  |                 // Scale previous KQ_rowsum to account for a potential increase in KQ_max: | ||||||
|  |                 KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; | ||||||
|  |             } else { | ||||||
|  |                 half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { | ||||||
|  |                     const int k = k0 + threadIdx.x; | ||||||
|  |  | ||||||
|  |                     KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; | ||||||
|  |  | ||||||
|  |                     if (use_logit_softcap) { | ||||||
|  |                         // There is no dedicated tangens hyperbolicus function for half2. | ||||||
|  |                         KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); | ||||||
|  |                         KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) | ||||||
|  |                                                /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); | ||||||
|  |  | ||||||
|  |                         KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 half2 KQ_max_new = KQ_max_h2[j0/nwarps]; | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { | ||||||
|  |                     const int k = k0 + threadIdx.x; | ||||||
|  |  | ||||||
|  |                     KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); | ||||||
|  |                     KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); | ||||||
|  |                 } | ||||||
|  |                 KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); | ||||||
|  |                 const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; | ||||||
|  |                 KQ_max_scale_h2[j0/nwarps] = h2exp(diff); | ||||||
|  |                 const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); | ||||||
|  |                 *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; | ||||||
|  |                 KQ_max_h2[j0/nwarps] = KQ_max_new; | ||||||
|  |  | ||||||
|  |                 half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { | ||||||
|  |                     const int k = k0 + threadIdx.x; | ||||||
|  |  | ||||||
|  |                     const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; | ||||||
|  |                     KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); | ||||||
|  |                     const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); | ||||||
|  |                     *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; | ||||||
|  |                     KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; | ||||||
|  |                     KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; | ||||||
|  |                 } | ||||||
|  |                 KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); | ||||||
|  |  | ||||||
|  |                 // Scale previous KQ_rowsum to account for a potential increase in KQ_max: | ||||||
|  |                 KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         __syncthreads(); | ||||||
|  |  | ||||||
|  |         frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; | ||||||
|  | #pragma unroll | ||||||
|  |         for (int j0 = 0; j0 < ncols; j0 += frag_n) { | ||||||
|  | #pragma unroll | ||||||
|  |             for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { | ||||||
|  |                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16; | ||||||
|  |                 nvcuda::wmma::load_matrix_sync( | ||||||
|  |                     KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], | ||||||
|  |                     KQ + j0*(kqar*kqs_padded) + k, | ||||||
|  |                     kqar*kqs_padded); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; | ||||||
|  | #pragma unroll | ||||||
|  |         for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { | ||||||
|  | #pragma unroll | ||||||
|  |             for (int j = 0; j < ncols/frag_n; ++j) { | ||||||
|  |                 nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |             for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { | ||||||
|  |                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16; | ||||||
|  |  | ||||||
|  |                 frag_a_V v_a; | ||||||
|  |                 nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int j = 0; j < ncols/frag_n; ++j) { | ||||||
|  |                     nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         __syncthreads(); | ||||||
|  |  | ||||||
|  |         const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); | ||||||
|  | #pragma unroll | ||||||
|  |         for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { | ||||||
|  | #pragma unroll | ||||||
|  |             for (int j0 = 0; j0 < ncols; j0 += frag_n) { | ||||||
|  |                 nvcuda::wmma::store_matrix_sync( | ||||||
|  |                     KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), | ||||||
|  |                     VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], | ||||||
|  |                     D_padded, nvcuda::wmma::mem_col_major); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         __syncthreads(); | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int j0 = 0; j0 < ncols; j0 += nwarps) { | ||||||
|  |             const int j = j0 + threadIdx.y; | ||||||
|  |  | ||||||
|  |             half2 VKQ_scale; | ||||||
|  |             if (std::is_same<KQ_acc_t, float>::value) { | ||||||
|  |                 VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); | ||||||
|  |             } else { | ||||||
|  |                 VKQ_scale = KQ_max_scale_h2[j0/nwarps]; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { | ||||||
|  |                 const int i = i0 + threadIdx.x; | ||||||
|  |                 if (i0 + WARP_SIZE > D/2 && i >= D/2) { | ||||||
|  |                     break; | ||||||
|  |                 } | ||||||
|  |  | ||||||
|  |                 half2 VKQ_add = make_half2(0.0f, 0.0f); | ||||||
|  | #pragma unroll | ||||||
|  |                 for (int l = 0; l < VKQ_ratio; ++l) { | ||||||
|  |                     VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; | ||||||
|  |                 } | ||||||
|  |                 VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         __syncthreads(); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |     for (int j0 = 0; j0 < ncols; j0 += nwarps) { | ||||||
|  |         const int j_VKQ = j0 + threadIdx.y; | ||||||
|  |         if (ic0 + j_VKQ >= ne01) { | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |         const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; | ||||||
|  |  | ||||||
|  |         float KQ_rowsum_j; | ||||||
|  |         if (std::is_same<KQ_acc_t, float>::value) { | ||||||
|  |             KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; | ||||||
|  |         } else { | ||||||
|  |             KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |         for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { | ||||||
|  |             const int i = i0 + threadIdx.x; | ||||||
|  |             if (i0 + WARP_SIZE > D && i >= D) { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             float dst_val = VKQ[j_VKQ*D_padded + i]; | ||||||
|  |             if (parallel_blocks == 1) { | ||||||
|  |                 dst_val /= KQ_rowsum_j; | ||||||
|  |             } | ||||||
|  |             dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         if (parallel_blocks == 1 || threadIdx.x != 0) { | ||||||
|  |             continue; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         float2 dst_meta_val; | ||||||
|  |         if (std::is_same<KQ_acc_t, float>::value) { | ||||||
|  |             dst_meta_val.x = KQ_max_f[j0/nwarps]; | ||||||
|  |         } else { | ||||||
|  |             dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); | ||||||
|  |         } | ||||||
|  |         dst_meta_val.y = KQ_rowsum_j; | ||||||
|  |         dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; | ||||||
|  |     } | ||||||
|  | #else | ||||||
|  |    NO_DEVICE_CODE; | ||||||
|  | #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||||||
|  | } | ||||||
|  |  | ||||||
|  | constexpr int get_max_power_of_2(int x) { | ||||||
|  |     return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static_assert(get_max_power_of_2(1) == 1, "Test failed."); | ||||||
|  | static_assert(get_max_power_of_2(2) == 2, "Test failed."); | ||||||
|  | static_assert(get_max_power_of_2(4) == 4, "Test failed."); | ||||||
|  | static_assert(get_max_power_of_2(6) == 2, "Test failed."); | ||||||
|  |  | ||||||
|  | // Number of VKQ rows calculated in parallel: | ||||||
|  | constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { | ||||||
|  |     return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed."); | ||||||
|  | static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed."); | ||||||
|  | static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); | ||||||
|  | static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed."); | ||||||
|  | static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed."); | ||||||
|  | static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed."); | ||||||
|  | static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed."); | ||||||
|  | static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed."); | ||||||
|  | static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed."); | ||||||
|  |  | ||||||
|  | template <int D, int cols_per_block, typename KQ_acc_t> | ||||||
|  | void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||||||
|  |     const ggml_tensor * KQV = dst; | ||||||
|  |     const ggml_tensor * Q   = dst->src[0]; | ||||||
|  |  | ||||||
|  |     constexpr int nwarps = 4; | ||||||
|  |  | ||||||
|  |     constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; | ||||||
|  |     const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; | ||||||
|  |     const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; | ||||||
|  |  | ||||||
|  |     float logit_softcap; | ||||||
|  |     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); | ||||||
|  |  | ||||||
|  |     if (4*blocks_num_pb1 < 2*nsm) { | ||||||
|  |         constexpr int parallel_blocks = 4; | ||||||
|  |         fattn_kernel_t fattn_kernel; | ||||||
|  |         if (logit_softcap == 0.0f) { | ||||||
|  |             constexpr bool use_logit_softcap = false; | ||||||
|  |             fattn_kernel = flash_attn_ext_f16< | ||||||
|  |                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; | ||||||
|  |         } else { | ||||||
|  |             constexpr bool use_logit_softcap = true; | ||||||
|  |             fattn_kernel = flash_attn_ext_f16< | ||||||
|  |                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; | ||||||
|  |         } | ||||||
|  |         launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true); | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |     if (2*blocks_num_pb1 < 2*nsm) { | ||||||
|  |         constexpr int parallel_blocks = 2; | ||||||
|  |         fattn_kernel_t fattn_kernel; | ||||||
|  |         if (logit_softcap == 0.0f) { | ||||||
|  |             constexpr bool use_logit_softcap = false; | ||||||
|  |             fattn_kernel = flash_attn_ext_f16< | ||||||
|  |                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; | ||||||
|  |         } else { | ||||||
|  |             constexpr bool use_logit_softcap = true; | ||||||
|  |             fattn_kernel = flash_attn_ext_f16< | ||||||
|  |                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; | ||||||
|  |         } | ||||||
|  |         launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true); | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |     constexpr int parallel_blocks = 1; | ||||||
|  |     fattn_kernel_t fattn_kernel; | ||||||
|  |     if (logit_softcap == 0.0f) { | ||||||
|  |         constexpr bool use_logit_softcap = false; | ||||||
|  |         fattn_kernel = flash_attn_ext_f16< | ||||||
|  |             D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; | ||||||
|  |     } else { | ||||||
|  |         constexpr bool use_logit_softcap = true; | ||||||
|  |         fattn_kernel = flash_attn_ext_f16< | ||||||
|  |             D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; | ||||||
|  |     } | ||||||
|  |     launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||||||
|  |     const ggml_tensor * KQV = dst; | ||||||
|  |     const ggml_tensor * Q   = dst->src[0]; | ||||||
|  |  | ||||||
|  |     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); | ||||||
|  |  | ||||||
|  |     if (prec != GGML_PREC_DEFAULT) { | ||||||
|  |         if (Q->ne[1] <= 32 || Q->ne[0] > 128) { | ||||||
|  |             constexpr int cols_per_block = 16; | ||||||
|  |             switch (Q->ne[0]) { | ||||||
|  |                 case 64: | ||||||
|  |                     ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); | ||||||
|  |                     break; | ||||||
|  |                 case 80: | ||||||
|  |                     ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); | ||||||
|  |                     break; | ||||||
|  |                 case 96: | ||||||
|  |                     ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); | ||||||
|  |                     break; | ||||||
|  |                 case 112: | ||||||
|  |                     ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); | ||||||
|  |                     break; | ||||||
|  |                 case 128: | ||||||
|  |                     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); | ||||||
|  |                     break; | ||||||
|  |                 case 256: | ||||||
|  |                     ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); | ||||||
|  |                     break; | ||||||
|  |                 default: | ||||||
|  |                     GGML_ABORT("fatal error"); | ||||||
|  |                     break; | ||||||
|  |             } | ||||||
|  |         } else { | ||||||
|  |             constexpr int cols_per_block = 32; | ||||||
|  |             switch (Q->ne[0]) { | ||||||
|  |                 case 64: | ||||||
|  |                     ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); | ||||||
|  |                     break; | ||||||
|  |                 case 80: | ||||||
|  |                     ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); | ||||||
|  |                     break; | ||||||
|  |                 case 96: | ||||||
|  |                     ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); | ||||||
|  |                     break; | ||||||
|  |                 case 112: | ||||||
|  |                     ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); | ||||||
|  |                     break; | ||||||
|  |                 case 128: | ||||||
|  |                     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); | ||||||
|  |                     break; | ||||||
|  |                 // case 256: | ||||||
|  |                 //     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); | ||||||
|  |                 //     break; | ||||||
|  |                 default: | ||||||
|  |                     GGML_ABORT("fatal error"); | ||||||
|  |                     break; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { | ||||||
|  |         constexpr int cols_per_block = 8; | ||||||
|  |         switch (Q->ne[0]) { | ||||||
|  |             case 64: | ||||||
|  |                 ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); | ||||||
|  |                 break; | ||||||
|  |             case 96: | ||||||
|  |                 ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); | ||||||
|  |                 break; | ||||||
|  |             case 128: | ||||||
|  |                 ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); | ||||||
|  |                 break; | ||||||
|  |             case 256: | ||||||
|  |                 ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); | ||||||
|  |                 break; | ||||||
|  |             default: | ||||||
|  |                 GGML_ABORT("fatal error"); | ||||||
|  |                 break; | ||||||
|  |         } | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (Q->ne[1] <= 32) { | ||||||
|  |         constexpr int cols_per_block = 16; | ||||||
|  |         switch (Q->ne[0]) { | ||||||
|  |             case 64: | ||||||
|  |                 ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); | ||||||
|  |                 break; | ||||||
|  |             case 80: | ||||||
|  |                 ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); | ||||||
|  |                 break; | ||||||
|  |             case 96: | ||||||
|  |                 ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); | ||||||
|  |                 break; | ||||||
|  |             case 112: | ||||||
|  |                 ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); | ||||||
|  |                 break; | ||||||
|  |             case 128: | ||||||
|  |                 ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); | ||||||
|  |                 break; | ||||||
|  |             case 256: | ||||||
|  |                 ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); | ||||||
|  |                 break; | ||||||
|  |             default: | ||||||
|  |                 GGML_ABORT("fatal error"); | ||||||
|  |                 break; | ||||||
|  |         } | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     constexpr int cols_per_block = 32; | ||||||
|  |     switch (Q->ne[0]) { | ||||||
|  |         case 64: | ||||||
|  |             ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); | ||||||
|  |             break; | ||||||
|  |         case 80: | ||||||
|  |             ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); | ||||||
|  |             break; | ||||||
|  |         case 96: | ||||||
|  |             ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); | ||||||
|  |             break; | ||||||
|  |         case 112: | ||||||
|  |             ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); | ||||||
|  |             break; | ||||||
|  |         case 128: | ||||||
|  |             ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); | ||||||
|  |             break; | ||||||
|  |         case 256: | ||||||
|  |             ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); | ||||||
|  |             break; | ||||||
|  |         default: | ||||||
|  |             GGML_ABORT("fatal error"); | ||||||
|  |             break; | ||||||
|  |     } | ||||||
|  | } | ||||||
| @@ -1,543 +1,3 @@ | |||||||
| #include "common.cuh" | #include "common.cuh" | ||||||
| #include "fattn-common.cuh" |  | ||||||
|  |  | ||||||
| #ifdef FP16_MMA_AVAILABLE | void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); | ||||||
| #include <mma.h> |  | ||||||
| #endif // FP16_MMA_AVAILABLE |  | ||||||
|  |  | ||||||
| // D == head size, VKQ_stride == num VKQ rows calculated in parallel: |  | ||||||
| template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap> |  | ||||||
| #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) |  | ||||||
| __launch_bounds__(nwarps*WARP_SIZE, 1) |  | ||||||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) |  | ||||||
| static __global__ void flash_attn_ext_f16( |  | ||||||
|         const char * __restrict__ Q, |  | ||||||
|         const char * __restrict__ K, |  | ||||||
|         const char * __restrict__ V, |  | ||||||
|         const char * __restrict__ mask, |  | ||||||
|         float      * __restrict__ dst, |  | ||||||
|         float2     * __restrict__ dst_meta, |  | ||||||
|         const float scale, |  | ||||||
|         const float max_bias, |  | ||||||
|         const float m0, |  | ||||||
|         const float m1, |  | ||||||
|         const uint32_t n_head_log2, |  | ||||||
|         const float logit_softcap, |  | ||||||
|         const int ne00, |  | ||||||
|         const int ne01, |  | ||||||
|         const int ne02, |  | ||||||
|         const int ne03, |  | ||||||
|         const int ne10, |  | ||||||
|         const int ne11, |  | ||||||
|         const int ne12, |  | ||||||
|         const int ne13, |  | ||||||
|         const int ne31, |  | ||||||
|         const int nb31, |  | ||||||
|         const int nb01, |  | ||||||
|         const int nb02, |  | ||||||
|         const int nb03, |  | ||||||
|         const int nb11, |  | ||||||
|         const int nb12, |  | ||||||
|         const int nb13, |  | ||||||
|         const int nb21, |  | ||||||
|         const int nb22, |  | ||||||
|         const int nb23, |  | ||||||
|         const int ne0, |  | ||||||
|         const int ne1, |  | ||||||
|         const int ne2, |  | ||||||
|         const int ne3) { |  | ||||||
| #ifdef FP16_MMA_AVAILABLE |  | ||||||
|     // Skip unused kernel variants for faster compilation: |  | ||||||
|     if (use_logit_softcap && !(D == 128 || D == 256)) { |  | ||||||
|         NO_DEVICE_CODE; |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     //In this kernel Q, K, V are matrices while i, j, k are matrix indices. |  | ||||||
|  |  | ||||||
|     const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on. |  | ||||||
|     const int ip  =        blockIdx.x % parallel_blocks;  // Index in group of blocks running for the same column in parallel. |  | ||||||
|  |  | ||||||
|     static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE."); |  | ||||||
|     static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16."); |  | ||||||
|     constexpr int frag_m = ncols == 8 ? 32 : 16; |  | ||||||
|     constexpr int frag_n = ncols == 8 ?  8 : 16; |  | ||||||
|     static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); |  | ||||||
|     typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K; |  | ||||||
|     typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V; |  | ||||||
|     typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b,    frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b; |  | ||||||
|     typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t>                      frag_c_KQ; |  | ||||||
|     typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half>                          frag_c_VKQ; |  | ||||||
|  |  | ||||||
|     constexpr int KQ_stride_tc  = nwarps*frag_m; // Number of KQ rows calculated in parallel. |  | ||||||
|     constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. |  | ||||||
|     static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps."); |  | ||||||
|  |  | ||||||
|     // Pad internal representation of KQ, KQV to reduce shared memory bank conflicts: |  | ||||||
|     constexpr int D_padded = D + 8; |  | ||||||
|     constexpr int kqs_padded = FATTN_KQ_STRIDE + 8; |  | ||||||
|     constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half); |  | ||||||
|  |  | ||||||
|     const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. |  | ||||||
|     const float * Q_f   = (const float *) (Q + nb02* blockIdx.y              + nb01*ic0); |  | ||||||
|     const half  * K_h   = (const half  *) (K + nb12*(blockIdx.y / gqa_ratio)); |  | ||||||
|     const half  * V_h   = (const half  *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape |  | ||||||
|     const half  * maskh = (const half  *)  mask + (nb31/sizeof(half))* ic0; |  | ||||||
|     const half2 * mask2 = (const half2 *)  mask + (nb31/sizeof(half))*(ic0/2); |  | ||||||
|  |  | ||||||
|     const int stride_Q  = nb01 / sizeof(float); |  | ||||||
|     const int stride_KV = nb11 / sizeof(half); |  | ||||||
|  |  | ||||||
|     const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1); |  | ||||||
|     const half  slopeh = __float2half(slopef); |  | ||||||
|     const half2 slope2 = make_half2(slopef, slopef); |  | ||||||
|  |  | ||||||
|     const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap); |  | ||||||
|  |  | ||||||
|     frag_b Q_b[D/16][ncols/frag_n]; |  | ||||||
|  |  | ||||||
|     // A single buffer for temporarily holding tiles of KQ and VKQ parts: |  | ||||||
|     constexpr int mem_KQ = ncols*kqs_padded*kqar; |  | ||||||
|     constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded; |  | ||||||
|     __shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts]; |  | ||||||
|     float * KQ_f = (float *) KQ; |  | ||||||
|     half2 * KQ2 = (half2 *) KQ; |  | ||||||
|  |  | ||||||
|     float    KQ_rowsum_f[ncols/nwarps] = {0.0f}; |  | ||||||
|     float       KQ_max_f[ncols/nwarps]; |  | ||||||
|     float KQ_max_scale_f[ncols/nwarps] = {0.0f}; |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|     for (int j = 0; j < ncols/nwarps; ++j) { |  | ||||||
|         KQ_max_f[j] = -FLT_MAX/2.0f; |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     half2    KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}}; |  | ||||||
|     half2       KQ_max_h2[ncols/nwarps]; |  | ||||||
|     half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}}; |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|     for (int j = 0; j < ncols/nwarps; ++j) { |  | ||||||
|         KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. |  | ||||||
|     half2 * VKQ2 = (half2 *) VKQ; |  | ||||||
| #pragma unroll |  | ||||||
|     for (int j0 = 0; j0 < ncols; j0 += nwarps) { |  | ||||||
|         const int j = j0 + threadIdx.y; |  | ||||||
| #pragma unroll |  | ||||||
|         for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { |  | ||||||
|             const int i = i0 + threadIdx.x; |  | ||||||
|             if (i0 + WARP_SIZE > D/2 && i >= D/2) { |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|             VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     // Convert Q to half and apply scale, temporarily store in KQ: |  | ||||||
| #pragma unroll |  | ||||||
|     for (int j0 = 0; j0 < ncols; j0 += nwarps) { |  | ||||||
|         const int j = j0 + threadIdx.y; |  | ||||||
| #pragma unroll |  | ||||||
|         for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { |  | ||||||
|             const int i = i0 + threadIdx.x; |  | ||||||
|             if (i0 + WARP_SIZE > D && i >= D) { |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|             KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     __syncthreads(); |  | ||||||
|  |  | ||||||
|     // Load Q into tensor core fragments/registers since it will be used frequently: |  | ||||||
| #pragma unroll |  | ||||||
|     for (int i0 = 0; i0 < D; i0 += 16) { |  | ||||||
| #pragma unroll |  | ||||||
|         for (int j0 = 0; j0 < ncols; j0 += frag_n) { |  | ||||||
|             nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     __syncthreads(); |  | ||||||
|  |  | ||||||
|     // Iterate over ne11 == previous tokens: |  | ||||||
|     for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) { |  | ||||||
|         // Calculate tile of KQ: |  | ||||||
| #pragma unroll |  | ||||||
|         for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) { |  | ||||||
|             frag_c_KQ KQ_c[ncols/frag_n]; |  | ||||||
| #pragma unroll |  | ||||||
|             for (int j = 0; j < ncols/frag_n; ++j) { |  | ||||||
|                 nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f); |  | ||||||
|             } |  | ||||||
| #pragma unroll |  | ||||||
|             for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { |  | ||||||
|                 frag_a_K K_a; |  | ||||||
|                 nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int j = 0; j < ncols/frag_n; ++j) { |  | ||||||
|                     nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| #pragma unroll |  | ||||||
|             for (int j0 = 0; j0 < ncols; j0 += frag_n) { |  | ||||||
|                 nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         __syncthreads(); |  | ||||||
|  |  | ||||||
|         // Calculate softmax for each KQ column using the current max. value. |  | ||||||
|         // The divisor is stored in KQ_rowsum and will be applied at the end. |  | ||||||
| #pragma unroll |  | ||||||
|         for (int j0 = 0; j0 < ncols; j0 += nwarps) { |  | ||||||
|             const int j = j0 + threadIdx.y; |  | ||||||
|  |  | ||||||
|             if (std::is_same<KQ_acc_t, float>::value) { |  | ||||||
|                 float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE]; |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { |  | ||||||
|                     const int k = k0 + threadIdx.x; |  | ||||||
|  |  | ||||||
|                     KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k]; |  | ||||||
|  |  | ||||||
|                     if (use_logit_softcap) { |  | ||||||
|                         KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 float KQ_max_new = KQ_max_f[j0/nwarps]; |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { |  | ||||||
|                     const int k = k0 + threadIdx.x; |  | ||||||
|  |  | ||||||
|                     KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; |  | ||||||
|                     KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); |  | ||||||
|                 } |  | ||||||
|                 KQ_max_new = warp_reduce_max(KQ_max_new); |  | ||||||
|  |  | ||||||
|                 const float diff = KQ_max_f[j0/nwarps] - KQ_max_new; |  | ||||||
|                 KQ_max_scale_f[j0/nwarps] = expf(diff); |  | ||||||
|                 if (diff <= SOFTMAX_FTZ_THRESHOLD) { |  | ||||||
|                     KQ_max_scale_f[j0/nwarps] = 0.0f; |  | ||||||
|                 } |  | ||||||
|                 KQ_max_f[j0/nwarps] = KQ_max_new; |  | ||||||
|  |  | ||||||
|                 float KQ_rowsum_add = 0.0f; |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { |  | ||||||
|                     const int k = k0 + threadIdx.x; |  | ||||||
|  |  | ||||||
|                     const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps]; |  | ||||||
|                     KQ_f_tmp[k0/WARP_SIZE] = expf(diff); |  | ||||||
|                     if (diff <= SOFTMAX_FTZ_THRESHOLD) { |  | ||||||
|                         KQ_f_tmp[k0/WARP_SIZE] = 0.0f; |  | ||||||
|                     } |  | ||||||
|                     KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE]; |  | ||||||
|                     KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE]; |  | ||||||
|                 } |  | ||||||
|                 KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); |  | ||||||
|  |  | ||||||
|                 // Scale previous KQ_rowsum to account for a potential increase in KQ_max: |  | ||||||
|                 KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add; |  | ||||||
|             } else { |  | ||||||
|                 half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)]; |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { |  | ||||||
|                     const int k = k0 + threadIdx.x; |  | ||||||
|  |  | ||||||
|                     KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k]; |  | ||||||
|  |  | ||||||
|                     if (use_logit_softcap) { |  | ||||||
|                         // There is no dedicated tangens hyperbolicus function for half2. |  | ||||||
|                         KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f)); |  | ||||||
|                         KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f)) |  | ||||||
|                                                /(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f)); |  | ||||||
|  |  | ||||||
|                         KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2; |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 half2 KQ_max_new = KQ_max_h2[j0/nwarps]; |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { |  | ||||||
|                     const int k = k0 + threadIdx.x; |  | ||||||
|  |  | ||||||
|                     KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); |  | ||||||
|                     KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); |  | ||||||
|                 } |  | ||||||
|                 KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); |  | ||||||
|                 const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new; |  | ||||||
|                 KQ_max_scale_h2[j0/nwarps] = h2exp(diff); |  | ||||||
|                 const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); |  | ||||||
|                 *((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask; |  | ||||||
|                 KQ_max_h2[j0/nwarps] = KQ_max_new; |  | ||||||
|  |  | ||||||
|                 half2 KQ_rowsum_add = make_half2(0.0f, 0.0f); |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { |  | ||||||
|                     const int k = k0 + threadIdx.x; |  | ||||||
|  |  | ||||||
|                     const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps]; |  | ||||||
|                     KQ2_tmp[k0/WARP_SIZE] = h2exp(diff); |  | ||||||
|                     const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD)); |  | ||||||
|                     *((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask; |  | ||||||
|                     KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE]; |  | ||||||
|                     KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE]; |  | ||||||
|                 } |  | ||||||
|                 KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add); |  | ||||||
|  |  | ||||||
|                 // Scale previous KQ_rowsum to account for a potential increase in KQ_max: |  | ||||||
|                 KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         __syncthreads(); |  | ||||||
|  |  | ||||||
|         frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n]; |  | ||||||
| #pragma unroll |  | ||||||
|         for (int j0 = 0; j0 < ncols; j0 += frag_n) { |  | ||||||
| #pragma unroll |  | ||||||
|             for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { |  | ||||||
|                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16; |  | ||||||
|                 nvcuda::wmma::load_matrix_sync( |  | ||||||
|                     KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], |  | ||||||
|                     KQ + j0*(kqar*kqs_padded) + k, |  | ||||||
|                     kqar*kqs_padded); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n]; |  | ||||||
| #pragma unroll |  | ||||||
|         for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) { |  | ||||||
| #pragma unroll |  | ||||||
|             for (int j = 0; j < ncols/frag_n; ++j) { |  | ||||||
|                 nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f); |  | ||||||
|             } |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|             for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) { |  | ||||||
|                 const int k = k0 + (threadIdx.y % VKQ_ratio)*16; |  | ||||||
|  |  | ||||||
|                 frag_a_V v_a; |  | ||||||
|                 nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int j = 0; j < ncols/frag_n; ++j) { |  | ||||||
|                     nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         __syncthreads(); |  | ||||||
|  |  | ||||||
|         const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded); |  | ||||||
| #pragma unroll |  | ||||||
|         for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) { |  | ||||||
| #pragma unroll |  | ||||||
|             for (int j0 = 0; j0 < ncols; j0 += frag_n) { |  | ||||||
|                 nvcuda::wmma::store_matrix_sync( |  | ||||||
|                     KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), |  | ||||||
|                     VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], |  | ||||||
|                     D_padded, nvcuda::wmma::mem_col_major); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         __syncthreads(); |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|         for (int j0 = 0; j0 < ncols; j0 += nwarps) { |  | ||||||
|             const int j = j0 + threadIdx.y; |  | ||||||
|  |  | ||||||
|             half2 VKQ_scale; |  | ||||||
|             if (std::is_same<KQ_acc_t, float>::value) { |  | ||||||
|                 VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]); |  | ||||||
|             } else { |  | ||||||
|                 VKQ_scale = KQ_max_scale_h2[j0/nwarps]; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|             for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { |  | ||||||
|                 const int i = i0 + threadIdx.x; |  | ||||||
|                 if (i0 + WARP_SIZE > D/2 && i >= D/2) { |  | ||||||
|                     break; |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 half2 VKQ_add = make_half2(0.0f, 0.0f); |  | ||||||
| #pragma unroll |  | ||||||
|                 for (int l = 0; l < VKQ_ratio; ++l) { |  | ||||||
|                     VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i]; |  | ||||||
|                 } |  | ||||||
|                 VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         __syncthreads(); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|     for (int j0 = 0; j0 < ncols; j0 += nwarps) { |  | ||||||
|         const int j_VKQ = j0 + threadIdx.y; |  | ||||||
|         if (ic0 + j_VKQ >= ne01) { |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
|         const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip; |  | ||||||
|  |  | ||||||
|         float KQ_rowsum_j; |  | ||||||
|         if (std::is_same<KQ_acc_t, float>::value) { |  | ||||||
|             KQ_rowsum_j = KQ_rowsum_f[j0/nwarps]; |  | ||||||
|         } else { |  | ||||||
|             KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
| #pragma unroll |  | ||||||
|         for (int i0 = 0; i0 < D; i0 += WARP_SIZE) { |  | ||||||
|             const int i = i0 + threadIdx.x; |  | ||||||
|             if (i0 + WARP_SIZE > D && i >= D) { |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|             float dst_val = VKQ[j_VKQ*D_padded + i]; |  | ||||||
|             if (parallel_blocks == 1) { |  | ||||||
|                 dst_val /= KQ_rowsum_j; |  | ||||||
|             } |  | ||||||
|             dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         if (parallel_blocks == 1 || threadIdx.x != 0) { |  | ||||||
|             continue; |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         float2 dst_meta_val; |  | ||||||
|         if (std::is_same<KQ_acc_t, float>::value) { |  | ||||||
|             dst_meta_val.x = KQ_max_f[j0/nwarps]; |  | ||||||
|         } else { |  | ||||||
|             dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]); |  | ||||||
|         } |  | ||||||
|         dst_meta_val.y = KQ_rowsum_j; |  | ||||||
|         dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val; |  | ||||||
|     } |  | ||||||
| #else |  | ||||||
|    NO_DEVICE_CODE; |  | ||||||
| #endif // FP16_MMA_AVAILABLE |  | ||||||
| } |  | ||||||
|  |  | ||||||
| constexpr int get_max_power_of_2(int x) { |  | ||||||
|     return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static_assert(get_max_power_of_2(1) == 1, "Test failed."); |  | ||||||
| static_assert(get_max_power_of_2(2) == 2, "Test failed."); |  | ||||||
| static_assert(get_max_power_of_2(4) == 4, "Test failed."); |  | ||||||
| static_assert(get_max_power_of_2(6) == 2, "Test failed."); |  | ||||||
|  |  | ||||||
| // Number of VKQ rows calculated in parallel: |  | ||||||
| constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) { |  | ||||||
|     return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m; |  | ||||||
| } |  | ||||||
|  |  | ||||||
| static_assert(get_VKQ_stride(128, 1, 32) ==  32, "Test failed."); |  | ||||||
| static_assert(get_VKQ_stride(128, 2, 32) ==  64, "Test failed."); |  | ||||||
| static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed."); |  | ||||||
| static_assert(get_VKQ_stride( 64, 1, 32) ==  32, "Test failed."); |  | ||||||
| static_assert(get_VKQ_stride( 64, 2, 32) ==  64, "Test failed."); |  | ||||||
| static_assert(get_VKQ_stride( 64, 4, 32) ==  64, "Test failed."); |  | ||||||
| static_assert(get_VKQ_stride( 80, 1, 16) ==  16, "Test failed."); |  | ||||||
| static_assert(get_VKQ_stride( 80, 2, 16) ==  16, "Test failed."); |  | ||||||
| static_assert(get_VKQ_stride( 80, 4, 16) ==  16, "Test failed."); |  | ||||||
|  |  | ||||||
| template <int D, int cols_per_block, typename KQ_acc_t> |  | ||||||
| void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |  | ||||||
|     const ggml_tensor * KQV = dst; |  | ||||||
|     const ggml_tensor * Q   = dst->src[0]; |  | ||||||
|  |  | ||||||
|     constexpr int nwarps = 4; |  | ||||||
|  |  | ||||||
|     constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; |  | ||||||
|     const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; |  | ||||||
|     const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; |  | ||||||
|  |  | ||||||
|     float logit_softcap; |  | ||||||
|     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); |  | ||||||
|  |  | ||||||
|     if (4*blocks_num_pb1 < 2*nsm) { |  | ||||||
|         constexpr int parallel_blocks = 4; |  | ||||||
|         fattn_kernel_t fattn_kernel; |  | ||||||
|         if (logit_softcap == 0.0f) { |  | ||||||
|             constexpr bool use_logit_softcap = false; |  | ||||||
|             fattn_kernel = flash_attn_ext_f16< |  | ||||||
|                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; |  | ||||||
|         } else { |  | ||||||
|             constexpr bool use_logit_softcap = true; |  | ||||||
|             fattn_kernel = flash_attn_ext_f16< |  | ||||||
|                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; |  | ||||||
|         } |  | ||||||
|         launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
|     if (2*blocks_num_pb1 < 2*nsm) { |  | ||||||
|         constexpr int parallel_blocks = 2; |  | ||||||
|         fattn_kernel_t fattn_kernel; |  | ||||||
|         if (logit_softcap == 0.0f) { |  | ||||||
|             constexpr bool use_logit_softcap = false; |  | ||||||
|             fattn_kernel = flash_attn_ext_f16< |  | ||||||
|                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; |  | ||||||
|         } else { |  | ||||||
|             constexpr bool use_logit_softcap = true; |  | ||||||
|             fattn_kernel = flash_attn_ext_f16< |  | ||||||
|                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; |  | ||||||
|         } |  | ||||||
|         launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); |  | ||||||
|         return; |  | ||||||
|     } |  | ||||||
|     constexpr int parallel_blocks = 1; |  | ||||||
|     fattn_kernel_t fattn_kernel; |  | ||||||
|     if (logit_softcap == 0.0f) { |  | ||||||
|         constexpr bool use_logit_softcap = false; |  | ||||||
|         fattn_kernel = flash_attn_ext_f16< |  | ||||||
|             D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; |  | ||||||
|     } else { |  | ||||||
|         constexpr bool use_logit_softcap = true; |  | ||||||
|         fattn_kernel = flash_attn_ext_f16< |  | ||||||
|             D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; |  | ||||||
|     } |  | ||||||
|     launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t)                         \ |  | ||||||
|     template void ggml_cuda_flash_attn_ext_wmma_f16_case                              \ |  | ||||||
|     <D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ |  | ||||||
|  |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE(112, 16, float); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE(128, 16, float); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE(256, 16, float); |  | ||||||
|  |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE(112, 32, float); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE(128, 32, float); |  | ||||||
| // extern DECL_FATTN_WMMA_F16_CASE(256, 16, float); |  | ||||||
|  |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 64,  8, half); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 96,  8, half); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE(128,  8, half); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE(256,  8, half); |  | ||||||
|  |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE(112, 16, half); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE(128, 16, half); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); |  | ||||||
|  |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE(112, 32, half); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE(128, 32, half); |  | ||||||
| extern DECL_FATTN_WMMA_F16_CASE(256, 16, half); |  | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| #include "common.cuh" | #include "common.cuh" | ||||||
| #include "fattn-common.cuh" | #include "fattn-common.cuh" | ||||||
|  | #include "fattn-mma-f16.cuh" | ||||||
| #include "fattn-tile-f16.cuh" | #include "fattn-tile-f16.cuh" | ||||||
| #include "fattn-tile-f32.cuh" | #include "fattn-tile-f32.cuh" | ||||||
| #include "fattn-vec-f16.cuh" | #include "fattn-vec-f16.cuh" | ||||||
| @@ -7,144 +8,56 @@ | |||||||
| #include "fattn-wmma-f16.cuh" | #include "fattn-wmma-f16.cuh" | ||||||
| #include "fattn.cuh" | #include "fattn.cuh" | ||||||
|  |  | ||||||
| #include <cstdint> | template <int cols_per_block> | ||||||
|  | static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||||||
| static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |  | ||||||
|     const ggml_tensor * KQV = dst; |  | ||||||
|     const ggml_tensor * Q = dst->src[0]; |     const ggml_tensor * Q = dst->src[0]; | ||||||
|  |  | ||||||
|     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); |  | ||||||
|  |  | ||||||
|     if (prec != GGML_PREC_DEFAULT) { |  | ||||||
|         if (Q->ne[1] <= 32 || Q->ne[0] > 128) { |  | ||||||
|             constexpr int cols_per_block = 16; |  | ||||||
|     switch (Q->ne[0]) { |     switch (Q->ne[0]) { | ||||||
|         case 64: |         case 64: | ||||||
|                     ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); |             ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst); | ||||||
|             break; |             break; | ||||||
|         case 80: |         case 80: | ||||||
|                     ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); |             ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst); | ||||||
|             break; |             break; | ||||||
|         case 96: |         case 96: | ||||||
|                     ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); |             ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst); | ||||||
|             break; |             break; | ||||||
|         case 112: |         case 112: | ||||||
|                     ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); |             ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst); | ||||||
|             break; |             break; | ||||||
|         case 128: |         case 128: | ||||||
|                     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); |             ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst); | ||||||
|             break; |             break; | ||||||
|         case 256: |         case 256: | ||||||
|                     ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst); |             ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst); | ||||||
|             break; |             break; | ||||||
|         default: |         default: | ||||||
|             GGML_ABORT("fatal error"); |             GGML_ABORT("fatal error"); | ||||||
|             break; |             break; | ||||||
|     } |     } | ||||||
|         } else { |  | ||||||
|             constexpr int cols_per_block = 32; |  | ||||||
|             switch (Q->ne[0]) { |  | ||||||
|                 case 64: |  | ||||||
|                     ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst); |  | ||||||
|                     break; |  | ||||||
|                 case 80: |  | ||||||
|                     ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst); |  | ||||||
|                     break; |  | ||||||
|                 case 96: |  | ||||||
|                     ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst); |  | ||||||
|                     break; |  | ||||||
|                 case 112: |  | ||||||
|                     ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst); |  | ||||||
|                     break; |  | ||||||
|                 case 128: |  | ||||||
|                     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); |  | ||||||
|                     break; |  | ||||||
|                 // case 256: |  | ||||||
|                 //     ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst); |  | ||||||
|                 //     break; |  | ||||||
|                 default: |  | ||||||
|                     GGML_ABORT("fatal error"); |  | ||||||
|                     break; |  | ||||||
|             } |  | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||||||
|  |     const ggml_tensor * Q = dst->src[0]; | ||||||
|  |  | ||||||
|  |     if (Q->ne[1] <= 8) { | ||||||
|  |         ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) { |     if (Q->ne[1] <= 16) { | ||||||
|         constexpr int cols_per_block = 8; |         ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst); | ||||||
|         switch (Q->ne[0]) { |  | ||||||
|             case 64: |  | ||||||
|                 ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); |  | ||||||
|                 break; |  | ||||||
|             case 96: |  | ||||||
|                 ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); |  | ||||||
|                 break; |  | ||||||
|             case 128: |  | ||||||
|                 ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); |  | ||||||
|                 break; |  | ||||||
|             case 256: |  | ||||||
|                 ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); |  | ||||||
|                 break; |  | ||||||
|             default: |  | ||||||
|                 GGML_ABORT("fatal error"); |  | ||||||
|                 break; |  | ||||||
|         } |  | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (Q->ne[1] <= 32) { |     if (Q->ne[1] <= 32) { | ||||||
|         constexpr int cols_per_block = 16; |         ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst); | ||||||
|         switch (Q->ne[0]) { |  | ||||||
|             case 64: |  | ||||||
|                 ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); |  | ||||||
|                 break; |  | ||||||
|             case 80: |  | ||||||
|                 ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); |  | ||||||
|                 break; |  | ||||||
|             case 96: |  | ||||||
|                 ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); |  | ||||||
|                 break; |  | ||||||
|             case 112: |  | ||||||
|                 ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); |  | ||||||
|                 break; |  | ||||||
|             case 128: |  | ||||||
|                 ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); |  | ||||||
|                 break; |  | ||||||
|             case 256: |  | ||||||
|                 ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); |  | ||||||
|                 break; |  | ||||||
|             default: |  | ||||||
|                 GGML_ABORT("fatal error"); |  | ||||||
|                 break; |  | ||||||
|         } |  | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     constexpr int cols_per_block = 32; |     ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst); | ||||||
|     switch (Q->ne[0]) { |  | ||||||
|         case 64: |  | ||||||
|             ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst); |  | ||||||
|             break; |  | ||||||
|         case 80: |  | ||||||
|             ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst); |  | ||||||
|             break; |  | ||||||
|         case 96: |  | ||||||
|             ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst); |  | ||||||
|             break; |  | ||||||
|         case 112: |  | ||||||
|             ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst); |  | ||||||
|             break; |  | ||||||
|         case 128: |  | ||||||
|             ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst); |  | ||||||
|             break; |  | ||||||
|         case 256: |  | ||||||
|             ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst); |  | ||||||
|             break; |  | ||||||
|         default: |  | ||||||
|             GGML_ABORT("fatal error"); |  | ||||||
|             break; |  | ||||||
|     } |  | ||||||
| } | } | ||||||
|  |  | ||||||
| #define FATTN_VEC_F16_CASE(D, type_K, type_V)                               \ | #define FATTN_VEC_F16_CASE(D, type_K, type_V)                               \ | ||||||
|     if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) {    \ |     if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) {    \ | ||||||
|         ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \ |         ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \ | ||||||
| @@ -322,12 +235,20 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst | |||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (!fp16_mma_available(cc)) { |     if (!new_mma_available(cc)) { | ||||||
|  |         if (prec == GGML_PREC_DEFAULT) { | ||||||
|             if (Q->ne[1] <= 8) { |             if (Q->ne[1] <= 8) { | ||||||
|                 ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); |                 ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); | ||||||
|             } else { |             } else { | ||||||
|                 ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); |                 ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); | ||||||
|             } |             } | ||||||
|  |         } else { | ||||||
|  |             if (Q->ne[1] <= 8) { | ||||||
|  |                 ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); | ||||||
|  |             } else { | ||||||
|  |                 ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -341,5 +262,10 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     // The MMA implementation needs Turing or newer, use the old WMMA code for Volta: | ||||||
|  |     if (cc == GGML_CUDA_CC_VOLTA) { | ||||||
|         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); |         ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); | ||||||
|  | } | ||||||
|   | |||||||
| @@ -1,11 +1,67 @@ | |||||||
|  | // This file contains primitives that expose the tensor core PTX instructions for CUDA code. | ||||||
|  | // The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. | ||||||
|  | // The documentation for the PTX instructions can be found under: | ||||||
|  | //   https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction | ||||||
|  | // | ||||||
|  | // Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C. | ||||||
|  | // A is a row-major matrix with shape I x K. | ||||||
|  | // B is a column-major matrix with shape K x J. | ||||||
|  | // C is a column-major matrix with shape I x J. | ||||||
|  | // Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements. | ||||||
|  | // The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile. | ||||||
|  | // All matrix tiles have ne physical 32 bit elements per warp. | ||||||
|  | // | ||||||
|  | // As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. | ||||||
|  |  | ||||||
| #include "common.cuh" | #include "common.cuh" | ||||||
|  |  | ||||||
| struct mma_int_A_I16K4 { |  | ||||||
|  | #if CUDART_VERSION >= 11800 | ||||||
|  |  | ||||||
|  | static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { | ||||||
|  |     int ret = 0; | ||||||
|  |  | ||||||
|  | #ifdef NEW_MMA_AVAILABLE | ||||||
|  |     asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;" | ||||||
|  |         : "+r"(ret) : "r"(x)); | ||||||
|  | #else | ||||||
|  |     NO_DEVICE_CODE; | ||||||
|  | #endif // defined(NEW_MMA_AVAILABLE) | ||||||
|  |     return ret; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #else | ||||||
|  |  | ||||||
|  | static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) { | ||||||
|  |     // Imagine transposing row-major matrix to column-major matrix. | ||||||
|  |     const int src_i_low  = 2 * (threadIdx.x % 4); | ||||||
|  |     const int src_i_high = src_i_low + 1; | ||||||
|  |     const int src_j      = threadIdx.x / 4; | ||||||
|  |  | ||||||
|  |     const int src_laneid_low  = src_i_low  * 4 + src_j / 2; | ||||||
|  |     const int src_laneid_high = src_i_high * 4 + src_j / 2; | ||||||
|  |  | ||||||
|  |     const int shift_low  = ((src_j + 0) % 2) * 16; | ||||||
|  |     const int shift_high = ((src_j + 1) % 2) * 16; | ||||||
|  |  | ||||||
|  |     const int ret_low  = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low,  WARP_SIZE) >> shift_low)  & 0x0000FFFF; | ||||||
|  |     const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000; | ||||||
|  |  | ||||||
|  |     return ret_low | ret_high; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #endif // CUDART_VERSION >= 11800 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | struct mma_A_I16K4 { | ||||||
|  |     static_assert(sizeof(T) == 4, "bad type size"); | ||||||
|  |  | ||||||
|     static constexpr int I  = 16; |     static constexpr int I  = 16; | ||||||
|     static constexpr int K  = 4; |     static constexpr int K  = 4; | ||||||
|     static constexpr int ne = 2; |     static constexpr int ne = 2; | ||||||
|  |  | ||||||
|     int x[ne] = {0}; |     T x[ne]; | ||||||
|  |  | ||||||
|     static __device__ __forceinline__ int get_i(const int l) { |     static __device__ __forceinline__ int get_i(const int l) { | ||||||
|         const int ret = (l%2) * (I/2) + threadIdx.x / K; |         const int ret = (l%2) * (I/2) + threadIdx.x / K; | ||||||
| @@ -21,27 +77,35 @@ struct mma_int_A_I16K4 { | |||||||
|         return ret; |         return ret; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { |     __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { | ||||||
| #if defined(INT8_MMA_AVAILABLE) |  | ||||||
|         const int * xs = xs0 + (threadIdx.x%I)*stride; |  | ||||||
|         asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" |  | ||||||
|             : "+r"(x[0]), "+r"(x[1]) |  | ||||||
|             : "l"(xs)); |  | ||||||
| #else |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|         for (int l = 0; l < ne; ++l) { |         for (int l = 0; l < ne; ++l) { | ||||||
|             x[l] = xs0[get_i(l)*stride + get_k(l)]; |             x[l] = xs0[get_i(l)*stride + get_k(l)]; | ||||||
|         } |         } | ||||||
| #endif // defined(INT8_MMA_AVAILABLE) |     } | ||||||
|  |  | ||||||
|  |     __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { | ||||||
|  | #ifdef NEW_MMA_AVAILABLE | ||||||
|  |         int * xi = (int *) x; | ||||||
|  |         const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride; | ||||||
|  |         asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" | ||||||
|  |             : "+r"(xi[0]), "+r"(xi[1]) | ||||||
|  |             : "l"(xs)); | ||||||
|  | #else | ||||||
|  |         load_generic(xs0, stride); | ||||||
|  | #endif // NEW_MMA_AVAILABLE | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct mma_int_A_I16K8 { | template <typename T> | ||||||
|  | struct mma_A_I16K8 { | ||||||
|  |     static_assert(sizeof(T) == 4, "bad type size"); | ||||||
|  |  | ||||||
|     static constexpr int I  = 16; |     static constexpr int I  = 16; | ||||||
|     static constexpr int K  = 8; |     static constexpr int K  = 8; | ||||||
|     static constexpr int ne = 4; |     static constexpr int ne = 4; | ||||||
|  |  | ||||||
|     int x[ne] = {0}; |     T x[ne]; | ||||||
|  |  | ||||||
|     static __device__ __forceinline__ int get_i(const int l) { |     static __device__ __forceinline__ int get_i(const int l) { | ||||||
|         const int ret = (l%2) * (I/2) + threadIdx.x / (K/2); |         const int ret = (l%2) * (I/2) + threadIdx.x / (K/2); | ||||||
| @@ -57,31 +121,62 @@ struct mma_int_A_I16K8 { | |||||||
|         return ret; |         return ret; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { |     __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { | ||||||
| #if defined(INT8_MMA_AVAILABLE) |  | ||||||
|         const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); |  | ||||||
|         asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" |  | ||||||
|             : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) |  | ||||||
|             : "l"(xs)); |  | ||||||
| #else |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|         for (int l = 0; l < ne; ++l) { |         for (int l = 0; l < ne; ++l) { | ||||||
|             x[l] = xs0[get_i(l)*stride + get_k(l)]; |             x[l] = xs0[get_i(l)*stride + get_k(l)]; | ||||||
|         } |         } | ||||||
| #endif // defined(INT8_MMA_AVAILABLE) |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     __device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) { |     __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { | ||||||
|         ((mma_int_A_I16K4 *) x)[0].load(xs0, stride); | #ifdef NEW_MMA_AVAILABLE | ||||||
|  |         int * xi = (int * ) x; | ||||||
|  |         const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); | ||||||
|  |         asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" | ||||||
|  |             : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) | ||||||
|  |             : "l"(xs)); | ||||||
|  | #else | ||||||
|  |         GGML_UNUSED(xs0); | ||||||
|  |         GGML_UNUSED(stride); | ||||||
|  |         NO_DEVICE_CODE; | ||||||
|  | #endif // NEW_MMA_AVAILABLE | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) { | ||||||
|  | #ifdef NEW_MMA_AVAILABLE | ||||||
|  |         int * xi = (int * ) x; | ||||||
|  |         const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2); | ||||||
|  |         asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];" | ||||||
|  |             : "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3]) | ||||||
|  |             : "l"(xs)); | ||||||
|  | #else | ||||||
|  |         GGML_UNUSED(xs0); | ||||||
|  |         GGML_UNUSED(stride); | ||||||
|  |         NO_DEVICE_CODE; | ||||||
|  | #endif // NEW_MMA_AVAILABLE | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __device__ __forceinline__ void transpose() { | ||||||
|  |         int * xi  = (int *) x; | ||||||
|  |         xi[0] = ggml_cuda_movmatrix(xi[0]); | ||||||
|  |  | ||||||
|  |         const int tmp = ggml_cuda_movmatrix(xi[1]); | ||||||
|  |         xi[1] = ggml_cuda_movmatrix(xi[2]); | ||||||
|  |         xi[2] = tmp; | ||||||
|  |  | ||||||
|  |         xi[3] = ggml_cuda_movmatrix(xi[3]); | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct mma_int_B_J8K4 { | template <typename T> | ||||||
|  | struct mma_B_J8K4 { | ||||||
|  |     static_assert(sizeof(T) == 4, "bad type size"); | ||||||
|  |  | ||||||
|     static constexpr int J  = 8; |     static constexpr int J  = 8; | ||||||
|     static constexpr int K  = 4; |     static constexpr int K  = 4; | ||||||
|     static constexpr int ne = 1; |     static constexpr int ne = 1; | ||||||
|  |  | ||||||
|     int x[ne] = {0}; |     T x[ne]; | ||||||
|  |  | ||||||
|     static __device__ __forceinline__ int get_j(const int /* l */) { |     static __device__ __forceinline__ int get_j(const int /* l */) { | ||||||
|         const int ret = threadIdx.x / K; |         const int ret = threadIdx.x / K; | ||||||
| @@ -97,27 +192,34 @@ struct mma_int_B_J8K4 { | |||||||
|         return ret; |         return ret; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { |     __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { | ||||||
| #if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster |  | ||||||
|         const int * xs = xs0 + (threadIdx.x%J)*stride; |  | ||||||
|         asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" |  | ||||||
|             : "+r"(x[0]) |  | ||||||
|             : "l"(xs)); |  | ||||||
| #else |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|         for (int l = 0; l < ne; ++l) { |         for (int l = 0; l < ne; ++l) { | ||||||
|             x[l] = xs0[get_j(l)*stride + get_k(l)]; |             x[l] = xs0[get_j(l)*stride + get_k(l)]; | ||||||
|         } |         } | ||||||
| #endif // defined(INT8_MMA_AVAILABLE) |     } | ||||||
|  |  | ||||||
|  |     __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { | ||||||
|  | #ifdef NEW_MMA_AVAILABLE | ||||||
|  |         int * xi = (int *) x; | ||||||
|  |         const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride; | ||||||
|  |         asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];" | ||||||
|  |             : "+r"(xi[0]) : "l"(xs)); | ||||||
|  | #else | ||||||
|  |         load_generic(xs0, stride); | ||||||
|  | #endif // NEW_MMA_AVAILABLE | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct mma_int_B_J8K8 { | template <typename T> | ||||||
|  | struct mma_B_J8K8 { | ||||||
|  |     static_assert(sizeof(T) == 4, "bad type size"); | ||||||
|  |  | ||||||
|     static constexpr int J  = 8; |     static constexpr int J  = 8; | ||||||
|     static constexpr int K  = 8; |     static constexpr int K  = 8; | ||||||
|     static constexpr int ne = 2; |     static constexpr int ne = 2; | ||||||
|  |  | ||||||
|     int x[ne] = {0}; |     T x[ne]; | ||||||
|  |  | ||||||
|     static __device__ __forceinline__ int get_j(const int /* l */) { |     static __device__ __forceinline__ int get_j(const int /* l */) { | ||||||
|         const int ret = threadIdx.x / (K/2); |         const int ret = threadIdx.x / (K/2); | ||||||
| @@ -133,22 +235,31 @@ struct mma_int_B_J8K8 { | |||||||
|         return ret; |         return ret; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     __device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) { |     __device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) { | ||||||
| #if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster |  | ||||||
|         const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; |  | ||||||
|         asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" |  | ||||||
|             : "+r"(x[0]), "+r"(x[1]) |  | ||||||
|             : "l"(xs)); |  | ||||||
| #else |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|         for (int l = 0; l < ne; ++l) { |         for (int l = 0; l < ne; ++l) { | ||||||
|             x[l] = xs0[get_j(l)*stride + get_k(l)]; |             x[l] = xs0[get_j(l)*stride + get_k(l)]; | ||||||
|         } |         } | ||||||
| #endif // defined(INT8_MMA_AVAILABLE) |     } | ||||||
|  |  | ||||||
|  |     __device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) { | ||||||
|  | #ifdef NEW_MMA_AVAILABLE | ||||||
|  |         int * xi = (int *) x; | ||||||
|  |         const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K; | ||||||
|  |         asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];" | ||||||
|  |             : "+r"(xi[0]), "+r"(xi[1]) | ||||||
|  |             : "l"(xs)); | ||||||
|  | #else | ||||||
|  |         load_generic(xs0, stride); | ||||||
|  | #endif // NEW_MMA_AVAILABLE | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| struct mma_int_C_I16J8 { | template <typename T> | ||||||
|  | struct mma_C_I16J8 {}; | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | struct mma_C_I16J8<int> { | ||||||
|     static constexpr int I  = 16; |     static constexpr int I  = 16; | ||||||
|     static constexpr int J  = 8; |     static constexpr int J  = 8; | ||||||
|     static constexpr int ne = 4; |     static constexpr int ne = 4; | ||||||
| @@ -169,8 +280,8 @@ struct mma_int_C_I16J8 { | |||||||
|         return ret; |         return ret; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) { |     __device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) { | ||||||
| #ifdef INT8_MMA_AVAILABLE | #ifdef NEW_MMA_AVAILABLE | ||||||
| #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||||
|         asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" |         asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" | ||||||
|             : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) |             : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) | ||||||
| @@ -188,11 +299,11 @@ struct mma_int_C_I16J8 { | |||||||
|         GGML_UNUSED(mma_A); |         GGML_UNUSED(mma_A); | ||||||
|         GGML_UNUSED(mma_B); |         GGML_UNUSED(mma_B); | ||||||
|         NO_DEVICE_CODE; |         NO_DEVICE_CODE; | ||||||
| #endif // INT8_MMA_AVAILABLE | #endif // NEW_MMA_AVAILABLE | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) { |     __device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) { | ||||||
| #ifdef INT8_MMA_AVAILABLE | #ifdef NEW_MMA_AVAILABLE | ||||||
| #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||||
|         asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" |         asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" | ||||||
|             : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) |             : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3]) | ||||||
| @@ -216,6 +327,132 @@ struct mma_int_C_I16J8 { | |||||||
|         GGML_UNUSED(mma_A); |         GGML_UNUSED(mma_A); | ||||||
|         GGML_UNUSED(mma_B); |         GGML_UNUSED(mma_B); | ||||||
|         NO_DEVICE_CODE; |         NO_DEVICE_CODE; | ||||||
| #endif // INT8_MMA_AVAILABLE | #endif // NEW_MMA_AVAILABLE | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | struct mma_C_I16J8<half2> { | ||||||
|  |     static constexpr int I  = 16; | ||||||
|  |     static constexpr int J  = 4; | ||||||
|  |     static constexpr int ne = 2; | ||||||
|  |  | ||||||
|  |     half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}}; | ||||||
|  |  | ||||||
|  |     static __device__ __forceinline__ int get_i(const int l) { | ||||||
|  |         const int ret = l * (I/2) + threadIdx.x / J; | ||||||
|  |         GGML_CUDA_ASSUME(ret >= 0); | ||||||
|  |         GGML_CUDA_ASSUME(ret <  I); | ||||||
|  |         return ret; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     static __device__ __forceinline__ int get_j(const int /* l */) { | ||||||
|  |         const int ret = threadIdx.x % J; | ||||||
|  |         GGML_CUDA_ASSUME(ret >= 0); | ||||||
|  |         GGML_CUDA_ASSUME(ret <  J); | ||||||
|  |         return ret; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) { | ||||||
|  | #ifdef NEW_MMA_AVAILABLE | ||||||
|  |         int * Axi = (int *) mma_A.x; | ||||||
|  |         int * Bxi = (int *) mma_B.x; | ||||||
|  |         int * xi  = (int *) x; | ||||||
|  | #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||||
|  |         asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};" | ||||||
|  |             : "+r"(xi[0]), "+r"(xi[1]) | ||||||
|  |             : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); | ||||||
|  | #else | ||||||
|  |         // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: | ||||||
|  |         asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" | ||||||
|  |             : "+r"(xi[0]), "+r"(xi[1]) | ||||||
|  |             : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); | ||||||
|  |         asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};" | ||||||
|  |             : "+r"(xi[0]), "+r"(xi[1]) | ||||||
|  |             : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); | ||||||
|  | #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||||
|  | #else | ||||||
|  |         GGML_UNUSED(mma_A); | ||||||
|  |         GGML_UNUSED(mma_B); | ||||||
|  |         NO_DEVICE_CODE; | ||||||
|  | #endif // NEW_MMA_AVAILABLE | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() { | ||||||
|  |         mma_B_J8K8<half2> mma_B; | ||||||
|  |  | ||||||
|  |         int * xi   = (int *) x; | ||||||
|  |         int * Bxi  = (int *) mma_B.x; | ||||||
|  |         Bxi[0] = ggml_cuda_movmatrix(xi[0]); | ||||||
|  |         Bxi[1] = ggml_cuda_movmatrix(xi[1]); | ||||||
|  |  | ||||||
|  |         return mma_B; | ||||||
|  |     } | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template <> | ||||||
|  | struct mma_C_I16J8<float> { | ||||||
|  |     static constexpr int I  = 16; | ||||||
|  |     static constexpr int J  = 8; | ||||||
|  |     static constexpr int ne = 4; | ||||||
|  |  | ||||||
|  |     float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f}; | ||||||
|  |  | ||||||
|  |     static __device__ __forceinline__ int get_i(const int l) { | ||||||
|  |         const int ret = (l/2) * (I/2) + threadIdx.x / (J/2); | ||||||
|  |         GGML_CUDA_ASSUME(ret >= 0); | ||||||
|  |         GGML_CUDA_ASSUME(ret <  I); | ||||||
|  |         return ret; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     static __device__ __forceinline__ int get_j(const int l) { | ||||||
|  |         const int ret = 2 * (threadIdx.x % (J/2)) + l%2; | ||||||
|  |         GGML_CUDA_ASSUME(ret >= 0); | ||||||
|  |         GGML_CUDA_ASSUME(ret <  J); | ||||||
|  |         return ret; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) { | ||||||
|  | #ifdef NEW_MMA_AVAILABLE | ||||||
|  |         int * Axi = (int *) mma_A.x; | ||||||
|  |         int * Bxi = (int *) mma_B.x; | ||||||
|  |         int * xi  = (int *) x; | ||||||
|  | #if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||||
|  |         asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};" | ||||||
|  |             : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) | ||||||
|  |             : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1])); | ||||||
|  | #else | ||||||
|  |         // On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead: | ||||||
|  |         asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" | ||||||
|  |             : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) | ||||||
|  |             : "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0])); | ||||||
|  |         asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};" | ||||||
|  |             : "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3]) | ||||||
|  |             : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1])); | ||||||
|  | #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||||
|  | #else | ||||||
|  |         GGML_UNUSED(mma_A); | ||||||
|  |         GGML_UNUSED(mma_B); | ||||||
|  |         NO_DEVICE_CODE; | ||||||
|  | #endif // NEW_MMA_AVAILABLE | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() { | ||||||
|  |         mma_B_J8K8<half2> mma_B; | ||||||
|  |         mma_B.x[0] = make_half2(x[0], x[1]); | ||||||
|  |         mma_B.x[1] = make_half2(x[2], x[3]); | ||||||
|  |  | ||||||
|  |         int * Bxi  = (int *) mma_B.x; | ||||||
|  |         Bxi[0] = ggml_cuda_movmatrix(Bxi[0]); | ||||||
|  |         Bxi[1] = ggml_cuda_movmatrix(Bxi[1]); | ||||||
|  |  | ||||||
|  |         return mma_B; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) { | ||||||
|  | #pragma unroll | ||||||
|  |         for (int l = 0; l < ne; ++l) { | ||||||
|  |             x[l] = xs0[get_j(l)*stride + get_i(l)]; | ||||||
|  |         } | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
|   | |||||||
| @@ -132,7 +132,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { | |||||||
|         return false; |         return false; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (int8_mma_available(cc)) { |     if (new_mma_available(cc)) { | ||||||
|         return true; |         return true; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|   | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -0,0 +1,10 @@ | |||||||
|  | // This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
|  | #include "../fattn-mma-f16.cuh" | ||||||
|  |  | ||||||
|  | DECL_FATTN_MMA_F16_CASE(64, 16); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(80, 16); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(96, 16); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(112, 16); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(128, 16); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(256, 16); | ||||||
| @@ -0,0 +1,10 @@ | |||||||
|  | // This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
|  | #include "../fattn-mma-f16.cuh" | ||||||
|  |  | ||||||
|  | DECL_FATTN_MMA_F16_CASE(64, 32); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(80, 32); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(96, 32); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(112, 32); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(128, 32); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(256, 32); | ||||||
| @@ -0,0 +1,10 @@ | |||||||
|  | // This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
|  | #include "../fattn-mma-f16.cuh" | ||||||
|  |  | ||||||
|  | DECL_FATTN_MMA_F16_CASE(64, 64); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(80, 64); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(96, 64); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(112, 64); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(128, 64); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(256, 64); | ||||||
| @@ -0,0 +1,10 @@ | |||||||
|  | // This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
|  | #include "../fattn-mma-f16.cuh" | ||||||
|  |  | ||||||
|  | DECL_FATTN_MMA_F16_CASE(64, 8); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(80, 8); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(96, 8); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(112, 8); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(128, 8); | ||||||
|  | DECL_FATTN_MMA_F16_CASE(256, 8); | ||||||
| @@ -1,10 +0,0 @@ | |||||||
| // This file has been autogenerated by generate_cu_files.py, do not edit manually. |  | ||||||
|  |  | ||||||
| #include "../fattn-wmma-f16.cuh" |  | ||||||
|  |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(64, 16, float); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(80, 16, float); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(96, 16, float); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(112, 16, float); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(128, 16, float); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(256, 16, float); |  | ||||||
| @@ -1,9 +0,0 @@ | |||||||
| // This file has been autogenerated by generate_cu_files.py, do not edit manually. |  | ||||||
|  |  | ||||||
| #include "../fattn-wmma-f16.cuh" |  | ||||||
|  |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(64, 32, float); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(80, 32, float); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(96, 32, float); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(112, 32, float); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(128, 32, float); |  | ||||||
| @@ -1,10 +0,0 @@ | |||||||
| // This file has been autogenerated by generate_cu_files.py, do not edit manually. |  | ||||||
|  |  | ||||||
| #include "../fattn-wmma-f16.cuh" |  | ||||||
|  |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(64, 16, half); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(80, 16, half); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(96, 16, half); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(112, 16, half); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(128, 16, half); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(256, 16, half); |  | ||||||
| @@ -1,10 +0,0 @@ | |||||||
| // This file has been autogenerated by generate_cu_files.py, do not edit manually. |  | ||||||
|  |  | ||||||
| #include "../fattn-wmma-f16.cuh" |  | ||||||
|  |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(64, 32, half); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(80, 32, half); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(96, 32, half); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(112, 32, half); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(128, 32, half); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(256, 32, half); |  | ||||||
| @@ -1,8 +0,0 @@ | |||||||
| // This file has been autogenerated by generate_cu_files.py, do not edit manually. |  | ||||||
|  |  | ||||||
| #include "../fattn-wmma-f16.cuh" |  | ||||||
|  |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(64, 8, half); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(96, 8, half); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(128, 8, half); |  | ||||||
| DECL_FATTN_WMMA_F16_CASE(256, 8, half); |  | ||||||
| @@ -12,13 +12,13 @@ SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.p | |||||||
| DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v}); | DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v}); | ||||||
| """ | """ | ||||||
|  |  | ||||||
| SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. | SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. | ||||||
|  |  | ||||||
| #include "../fattn-wmma-f16.cuh" | #include "../fattn-mma-f16.cuh" | ||||||
|  |  | ||||||
| """ | """ | ||||||
|  |  | ||||||
| SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n" | SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n" | ||||||
|  |  | ||||||
| TYPES_MMQ = [ | TYPES_MMQ = [ | ||||||
|     "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", |     "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", | ||||||
| @@ -57,20 +57,12 @@ for vkq_size in [16, 32]: | |||||||
|                 with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: |                 with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f: | ||||||
|                     f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) |                     f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v)) | ||||||
|  |  | ||||||
| for kq_acc_t in ["half", "float"]: | for cols_per_block in [8, 16, 32, 64]: | ||||||
|     for cols_per_block in [8, 16, 32]: |     with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f: | ||||||
|         if kq_acc_t == "float" and cols_per_block == 8: |         f.write(SOURCE_FATTN_MMA_START) | ||||||
|             continue |  | ||||||
|  |  | ||||||
|         with open(f"fattn-wmma-f16-instance-kq{kq_acc_t}-cpb{cols_per_block}.cu", "w") as f: |  | ||||||
|             f.write(SOURCE_FATTN_WMMA_START) |  | ||||||
|  |  | ||||||
|         for head_size in [64, 80, 96, 112, 128, 256]: |         for head_size in [64, 80, 96, 112, 128, 256]: | ||||||
|                 if cols_per_block == 8 and head_size % 32 != 0: # wmma fragment is 8x32 |             f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size)) | ||||||
|                     continue |  | ||||||
|                 if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance |  | ||||||
|                     continue |  | ||||||
|                 f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size)) |  | ||||||
|  |  | ||||||
| for type in TYPES_MMQ: | for type in TYPES_MMQ: | ||||||
|     with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: |     with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								ggml/src/ggml-cuda/vendors/hip.h
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								ggml/src/ggml-cuda/vendors/hip.h
									
									
									
									
										vendored
									
									
								
							| @@ -25,6 +25,7 @@ | |||||||
| #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice | #define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice | ||||||
| #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite | #define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite | ||||||
| #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} | #define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }} | ||||||
|  | #define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width) | ||||||
| #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) | #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) | ||||||
| #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 | #define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6 | ||||||
| #define cublasCreate hipblasCreate | #define cublasCreate hipblasCreate | ||||||
|   | |||||||
| @@ -50,7 +50,7 @@ file(GLOB   GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") | |||||||
| list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") | list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") | ||||||
|  |  | ||||||
| file(GLOB   GGML_SOURCES_ROCM "../ggml-cuda/*.cu") | file(GLOB   GGML_SOURCES_ROCM "../ggml-cuda/*.cu") | ||||||
| file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu") | file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") | ||||||
| list(APPEND GGML_SOURCES_ROCM ${SRCS}) | list(APPEND GGML_SOURCES_ROCM ${SRCS}) | ||||||
| file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu") | file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu") | ||||||
| list(APPEND GGML_SOURCES_ROCM ${SRCS}) | list(APPEND GGML_SOURCES_ROCM ${SRCS}) | ||||||
|   | |||||||
| @@ -29,7 +29,7 @@ if (MUSAToolkit_FOUND) | |||||||
|     list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h") |     list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h") | ||||||
|  |  | ||||||
|     file(GLOB   GGML_SOURCES_MUSA "../ggml-cuda/*.cu") |     file(GLOB   GGML_SOURCES_MUSA "../ggml-cuda/*.cu") | ||||||
|     file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu") |     file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") | ||||||
|     list(APPEND GGML_SOURCES_MUSA ${SRCS}) |     list(APPEND GGML_SOURCES_MUSA ${SRCS}) | ||||||
|     file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu") |     file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu") | ||||||
|     list(APPEND GGML_SOURCES_MUSA ${SRCS}) |     list(APPEND GGML_SOURCES_MUSA ${SRCS}) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler