mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	ggml : add mrope kernel for metal (#13457)
This commit is contained in:
		| @@ -207,6 +207,10 @@ typedef struct { | |||||||
|     float    attn_factor; |     float    attn_factor; | ||||||
|     float    beta_fast; |     float    beta_fast; | ||||||
|     float    beta_slow; |     float    beta_slow; | ||||||
|  |     int32_t  sect_0; | ||||||
|  |     int32_t  sect_1; | ||||||
|  |     int32_t  sect_2; | ||||||
|  |     int32_t  sect_3; | ||||||
| } ggml_metal_kargs_rope; | } ggml_metal_kargs_rope; | ||||||
|  |  | ||||||
| typedef struct { | typedef struct { | ||||||
|   | |||||||
| @@ -332,6 +332,10 @@ enum ggml_metal_kernel_type { | |||||||
|     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, |     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16, | ||||||
|     GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, |     GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32, | ||||||
|     GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, |     GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, | ||||||
|     GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, |     GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, | ||||||
|     GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, |     GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, | ||||||
|     GGML_METAL_KERNEL_TYPE_IM2COL_F16, |     GGML_METAL_KERNEL_TYPE_IM2COL_F16, | ||||||
| @@ -1275,6 +1279,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | |||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,            mul_mm_id_iq4_xs_f16,            has_simdgroup_mm); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F16,            mul_mm_id_iq4_xs_f16,            has_simdgroup_mm); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                   rope_norm_f32,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                   rope_norm_f32,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                   rope_norm_f16,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                   rope_norm_f16,                   true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,                  rope_multi_f32,                  true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,                  rope_multi_f16,                  true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,                 rope_vision_f32,                 true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,                 rope_vision_f16,                 true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                   rope_neox_f32,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                   rope_neox_f32,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                   rope_neox_f16,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                   rope_neox_f16,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                      im2col_f16,                      true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                      im2col_f16,                      true); | ||||||
| @@ -1637,16 +1645,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex | |||||||
|         case GGML_OP_NORM: |         case GGML_OP_NORM: | ||||||
|             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); |             return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); | ||||||
|         case GGML_OP_ROPE: |         case GGML_OP_ROPE: | ||||||
|             { |  | ||||||
|                 const int mode = ((const int32_t *) op->op_params)[2]; |  | ||||||
|                 if (mode & GGML_ROPE_TYPE_MROPE) { |  | ||||||
|                     return false; |  | ||||||
|                 } |  | ||||||
|                 if (mode & GGML_ROPE_TYPE_VISION) { |  | ||||||
|                     return false; |  | ||||||
|                 } |  | ||||||
|             return true; |             return true; | ||||||
|             } |  | ||||||
|         case GGML_OP_IM2COL: |         case GGML_OP_IM2COL: | ||||||
|             return op->src[0]->type == GGML_TYPE_F16; |             return op->src[0]->type == GGML_TYPE_F16; | ||||||
|         case GGML_OP_POOL_1D: |         case GGML_OP_POOL_1D: | ||||||
| @@ -3826,6 +3825,7 @@ static bool ggml_metal_encode_node( | |||||||
|             } break; |             } break; | ||||||
|         case GGML_OP_ROPE: |         case GGML_OP_ROPE: | ||||||
|             { |             { | ||||||
|  |  | ||||||
|                 // make sure we have one or more position id(ne10) per token(ne02) |                 // make sure we have one or more position id(ne10) per token(ne02) | ||||||
|                 GGML_ASSERT(ne10 % ne02 == 0); |                 GGML_ASSERT(ne10 % ne02 == 0); | ||||||
|                 GGML_ASSERT(ne10 >= ne02); |                 GGML_ASSERT(ne10 >= ne02); | ||||||
| @@ -3853,19 +3853,41 @@ static bool ggml_metal_encode_node( | |||||||
|                 memcpy(&beta_slow,   (const int32_t *) dst->op_params + 10, sizeof(float)); |                 memcpy(&beta_slow,   (const int32_t *) dst->op_params + 10, sizeof(float)); | ||||||
|  |  | ||||||
|                 const bool is_neox   = mode & GGML_ROPE_TYPE_NEOX; |                 const bool is_neox   = mode & GGML_ROPE_TYPE_NEOX; | ||||||
|  |                 const bool is_mrope  = mode & GGML_ROPE_TYPE_MROPE; | ||||||
|  |                 const bool is_vision = mode == GGML_ROPE_TYPE_VISION; | ||||||
|  |  | ||||||
|  |                 // mrope | ||||||
|  |                 const int sect_0 = ((const int32_t *) dst->op_params)[11]; | ||||||
|  |                 const int sect_1 = ((const int32_t *) dst->op_params)[12]; | ||||||
|  |                 const int sect_2 = ((const int32_t *) dst->op_params)[13]; | ||||||
|  |                 const int sect_3 = ((const int32_t *) dst->op_params)[14]; | ||||||
|  |  | ||||||
|                 id<MTLComputePipelineState> pipeline = nil; |                 id<MTLComputePipelineState> pipeline = nil; | ||||||
|  |  | ||||||
|                 if (!is_neox) { |                 if (is_neox) { | ||||||
|                     switch (src0->type) { |                     switch (src0->type) { | ||||||
|                         case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; |                         case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; | ||||||
|                         case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; |                         case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; | ||||||
|  |                         default: GGML_ABORT("fatal error"); | ||||||
|  |                     }; | ||||||
|  |                 } else if (is_mrope && !is_vision) { | ||||||
|  |                     GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token | ||||||
|  |                     switch (src0->type) { | ||||||
|  |                         case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break; | ||||||
|  |                         case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break; | ||||||
|  |                         default: GGML_ABORT("fatal error"); | ||||||
|  |                     }; | ||||||
|  |                 } else if (is_vision) { | ||||||
|  |                     GGML_ASSERT(ne10*4 >= ne02); // need at least 4 pos per token | ||||||
|  |                     switch (src0->type) { | ||||||
|  |                         case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break; | ||||||
|  |                         case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break; | ||||||
|                         default: GGML_ABORT("fatal error"); |                         default: GGML_ABORT("fatal error"); | ||||||
|                     }; |                     }; | ||||||
|                 } else { |                 } else { | ||||||
|                     switch (src0->type) { |                     switch (src0->type) { | ||||||
|                         case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; |                         case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; | ||||||
|                         case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; |                         case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; | ||||||
|                         default: GGML_ABORT("fatal error"); |                         default: GGML_ABORT("fatal error"); | ||||||
|                     }; |                     }; | ||||||
|                 } |                 } | ||||||
| @@ -3896,6 +3918,10 @@ static bool ggml_metal_encode_node( | |||||||
|                     /*.attn_factor =*/ attn_factor, |                     /*.attn_factor =*/ attn_factor, | ||||||
|                     /*.beta_fast   =*/ beta_fast, |                     /*.beta_fast   =*/ beta_fast, | ||||||
|                     /*.beta_slow   =*/ beta_slow, |                     /*.beta_slow   =*/ beta_slow, | ||||||
|  |                     /* sect_0      =*/ sect_0, | ||||||
|  |                     /* sect_1      =*/ sect_1, | ||||||
|  |                     /* sect_2      =*/ sect_2, | ||||||
|  |                     /* sect_3      =*/ sect_3, | ||||||
|                 }; |                 }; | ||||||
|  |  | ||||||
|                 [encoder setComputePipelineState:pipeline]; |                 [encoder setComputePipelineState:pipeline]; | ||||||
|   | |||||||
| @@ -2713,8 +2713,148 @@ kernel void kernel_rope_neox( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template<typename T> | ||||||
|  | kernel void kernel_rope_multi( | ||||||
|  |         constant ggml_metal_kargs_rope & args, | ||||||
|  |         device const char * src0, | ||||||
|  |         device const char * src1, | ||||||
|  |         device const char * src2, | ||||||
|  |         device       char * dst, | ||||||
|  |         ushort  tiitg[[thread_index_in_threadgroup]], | ||||||
|  |         ushort3 tptg [[threads_per_threadgroup]], | ||||||
|  |         uint3   tgpig[[threadgroup_position_in_grid]]) { | ||||||
|  |     const int i3 = tgpig[2]; | ||||||
|  |     const int i2 = tgpig[1]; | ||||||
|  |     const int i1 = tgpig[0]; | ||||||
|  |  | ||||||
|  |     float corr_dims[2]; | ||||||
|  |     rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); | ||||||
|  |  | ||||||
|  |     device const int32_t * pos = (device const int32_t *) src1; | ||||||
|  |  | ||||||
|  |     const float inv_ndims = -1.f/args.n_dims; | ||||||
|  |  | ||||||
|  |     float cos_theta; | ||||||
|  |     float sin_theta; | ||||||
|  |  | ||||||
|  |     for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { | ||||||
|  |         if (i0 < args.n_dims) { | ||||||
|  |             const int ic = i0/2; | ||||||
|  |  | ||||||
|  |             // mrope theta calculations | ||||||
|  |             // note: the rest is the same as kernel_rope_neox | ||||||
|  |             const int sect_dims = args.sect_0 + args.sect_1 + args.sect_2 + args.sect_3; | ||||||
|  |             const int sec_w01   = args.sect_0 + args.sect_1;               // end of section 1 | ||||||
|  |             const int sec_w012  = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2 | ||||||
|  |             const int sector    = ic % sect_dims; | ||||||
|  |  | ||||||
|  |             float theta_base; | ||||||
|  |             if (sector < args.sect_0) { | ||||||
|  |                 theta_base = (float) pos[i2]; | ||||||
|  |             } else if (sector < sec_w01) { | ||||||
|  |                 theta_base = (float) pos[i2 + args.ne02]; | ||||||
|  |             } else if (sector < sec_w012) { | ||||||
|  |                 theta_base = (float) pos[i2 + args.ne02 * 2]; | ||||||
|  |             } else { | ||||||
|  |                 theta_base = (float) pos[i2 + args.ne02 * 3]; | ||||||
|  |             } | ||||||
|  |             // end of mrope | ||||||
|  |  | ||||||
|  |             const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); | ||||||
|  |  | ||||||
|  |             const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; | ||||||
|  |  | ||||||
|  |             rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); | ||||||
|  |  | ||||||
|  |             device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); | ||||||
|  |             device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0); | ||||||
|  |  | ||||||
|  |             const float x0 = src[0]; | ||||||
|  |             const float x1 = src[args.n_dims/2]; | ||||||
|  |  | ||||||
|  |             dst_data[0]             = x0*cos_theta - x1*sin_theta; | ||||||
|  |             dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; | ||||||
|  |         } else { | ||||||
|  |             device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); | ||||||
|  |             device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0); | ||||||
|  |  | ||||||
|  |             dst_data[0] = src[0]; | ||||||
|  |             dst_data[1] = src[1]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template<typename T> | ||||||
|  | kernel void kernel_rope_vision( | ||||||
|  |         constant ggml_metal_kargs_rope & args, | ||||||
|  |         device const char * src0, | ||||||
|  |         device const char * src1, | ||||||
|  |         device const char * src2, | ||||||
|  |         device       char * dst, | ||||||
|  |         ushort  tiitg[[thread_index_in_threadgroup]], | ||||||
|  |         ushort3 tptg [[threads_per_threadgroup]], | ||||||
|  |         uint3   tgpig[[threadgroup_position_in_grid]]) { | ||||||
|  |     const int i3 = tgpig[2]; | ||||||
|  |     const int i2 = tgpig[1]; | ||||||
|  |     const int i1 = tgpig[0]; | ||||||
|  |  | ||||||
|  |     float corr_dims[2]; | ||||||
|  |     rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); | ||||||
|  |  | ||||||
|  |     device const int32_t * pos = (device const int32_t *) src1; | ||||||
|  |  | ||||||
|  |     const float inv_ndims = -1.f/args.n_dims; | ||||||
|  |  | ||||||
|  |     float cos_theta; | ||||||
|  |     float sin_theta; | ||||||
|  |  | ||||||
|  |     for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { | ||||||
|  |         if (i0 < 2*args.n_dims) { // different from kernel_rope_multi | ||||||
|  |             const int ic = i0/2; | ||||||
|  |  | ||||||
|  |             // mrope theta calculations (only support 2 dimensions) | ||||||
|  |             const int sect_dims = args.sect_0 + args.sect_1; | ||||||
|  |             const int sector    = ic % sect_dims; | ||||||
|  |  | ||||||
|  |             float p; | ||||||
|  |             float theta_base; | ||||||
|  |             if (sector < args.sect_1) { | ||||||
|  |                 p = (float) sector; | ||||||
|  |                 theta_base = (float) pos[i2]; | ||||||
|  |             } else { | ||||||
|  |                 p = (float) sector - args.sect_0; | ||||||
|  |                 theta_base = (float) pos[i2 + args.ne02]; | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             const float theta = theta_base * pow(args.freq_base, 2.0f * inv_ndims * p); | ||||||
|  |             // end of mrope | ||||||
|  |  | ||||||
|  |             const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; | ||||||
|  |  | ||||||
|  |             rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); | ||||||
|  |  | ||||||
|  |             device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); | ||||||
|  |             device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + ic*args.nb0); | ||||||
|  |  | ||||||
|  |             const float x0 = src[0]; | ||||||
|  |             const float x1 = src[args.n_dims]; // different from kernel_rope_multi | ||||||
|  |  | ||||||
|  |             dst_data[0]           = x0*cos_theta - x1*sin_theta; | ||||||
|  |             dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; // different from kernel_rope_multi | ||||||
|  |         } else { | ||||||
|  |             device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); | ||||||
|  |             device       T * dst_data  = (device T *)( dst + i3*args.nb3  + i2*args.nb2  + i1*args.nb1  + i0*args.nb0); | ||||||
|  |  | ||||||
|  |             dst_data[0] = src[0]; | ||||||
|  |             dst_data[1] = src[1]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t; | typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t; | ||||||
| typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t; | typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t; | ||||||
|  | typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t; | ||||||
|  | typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t; | ||||||
|  |  | ||||||
| template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>; | template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>; | ||||||
| template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>; | template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>; | ||||||
| @@ -2722,6 +2862,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_ | |||||||
| template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>; | template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>; | ||||||
| template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>; | template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>; | ||||||
|  |  | ||||||
|  | template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>; | ||||||
|  | template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>; | ||||||
|  |  | ||||||
|  | template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>; | ||||||
|  | template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>; | ||||||
|  |  | ||||||
| typedef void (im2col_t)( | typedef void (im2col_t)( | ||||||
|         device const float * x, |         device const float * x, | ||||||
|         device        char * dst, |         device        char * dst, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Xuan-Son Nguyen
					Xuan-Son Nguyen