mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	metal : fix kernel requirements (#15983)
* metal : fix kernel requirements ggml-ci * cont : fix supports_op * cont : fix supports_op for ARGMAX
This commit is contained in:
		@@ -1219,10 +1219,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
 | 
				
			|||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,                 set_rows_iq4_nl,                 true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,                 set_rows_iq4_nl,                 true);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM,                         l2_norm,                         has_simdgroup_reduction);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM,                         l2_norm,                         has_simdgroup_reduction);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                      group_norm,                      has_simdgroup_reduction);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                      group_norm,                      has_simdgroup_reduction);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                            norm,                            true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                            norm,                            has_simdgroup_reduction);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                    ssm_conv_f32,                    true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                    ssm_conv_f32,                    true);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                    ssm_scan_f32,                    true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                    ssm_scan_f32,                    has_simdgroup_reduction);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,              ssm_scan_f32_group,              true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP,              ssm_scan_f32_group,              has_simdgroup_reduction);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,                   rwkv_wkv6_f32,                   true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32,                   rwkv_wkv6_f32,                   true);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,                   rwkv_wkv7_f32,                   true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32,                   rwkv_wkv7_f32,                   true);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                  mul_mv_f32_f32,                  has_simdgroup_reduction);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                  mul_mv_f32_f32,                  has_simdgroup_reduction);
 | 
				
			||||||
@@ -1443,9 +1443,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
 | 
				
			|||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,                      swiglu_oai,                      true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU_OAI,                      swiglu_oai,                      true);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF,                       geglu_erf,                       true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_ERF,                       geglu_erf,                       true);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,                     geglu_quick,                     true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU_QUICK,                     geglu_quick,                     true);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                        sum_rows,                        has_simdgroup_reduction);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN,                            mean,                            true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN,                            mean,                            has_simdgroup_reduction);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                          argmax,                          has_simdgroup_reduction);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,                 pool_2d_avg_f32,                 true);
 | 
				
			||||||
        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,                 pool_2d_max_f32,                 true);
 | 
					        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,                 pool_2d_max_f32,                 true);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -1982,7 +1982,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
 | 
				
			|||||||
        case GGML_OP_L2_NORM:
 | 
					        case GGML_OP_L2_NORM:
 | 
				
			||||||
            return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
 | 
					            return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
 | 
				
			||||||
        case GGML_OP_ARGMAX:
 | 
					        case GGML_OP_ARGMAX:
 | 
				
			||||||
            return true;
 | 
					            return has_simdgroup_reduction;
 | 
				
			||||||
        case GGML_OP_NORM:
 | 
					        case GGML_OP_NORM:
 | 
				
			||||||
            return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
 | 
					            return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
 | 
				
			||||||
        case GGML_OP_ROPE:
 | 
					        case GGML_OP_ROPE:
 | 
				
			||||||
@@ -2028,6 +2028,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
 | 
				
			|||||||
            return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
 | 
					            return has_simdgroup_mm; // TODO: over-restricted for vec-kernels
 | 
				
			||||||
        case GGML_OP_SSM_CONV:
 | 
					        case GGML_OP_SSM_CONV:
 | 
				
			||||||
        case GGML_OP_SSM_SCAN:
 | 
					        case GGML_OP_SSM_SCAN:
 | 
				
			||||||
 | 
					            return has_simdgroup_reduction;
 | 
				
			||||||
        case GGML_OP_RWKV_WKV6:
 | 
					        case GGML_OP_RWKV_WKV6:
 | 
				
			||||||
        case GGML_OP_RWKV_WKV7:
 | 
					        case GGML_OP_RWKV_WKV7:
 | 
				
			||||||
            return true;
 | 
					            return true;
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user