mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : use F16 precision in FA kernels
ggml-ci
This commit is contained in:
		
							
								
								
									
										5
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								Makefile
									
									
									
									
									
								
							| @@ -876,6 +876,11 @@ endif # GGML_HIPBLAS | |||||||
|  |  | ||||||
| ifdef GGML_METAL | ifdef GGML_METAL | ||||||
| 	MK_CPPFLAGS += -DGGML_USE_METAL | 	MK_CPPFLAGS += -DGGML_USE_METAL | ||||||
|  |  | ||||||
|  | ifdef GGML_METAL_FORCE_FATTN_PREC_F16 | ||||||
|  | 	MK_CPPFLAGS += -DGGML_METAL_FORCE_FATTN_PREC_F16 | ||||||
|  | endif # GGML_METAL_FORCE_FATTN_PREC_F16 | ||||||
|  |  | ||||||
| 	MK_LDFLAGS  += -framework Foundation -framework Metal -framework MetalKit | 	MK_LDFLAGS  += -framework Foundation -framework Metal -framework MetalKit | ||||||
| 	OBJ_GGML	+= ggml/src/ggml-metal.o | 	OBJ_GGML	+= ggml/src/ggml-metal.o | ||||||
| ifdef GGML_METAL_NDEBUG | ifdef GGML_METAL_NDEBUG | ||||||
|   | |||||||
| @@ -256,6 +256,9 @@ static ggml_type ggml_type_from_name(const std::string & s) { | |||||||
|     if (s == "f16") { |     if (s == "f16") { | ||||||
|         return GGML_TYPE_F16; |         return GGML_TYPE_F16; | ||||||
|     } |     } | ||||||
|  |     if (s == "bf16") { | ||||||
|  |         return GGML_TYPE_BF16; | ||||||
|  |     } | ||||||
|     if (s == "q8_0") { |     if (s == "q8_0") { | ||||||
|         return GGML_TYPE_Q8_0; |         return GGML_TYPE_Q8_0; | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -153,6 +153,7 @@ option(GGML_VULKAN_VALIDATE                 "ggml: enable Vulkan validation" | |||||||
| option(GGML_VULKAN_RUN_TESTS                "ggml: run Vulkan tests"                          OFF) | option(GGML_VULKAN_RUN_TESTS                "ggml: run Vulkan tests"                          OFF) | ||||||
| option(GGML_KOMPUTE                         "ggml: use Kompute"                               OFF) | option(GGML_KOMPUTE                         "ggml: use Kompute"                               OFF) | ||||||
| option(GGML_METAL                           "ggml: use Metal"                                 ${GGML_METAL_DEFAULT}) | option(GGML_METAL                           "ggml: use Metal"                                 ${GGML_METAL_DEFAULT}) | ||||||
|  | option(GGML_METAL_FORCE_FATTN_PREC_F16      "ggml: force F16 accumulators for FA kernels"     OFF) | ||||||
| option(GGML_METAL_NDEBUG                    "ggml: disable Metal debugging"                   OFF) | option(GGML_METAL_NDEBUG                    "ggml: disable Metal debugging"                   OFF) | ||||||
| option(GGML_METAL_SHADER_DEBUG              "ggml: compile Metal with -fno-fast-math"         OFF) | option(GGML_METAL_SHADER_DEBUG              "ggml: compile Metal with -fno-fast-math"         OFF) | ||||||
| option(GGML_METAL_EMBED_LIBRARY             "ggml: embed Metal library"                       ${GGML_METAL}) | option(GGML_METAL_EMBED_LIBRARY             "ggml: embed Metal library"                       ${GGML_METAL}) | ||||||
|   | |||||||
| @@ -58,6 +58,10 @@ if (GGML_METAL) | |||||||
|         add_compile_definitions(GGML_METAL_NDEBUG) |         add_compile_definitions(GGML_METAL_NDEBUG) | ||||||
|     endif() |     endif() | ||||||
|  |  | ||||||
|  |     if (GGML_METAL_FORCE_FATTN_PREC_F16) | ||||||
|  |         add_compile_definitions(GGML_METAL_FORCE_FATTN_PREC_F16) | ||||||
|  |     endif() | ||||||
|  |  | ||||||
|     # copy ggml-common.h and ggml-metal.metal to bin directory |     # copy ggml-common.h and ggml-metal.metal to bin directory | ||||||
|     configure_file(ggml-common.h    ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h    COPYONLY) |     configure_file(ggml-common.h    ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h    COPYONLY) | ||||||
|     configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) |     configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) | ||||||
|   | |||||||
| @@ -269,6 +269,12 @@ enum ggml_metal_kernel_type { | |||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96, | ||||||
| @@ -300,12 +306,14 @@ enum ggml_metal_kernel_type { | |||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256, | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256, | ||||||
| @@ -585,6 +593,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | |||||||
|             struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ |             struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \ | ||||||
|             id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ |             id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ | ||||||
|             kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \ |             kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \ | ||||||
|  |             GGML_LOG_INFO("%s: loaded %-40s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ | ||||||
|  |                     (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ | ||||||
|  |                     (int) kernel->pipeline.threadExecutionWidth); \ | ||||||
|             [metal_function release]; \ |             [metal_function release]; \ | ||||||
|             if (error) { \ |             if (error) { \ | ||||||
|                 GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ |                 GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ | ||||||
| @@ -777,6 +788,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | |||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        has_simdgroup_mm); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        has_simdgroup_mm); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        has_simdgroup_mm); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        has_simdgroup_mm); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        has_simdgroup_mm); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        has_simdgroup_mm); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64,       flash_attn_ext_bf16_h64,        has_simdgroup_mm); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80,       flash_attn_ext_bf16_h80,        has_simdgroup_mm); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96,       flash_attn_ext_bf16_h96,        has_simdgroup_mm); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112,      flash_attn_ext_bf16_h112,       has_simdgroup_mm); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128,      flash_attn_ext_bf16_h128,       has_simdgroup_mm); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256,      flash_attn_ext_bf16_h256,       has_simdgroup_mm); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,       flash_attn_ext_q4_0_h64,        has_simdgroup_mm); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,       flash_attn_ext_q4_0_h64,        has_simdgroup_mm); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,       flash_attn_ext_q4_0_h80,        has_simdgroup_mm); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,       flash_attn_ext_q4_0_h80,        has_simdgroup_mm); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,       flash_attn_ext_q4_0_h96,        has_simdgroup_mm); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,       flash_attn_ext_q4_0_h96,        has_simdgroup_mm); | ||||||
| @@ -808,12 +825,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | |||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,      flash_attn_ext_q8_0_h128,       has_simdgroup_mm); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,      flash_attn_ext_q8_0_h128,       has_simdgroup_mm); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,      flash_attn_ext_q8_0_h256,       has_simdgroup_mm); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,      flash_attn_ext_q8_0_h256,       has_simdgroup_mm); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    has_simdgroup_reduction); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    has_simdgroup_reduction); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128,  flash_attn_ext_vec_bf16_h128,   has_simdgroup_reduction); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,  flash_attn_ext_vec_q4_0_h128,   has_simdgroup_reduction); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,  flash_attn_ext_vec_q4_0_h128,   has_simdgroup_reduction); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,  flash_attn_ext_vec_q4_1_h128,   has_simdgroup_reduction); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,  flash_attn_ext_vec_q4_1_h128,   has_simdgroup_reduction); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,  flash_attn_ext_vec_q5_0_h128,   has_simdgroup_reduction); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,  flash_attn_ext_vec_q5_0_h128,   has_simdgroup_reduction); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,  flash_attn_ext_vec_q5_1_h128,   has_simdgroup_reduction); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,  flash_attn_ext_vec_q5_1_h128,   has_simdgroup_reduction); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,  flash_attn_ext_vec_q8_0_h128,   has_simdgroup_reduction); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,  flash_attn_ext_vec_q8_0_h128,   has_simdgroup_reduction); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    has_simdgroup_reduction); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    has_simdgroup_reduction); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256,  flash_attn_ext_vec_bf16_h256,   has_simdgroup_reduction); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,  flash_attn_ext_vec_q4_0_h256,   has_simdgroup_reduction); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,  flash_attn_ext_vec_q4_0_h256,   has_simdgroup_reduction); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,  flash_attn_ext_vec_q4_1_h256,   has_simdgroup_reduction); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,  flash_attn_ext_vec_q4_1_h256,   has_simdgroup_reduction); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,  flash_attn_ext_vec_q5_0_h256,   has_simdgroup_reduction); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,  flash_attn_ext_vec_q5_0_h256,   has_simdgroup_reduction); | ||||||
| @@ -1111,7 +1130,7 @@ static void ggml_metal_encode_node( | |||||||
|     const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); |     const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); | ||||||
|     const uint64_t nb21 = src2 ? src2->nb[1] : 0; |     const uint64_t nb21 = src2 ? src2->nb[1] : 0; | ||||||
|     const uint64_t nb22 = src2 ? src2->nb[2] : 0; |     const uint64_t nb22 = src2 ? src2->nb[2] : 0; | ||||||
|     const uint64_t nb23 = src2 ? src2->nb[3] : 0; |     const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23); | ||||||
|  |  | ||||||
|     const int64_t  ne0  =  dst ?  dst->ne[0] : 0; |     const int64_t  ne0  =  dst ?  dst->ne[0] : 0; | ||||||
|     const int64_t  ne1  =  dst ?  dst->ne[1] : 0; |     const int64_t  ne1  =  dst ?  dst->ne[1] : 0; | ||||||
| @@ -3033,6 +3052,23 @@ static void ggml_metal_encode_node( | |||||||
|                                               } |                                               } | ||||||
|                                 } |                                 } | ||||||
|                             } break; |                             } break; | ||||||
|  |                         case GGML_TYPE_BF16: | ||||||
|  |                             { | ||||||
|  |                                 switch (ne00) { | ||||||
|  |                                     case 64:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H64 ].pipeline; break; | ||||||
|  |                                     case 80:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H80 ].pipeline; break; | ||||||
|  |                                     case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H96 ].pipeline; break; | ||||||
|  |                                     case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H112].pipeline; break; | ||||||
|  |                                     case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H128].pipeline; break; | ||||||
|  |                                     case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_BF16_H256].pipeline; break; | ||||||
|  |                                     default: | ||||||
|  |                                               { | ||||||
|  |                                                   GGML_LOG_ERROR("unsupported size: %lld\n", ne00); | ||||||
|  |                                                   GGML_LOG_ERROR("add template specialization for this size\n"); | ||||||
|  |                                                   GGML_ABORT("add template specialization for this size"); | ||||||
|  |                                               } | ||||||
|  |                                 } | ||||||
|  |                             } break; | ||||||
|                         case GGML_TYPE_Q4_0: |                         case GGML_TYPE_Q4_0: | ||||||
|                             { |                             { | ||||||
|                                 switch (ne00) { |                                 switch (ne00) { | ||||||
| @@ -3133,6 +3169,7 @@ static void ggml_metal_encode_node( | |||||||
|                             { |                             { | ||||||
|                                 switch (src1->type) { |                                 switch (src1->type) { | ||||||
|                                     case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; |                                     case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; | ||||||
|  |                                     case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H128].pipeline; break; | ||||||
|                                     case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break; |                                     case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128].pipeline; break; | ||||||
|                                     case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break; |                                     case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128].pipeline; break; | ||||||
|                                     case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break; |                                     case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128].pipeline; break; | ||||||
| @@ -3150,6 +3187,7 @@ static void ggml_metal_encode_node( | |||||||
|                             { |                             { | ||||||
|                                 switch (src1->type) { |                                 switch (src1->type) { | ||||||
|                                     case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; |                                     case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; | ||||||
|  |                                     case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H256].pipeline; break; | ||||||
|                                     case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break; |                                     case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256].pipeline; break; | ||||||
|                                     case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break; |                                     case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256].pipeline; break; | ||||||
|                                     case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break; |                                     case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256].pipeline; break; | ||||||
| @@ -3194,18 +3232,15 @@ static void ggml_metal_encode_node( | |||||||
|                 [encoder setBytes:&nb11          length:sizeof(uint64_t)      atIndex:14]; |                 [encoder setBytes:&nb11          length:sizeof(uint64_t)      atIndex:14]; | ||||||
|                 [encoder setBytes:&nb12          length:sizeof(uint64_t)      atIndex:15]; |                 [encoder setBytes:&nb12          length:sizeof(uint64_t)      atIndex:15]; | ||||||
|                 [encoder setBytes:&nb13          length:sizeof(uint64_t)      atIndex:16]; |                 [encoder setBytes:&nb13          length:sizeof(uint64_t)      atIndex:16]; | ||||||
|                 [encoder setBytes:&nb21          length:sizeof(uint64_t)      atIndex:17]; |                 [encoder setBytes:&nb31          length:sizeof(uint64_t)      atIndex:17]; | ||||||
|                 [encoder setBytes:&nb22          length:sizeof(uint64_t)      atIndex:18]; |                 [encoder setBytes:&ne1           length:sizeof( int64_t)      atIndex:18]; | ||||||
|                 [encoder setBytes:&nb23          length:sizeof(uint64_t)      atIndex:19]; |                 [encoder setBytes:&ne2           length:sizeof( int64_t)      atIndex:19]; | ||||||
|                 [encoder setBytes:&nb31          length:sizeof(uint64_t)      atIndex:20]; |                 [encoder setBytes:&scale         length:sizeof(   float)      atIndex:20]; | ||||||
|                 [encoder setBytes:&ne1           length:sizeof( int64_t)      atIndex:21]; |                 [encoder setBytes:&max_bias      length:sizeof(   float)      atIndex:21]; | ||||||
|                 [encoder setBytes:&ne2           length:sizeof( int64_t)      atIndex:22]; |                 [encoder setBytes:&m0            length:sizeof(m0)            atIndex:22]; | ||||||
|                 [encoder setBytes:&scale         length:sizeof(   float)      atIndex:23]; |                 [encoder setBytes:&m1            length:sizeof(m1)            atIndex:23]; | ||||||
|                 [encoder setBytes:&max_bias      length:sizeof(   float)      atIndex:24]; |                 [encoder setBytes:&n_head_log2   length:sizeof(n_head_log2)   atIndex:24]; | ||||||
|                 [encoder setBytes:&m0            length:sizeof(m0)            atIndex:25]; |                 [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25]; | ||||||
|                 [encoder setBytes:&m1            length:sizeof(m1)            atIndex:26]; |  | ||||||
|                 [encoder setBytes:&n_head_log2   length:sizeof(n_head_log2)   atIndex:27]; |  | ||||||
|                 [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28]; |  | ||||||
|  |  | ||||||
|                 if (!use_vec_kernel) { |                 if (!use_vec_kernel) { | ||||||
|                     // half8x8 kernel |                     // half8x8 kernel | ||||||
| @@ -3216,11 +3251,14 @@ static void ggml_metal_encode_node( | |||||||
|                     GGML_ASSERT(nqptg  % 8  == 0); |                     GGML_ASSERT(nqptg  % 8  == 0); | ||||||
|                     GGML_ASSERT(ncpsg  % 32 == 0); |                     GGML_ASSERT(ncpsg  % 32 == 0); | ||||||
|  |  | ||||||
|  |                     // 2*(2*ncpsg + nqptg)*(nsg) | ||||||
|  |                     // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float) | ||||||
|  |                     // | ||||||
|                     // 16*32*(nsg) |                     // 16*32*(nsg) | ||||||
|                     // the shared memory needed for the simdgroups to load the KV cache |                     // the shared memory needed for the simdgroups to load the KV cache | ||||||
|                     // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG |                     // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG | ||||||
|                     // |                     // | ||||||
| #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) | #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(2*ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) | ||||||
|  |  | ||||||
|                     int64_t nsgmax = 2; |                     int64_t nsgmax = 2; | ||||||
|  |  | ||||||
| @@ -3256,10 +3294,10 @@ static void ggml_metal_encode_node( | |||||||
|                     // for each query, we load it as f16 in shared memory (ne00) |                     // for each query, we load it as f16 in shared memory (ne00) | ||||||
|                     // and store the attention scores (nqptg x ncpsg) as f32 |                     // and store the attention scores (nqptg x ncpsg) as f32 | ||||||
|                     // |                     // | ||||||
|                     // 2*ne00*(nsg) |                     // ne00*(nsg) | ||||||
|                     // each simdgroup has a full f32 head vector in shared mem to accumulate results |                     // each simdgroup has a full f16 head vector in shared mem to accumulate results | ||||||
|                     // |                     // | ||||||
| #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*ncpsg*(nsg)) + 2*ne00*(nsg))*(sizeof(float)/2), 16)) | #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 4*ncpsg*(nsg)) + ne00*(nsg))*(sizeof(float)/2), 16)) | ||||||
|  |  | ||||||
|                     int64_t nsgmax = 2; |                     int64_t nsgmax = 2; | ||||||
|  |  | ||||||
|   | |||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @@ -3745,7 +3745,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() { | |||||||
|                     for (int nh : { 32, }) { |                     for (int nh : { 32, }) { | ||||||
|                         for (int kv : { 512, 1024, }) { |                         for (int kv : { 512, 1024, }) { | ||||||
|                             for (int nb : { 1, 3, 32, 35, }) { |                             for (int nb : { 1, 3, 32, 35, }) { | ||||||
|                                 for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { |                                 for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { | ||||||
|                                     test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV)); |                                     test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, mask, max_bias, logit_softcap, type_KV)); | ||||||
|                                 } |                                 } | ||||||
|                             } |                             } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov