mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	metal : copy kernels for quant to F32/F16 conversions (#12017)
metal: use dequantize_q templates --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		 Gian-Carlo Pascutto
					Gian-Carlo Pascutto
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							34a846b584
						
					
				
				
					commit
					58d07a8043
				
			| @@ -407,6 +407,16 @@ enum ggml_metal_kernel_type { | |||||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, |     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, | ||||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, |     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, | ||||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, |     GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16, | ||||||
|     GGML_METAL_KERNEL_TYPE_CONCAT, |     GGML_METAL_KERNEL_TYPE_CONCAT, | ||||||
|     GGML_METAL_KERNEL_TYPE_SQR, |     GGML_METAL_KERNEL_TYPE_SQR, | ||||||
|     GGML_METAL_KERNEL_TYPE_SQRT, |     GGML_METAL_KERNEL_TYPE_SQRT, | ||||||
| @@ -1012,6 +1022,16 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | |||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32,                  cpy_q4_0_f32,                   true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16,                  cpy_q4_0_f16,                   true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32,                  cpy_q4_1_f32,                   true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16,                  cpy_q4_1_f16,                   true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32,                  cpy_q5_0_f32,                   true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16,                  cpy_q5_0_f16,                   true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32,                  cpy_q5_1_f32,                   true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16,                  cpy_q5_1_f16,                   true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32,                  cpy_q8_0_f32,                   true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16,                  cpy_q8_0_f16,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQRT,                          sqrt,                           true); | ||||||
| @@ -1287,6 +1307,18 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex | |||||||
|                             default: |                             default: | ||||||
|                                 return false; |                                 return false; | ||||||
|                         } |                         } | ||||||
|  |                     case GGML_TYPE_Q4_0: | ||||||
|  |                     case GGML_TYPE_Q4_1: | ||||||
|  |                     case GGML_TYPE_Q5_0: | ||||||
|  |                     case GGML_TYPE_Q5_1: | ||||||
|  |                     case GGML_TYPE_Q8_0: | ||||||
|  |                         switch (op->type) { | ||||||
|  |                             case GGML_TYPE_F32: | ||||||
|  |                             case GGML_TYPE_F16: | ||||||
|  |                                 return true; | ||||||
|  |                             default: | ||||||
|  |                                 return false; | ||||||
|  |                         } | ||||||
|                     default: |                     default: | ||||||
|                         return false; |                         return false; | ||||||
|                 }; |                 }; | ||||||
| @@ -3899,10 +3931,6 @@ static void ggml_metal_encode_node( | |||||||
|         case GGML_OP_CPY: |         case GGML_OP_CPY: | ||||||
|         case GGML_OP_CONT: |         case GGML_OP_CONT: | ||||||
|             { |             { | ||||||
|                 GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); |  | ||||||
|  |  | ||||||
|                 int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); |  | ||||||
|  |  | ||||||
|                 id<MTLComputePipelineState> pipeline = nil; |                 id<MTLComputePipelineState> pipeline = nil; | ||||||
|  |  | ||||||
|                 switch (src0t) { |                 switch (src0t) { | ||||||
| @@ -3936,7 +3964,47 @@ static void ggml_metal_encode_node( | |||||||
|                             switch (dstt) { |                             switch (dstt) { | ||||||
|                                 case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break; |                                 case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_F32].pipeline; break; | ||||||
|                                 case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break; |                                 case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16].pipeline; break; | ||||||
|                                 default: GGML_ASSERT(false && "not implemented"); |                                 default: GGML_ABORT("not implemented"); | ||||||
|  |                             }; | ||||||
|  |                         } break; | ||||||
|  |                     case GGML_TYPE_Q4_0: | ||||||
|  |                         { | ||||||
|  |                             switch (dstt) { | ||||||
|  |                                 case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F32].pipeline; break; | ||||||
|  |                                 case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_0_F16].pipeline; break; | ||||||
|  |                                 default: GGML_ABORT("not implemented"); | ||||||
|  |                             }; | ||||||
|  |                         } break; | ||||||
|  |                     case GGML_TYPE_Q4_1: | ||||||
|  |                         { | ||||||
|  |                             switch (dstt) { | ||||||
|  |                                 case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F32].pipeline; break; | ||||||
|  |                                 case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q4_1_F16].pipeline; break; | ||||||
|  |                                 default: GGML_ABORT("not implemented"); | ||||||
|  |                             }; | ||||||
|  |                         } break; | ||||||
|  |                     case GGML_TYPE_Q5_0: | ||||||
|  |                         { | ||||||
|  |                             switch (dstt) { | ||||||
|  |                                 case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F32].pipeline; break; | ||||||
|  |                                 case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_0_F16].pipeline; break; | ||||||
|  |                                 default: GGML_ABORT("not implemented"); | ||||||
|  |                             }; | ||||||
|  |                         } break; | ||||||
|  |                     case GGML_TYPE_Q5_1: | ||||||
|  |                         { | ||||||
|  |                             switch (dstt) { | ||||||
|  |                                 case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F32].pipeline; break; | ||||||
|  |                                 case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q5_1_F16].pipeline; break; | ||||||
|  |                                 default: GGML_ABORT("not implemented"); | ||||||
|  |                             }; | ||||||
|  |                         } break; | ||||||
|  |                     case GGML_TYPE_Q8_0: | ||||||
|  |                         { | ||||||
|  |                             switch (dstt) { | ||||||
|  |                                 case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F32].pipeline; break; | ||||||
|  |                                 case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_Q8_0_F16].pipeline; break; | ||||||
|  |                                 default: GGML_ABORT("not implemented"); | ||||||
|                             }; |                             }; | ||||||
|                         } break; |                         } break; | ||||||
|                     default: GGML_ABORT("not implemented"); |                     default: GGML_ABORT("not implemented"); | ||||||
| @@ -3966,7 +4034,11 @@ static void ggml_metal_encode_node( | |||||||
|                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; |                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; | ||||||
|                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; |                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; | ||||||
|  |  | ||||||
|  |                 GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); | ||||||
|  |                 int nth = MIN(1024, ne00/ggml_blck_size(src0->type)); | ||||||
|  |  | ||||||
|                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; |                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; | ||||||
|  |  | ||||||
|             } break; |             } break; | ||||||
|         case GGML_OP_SET: |         case GGML_OP_SET: | ||||||
|             { |             { | ||||||
|   | |||||||
| @@ -4341,6 +4341,49 @@ kernel void kernel_cpy_f32_iq4_nl( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)> | ||||||
|  | kernel void kernel_cpy_q_f32( | ||||||
|  |         constant ggml_metal_kargs_cpy & args, | ||||||
|  |         device  const char * src0, | ||||||
|  |         device        char * dst, | ||||||
|  |         uint3   tgpig[[threadgroup_position_in_grid]], | ||||||
|  |         ushort3 tpitg[[thread_position_in_threadgroup]], | ||||||
|  |         ushort3   ntg[[threads_per_threadgroup]]) { | ||||||
|  |     const int i03 = tgpig[2]; | ||||||
|  |     const int i02 = tgpig[1]; | ||||||
|  |     const int i01 = tgpig[0]; | ||||||
|  |  | ||||||
|  |     const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; | ||||||
|  |  | ||||||
|  |     const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); | ||||||
|  |     const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); | ||||||
|  |     const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; | ||||||
|  |     const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); | ||||||
|  |  | ||||||
|  |     device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); | ||||||
|  |     device       T4x4    * dst_data = (device       T4x4    *)(dst  +  i3*args.nb3  +  i2*args.nb2  +  i1*args.nb1 + i0*args.nb0); | ||||||
|  |  | ||||||
|  |     for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { | ||||||
|  |         T4x4 temp; | ||||||
|  |         dequantize_func(src_data + i00/nl, i00%nl, temp); | ||||||
|  |         dst_data[i00] = temp; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | typedef decltype(kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>) cpy_q_f_t; | ||||||
|  |  | ||||||
|  | template [[host_name("kernel_cpy_q4_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_0, 2, dequantize_q4_0>; | ||||||
|  | template [[host_name("kernel_cpy_q4_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q4_1, 2, dequantize_q4_1>; | ||||||
|  | template [[host_name("kernel_cpy_q5_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_0, 2, dequantize_q5_0>; | ||||||
|  | template [[host_name("kernel_cpy_q5_1_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q5_1, 2, dequantize_q5_1>; | ||||||
|  | template [[host_name("kernel_cpy_q8_0_f32")]] kernel cpy_q_f_t kernel_cpy_q_f32<float4x4, block_q8_0, 2, dequantize_q8_0>; | ||||||
|  |  | ||||||
|  | template [[host_name("kernel_cpy_q4_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_0, 2, dequantize_q4_0>; | ||||||
|  | template [[host_name("kernel_cpy_q4_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q4_1, 2, dequantize_q4_1>; | ||||||
|  | template [[host_name("kernel_cpy_q5_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_0, 2, dequantize_q5_0>; | ||||||
|  | template [[host_name("kernel_cpy_q5_1_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q5_1, 2, dequantize_q5_1>; | ||||||
|  | template [[host_name("kernel_cpy_q8_0_f16")]] kernel cpy_q_f_t kernel_cpy_q_f32<half4x4, block_q8_0, 2, dequantize_q8_0>; | ||||||
|  |  | ||||||
| kernel void kernel_concat( | kernel void kernel_concat( | ||||||
|     constant ggml_metal_kargs_concat & args, |     constant ggml_metal_kargs_concat & args, | ||||||
|     device  const char * src0, |     device  const char * src0, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user