mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : add rope_f16 kernel + optimize cpy kernels
This commit is contained in:
		
							
								
								
									
										36
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										36
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -100,7 +100,8 @@ struct ggml_metal_context { | ||||
|     GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); | ||||
|     GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); | ||||
|     GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); | ||||
|     GGML_METAL_DECL_KERNEL(rope); | ||||
|     GGML_METAL_DECL_KERNEL(rope_f32); | ||||
|     GGML_METAL_DECL_KERNEL(rope_f16); | ||||
|     GGML_METAL_DECL_KERNEL(alibi_f32); | ||||
|     GGML_METAL_DECL_KERNEL(cpy_f32_f16); | ||||
|     GGML_METAL_DECL_KERNEL(cpy_f32_f32); | ||||
| @@ -261,7 +262,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { | ||||
|         GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); | ||||
|         GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); | ||||
|         GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); | ||||
|         GGML_METAL_ADD_KERNEL(rope); | ||||
|         GGML_METAL_ADD_KERNEL(rope_f32); | ||||
|         GGML_METAL_ADD_KERNEL(rope_f16); | ||||
|         GGML_METAL_ADD_KERNEL(alibi_f32); | ||||
|         GGML_METAL_ADD_KERNEL(cpy_f32_f16); | ||||
|         GGML_METAL_ADD_KERNEL(cpy_f32_f32); | ||||
| @@ -335,7 +337,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { | ||||
|     GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); | ||||
|     GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); | ||||
|     GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); | ||||
|     GGML_METAL_DEL_KERNEL(rope); | ||||
|     GGML_METAL_DEL_KERNEL(rope_f32); | ||||
|     GGML_METAL_DEL_KERNEL(rope_f16); | ||||
|     GGML_METAL_DEL_KERNEL(alibi_f32); | ||||
|     GGML_METAL_DEL_KERNEL(cpy_f32_f16); | ||||
|     GGML_METAL_DEL_KERNEL(cpy_f32_f32); | ||||
| @@ -870,7 +873,7 @@ void ggml_metal_graph_compute( | ||||
|                         } break; | ||||
|                     case GGML_OP_SOFT_MAX: | ||||
|                         { | ||||
|                             const int nth = 32; | ||||
|                             const int nth = MIN(32, ne00); | ||||
|  | ||||
|                             if (ne00%4 == 0) { | ||||
|                                 [encoder setComputePipelineState:ctx->pipeline_soft_max_4]; | ||||
| @@ -1134,7 +1137,7 @@ void ggml_metal_graph_compute( | ||||
|                             float eps; | ||||
|                             memcpy(&eps, dst->op_params, sizeof(float)); | ||||
|  | ||||
|                             const int nth = 512; | ||||
|                             const int nth = MIN(512, ne00); | ||||
|  | ||||
|                             [encoder setComputePipelineState:ctx->pipeline_rms_norm]; | ||||
|                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||
| @@ -1153,7 +1156,7 @@ void ggml_metal_graph_compute( | ||||
|                             float eps; | ||||
|                             memcpy(&eps, dst->op_params, sizeof(float)); | ||||
|  | ||||
|                             const int nth = 256; | ||||
|                             const int nth = MIN(256, ne00); | ||||
|  | ||||
|                             [encoder setComputePipelineState:ctx->pipeline_norm]; | ||||
|                             [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0]; | ||||
| @@ -1171,6 +1174,8 @@ void ggml_metal_graph_compute( | ||||
|                         { | ||||
|                             GGML_ASSERT((src0t == GGML_TYPE_F32)); | ||||
|  | ||||
|                             const int nth = MIN(1024, ne00); | ||||
|  | ||||
|                             const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); | ||||
|                             const int n_head = ((int32_t *) dst->op_params)[1]; | ||||
|                             float max_bias; | ||||
| @@ -1204,15 +1209,15 @@ void ggml_metal_graph_compute( | ||||
|                             [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17]; | ||||
|                             [encoder setBytes:&m0  length:sizeof(    float) atIndex:18]; | ||||
|  | ||||
|                             const int nth = 32; | ||||
|  | ||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; | ||||
|                         } break; | ||||
|                     case GGML_OP_ROPE: | ||||
|                         { | ||||
|                             GGML_ASSERT(ne10 == ne02); | ||||
|  | ||||
|                             //const int n_past = ((int32_t *) dst->op_params)[0]; | ||||
|                             const int nth = MIN(1024, ne00); | ||||
|  | ||||
|                             const int n_past = ((int32_t *) dst->op_params)[0]; | ||||
|                             const int n_dims = ((int32_t *) dst->op_params)[1]; | ||||
|                             const int mode   = ((int32_t *) dst->op_params)[2]; | ||||
|  | ||||
| @@ -1221,7 +1226,12 @@ void ggml_metal_graph_compute( | ||||
|                             memcpy(&freq_base,  (int32_t *) dst->op_params + 4, sizeof(float)); | ||||
|                             memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); | ||||
|  | ||||
|                             [encoder setComputePipelineState:ctx->pipeline_rope]; | ||||
|                             switch (src0->type) { | ||||
|                                 case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_rope_f32]; break; | ||||
|                                 case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_rope_f16]; break; | ||||
|                                 default: GGML_ASSERT(false); | ||||
|                             }; | ||||
|  | ||||
|                             [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0]; | ||||
|                             [encoder setBuffer:id_src1 offset:offs_src1        atIndex:1]; | ||||
|                             [encoder setBuffer:id_dst  offset:offs_dst         atIndex:2]; | ||||
| @@ -1241,19 +1251,19 @@ void ggml_metal_graph_compute( | ||||
|                             [encoder setBytes:&nb1     length:sizeof(uint64_t) atIndex:16]; | ||||
|                             [encoder setBytes:&nb2     length:sizeof(uint64_t) atIndex:17]; | ||||
|                             [encoder setBytes:&nb3     length:sizeof(uint64_t) atIndex:18]; | ||||
|                             //[encoder setBytes:&n_past  length:sizeof(     int) atIndex:19]; | ||||
|                             [encoder setBytes:&n_past  length:sizeof(     int) atIndex:19]; | ||||
|                             [encoder setBytes:&n_dims  length:sizeof(     int) atIndex:20]; | ||||
|                             [encoder setBytes:&mode    length:sizeof(     int) atIndex:21]; | ||||
|                             [encoder setBytes:&freq_base  length:sizeof(float) atIndex:22]; | ||||
|                             [encoder setBytes:&freq_scale length:sizeof(float) atIndex:23]; | ||||
|  | ||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; | ||||
|                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; | ||||
|                         } break; | ||||
|                     case GGML_OP_DUP: | ||||
|                     case GGML_OP_CPY: | ||||
|                     case GGML_OP_CONT: | ||||
|                         { | ||||
|                             const int nth = 32; | ||||
|                             const int nth = MIN(1024, ne00); | ||||
|  | ||||
|                             switch (src0t) { | ||||
|                                 case GGML_TYPE_F32: | ||||
|   | ||||
| @@ -853,6 +853,36 @@ kernel void kernel_alibi_f32( | ||||
|     } | ||||
| } | ||||
|  | ||||
| typedef void (rope_t)( | ||||
|         device const    void * src0, | ||||
|         device const int32_t * src1, | ||||
|         device         float * dst, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne01, | ||||
|         constant     int64_t & ne02, | ||||
|         constant     int64_t & ne03, | ||||
|         constant    uint64_t & nb00, | ||||
|         constant    uint64_t & nb01, | ||||
|         constant    uint64_t & nb02, | ||||
|         constant    uint64_t & nb03, | ||||
|         constant     int64_t & ne0, | ||||
|         constant     int64_t & ne1, | ||||
|         constant     int64_t & ne2, | ||||
|         constant     int64_t & ne3, | ||||
|         constant    uint64_t & nb0, | ||||
|         constant    uint64_t & nb1, | ||||
|         constant    uint64_t & nb2, | ||||
|         constant    uint64_t & nb3, | ||||
|         constant         int & n_past, | ||||
|         constant         int & n_dims, | ||||
|         constant         int & mode, | ||||
|         constant       float & freq_base, | ||||
|         constant       float & freq_scale, | ||||
|         uint  tiitg[[thread_index_in_threadgroup]], | ||||
|         uint3 tptg[[threads_per_threadgroup]], | ||||
|         uint3 tgpig[[threadgroup_position_in_grid]]); | ||||
|  | ||||
| template<typename T> | ||||
| kernel void kernel_rope( | ||||
|         device const    void * src0, | ||||
|         device const int32_t * src1, | ||||
| @@ -901,11 +931,11 @@ kernel void kernel_rope( | ||||
|             const float cos_theta = cos(theta); | ||||
|             const float sin_theta = sin(theta); | ||||
|  | ||||
|             device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||
|             device       float * dst_data  = (device float *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0); | ||||
|             device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||
|             device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0); | ||||
|  | ||||
|             const float x0 = src[0]; | ||||
|             const float x1 = src[1]; | ||||
|             const T x0 = src[0]; | ||||
|             const T x1 = src[1]; | ||||
|  | ||||
|             dst_data[0] = x0*cos_theta - x1*sin_theta; | ||||
|             dst_data[1] = x0*sin_theta + x1*cos_theta; | ||||
| @@ -920,8 +950,8 @@ kernel void kernel_rope( | ||||
|  | ||||
|                 const int64_t i0 = ib*n_dims + ic/2; | ||||
|  | ||||
|                 device const float * const src = (device float *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||
|                 device       float * dst_data  = (device float *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0); | ||||
|                 device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); | ||||
|                 device       T * dst_data  = (device T *)((device char *)  dst + i3*nb3  + i2*nb2  + i1*nb1  + i0*nb0); | ||||
|  | ||||
|                 const float x0 = src[0]; | ||||
|                 const float x1 = src[n_dims/2]; | ||||
| @@ -933,6 +963,9 @@ kernel void kernel_rope( | ||||
|     } | ||||
| } | ||||
|  | ||||
| template [[host_name("kernel_rope_f32")]] kernel rope_t kernel_rope<float>; | ||||
| template [[host_name("kernel_rope_f16")]] kernel rope_t kernel_rope<half>; | ||||
|  | ||||
| kernel void kernel_cpy_f16_f16( | ||||
|         device const half * src0, | ||||
|         device       half * dst, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov