mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : refactor mat-vec code (#12569)
* metal : refactor mat-vec code ggml-ci * metal : rename all_sum -> sum_all ggml-ci * metal : fix comments [no ci] * metal : fix nr constant [no ci] * metal : mv q6_K support nr0 > 1 ggml-ci * metal : reduce register pressure ggml-ci * metal : fix typo [no ci] * metal : reduce register pressure ggml-ci
This commit is contained in:
		| @@ -1,6 +1,70 @@ | ||||
| #ifndef GGML_METAL_IMPL | ||||
| #define GGML_METAL_IMPL | ||||
|  | ||||
| // kernel parameters for mat-vec threadgroups | ||||
| // | ||||
| // N_R0: number of src0 rows to process per simdgroup | ||||
| // N_SG: number of simdgroups per threadgroup | ||||
| // | ||||
| // TODO: for optimal performance, become function of the device and work size | ||||
|  | ||||
| #define N_R0_Q4_0 4 | ||||
| #define N_SG_Q4_0 2 | ||||
|  | ||||
| #define N_R0_Q4_1 4 | ||||
| #define N_SG_Q4_1 2 | ||||
|  | ||||
| #define N_R0_Q5_0 4 | ||||
| #define N_SG_Q5_0 2 | ||||
|  | ||||
| #define N_R0_Q5_1 4 | ||||
| #define N_SG_Q5_1 2 | ||||
|  | ||||
| #define N_R0_Q8_0 4 | ||||
| #define N_SG_Q8_0 2 | ||||
|  | ||||
| #define N_R0_Q2_K 4 | ||||
| #define N_SG_Q2_K 2 | ||||
|  | ||||
| #define N_R0_Q3_K 2 | ||||
| #define N_SG_Q3_K 2 | ||||
|  | ||||
| #define N_R0_Q4_K 4 | ||||
| #define N_SG_Q4_K 2 | ||||
|  | ||||
| #define N_R0_Q5_K 2 | ||||
| #define N_SG_Q5_K 2 | ||||
|  | ||||
| #define N_R0_Q6_K 1 | ||||
| #define N_SG_Q6_K 2 | ||||
|  | ||||
| #define N_R0_IQ1_S 4 | ||||
| #define N_SG_IQ1_S 2 | ||||
|  | ||||
| #define N_R0_IQ1_M 4 | ||||
| #define N_SG_IQ1_M 2 | ||||
|  | ||||
| #define N_R0_IQ2_XXS 4 | ||||
| #define N_SG_IQ2_XXS 2 | ||||
|  | ||||
| #define N_R0_IQ2_XS 4 | ||||
| #define N_SG_IQ2_XS 2 | ||||
|  | ||||
| #define N_R0_IQ2_S 4 | ||||
| #define N_SG_IQ2_S 2 | ||||
|  | ||||
| #define N_R0_IQ3_XXS 4 | ||||
| #define N_SG_IQ3_XXS 2 | ||||
|  | ||||
| #define N_R0_IQ3_S 4 | ||||
| #define N_SG_IQ3_S 2 | ||||
|  | ||||
| #define N_R0_IQ4_NL 2 | ||||
| #define N_SG_IQ4_NL 2 | ||||
|  | ||||
| #define N_R0_IQ4_XS 2 | ||||
| #define N_SG_IQ4_XS 2 | ||||
|  | ||||
| // kernel argument structs | ||||
| // | ||||
| // - element counters (e.g. ne00) typically use int32_t to reduce register usage | ||||
|   | ||||
| @@ -2561,171 +2561,180 @@ static void ggml_metal_encode_node( | ||||
|                     [encoder setThreadgroupMemoryLength:8192 atIndex:0]; | ||||
|                     [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; | ||||
|                 } else { | ||||
|                     int nth0 = 32; | ||||
|                     int nth1 = 1; | ||||
|                     int nrows = 1; | ||||
|                     //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); | ||||
|  | ||||
|                     id<MTLComputePipelineState> pipeline = nil; | ||||
|  | ||||
|                     int nsg = 0; // number of simdgroups | ||||
|                     int nr0 = 0; // number of src0 rows per simdgroup | ||||
|                     int nr1 = 1; // number of src1 rows per threadgroup | ||||
|  | ||||
|                     size_t smem = 0; // shared memory | ||||
|  | ||||
|                     // use custom matrix x vector kernel | ||||
|                     switch (src0t) { | ||||
|                         case GGML_TYPE_F32: | ||||
|                             { | ||||
|                                 GGML_ASSERT(src1t == GGML_TYPE_F32); | ||||
|                                 nsg = 1; | ||||
|                                 nr0 = 1; | ||||
|                                 nr1 = 4; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; | ||||
|                                 nrows = 4; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_F16: | ||||
|                             { | ||||
|                                 nth0 = 32; | ||||
|                                 nth1 = 1; | ||||
|                                 nsg = 1; | ||||
|                                 nr0 = 1; | ||||
|                                 if (src1t == GGML_TYPE_F32) { | ||||
|                                     if (ne11 * ne12 < 4) { | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; | ||||
|                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; | ||||
|                                         nrows = ne11; | ||||
|                                         nr1 = ne11; | ||||
|                                     } else { | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; | ||||
|                                         nrows = 4; | ||||
|                                         nr1 = 4; | ||||
|                                     } | ||||
|                                 } else { | ||||
|                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; | ||||
|                                     nrows = 4; | ||||
|                                     nr1 = 4; | ||||
|                                 } | ||||
|                             } break; | ||||
|                         case GGML_TYPE_BF16: | ||||
|                             { | ||||
|                                 nth0 = 32; | ||||
|                                 nth1 = 1; | ||||
|                                 nsg = 1; | ||||
|                                 nr0 = 1; | ||||
|                                 if (src1t == GGML_TYPE_F32) { | ||||
|                                     if (ne11 * ne12 < 4) { | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; | ||||
|                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; | ||||
|                                         nrows = ne11; | ||||
|                                         nr1 = ne11; | ||||
|                                     } else { | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; | ||||
|                                         nrows = 4; | ||||
|                                         nr1 = 4; | ||||
|                                     } | ||||
|                                 } else { | ||||
|                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; | ||||
|                                     nrows = 4; | ||||
|                                     nr1 = 4; | ||||
|                                 } | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q4_0: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 nsg = N_SG_Q4_0; | ||||
|                                 nr0 = N_R0_Q4_0; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q4_1: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 nsg = N_SG_Q4_1; | ||||
|                                 nr0 = N_R0_Q4_1; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q5_0: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 nsg = N_SG_Q5_0; | ||||
|                                 nr0 = N_R0_Q5_0; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q5_1: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 nsg = N_SG_Q5_1; | ||||
|                                 nr0 = N_R0_Q5_1; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q8_0: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 nsg = N_SG_Q8_0; | ||||
|                                 nr0 = N_R0_Q8_0; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q2_K: | ||||
|                             { | ||||
|                                 nth0 = 2; | ||||
|                                 nth1 = 32; | ||||
|                                 nsg = N_SG_Q2_K; | ||||
|                                 nr0 = N_R0_Q2_K; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q3_K: | ||||
|                             { | ||||
|                                 nth0 = 2; | ||||
|                                 nth1 = 32; | ||||
|                                 nsg = N_SG_Q3_K; | ||||
|                                 nr0 = N_R0_Q3_K; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q4_K: | ||||
|                             { | ||||
|                                 nth0 = 4; //1; | ||||
|                                 nth1 = 8; //32; | ||||
|                                 nsg = N_SG_Q4_K; | ||||
|                                 nr0 = N_R0_Q4_K; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q5_K: | ||||
|                             { | ||||
|                                 nth0 = 2; | ||||
|                                 nth1 = 32; | ||||
|                                 nsg = N_SG_Q5_K; | ||||
|                                 nr0 = N_R0_Q5_K; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q6_K: | ||||
|                             { | ||||
|                                 nth0 = 2; | ||||
|                                 nth1 = 32; | ||||
|                                 nsg = N_SG_Q6_K; | ||||
|                                 nr0 = N_R0_Q6_K; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ2_XXS: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ2_XXS; | ||||
|                                 nr0 = N_R0_IQ2_XXS; | ||||
|                                 smem = 256*8+128; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ2_XS: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ2_XS; | ||||
|                                 nr0 = N_R0_IQ2_XS; | ||||
|                                 smem = 512*8+128; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ3_XXS: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ3_XXS; | ||||
|                                 nr0 = N_R0_IQ3_XXS; | ||||
|                                 smem = 256*4+128; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ3_S: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ3_S; | ||||
|                                 nr0 = N_R0_IQ3_S; | ||||
|                                 smem = 512*4; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ2_S: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ2_S; | ||||
|                                 nr0 = N_R0_IQ2_S; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ1_S: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ1_S; | ||||
|                                 nr0 = N_R0_IQ1_S; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ1_M: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ1_M; | ||||
|                                 nr0 = N_R0_IQ1_M; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ4_NL: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ4_NL; | ||||
|                                 nr0 = N_R0_IQ4_NL; | ||||
|                                 smem = 32*sizeof(float); | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ4_XS: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ4_XS; | ||||
|                                 nr0 = N_R0_IQ4_XS; | ||||
|                                 smem = 32*sizeof(float); | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; | ||||
|                             } break; | ||||
|                         default: | ||||
| @@ -2762,41 +2771,10 @@ static void ggml_metal_encode_node( | ||||
|                     [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; | ||||
|                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3]; | ||||
|  | ||||
|                     if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 || | ||||
|                         src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K || | ||||
|                         src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { | ||||
|                         const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; | ||||
|                         [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { | ||||
|                         const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; | ||||
|                         [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { | ||||
|                         const int mem_size = 32*sizeof(float); | ||||
|                         [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_Q4_K) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_Q3_K) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_Q5_K) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_Q6_K) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } else { | ||||
|                         const int64_t ny = (ne11 + nrows - 1)/nrows; | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     if (smem > 0) { | ||||
|                         [encoder setThreadgroupMemoryLength:smem atIndex:0]; | ||||
|                     } | ||||
|                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (ne11 + nr1 - 1)/nr1, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; | ||||
|                 } | ||||
|             } break; | ||||
|         case GGML_OP_MUL_MAT_ID: | ||||
| @@ -2902,146 +2880,155 @@ static void ggml_metal_encode_node( | ||||
|  | ||||
|                     [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; | ||||
|                 } else { | ||||
|                     int nth0 = 32; | ||||
|                     int nth1 = 1; | ||||
|                     int nrows = 1; | ||||
|                     //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); | ||||
|  | ||||
|                     id<MTLComputePipelineState> pipeline = nil; | ||||
|  | ||||
|                     int nsg = 0; // number of simdgroups | ||||
|                     int nr0 = 0; // number of src0 rows per simdgroup | ||||
|                     int nr1 = 1; // number of src1 rows per threadgroup | ||||
|  | ||||
|                     size_t smem = 0; // shared memory | ||||
|  | ||||
|                     // use custom matrix x vector kernel | ||||
|                     switch (src0t) { | ||||
|                         case GGML_TYPE_F32: | ||||
|                             { | ||||
|                                 GGML_ASSERT(src1t == GGML_TYPE_F32); | ||||
|                                 nsg = 1; | ||||
|                                 nr0 = 1; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_F16: | ||||
|                             { | ||||
|                                 GGML_ASSERT(src1t == GGML_TYPE_F32); | ||||
|                                 nth0 = 32; | ||||
|                                 nth1 = 1; | ||||
|                                 nsg = 1; | ||||
|                                 nr0 = 1; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_BF16: | ||||
|                             { | ||||
|                                 GGML_ASSERT(src1t == GGML_TYPE_F32); | ||||
|                                 nth0 = 32; | ||||
|                                 nth1 = 1; | ||||
|                                 nsg = 1; | ||||
|                                 nr0 = 1; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_BF16_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q4_0: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 nsg = N_SG_Q4_0; | ||||
|                                 nr0 = N_R0_Q4_0; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q4_1: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 nsg = N_SG_Q4_1; | ||||
|                                 nr0 = N_R0_Q4_1; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q5_0: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 nsg = N_SG_Q5_0; | ||||
|                                 nr0 = N_R0_Q5_0; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q5_1: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 nsg = N_SG_Q5_1; | ||||
|                                 nr0 = N_R0_Q5_1; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q8_0: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 nsg = N_SG_Q8_0; | ||||
|                                 nr0 = N_R0_Q8_0; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q2_K: | ||||
|                             { | ||||
|                                 nth0 = 2; | ||||
|                                 nth1 = 32; | ||||
|                                 nsg = N_SG_Q2_K; | ||||
|                                 nr0 = N_R0_Q2_K; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q3_K: | ||||
|                             { | ||||
|                                 nth0 = 2; | ||||
|                                 nth1 = 32; | ||||
|                                 nsg = N_SG_Q3_K; | ||||
|                                 nr0 = N_R0_Q3_K; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q4_K: | ||||
|                             { | ||||
|                                 nth0 = 4; //1; | ||||
|                                 nth1 = 8; //32; | ||||
|                                 nsg = N_SG_Q4_K; | ||||
|                                 nr0 = N_R0_Q4_K; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q5_K: | ||||
|                             { | ||||
|                                 nth0 = 2; | ||||
|                                 nth1 = 32; | ||||
|                                 nsg = N_SG_Q5_K; | ||||
|                                 nr0 = N_R0_Q5_K; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q6_K: | ||||
|                             { | ||||
|                                 nth0 = 2; | ||||
|                                 nth1 = 32; | ||||
|                                 nsg = N_SG_Q6_K; | ||||
|                                 nr0 = N_R0_Q6_K; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ2_XXS: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ2_XXS; | ||||
|                                 nr0 = N_R0_IQ2_XXS; | ||||
|                                 smem = 256*8+128; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ2_XS: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ2_XS; | ||||
|                                 nr0 = N_R0_IQ2_XS; | ||||
|                                 smem = 512*8+128; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ3_XXS: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ3_XXS; | ||||
|                                 nr0 = N_R0_IQ3_XXS; | ||||
|                                 smem = 256*4+128; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ3_S: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ3_S; | ||||
|                                 nr0 = N_R0_IQ3_S; | ||||
|                                 smem = 512*4; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ2_S: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ2_S; | ||||
|                                 nr0 = N_R0_IQ2_S; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ1_S: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ1_S; | ||||
|                                 nr0 = N_R0_IQ1_S; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ1_M: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ1_M; | ||||
|                                 nr0 = N_R0_IQ1_M; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ4_NL: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ4_NL; | ||||
|                                 nr0 = N_R0_IQ4_NL; | ||||
|                                 smem = 32*sizeof(float); | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ4_XS: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 nsg = N_SG_IQ4_XS; | ||||
|                                 nr0 = N_R0_IQ4_XS; | ||||
|                                 smem = 32*sizeof(float); | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32].pipeline; | ||||
|                             } break; | ||||
|                         default: | ||||
| @@ -3052,7 +3039,7 @@ static void ggml_metal_encode_node( | ||||
|                     }; | ||||
|  | ||||
|                     if (ggml_is_quantized(src0t)) { | ||||
|                         GGML_ASSERT(ne00 >= nth0*nth1); | ||||
|                         GGML_ASSERT(ne00 >= nsg*nr0); | ||||
|                     } | ||||
|  | ||||
|                     ggml_metal_kargs_mul_mv_id args = { | ||||
| @@ -3085,43 +3072,12 @@ static void ggml_metal_encode_node( | ||||
|                     [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; | ||||
|  | ||||
|                     const int64_t _ne1 = 1; | ||||
|                     const int tgz = dst_rows; | ||||
|                     const int64_t ne123 = dst_rows; | ||||
|  | ||||
|                     if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 || | ||||
|                             src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K || | ||||
|                             src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { | ||||
|                         const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; | ||||
|                         [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { | ||||
|                         const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; | ||||
|                         [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { | ||||
|                         const int mem_size = 32*sizeof(float); | ||||
|                         [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_Q4_K) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_Q3_K) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_Q5_K) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_Q6_K) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } else { | ||||
|                         const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1 | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     if (smem > 0) { | ||||
|                         [encoder setThreadgroupMemoryLength:smem atIndex:0]; | ||||
|                     } | ||||
|                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; | ||||
|                 } | ||||
|             } break; | ||||
|         case GGML_OP_GET_ROWS: | ||||
|   | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov