mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-29 08:41:22 +00:00 
			
		
		
		
	metal : add BF16 support (#8439)
* ggml : add initial BF16 support ggml-ci * metal : add mul_mat_id BF16 support ggml-ci * metal : check for bfloat support on the Metal device ggml-ci * metal : better var names [no ci] * metal : do not build bfloat kernels when not supported ggml-ci * metal : try to fix BF16 support check ggml-ci * metal : this should correctly check bfloat support
This commit is contained in:
		| @@ -1003,6 +1003,9 @@ static ggml_type kv_cache_type_from_str(const std::string & s) { | ||||
|     if (s == "f16") { | ||||
|         return GGML_TYPE_F16; | ||||
|     } | ||||
|     if (s == "bf16") { | ||||
|         return GGML_TYPE_BF16; | ||||
|     } | ||||
|     if (s == "q8_0") { | ||||
|         return GGML_TYPE_Q8_0; | ||||
|     } | ||||
|   | ||||
| @@ -36,16 +36,18 @@ static struct ggml_backend_metal_device_context { | ||||
|     id<MTLDevice> mtl_device; | ||||
|     int           mtl_device_ref_count; | ||||
|  | ||||
|     bool support_simdgroup_reduction; | ||||
|     bool support_simdgroup_mm; | ||||
|     bool has_simdgroup_reduction; | ||||
|     bool has_simdgroup_mm; | ||||
|     bool has_bfloat; | ||||
|  | ||||
|     char name[128]; | ||||
| } g_ggml_ctx_dev_main = { | ||||
|     /*.mtl_device                  =*/ nil, | ||||
|     /*.mtl_device_ref_count        =*/ 0, | ||||
|     /*.support_simdgroup_reduction =*/ false, | ||||
|     /*.support_simdgroup_mm        =*/ false, | ||||
|     /*.name                        =*/ "", | ||||
|     /*.mtl_device              =*/ nil, | ||||
|     /*.mtl_device_ref_count    =*/ 0, | ||||
|     /*.has_simdgroup_reduction =*/ false, | ||||
|     /*.has_simdgroup_mm        =*/ false, | ||||
|     /*.has_bfloat              =*/ false, | ||||
|     /*.name                    =*/ "", | ||||
| }; | ||||
|  | ||||
| // acquire | ||||
| @@ -55,10 +57,13 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev | ||||
|     if (ctx->mtl_device == nil) { | ||||
|         ctx->mtl_device = MTLCreateSystemDefaultDevice(); | ||||
|  | ||||
|         ctx->support_simdgroup_reduction  = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; | ||||
|         ctx->support_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; | ||||
|         ctx->has_simdgroup_reduction  = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; | ||||
|         ctx->has_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; | ||||
|  | ||||
|         ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; | ||||
|         ctx->has_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7]; | ||||
|  | ||||
|         ctx->has_bfloat  = [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML]; | ||||
|         ctx->has_bfloat |= [ctx->mtl_device supportsFamily:MTLGPUFamilyApple6]; | ||||
|  | ||||
|         strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1); | ||||
|     } | ||||
| @@ -120,6 +125,7 @@ enum ggml_metal_kernel_type { | ||||
|     GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, | ||||
|     GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, | ||||
|     GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16, | ||||
|     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, | ||||
|     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, | ||||
|     GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, | ||||
| @@ -146,10 +152,14 @@ enum ggml_metal_kernel_type { | ||||
|     GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, | ||||
| @@ -170,10 +180,11 @@ enum ggml_metal_kernel_type { | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, | ||||
|   //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, | ||||
|   //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, | ||||
|   //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, | ||||
|   //GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, | ||||
| @@ -195,6 +206,7 @@ enum ggml_metal_kernel_type { | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, | ||||
| @@ -216,6 +228,7 @@ enum ggml_metal_kernel_type { | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, | ||||
| @@ -300,8 +313,11 @@ enum ggml_metal_kernel_type { | ||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256, | ||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_F16, | ||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_BF16, | ||||
|     GGML_METAL_KERNEL_TYPE_CPY_F16_F16, | ||||
|     GGML_METAL_KERNEL_TYPE_CPY_F16_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, | ||||
|     GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, | ||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, | ||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, | ||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, | ||||
| @@ -480,7 +496,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | ||||
|                 // dictionary of preprocessor macros | ||||
|                 NSMutableDictionary * prep = [NSMutableDictionary dictionary]; | ||||
|  | ||||
|                 MTLCompileOptions* options = [MTLCompileOptions new]; | ||||
|                 MTLCompileOptions * options = [MTLCompileOptions new]; | ||||
|                 options.preprocessorMacros = prep; | ||||
|  | ||||
|                 //[options setFastMathEnabled:false]; | ||||
| @@ -530,9 +546,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     GGML_LOG_INFO("%s: simdgroup reduction support   = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false"); | ||||
|     GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false"); | ||||
|     GGML_LOG_INFO("%s: hasUnifiedMemory              = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); | ||||
|     GGML_LOG_INFO("%s: simdgroup reduction   = %s\n", __func__, ctx_dev->has_simdgroup_reduction     ? "true" : "false"); | ||||
|     GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm            ? "true" : "false"); | ||||
|     GGML_LOG_INFO("%s: bfloat                = %s\n", __func__, ctx_dev->has_bfloat                  ? "true" : "false"); | ||||
|     GGML_LOG_INFO("%s: hasUnifiedMemory      = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false"); | ||||
|  | ||||
|     ctx->capture_next_compute = false; | ||||
|     ctx->capture_started = false; | ||||
| @@ -578,8 +595,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | ||||
|             GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \ | ||||
|         } | ||||
|  | ||||
|         const bool support_simdgroup_mm        = ctx_dev->support_simdgroup_mm; | ||||
|         const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction; | ||||
|         const bool has_simdgroup_mm        = ctx_dev->has_simdgroup_mm; | ||||
|         const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction; | ||||
|         const bool has_bfloat              = ctx_dev->has_bfloat; | ||||
|  | ||||
|         // simd_sum and simd_max requires MTLGPUFamilyApple7 | ||||
|  | ||||
| @@ -607,14 +625,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,                  gelu_quick_4,                   true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                          silu,                           true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                        silu_4,                         true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                  soft_max_f16,                   support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                soft_max_f16_4,                 support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                  soft_max_f32,                   support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,                soft_max_f32_4,                 support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                  soft_max_f16,                   has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                soft_max_f16_4,                 has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                  soft_max_f32,                   has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,                soft_max_f32_4,                 has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,                 diag_mask_inf,                  true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,               diag_mask_inf_8,                true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,                  get_rows_f32,                   true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16,                  get_rows_f16,                   true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16,                 get_rows_bf16,                  has_bfloat); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0,                 get_rows_q4_0,                  true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1,                 get_rows_q4_1,                  true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0,                 get_rows_q5_0,                  true); | ||||
| @@ -635,101 +654,108 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,               get_rows_iq4_nl,                true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,               get_rows_iq4_xs,                true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,                  get_rows_i32,                   true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                          norm,                           true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                  ssm_conv_f32,                   true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                  ssm_scan_f32,                   true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                mul_mv_f16_f16,                 support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,           mul_mv_f16_f32_1row,            support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,             mul_mv_f16_f32_l4,              support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,               mul_mv_q4_0_f32,                support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,               mul_mv_q4_1_f32,                support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,               mul_mv_q5_0_f32,                support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,               mul_mv_q5_1_f32,                support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,               mul_mv_q8_0_f32,                support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,               mul_mv_q2_K_f32,                support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,               mul_mv_q3_K_f32,                support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,               mul_mv_q4_K_f32,                support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,               mul_mv_q5_K_f32,                support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,               mul_mv_q6_K_f32,                support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,            mul_mv_iq2_xxs_f32,             support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,             mul_mv_iq2_xs_f32,              support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,            mul_mv_iq3_xxs_f32,             support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,              mul_mv_iq3_s_f32,               support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,              mul_mv_iq2_s_f32,               support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,              mul_mv_iq1_s_f32,               support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,              mul_mv_iq1_m_f32,               support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,             mul_mv_iq4_nl_f32,              support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,             mul_mv_iq4_xs_f32,              support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,             mul_mv_id_f32_f32,              support_simdgroup_reduction); | ||||
|       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,             mul_mv_id_f16_f16,              support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,             mul_mv_id_f16_f32,              support_simdgroup_reduction); | ||||
|       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,        mul_mv_id_f16_f32_1row,         support_simdgroup_reduction); | ||||
|       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,          mul_mv_id_f16_f32_l4,           support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,            mul_mv_id_q4_0_f32,             support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,            mul_mv_id_q4_1_f32,             support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,            mul_mv_id_q5_0_f32,             support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,            mul_mv_id_q5_1_f32,             support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,            mul_mv_id_q8_0_f32,             support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,            mul_mv_id_q2_K_f32,             support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,            mul_mv_id_q3_K_f32,             support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,            mul_mv_id_q4_K_f32,             support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,            mul_mv_id_q5_K_f32,             support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,            mul_mv_id_q6_K_f32,             support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,         mul_mv_id_iq2_xxs_f32,          support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,          mul_mv_id_iq2_xs_f32,           support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,         mul_mv_id_iq3_xxs_f32,          support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,           mul_mv_id_iq3_s_f32,            support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,           mul_mv_id_iq2_s_f32,            support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,           mul_mv_id_iq1_s_f32,            support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,           mul_mv_id_iq1_m_f32,            support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,          mul_mv_id_iq4_nl_f32,           support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,          mul_mv_id_iq4_xs_f32,           support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,                mul_mm_f32_f32,                 support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,                mul_mm_f16_f32,                 support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,               mul_mm_q4_0_f32,                support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,               mul_mm_q4_1_f32,                support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,               mul_mm_q5_0_f32,                support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,               mul_mm_q5_1_f32,                support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,               mul_mm_q8_0_f32,                support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,               mul_mm_q2_K_f32,                support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,               mul_mm_q3_K_f32,                support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,               mul_mm_q4_K_f32,                support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,               mul_mm_q5_K_f32,                support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,               mul_mm_q6_K_f32,                support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,            mul_mm_iq2_xxs_f32,             support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,             mul_mm_iq2_xs_f32,              support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,            mul_mm_iq3_xxs_f32,             support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,              mul_mm_iq3_s_f32,               support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,              mul_mm_iq2_s_f32,               support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,              mul_mm_iq1_s_f32,               support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,              mul_mm_iq1_m_f32,               support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,             mul_mm_iq4_nl_f32,              support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,             mul_mm_iq4_xs_f32,              support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,             mul_mm_id_f32_f32,              support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,             mul_mm_id_f16_f32,              support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,            mul_mm_id_q4_0_f32,             support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,            mul_mm_id_q4_1_f32,             support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,            mul_mm_id_q5_0_f32,             support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,            mul_mm_id_q5_1_f32,             support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,            mul_mm_id_q8_0_f32,             support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,            mul_mm_id_q2_K_f32,             support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,            mul_mm_id_q3_K_f32,             support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,            mul_mm_id_q4_K_f32,             support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,            mul_mm_id_q5_K_f32,             support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,            mul_mm_id_q6_K_f32,             support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,         mul_mm_id_iq2_xxs_f32,          support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,          mul_mm_id_iq2_xs_f32,           support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,         mul_mm_id_iq3_xxs_f32,          support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,           mul_mm_id_iq3_s_f32,            support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,           mul_mm_id_iq2_s_f32,            support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,           mul_mm_id_iq1_s_f32,            support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32,               mul_mv_bf16_f32,                has_simdgroup_reduction && has_bfloat); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW,          mul_mv_bf16_f32_1row,           has_simdgroup_reduction && has_bfloat); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4,            mul_mv_bf16_f32_l4,             has_simdgroup_reduction && has_bfloat); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16,              mul_mv_bf16_bf16,               has_simdgroup_reduction && has_bfloat); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,           mul_mv_f16_f32_1row,            has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,             mul_mv_f16_f32_l4,              has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                mul_mv_f16_f16,                 has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,               mul_mv_q4_0_f32,                has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,               mul_mv_q4_1_f32,                has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,               mul_mv_q5_0_f32,                has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,               mul_mv_q5_1_f32,                has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,               mul_mv_q8_0_f32,                has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,               mul_mv_q2_K_f32,                has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,               mul_mv_q3_K_f32,                has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,               mul_mv_q4_K_f32,                has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,               mul_mv_q5_K_f32,                has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,               mul_mv_q6_K_f32,                has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,            mul_mv_iq2_xxs_f32,             has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,             mul_mv_iq2_xs_f32,              has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,            mul_mv_iq3_xxs_f32,             has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,              mul_mv_iq3_s_f32,               has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,              mul_mv_iq2_s_f32,               has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,              mul_mv_iq1_s_f32,               has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,              mul_mv_iq1_m_f32,               has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,             mul_mv_iq4_nl_f32,              has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,             mul_mv_iq4_xs_f32,              has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,             mul_mv_id_f32_f32,              has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,             mul_mv_id_f16_f32,              has_simdgroup_reduction); | ||||
|       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,        mul_mv_id_f16_f32_1row,         has_simdgroup_reduction); | ||||
|       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,          mul_mv_id_f16_f32_l4,           has_simdgroup_reduction); | ||||
|       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,             mul_mv_id_f16_f16,              has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32,            mul_mv_id_bf16_f32,             has_simdgroup_reduction && has_bfloat); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,            mul_mv_id_q4_0_f32,             has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,            mul_mv_id_q4_1_f32,             has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,            mul_mv_id_q5_0_f32,             has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,            mul_mv_id_q5_1_f32,             has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,            mul_mv_id_q8_0_f32,             has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,            mul_mv_id_q2_K_f32,             has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,            mul_mv_id_q3_K_f32,             has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,            mul_mv_id_q4_K_f32,             has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,            mul_mv_id_q5_K_f32,             has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,            mul_mv_id_q6_K_f32,             has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,         mul_mv_id_iq2_xxs_f32,          has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,          mul_mv_id_iq2_xs_f32,           has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,         mul_mv_id_iq3_xxs_f32,          has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,           mul_mv_id_iq3_s_f32,            has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,           mul_mv_id_iq2_s_f32,            has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,           mul_mv_id_iq1_s_f32,            has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,           mul_mv_id_iq1_m_f32,            has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,          mul_mv_id_iq4_nl_f32,           has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,          mul_mv_id_iq4_xs_f32,           has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,                mul_mm_f32_f32,                 has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,                mul_mm_f16_f32,                 has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32,               mul_mm_bf16_f32,                has_simdgroup_mm && has_bfloat); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,               mul_mm_q4_0_f32,                has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,               mul_mm_q4_1_f32,                has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,               mul_mm_q5_0_f32,                has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,               mul_mm_q5_1_f32,                has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,               mul_mm_q8_0_f32,                has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,               mul_mm_q2_K_f32,                has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,               mul_mm_q3_K_f32,                has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,               mul_mm_q4_K_f32,                has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,               mul_mm_q5_K_f32,                has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,               mul_mm_q6_K_f32,                has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,            mul_mm_iq2_xxs_f32,             has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,             mul_mm_iq2_xs_f32,              has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,            mul_mm_iq3_xxs_f32,             has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,              mul_mm_iq3_s_f32,               has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,              mul_mm_iq2_s_f32,               has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,              mul_mm_iq1_s_f32,               has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,              mul_mm_iq1_m_f32,               has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,             mul_mm_iq4_nl_f32,              has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,             mul_mm_iq4_xs_f32,              has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,             mul_mm_id_f32_f32,              has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,             mul_mm_id_f16_f32,              has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32,            mul_mm_id_bf16_f32,             has_simdgroup_mm && has_bfloat); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,            mul_mm_id_q4_0_f32,             has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,            mul_mm_id_q4_1_f32,             has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,            mul_mm_id_q5_0_f32,             has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,            mul_mm_id_q5_1_f32,             has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,            mul_mm_id_q8_0_f32,             has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,            mul_mm_id_q2_K_f32,             has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,            mul_mm_id_q3_K_f32,             has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,            mul_mm_id_q4_K_f32,             has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,            mul_mm_id_q5_K_f32,             has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,            mul_mm_id_q6_K_f32,             has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,         mul_mm_id_iq2_xxs_f32,          has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,          mul_mm_id_iq2_xs_f32,           has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,         mul_mm_id_iq3_xxs_f32,          has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,           mul_mm_id_iq3_s_f32,            has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,           mul_mm_id_iq2_s_f32,            has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,           mul_mm_id_iq1_s_f32,            has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                 rope_norm_f32,                  true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                 rope_norm_f16,                  true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                 rope_neox_f32,                  true); | ||||
| @@ -745,58 +771,61 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,          argsort_f32_i32_desc,           true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,                leaky_relu_f32,                 true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,        flash_attn_ext_f16_h64,         support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,        flash_attn_ext_f16_h80,         support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,       flash_attn_ext_q4_0_h64,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,       flash_attn_ext_q4_0_h80,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,       flash_attn_ext_q4_0_h96,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,      flash_attn_ext_q4_0_h112,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,      flash_attn_ext_q4_0_h128,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,      flash_attn_ext_q4_0_h256,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,       flash_attn_ext_q4_1_h64,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,       flash_attn_ext_q4_1_h80,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,       flash_attn_ext_q4_1_h96,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,      flash_attn_ext_q4_1_h112,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,      flash_attn_ext_q4_1_h128,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,      flash_attn_ext_q4_1_h256,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,       flash_attn_ext_q5_0_h64,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,       flash_attn_ext_q5_0_h80,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,       flash_attn_ext_q5_0_h96,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,      flash_attn_ext_q5_0_h112,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,      flash_attn_ext_q5_0_h128,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,      flash_attn_ext_q5_0_h256,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,       flash_attn_ext_q5_1_h64,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,       flash_attn_ext_q5_1_h80,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,       flash_attn_ext_q5_1_h96,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,      flash_attn_ext_q5_1_h112,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,      flash_attn_ext_q5_1_h128,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,      flash_attn_ext_q5_1_h256,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,       flash_attn_ext_q8_0_h64,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,       flash_attn_ext_q8_0_h80,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,       flash_attn_ext_q8_0_h96,        support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,      flash_attn_ext_q8_0_h112,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,      flash_attn_ext_q8_0_h128,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,      flash_attn_ext_q8_0_h256,       support_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,  flash_attn_ext_vec_q4_0_h128,   support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,  flash_attn_ext_vec_q4_1_h128,   support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,  flash_attn_ext_vec_q5_0_h128,   support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,  flash_attn_ext_vec_q5_1_h128,   support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,  flash_attn_ext_vec_q8_0_h128,   support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,  flash_attn_ext_vec_q4_0_h256,   support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,  flash_attn_ext_vec_q4_1_h256,   support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,  flash_attn_ext_vec_q5_0_h256,   support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,  flash_attn_ext_vec_q5_1_h256,   support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,  flash_attn_ext_vec_q8_0_h256,   support_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,        flash_attn_ext_f16_h64,         has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,        flash_attn_ext_f16_h80,         has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H64,       flash_attn_ext_q4_0_h64,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H80,       flash_attn_ext_q4_0_h80,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H96,       flash_attn_ext_q4_0_h96,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H112,      flash_attn_ext_q4_0_h112,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H128,      flash_attn_ext_q4_0_h128,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_0_H256,      flash_attn_ext_q4_0_h256,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H64,       flash_attn_ext_q4_1_h64,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H80,       flash_attn_ext_q4_1_h80,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H96,       flash_attn_ext_q4_1_h96,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H112,      flash_attn_ext_q4_1_h112,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H128,      flash_attn_ext_q4_1_h128,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q4_1_H256,      flash_attn_ext_q4_1_h256,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H64,       flash_attn_ext_q5_0_h64,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H80,       flash_attn_ext_q5_0_h80,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H96,       flash_attn_ext_q5_0_h96,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H112,      flash_attn_ext_q5_0_h112,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H128,      flash_attn_ext_q5_0_h128,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_0_H256,      flash_attn_ext_q5_0_h256,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H64,       flash_attn_ext_q5_1_h64,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H80,       flash_attn_ext_q5_1_h80,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H96,       flash_attn_ext_q5_1_h96,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H112,      flash_attn_ext_q5_1_h112,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H128,      flash_attn_ext_q5_1_h128,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q5_1_H256,      flash_attn_ext_q5_1_h256,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H64,       flash_attn_ext_q8_0_h64,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H80,       flash_attn_ext_q8_0_h80,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H96,       flash_attn_ext_q8_0_h96,        has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H112,      flash_attn_ext_q8_0_h112,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H128,      flash_attn_ext_q8_0_h128,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,      flash_attn_ext_q8_0_h256,       has_simdgroup_mm); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H128,  flash_attn_ext_vec_q4_0_h128,   has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H128,  flash_attn_ext_vec_q4_1_h128,   has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H128,  flash_attn_ext_vec_q5_0_h128,   has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H128,  flash_attn_ext_vec_q5_1_h128,   has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H128,  flash_attn_ext_vec_q8_0_h128,   has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H256,  flash_attn_ext_vec_q4_0_h256,   has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H256,  flash_attn_ext_vec_q4_1_h256,   has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,  flash_attn_ext_vec_q5_0_h256,   has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,  flash_attn_ext_vec_q5_1_h256,   has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,  flash_attn_ext_vec_q8_0_h256,   has_simdgroup_reduction); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,                  cpy_f32_bf16,                   has_bfloat); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,                   cpy_f16_f32,                    true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,                  cpy_bf16_f32,                   has_bfloat); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,                 cpy_bf16_bf16,                  has_bfloat); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,                  cpy_f32_q4_0,                   true); | ||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,                  cpy_f32_q4_1,                   true); | ||||
| @@ -886,15 +915,18 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * t, size_t * offs | ||||
| } | ||||
|  | ||||
| static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) { | ||||
|     for (size_t i = 0, n = 3; i < n; ++i) { | ||||
|         if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { | ||||
|             return false; | ||||
|     const bool has_simdgroup_mm        = ctx_dev->has_simdgroup_mm; | ||||
|     const bool has_simdgroup_reduction = ctx_dev->has_simdgroup_reduction; | ||||
|     const bool has_bfloat              = ctx_dev->has_bfloat; | ||||
|  | ||||
|     if (!has_bfloat) { | ||||
|         for (size_t i = 0, n = 3; i < n; ++i) { | ||||
|             if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) { | ||||
|                 return false; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     const bool support_simdgroup_mm        = ctx_dev->support_simdgroup_mm; | ||||
|     const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction; | ||||
|  | ||||
|     switch (op->op) { | ||||
|         case GGML_OP_UNARY: | ||||
|             switch (ggml_get_unary_op(op)) { | ||||
| @@ -932,7 +964,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex | ||||
|         case GGML_OP_SOFT_MAX: | ||||
|         case GGML_OP_RMS_NORM: | ||||
|         case GGML_OP_GROUP_NORM: | ||||
|             return support_simdgroup_reduction; | ||||
|             return has_simdgroup_reduction; | ||||
|         case GGML_OP_NORM: | ||||
|         case GGML_OP_ROPE: | ||||
|             return true; | ||||
| @@ -952,13 +984,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex | ||||
|             if (op->src[1]->type != op->src[2]->type) { | ||||
|                 return false; | ||||
|             } | ||||
|             return support_simdgroup_mm; // TODO: over-restricted for vec-kernels | ||||
|             return has_simdgroup_mm; // TODO: over-restricted for vec-kernels | ||||
|         case GGML_OP_SSM_CONV: | ||||
|         case GGML_OP_SSM_SCAN: | ||||
|             return true; | ||||
|         case GGML_OP_MUL_MAT: | ||||
|         case GGML_OP_MUL_MAT_ID: | ||||
|             return support_simdgroup_reduction && | ||||
|             return has_simdgroup_reduction && | ||||
|                 (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32); | ||||
|         case GGML_OP_CPY: | ||||
|         case GGML_OP_DUP: | ||||
| @@ -969,6 +1001,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex | ||||
|                         switch (op->type) { | ||||
|                            case GGML_TYPE_F32: | ||||
|                            case GGML_TYPE_F16: | ||||
|                            case GGML_TYPE_BF16: | ||||
|                            case GGML_TYPE_Q8_0: | ||||
|                            case GGML_TYPE_Q4_0: | ||||
|                            case GGML_TYPE_Q4_1: | ||||
| @@ -981,10 +1014,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex | ||||
|                         } | ||||
|                     case GGML_TYPE_F16: | ||||
|                         switch (op->type) { | ||||
|                            case GGML_TYPE_F32: | ||||
|                            case GGML_TYPE_F16: | ||||
|                             case GGML_TYPE_F32: | ||||
|                             case GGML_TYPE_F16: | ||||
|                                 return true; | ||||
|                            default: | ||||
|                             default: | ||||
|                                 return false; | ||||
|                         } | ||||
|                     case GGML_TYPE_BF16: | ||||
|                         switch (op->type) { | ||||
|                             case GGML_TYPE_F32: | ||||
|                             case GGML_TYPE_BF16: | ||||
|                                 return true; | ||||
|                             default: | ||||
|                                 return false; | ||||
|                         } | ||||
|                     default: | ||||
| @@ -1855,6 +1896,7 @@ static void ggml_metal_encode_node( | ||||
|                             switch (src0->type) { | ||||
|                                 case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break; | ||||
|                                 case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break; | ||||
|                                 case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break; | ||||
|                                 default: break; | ||||
|                             } | ||||
|  | ||||
| @@ -1863,6 +1905,7 @@ static void ggml_metal_encode_node( | ||||
|                             switch (src0->type) { | ||||
|                                 case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32    ].pipeline; break; | ||||
|                                 case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32    ].pipeline; break; | ||||
|                                 case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32   ].pipeline; break; | ||||
| @@ -1940,6 +1983,25 @@ static void ggml_metal_encode_node( | ||||
|                                             nrows = 4; | ||||
|                                         } | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_BF16: | ||||
|                                     { | ||||
|                                         nth0 = 32; | ||||
|                                         nth1 = 1; | ||||
|                                         if (src1t == GGML_TYPE_F32) { | ||||
|                                             if (ne11 * ne12 < 4) { | ||||
|                                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; | ||||
|                                             } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { | ||||
|                                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; | ||||
|                                                 nrows = ne11; | ||||
|                                             } else { | ||||
|                                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; | ||||
|                                                 nrows = 4; | ||||
|                                             } | ||||
|                                         } else { | ||||
|                                             pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; | ||||
|                                             nrows = 4; | ||||
|                                         } | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_Q4_0: | ||||
|                                     { | ||||
|                                         nth0 = 8; | ||||
| @@ -2158,12 +2220,12 @@ static void ggml_metal_encode_node( | ||||
|                 if ([device supportsFamily:MTLGPUFamilyApple7] && | ||||
|                         ne00 % 32 == 0 && ne00 >= 64 && | ||||
|                         dst_rows > dst_rows_min) { | ||||
|  | ||||
|                     // some Metal matrix data types require aligned pointers | ||||
|                     // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) | ||||
|                     switch (src0->type) { | ||||
|                         case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break; | ||||
|                         case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8  == 0); break; | ||||
|                         case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break; | ||||
|                         case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break; | ||||
|                         case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break; | ||||
|                         default: break; | ||||
|                     } | ||||
|  | ||||
| @@ -2172,6 +2234,7 @@ static void ggml_metal_encode_node( | ||||
|                     switch (src0->type) { | ||||
|                         case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32    ].pipeline; break; | ||||
|                         case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32    ].pipeline; break; | ||||
|                         case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32   ].pipeline; break; | ||||
| @@ -2241,6 +2304,13 @@ static void ggml_metal_encode_node( | ||||
|                                 nth1 = 1; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_BF16: | ||||
|                             { | ||||
|                                 GGML_ASSERT(src1t == GGML_TYPE_F32); | ||||
|                                 nth0 = 32; | ||||
|                                 nth1 = 1; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q4_0: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
| @@ -2438,6 +2508,7 @@ static void ggml_metal_encode_node( | ||||
|                 switch (src0->type) { | ||||
|                     case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F32    ].pipeline; break; | ||||
|                     case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_F16    ].pipeline; break; | ||||
|                     case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_BF16   ].pipeline; break; | ||||
|                     case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0   ].pipeline; break; | ||||
|                     case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1   ].pipeline; break; | ||||
|                     case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0   ].pipeline; break; | ||||
| @@ -3237,6 +3308,7 @@ static void ggml_metal_encode_node( | ||||
|                             switch (dstt) { | ||||
|                                 case GGML_TYPE_F32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; | ||||
|                                 case GGML_TYPE_F16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; | ||||
|                                 case GGML_TYPE_BF16:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break; | ||||
|                                 case GGML_TYPE_Q8_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; | ||||
|                                 case GGML_TYPE_Q4_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; | ||||
|                                 case GGML_TYPE_Q4_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; | ||||
| @@ -3254,6 +3326,14 @@ static void ggml_metal_encode_node( | ||||
|                                 default: GGML_ABORT("not implemented"); | ||||
|                             }; | ||||
|                         } break; | ||||
|                     case GGML_TYPE_BF16: | ||||
|                         { | ||||
|                             switch (dstt) { | ||||
|                                 case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break; | ||||
|                                 case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break; | ||||
|                                 default: GGML_ASSERT(false && "not implemented"); | ||||
|                             }; | ||||
|                         } break; | ||||
|                     default: GGML_ABORT("not implemented"); | ||||
|                 } | ||||
|  | ||||
|   | ||||
| @@ -12,6 +12,20 @@ using namespace metal; | ||||
|  | ||||
| #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 | ||||
|  | ||||
| // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf | ||||
| // | ||||
| // cmd: | ||||
| //   .../usr/bin/metal -dM -E -c                             ggml/src/ggml-metal.metal | ||||
| //   .../usr/bin/metal -dM -E -c -target air64-apple-ios14.0 ggml/src/ggml-metal.metal | ||||
| // | ||||
| #if __METAL_VERSION__ < 310 | ||||
| #define GGML_METAL_NO_BFLOAT | ||||
| #endif | ||||
|  | ||||
| #if !defined(GGML_METAL_NO_BFLOAT) | ||||
| typedef matrix<bfloat, 4, 4> bfloat4x4; | ||||
| #endif | ||||
|  | ||||
| constexpr constant static float kvalues_iq4nl_f[16] = { | ||||
|     -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f | ||||
| }; | ||||
| @@ -27,6 +41,13 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) | ||||
|     reg = (type4x4)(*src); | ||||
| } | ||||
|  | ||||
| #if !defined(GGML_METAL_NO_BFLOAT) | ||||
| template <typename type4x4> | ||||
| void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { | ||||
|     reg = (type4x4)(*src); | ||||
| } | ||||
| #endif | ||||
|  | ||||
| template <typename type4x4> | ||||
| void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { | ||||
|     device const uint16_t * qs = ((device const uint16_t *)xb + 1); | ||||
| @@ -2041,6 +2062,10 @@ typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t; | ||||
| template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t kernel_mul_mv<float,  float4,  float,  float4>; | ||||
| template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   float,  float4>; | ||||
| template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   half,   half4>; | ||||
| #if !defined(GGML_METAL_NO_BFLOAT) | ||||
| template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float,  float4>; | ||||
| template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>; | ||||
| #endif | ||||
|  | ||||
| template<typename T, typename T4> | ||||
| kernel void kernel_mul_mv_1row( | ||||
| @@ -2110,6 +2135,9 @@ kernel void kernel_mul_mv_1row( | ||||
| typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t; | ||||
|  | ||||
| template [[host_name("kernel_mul_mv_f16_f32_1row")]]  kernel mul_mv_1row_t kernel_mul_mv_1row<half,   half4>; | ||||
| #if !defined(GGML_METAL_NO_BFLOAT) | ||||
| template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>; | ||||
| #endif | ||||
|  | ||||
| // Assumes row size (ne00) is a multiple of 4 | ||||
| template<typename T, typename T4> | ||||
| @@ -2169,6 +2197,9 @@ kernel void kernel_mul_mv_l4( | ||||
| typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t; | ||||
|  | ||||
| template [[host_name("kernel_mul_mv_f16_f32_l4")]]  kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>; | ||||
| #if !defined(GGML_METAL_NO_BFLOAT) | ||||
| template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>; | ||||
| #endif | ||||
|  | ||||
| static float rope_yarn_ramp(const float low, const float high, const int i0) { | ||||
|     const float y = (i0 / 2 - low) / max(0.001f, high - low); | ||||
| @@ -3565,10 +3596,17 @@ kernel void kernel_cpy( | ||||
|  | ||||
| typedef decltype(kernel_cpy<float, float>) kernel_cpy_t; | ||||
|  | ||||
| template [[host_name("kernel_cpy_f32_f32")]]  kernel kernel_cpy_t kernel_cpy<float,  float>; | ||||
| template [[host_name("kernel_cpy_f32_f16")]]  kernel kernel_cpy_t kernel_cpy<float,  half>; | ||||
| template [[host_name("kernel_cpy_f16_f16")]]  kernel kernel_cpy_t kernel_cpy<half,   half>; | ||||
| template [[host_name("kernel_cpy_f16_f32")]]  kernel kernel_cpy_t kernel_cpy<half,   float>; | ||||
| template [[host_name("kernel_cpy_f32_f32")]]   kernel kernel_cpy_t kernel_cpy<float,  float>; | ||||
| template [[host_name("kernel_cpy_f32_f16")]]   kernel kernel_cpy_t kernel_cpy<float,  half>; | ||||
| #if !defined(GGML_METAL_NO_BFLOAT) | ||||
| template [[host_name("kernel_cpy_f32_bf16")]]  kernel kernel_cpy_t kernel_cpy<float,  bfloat>; | ||||
| #endif | ||||
| template [[host_name("kernel_cpy_f16_f32")]]   kernel kernel_cpy_t kernel_cpy<half,   float>; | ||||
| template [[host_name("kernel_cpy_f16_f16")]]   kernel kernel_cpy_t kernel_cpy<half,   half>; | ||||
| #if !defined(GGML_METAL_NO_BFLOAT) | ||||
| template [[host_name("kernel_cpy_bf16_f32")]]  kernel kernel_cpy_t kernel_cpy<bfloat, float>; | ||||
| template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>; | ||||
| #endif | ||||
|  | ||||
| kernel void kernel_cpy_f32_q8_0( | ||||
|         device const float * src0, | ||||
| @@ -6473,6 +6511,9 @@ typedef decltype(kernel_get_rows_f<float>) get_rows_f_t; | ||||
|  | ||||
| template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_f_t kernel_get_rows_f<float>; | ||||
| template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_f_t kernel_get_rows_f<half>; | ||||
| #if !defined(GGML_METAL_NO_BFLOAT) | ||||
| template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>; | ||||
| #endif | ||||
|  | ||||
| typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t; | ||||
|  | ||||
| @@ -6504,6 +6545,9 @@ typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, de | ||||
|  | ||||
| template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32>; | ||||
| template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16>; | ||||
| #if !defined(GGML_METAL_NO_BFLOAT) | ||||
| template [[host_name("kernel_mul_mm_bf16_f32")]]    kernel mat_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4,     1,     dequantize_bf16>; | ||||
| #endif | ||||
| template [[host_name("kernel_mul_mm_q4_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0>; | ||||
| template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1>; | ||||
| template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0>; | ||||
| @@ -6532,6 +6576,9 @@ typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t; | ||||
|  | ||||
| template [[host_name("kernel_mul_mm_id_f32_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id<float4x4,      1,     dequantize_f32>; | ||||
| template [[host_name("kernel_mul_mm_id_f16_f32")]]     kernel mat_mm_id_t kernel_mul_mm_id<half4x4,       1,     dequantize_f16>; | ||||
| #if !defined(GGML_METAL_NO_BFLOAT) | ||||
| template [[host_name("kernel_mul_mm_id_bf16_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<bfloat4x4,     1,     dequantize_bf16>; | ||||
| #endif | ||||
| template [[host_name("kernel_mul_mm_id_q4_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q4_0,    2,     dequantize_q4_0>; | ||||
| template [[host_name("kernel_mul_mm_id_q4_1_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q4_1,    2,     dequantize_q4_1>; | ||||
| template [[host_name("kernel_mul_mm_id_q5_0_f32")]]    kernel mat_mm_id_t kernel_mul_mm_id<block_q5_0,    2,     dequantize_q5_0>; | ||||
| @@ -6755,6 +6802,9 @@ typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float | ||||
|  | ||||
| template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>; | ||||
| template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>; | ||||
| #if !defined(GGML_METAL_NO_BFLOAT) | ||||
| template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>; | ||||
| #endif | ||||
| template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>; | ||||
| template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; | ||||
| template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; | ||||
|   | ||||
| @@ -3599,7 +3599,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() { | ||||
|             for (int n_mats : {4}) { | ||||
|                 for (int n_used : {2}) { | ||||
|                     for (bool b : {false}) { | ||||
|                         for (int n : {1}) { | ||||
|                         for (int n : {1, 32}) { | ||||
|                             int m = 512; | ||||
|                             int k = 256; | ||||
|                             test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, n_used, b, m, n, k)); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov