mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	
							
								
								
									
										15
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -184,9 +184,9 @@ enum ggml_metal_kernel_type { | ||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, | ||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, | ||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, | ||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, | ||||
|   //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,     // https://github.com/ggerganov/llama.cpp/issues/7261 | ||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, | ||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, | ||||
|   //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 | ||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_F16, | ||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, | ||||
| @@ -634,9 +634,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         ctx->support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        ctx->support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        ctx->support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        ctx->support_simdgroup_mm); | ||||
|       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        ctx->support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    ctx->support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    ctx->support_simdgroup_reduction); | ||||
|       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    ctx->support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true); | ||||
| @@ -770,6 +770,9 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const | ||||
|         case GGML_OP_LEAKY_RELU: | ||||
|             return true; | ||||
|         case GGML_OP_FLASH_ATTN_EXT: | ||||
|             if (op->src[0]->ne[0] == 256) { | ||||
|                 return false; | ||||
|             } | ||||
|             return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels | ||||
|         case GGML_OP_MUL_MAT: | ||||
|         case GGML_OP_MUL_MAT_ID: | ||||
| @@ -2573,7 +2576,7 @@ static enum ggml_status ggml_metal_graph_compute( | ||||
|                                 case 96:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; | ||||
|                                 case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; | ||||
|                                 case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; | ||||
|                                 case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; | ||||
|                               //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; | ||||
|                                 default: | ||||
|                                           { | ||||
|                                               GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); | ||||
| @@ -2586,7 +2589,7 @@ static enum ggml_status ggml_metal_graph_compute( | ||||
|  | ||||
|                             switch (ne00) { | ||||
|                                 case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; | ||||
|                                 case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; | ||||
|                               //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; | ||||
|                                 default: | ||||
|                                           { | ||||
|                                               GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); | ||||
|   | ||||
| @@ -2418,7 +2418,7 @@ template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; | ||||
| //template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256>; | ||||
|  | ||||
| template<int64_t D, int64_t Q = 1, int64_t C = 32> // head size, queries per threadgroup, cache items per threadgroup | ||||
| kernel void kernel_flash_attn_ext_vec_f16( | ||||
| @@ -2696,7 +2696,7 @@ kernel void kernel_flash_attn_ext_vec_f16( | ||||
| } | ||||
|  | ||||
| template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; | ||||
| template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; | ||||
| //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; | ||||
|  | ||||
| kernel void kernel_cpy_f16_f16( | ||||
|         device  const half * src0, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov