mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	CUDA: app option to compile without FlashAttention (#12025)
This commit is contained in:
		
							
								
								
									
										12
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								Makefile
									
									
									
									
									
								
							| @@ -680,6 +680,10 @@ ifdef GGML_CUDA_CCBIN | |||||||
| 	MK_NVCCFLAGS += -ccbin $(GGML_CUDA_CCBIN) | 	MK_NVCCFLAGS += -ccbin $(GGML_CUDA_CCBIN) | ||||||
| endif # GGML_CUDA_CCBIN | endif # GGML_CUDA_CCBIN | ||||||
|  |  | ||||||
|  | ifdef GGML_CUDA_NO_FA | ||||||
|  | 	MK_NVCCFLAGS += -DGGML_CUDA_NO_FA | ||||||
|  | endif # GGML_CUDA_NO_FA | ||||||
|  |  | ||||||
| ifdef GGML_CUDA_FA_ALL_QUANTS | ifdef GGML_CUDA_FA_ALL_QUANTS | ||||||
| 	MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS | 	MK_NVCCFLAGS += -DGGML_CUDA_FA_ALL_QUANTS | ||||||
| endif # GGML_CUDA_FA_ALL_QUANTS | endif # GGML_CUDA_FA_ALL_QUANTS | ||||||
| @@ -800,6 +804,10 @@ ifdef GGML_CUDA_NO_PEER_COPY | |||||||
| 	HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY | 	HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY | ||||||
| endif # GGML_CUDA_NO_PEER_COPY | endif # GGML_CUDA_NO_PEER_COPY | ||||||
|  |  | ||||||
|  | ifdef GGML_CUDA_NO_FA | ||||||
|  | 	HIPFLAGS += -DGGML_CUDA_NO_FA | ||||||
|  | endif # GGML_CUDA_NO_FA | ||||||
|  |  | ||||||
| 	OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o | 	OBJ_GGML_EXT += ggml/src/ggml-cuda/ggml-cuda.o | ||||||
| 	OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu)) | 	OBJ_GGML_EXT += $(patsubst %.cu,%.o,$(wildcard ggml/src/ggml-cuda/*.cu)) | ||||||
| 	OBJ_GGML_EXT += $(OBJ_CUDA_TMPL) | 	OBJ_GGML_EXT += $(OBJ_CUDA_TMPL) | ||||||
| @@ -876,6 +884,10 @@ ifdef GGML_CUDA_NO_PEER_COPY | |||||||
| 	MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY | 	MUSAFLAGS += -DGGML_CUDA_NO_PEER_COPY | ||||||
| endif # GGML_CUDA_NO_PEER_COPY | endif # GGML_CUDA_NO_PEER_COPY | ||||||
|  |  | ||||||
|  | ifdef GGML_CUDA_NO_FA | ||||||
|  | 	MUSAFLAGS += -DGGML_CUDA_NO_FA | ||||||
|  | endif # GGML_CUDA_NO_FA | ||||||
|  |  | ||||||
| ifdef GGML_CUDA_FA_ALL_QUANTS | ifdef GGML_CUDA_FA_ALL_QUANTS | ||||||
| 	MUSAFLAGS += -DGGML_CUDA_FA_ALL_QUANTS | 	MUSAFLAGS += -DGGML_CUDA_FA_ALL_QUANTS | ||||||
| endif # GGML_CUDA_FA_ALL_QUANTS | endif # GGML_CUDA_FA_ALL_QUANTS | ||||||
|   | |||||||
| @@ -151,6 +151,7 @@ set   (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING | |||||||
|                                             "ggml: max. batch size for using peer access") |                                             "ggml: max. batch size for using peer access") | ||||||
| option(GGML_CUDA_NO_PEER_COPY               "ggml: do not use peer to peer copies"            OFF) | option(GGML_CUDA_NO_PEER_COPY               "ggml: do not use peer to peer copies"            OFF) | ||||||
| option(GGML_CUDA_NO_VMM                     "ggml: do not try to use CUDA VMM"                OFF) | option(GGML_CUDA_NO_VMM                     "ggml: do not try to use CUDA VMM"                OFF) | ||||||
|  | option(GGML_CUDA_FA                         "ggml: compile ggml FlashAttention CUDA kernels"  ON) | ||||||
| option(GGML_CUDA_FA_ALL_QUANTS              "ggml: compile all quants for FlashAttention"     OFF) | option(GGML_CUDA_FA_ALL_QUANTS              "ggml: compile all quants for FlashAttention"     OFF) | ||||||
| option(GGML_CUDA_GRAPHS                     "ggml: use CUDA graphs (llama.cpp only)"          ${GGML_CUDA_GRAPHS_DEFAULT}) | option(GGML_CUDA_GRAPHS                     "ggml: use CUDA graphs (llama.cpp only)"          ${GGML_CUDA_GRAPHS_DEFAULT}) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -69,6 +69,10 @@ if (CUDAToolkit_FOUND) | |||||||
|         add_compile_definitions(GGML_CUDA_NO_VMM) |         add_compile_definitions(GGML_CUDA_NO_VMM) | ||||||
|     endif() |     endif() | ||||||
|  |  | ||||||
|  |     if (NOT GGML_CUDA_FA) | ||||||
|  |         add_compile_definitions(GGML_CUDA_NO_FA) | ||||||
|  |     endif() | ||||||
|  |  | ||||||
|     if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) |     if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) | ||||||
|         add_compile_definitions(GGML_CUDA_F16) |         add_compile_definitions(GGML_CUDA_F16) | ||||||
|     endif() |     endif() | ||||||
|   | |||||||
| @@ -204,9 +204,9 @@ typedef float2 dfloat2; | |||||||
| #define CP_ASYNC_AVAILABLE | #define CP_ASYNC_AVAILABLE | ||||||
| #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||||
|  |  | ||||||
| #if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) | #if !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) | ||||||
| #define FLASH_ATTN_AVAILABLE | #define FLASH_ATTN_AVAILABLE | ||||||
| #endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) | #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1) | ||||||
|  |  | ||||||
| static bool fp16_available(const int cc) { | static bool fp16_available(const int cc) { | ||||||
|     return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; |     return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; | ||||||
|   | |||||||
| @@ -839,10 +839,7 @@ static __global__ void flash_attn_ext_f16( | |||||||
|         const int ne1, |         const int ne1, | ||||||
|         const int ne2, |         const int ne2, | ||||||
|         const int ne3) { |         const int ne3) { | ||||||
| #ifndef NEW_MMA_AVAILABLE | #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) | ||||||
|     NO_DEVICE_CODE; |  | ||||||
|     return; |  | ||||||
| #endif // NEW_MMA_AVAILABLE |  | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
|     if (use_logit_softcap && !(D == 128 || D == 256)) { |     if (use_logit_softcap && !(D == 128 || D == 256)) { | ||||||
| @@ -933,6 +930,9 @@ static __global__ void flash_attn_ext_f16( | |||||||
|     flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup> |     flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup> | ||||||
|         (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, |         (Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap, | ||||||
|          ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); |          ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); | ||||||
|  | #else | ||||||
|  |     NO_DEVICE_CODE; | ||||||
|  | #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) | ||||||
| } | } | ||||||
|  |  | ||||||
| template <int D, int ncols1, int ncols2> | template <int D, int ncols1, int ncols2> | ||||||
|   | |||||||
| @@ -44,12 +44,7 @@ static __global__ void flash_attn_tile_ext_f16( | |||||||
|         const int ne1, |         const int ne1, | ||||||
|         const int ne2, |         const int ne2, | ||||||
|         const int ne3) { |         const int ne3) { | ||||||
| #ifdef FP16_AVAILABLE | #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) | ||||||
|  |  | ||||||
| #ifndef FLASH_ATTN_AVAILABLE |  | ||||||
|     NO_DEVICE_CODE; |  | ||||||
|     return; |  | ||||||
| #endif // FLASH_ATTN_AVAILABLE |  | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
| #ifdef FP16_MMA_AVAILABLE | #ifdef FP16_MMA_AVAILABLE | ||||||
| @@ -290,7 +285,7 @@ static __global__ void flash_attn_tile_ext_f16( | |||||||
|     } |     } | ||||||
| #else | #else | ||||||
|    NO_DEVICE_CODE; |    NO_DEVICE_CODE; | ||||||
| #endif // FP16_AVAILABLE | #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) | ||||||
| } | } | ||||||
|  |  | ||||||
| template <int cols_per_block, int parallel_blocks, bool use_logit_softcap> | template <int cols_per_block, int parallel_blocks, bool use_logit_softcap> | ||||||
|   | |||||||
| @@ -44,10 +44,7 @@ static __global__ void flash_attn_tile_ext_f32( | |||||||
|         const int ne1, |         const int ne1, | ||||||
|         const int ne2, |         const int ne2, | ||||||
|         const int ne3) { |         const int ne3) { | ||||||
| #ifndef FLASH_ATTN_AVAILABLE | #ifdef FLASH_ATTN_AVAILABLE | ||||||
|     NO_DEVICE_CODE; |  | ||||||
|     return; |  | ||||||
| #endif // FLASH_ATTN_AVAILABLE |  | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
| #ifdef FP16_MMA_AVAILABLE | #ifdef FP16_MMA_AVAILABLE | ||||||
| @@ -285,6 +282,9 @@ static __global__ void flash_attn_tile_ext_f32( | |||||||
|             dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); |             dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  | #else | ||||||
|  |     NO_DEVICE_CODE; | ||||||
|  | #endif // FLASH_ATTN_AVAILABLE | ||||||
| } | } | ||||||
|  |  | ||||||
| template <int cols_per_block, int parallel_blocks, bool use_logit_softcap> | template <int cols_per_block, int parallel_blocks, bool use_logit_softcap> | ||||||
|   | |||||||
| @@ -41,12 +41,7 @@ static __global__ void flash_attn_vec_ext_f16( | |||||||
|         const int ne1, |         const int ne1, | ||||||
|         const int ne2, |         const int ne2, | ||||||
|         const int ne3) { |         const int ne3) { | ||||||
| #ifdef FP16_AVAILABLE | #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) | ||||||
|  |  | ||||||
| #ifndef FLASH_ATTN_AVAILABLE |  | ||||||
|     NO_DEVICE_CODE; |  | ||||||
|     return; |  | ||||||
| #endif // FLASH_ATTN_AVAILABLE |  | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
|     if (use_logit_softcap && !(D == 128 || D == 256)) { |     if (use_logit_softcap && !(D == 128 || D == 256)) { | ||||||
| @@ -300,7 +295,7 @@ static __global__ void flash_attn_vec_ext_f16( | |||||||
|     } |     } | ||||||
| #else | #else | ||||||
|    NO_DEVICE_CODE; |    NO_DEVICE_CODE; | ||||||
| #endif // FP16_AVAILABLE | #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) | ||||||
| } | } | ||||||
|  |  | ||||||
| template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> | template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> | ||||||
|   | |||||||
| @@ -41,10 +41,7 @@ static __global__ void flash_attn_vec_ext_f32( | |||||||
|         const int ne1, |         const int ne1, | ||||||
|         const int ne2, |         const int ne2, | ||||||
|         const int ne3) { |         const int ne3) { | ||||||
| #ifndef FLASH_ATTN_AVAILABLE | #ifdef FLASH_ATTN_AVAILABLE | ||||||
|     NO_DEVICE_CODE; |  | ||||||
|     return; |  | ||||||
| #endif // FLASH_ATTN_AVAILABLE |  | ||||||
|  |  | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
|     if (use_logit_softcap && !(D == 128 || D == 256)) { |     if (use_logit_softcap && !(D == 128 || D == 256)) { | ||||||
| @@ -281,6 +278,9 @@ static __global__ void flash_attn_vec_ext_f32( | |||||||
|     if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { |     if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) { | ||||||
|         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); |         dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]); | ||||||
|     } |     } | ||||||
|  | #else | ||||||
|  |     NO_DEVICE_CODE; | ||||||
|  | #endif // FLASH_ATTN_AVAILABLE | ||||||
| } | } | ||||||
|  |  | ||||||
| template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> | template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> | ||||||
|   | |||||||
| @@ -51,7 +51,7 @@ static __global__ void flash_attn_ext_f16( | |||||||
|         const int ne1, |         const int ne1, | ||||||
|         const int ne2, |         const int ne2, | ||||||
|         const int ne3) { |         const int ne3) { | ||||||
| #if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | #if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||||||
|     // Skip unused kernel variants for faster compilation: |     // Skip unused kernel variants for faster compilation: | ||||||
|     if (use_logit_softcap && !(D == 128 || D == 256)) { |     if (use_logit_softcap && !(D == 128 || D == 256)) { | ||||||
|         NO_DEVICE_CODE; |         NO_DEVICE_CODE; | ||||||
| @@ -425,7 +425,7 @@ static __global__ void flash_attn_ext_f16( | |||||||
|     } |     } | ||||||
| #else | #else | ||||||
|    NO_DEVICE_CODE; |    NO_DEVICE_CODE; | ||||||
| #endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | #endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA | ||||||
| } | } | ||||||
|  |  | ||||||
| constexpr int get_max_power_of_2(int x) { | constexpr int get_max_power_of_2(int x) { | ||||||
|   | |||||||
| @@ -3203,7 +3203,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | |||||||
|         case GGML_OP_FLASH_ATTN_EXT: { |         case GGML_OP_FLASH_ATTN_EXT: { | ||||||
| #ifndef FLASH_ATTN_AVAILABLE | #ifndef FLASH_ATTN_AVAILABLE | ||||||
|             return false; |             return false; | ||||||
| #endif | #endif // FLASH_ATTN_AVAILABLE | ||||||
|             if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { |             if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) { | ||||||
|                 return false; |                 return false; | ||||||
|             } |             } | ||||||
|   | |||||||
| @@ -107,6 +107,10 @@ if (GGML_HIP_NO_VMM) | |||||||
|     add_compile_definitions(GGML_HIP_NO_VMM) |     add_compile_definitions(GGML_HIP_NO_VMM) | ||||||
| endif() | endif() | ||||||
|  |  | ||||||
|  | if (NOT GGML_CUDA_FA) | ||||||
|  |     add_compile_definitions(GGML_CUDA_NO_FA) | ||||||
|  | endif() | ||||||
|  |  | ||||||
| if (CXX_IS_HIPCC) | if (CXX_IS_HIPCC) | ||||||
|     set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) |     set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX) | ||||||
|     target_link_libraries(ggml-hip PRIVATE hip::device) |     target_link_libraries(ggml-hip PRIVATE hip::device) | ||||||
|   | |||||||
| @@ -83,6 +83,10 @@ if (MUSAToolkit_FOUND) | |||||||
|         add_compile_definitions(GGML_CUDA_NO_VMM) |         add_compile_definitions(GGML_CUDA_NO_VMM) | ||||||
|     endif() |     endif() | ||||||
|  |  | ||||||
|  |     if (NOT GGML_CUDA_FA) | ||||||
|  |         add_compile_definitions(GGML_CUDA_NO_FA) | ||||||
|  |     endif() | ||||||
|  |  | ||||||
|     if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) |     if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) | ||||||
|         add_compile_definitions(GGML_CUDA_F16) |         add_compile_definitions(GGML_CUDA_F16) | ||||||
|     endif() |     endif() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler