mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	metal : minor code formatting
This commit is contained in:
		| @@ -1951,316 +1951,316 @@ static void ggml_metal_encode_node( | ||||
|                         } | ||||
| #endif | ||||
|  | ||||
|                         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs | ||||
|                         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel | ||||
|                         if ([device supportsFamily:MTLGPUFamilyApple7] && | ||||
|                                 !ggml_is_transposed(src0) && | ||||
|                                 !ggml_is_transposed(src1) && | ||||
|                                 src1t == GGML_TYPE_F32 && | ||||
|                                 ne00 % 32 == 0 && ne00 >= 64 && | ||||
|                                 (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { | ||||
|                             //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); | ||||
|                 // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs | ||||
|                 // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel | ||||
|                 if ([device supportsFamily:MTLGPUFamilyApple7] && | ||||
|                         !ggml_is_transposed(src0) && | ||||
|                         !ggml_is_transposed(src1) && | ||||
|                         src1t == GGML_TYPE_F32 && | ||||
|                         ne00 % 32 == 0 && ne00 >= 64 && | ||||
|                         (ne11 > ne11_mm_min || (ggml_is_quantized(src0t) && ne12 > 1))) { | ||||
|                     //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); | ||||
|  | ||||
|                             // some Metal matrix data types require aligned pointers | ||||
|                             // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) | ||||
|                             switch (src0->type) { | ||||
|                                 case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break; | ||||
|                                 case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break; | ||||
|                                 case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break; | ||||
|                                 default: break; | ||||
|                             } | ||||
|                     // some Metal matrix data types require aligned pointers | ||||
|                     // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) | ||||
|                     switch (src0->type) { | ||||
|                         case GGML_TYPE_F32:  GGML_ASSERT(nb01 % 16 == 0); break; | ||||
|                         case GGML_TYPE_F16:  GGML_ASSERT(nb01 % 8  == 0); break; | ||||
|                         case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8  == 0); break; | ||||
|                         default: break; | ||||
|                     } | ||||
|  | ||||
|                             id<MTLComputePipelineState> pipeline = nil; | ||||
|                     id<MTLComputePipelineState> pipeline = nil; | ||||
|  | ||||
|                             switch (src0->type) { | ||||
|                                 case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32    ].pipeline; break; | ||||
|                                 case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32    ].pipeline; break; | ||||
|                                 case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32   ].pipeline; break; | ||||
|                                 case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; | ||||
|                                 case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; | ||||
|                                 case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; | ||||
|                                 case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32  ].pipeline; break; | ||||
|                                 case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32  ].pipeline; break; | ||||
|                                 case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32  ].pipeline; break; | ||||
|                                 case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32  ].pipeline; break; | ||||
|                                 case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; | ||||
|                                 case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; | ||||
|                                 default: GGML_ABORT("MUL MAT-MAT not implemented"); | ||||
|                             } | ||||
|                     switch (src0->type) { | ||||
|                         case GGML_TYPE_F32:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32    ].pipeline; break; | ||||
|                         case GGML_TYPE_F16:     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32    ].pipeline; break; | ||||
|                         case GGML_TYPE_BF16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_BF16_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_Q4_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_Q4_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_Q5_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_Q5_1:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_Q8_0:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_Q2_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_Q3_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_Q4_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_Q5_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_Q6_K:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32   ].pipeline; break; | ||||
|                         case GGML_TYPE_IQ2_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32].pipeline; break; | ||||
|                         case GGML_TYPE_IQ2_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32 ].pipeline; break; | ||||
|                         case GGML_TYPE_IQ3_XXS: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32].pipeline; break; | ||||
|                         case GGML_TYPE_IQ3_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32  ].pipeline; break; | ||||
|                         case GGML_TYPE_IQ2_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32  ].pipeline; break; | ||||
|                         case GGML_TYPE_IQ1_S:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32  ].pipeline; break; | ||||
|                         case GGML_TYPE_IQ1_M:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32  ].pipeline; break; | ||||
|                         case GGML_TYPE_IQ4_NL:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32 ].pipeline; break; | ||||
|                         case GGML_TYPE_IQ4_XS:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32 ].pipeline; break; | ||||
|                         default: GGML_ABORT("MUL MAT-MAT not implemented"); | ||||
|                     } | ||||
|  | ||||
|                             ggml_metal_kargs_mul_mm args = { | ||||
|                                 /*.ne00 =*/ ne00, | ||||
|                                 /*.ne02 =*/ ne02, | ||||
|                                 /*.nb01 =*/ nb01, | ||||
|                                 /*.nb02 =*/ nb02, | ||||
|                                 /*.nb03 =*/ nb03, | ||||
|                                 /*.ne12 =*/ ne12, | ||||
|                                 /*.nb10 =*/ nb10, | ||||
|                                 /*.nb11 =*/ nb11, | ||||
|                                 /*.nb12 =*/ nb12, | ||||
|                                 /*.nb13 =*/ nb13, | ||||
|                                 /*.ne0  =*/ ne0, | ||||
|                                 /*.ne1  =*/ ne1, | ||||
|                                 /*.r2   =*/ r2, | ||||
|                                 /*.r3   =*/ r3, | ||||
|                             }; | ||||
|                     ggml_metal_kargs_mul_mm args = { | ||||
|                         /*.ne00 =*/ ne00, | ||||
|                         /*.ne02 =*/ ne02, | ||||
|                         /*.nb01 =*/ nb01, | ||||
|                         /*.nb02 =*/ nb02, | ||||
|                         /*.nb03 =*/ nb03, | ||||
|                         /*.ne12 =*/ ne12, | ||||
|                         /*.nb10 =*/ nb10, | ||||
|                         /*.nb11 =*/ nb11, | ||||
|                         /*.nb12 =*/ nb12, | ||||
|                         /*.nb13 =*/ nb13, | ||||
|                         /*.ne0  =*/ ne0, | ||||
|                         /*.ne1  =*/ ne1, | ||||
|                         /*.r2   =*/ r2, | ||||
|                         /*.r3   =*/ r3, | ||||
|                     }; | ||||
|  | ||||
|                             [encoder setComputePipelineState:pipeline]; | ||||
|                             [encoder setBytes:&args    length:sizeof(args) atIndex:0]; | ||||
|                             [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1]; | ||||
|                             [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2]; | ||||
|                             [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3]; | ||||
|                     [encoder setComputePipelineState:pipeline]; | ||||
|                     [encoder setBytes:&args    length:sizeof(args) atIndex:0]; | ||||
|                     [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1]; | ||||
|                     [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2]; | ||||
|                     [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3]; | ||||
|  | ||||
|                             [encoder setThreadgroupMemoryLength:8192 atIndex:0]; | ||||
|                             [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; | ||||
|                         } else { | ||||
|                             int nth0 = 32; | ||||
|                             int nth1 = 1; | ||||
|                             int nrows = 1; | ||||
|                             //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); | ||||
|                     [encoder setThreadgroupMemoryLength:8192 atIndex:0]; | ||||
|                     [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; | ||||
|                 } else { | ||||
|                     int nth0 = 32; | ||||
|                     int nth1 = 1; | ||||
|                     int nrows = 1; | ||||
|                     //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); | ||||
|  | ||||
|                             id<MTLComputePipelineState> pipeline = nil; | ||||
|                     id<MTLComputePipelineState> pipeline = nil; | ||||
|  | ||||
|                             // use custom matrix x vector kernel | ||||
|                             switch (src0t) { | ||||
|                                 case GGML_TYPE_F32: | ||||
|                                     { | ||||
|                                         GGML_ASSERT(src1t == GGML_TYPE_F32); | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; | ||||
|                     // use custom matrix x vector kernel | ||||
|                     switch (src0t) { | ||||
|                         case GGML_TYPE_F32: | ||||
|                             { | ||||
|                                 GGML_ASSERT(src1t == GGML_TYPE_F32); | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32].pipeline; | ||||
|                                 nrows = 4; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_F16: | ||||
|                             { | ||||
|                                 nth0 = 32; | ||||
|                                 nth1 = 1; | ||||
|                                 if (src1t == GGML_TYPE_F32) { | ||||
|                                     if (ne11 * ne12 < 4) { | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; | ||||
|                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; | ||||
|                                         nrows = ne11; | ||||
|                                     } else { | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; | ||||
|                                         nrows = 4; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_F16: | ||||
|                                     { | ||||
|                                         nth0 = 32; | ||||
|                                         nth1 = 1; | ||||
|                                         if (src1t == GGML_TYPE_F32) { | ||||
|                                             if (ne11 * ne12 < 4) { | ||||
|                                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW].pipeline; | ||||
|                                             } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { | ||||
|                                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4].pipeline; | ||||
|                                                 nrows = ne11; | ||||
|                                             } else { | ||||
|                                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32].pipeline; | ||||
|                                                 nrows = 4; | ||||
|                                             } | ||||
|                                         } else { | ||||
|                                             pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; | ||||
|                                             nrows = 4; | ||||
|                                         } | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_BF16: | ||||
|                                     { | ||||
|                                         nth0 = 32; | ||||
|                                         nth1 = 1; | ||||
|                                         if (src1t == GGML_TYPE_F32) { | ||||
|                                             if (ne11 * ne12 < 4) { | ||||
|                                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; | ||||
|                                             } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { | ||||
|                                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; | ||||
|                                                 nrows = ne11; | ||||
|                                             } else { | ||||
|                                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; | ||||
|                                                 nrows = 4; | ||||
|                                             } | ||||
|                                         } else { | ||||
|                                             pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; | ||||
|                                             nrows = 4; | ||||
|                                         } | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_Q4_0: | ||||
|                                     { | ||||
|                                         nth0 = 8; | ||||
|                                         nth1 = 8; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_Q4_1: | ||||
|                                     { | ||||
|                                         nth0 = 8; | ||||
|                                         nth1 = 8; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_Q5_0: | ||||
|                                     { | ||||
|                                         nth0 = 8; | ||||
|                                         nth1 = 8; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_Q5_1: | ||||
|                                     { | ||||
|                                         nth0 = 8; | ||||
|                                         nth1 = 8; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_Q8_0: | ||||
|                                     { | ||||
|                                         nth0 = 8; | ||||
|                                         nth1 = 8; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_Q2_K: | ||||
|                                     { | ||||
|                                         nth0 = 2; | ||||
|                                         nth1 = 32; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_Q3_K: | ||||
|                                     { | ||||
|                                         nth0 = 2; | ||||
|                                         nth1 = 32; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_Q4_K: | ||||
|                                     { | ||||
|                                         nth0 = 4; //1; | ||||
|                                         nth1 = 8; //32; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_Q5_K: | ||||
|                                     { | ||||
|                                         nth0 = 2; | ||||
|                                         nth1 = 32; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_Q6_K: | ||||
|                                     { | ||||
|                                         nth0 = 2; | ||||
|                                         nth1 = 32; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_IQ2_XXS: | ||||
|                                     { | ||||
|                                         nth0 = 4; | ||||
|                                         nth1 = 16; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_IQ2_XS: | ||||
|                                     { | ||||
|                                         nth0 = 4; | ||||
|                                         nth1 = 16; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_IQ3_XXS: | ||||
|                                     { | ||||
|                                         nth0 = 4; | ||||
|                                         nth1 = 16; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_IQ3_S: | ||||
|                                     { | ||||
|                                         nth0 = 4; | ||||
|                                         nth1 = 16; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_IQ2_S: | ||||
|                                     { | ||||
|                                         nth0 = 4; | ||||
|                                         nth1 = 16; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_IQ1_S: | ||||
|                                     { | ||||
|                                         nth0 = 4; | ||||
|                                         nth1 = 16; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_IQ1_M: | ||||
|                                     { | ||||
|                                         nth0 = 4; | ||||
|                                         nth1 = 16; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_IQ4_NL: | ||||
|                                     { | ||||
|                                         nth0 = 4; | ||||
|                                         nth1 = 16; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 case GGML_TYPE_IQ4_XS: | ||||
|                                     { | ||||
|                                         nth0 = 4; | ||||
|                                         nth1 = 16; | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; | ||||
|                                     } break; | ||||
|                                 default: | ||||
|                                     { | ||||
|                                         GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t); | ||||
|                                         GGML_ABORT("not implemented"); | ||||
|                                     } | ||||
|                             }; | ||||
|                                 } else { | ||||
|                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16].pipeline; | ||||
|                                     nrows = 4; | ||||
|                                 } | ||||
|                             } break; | ||||
|                         case GGML_TYPE_BF16: | ||||
|                             { | ||||
|                                 nth0 = 32; | ||||
|                                 nth1 = 1; | ||||
|                                 if (src1t == GGML_TYPE_F32) { | ||||
|                                     if (ne11 * ne12 < 4) { | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_1ROW].pipeline; | ||||
|                                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32_L4].pipeline; | ||||
|                                         nrows = ne11; | ||||
|                                     } else { | ||||
|                                         pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_F32].pipeline; | ||||
|                                         nrows = 4; | ||||
|                                     } | ||||
|                                 } else { | ||||
|                                     pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_BF16_BF16].pipeline; | ||||
|                                     nrows = 4; | ||||
|                                 } | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q4_0: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q4_1: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q5_0: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q5_1: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q8_0: | ||||
|                             { | ||||
|                                 nth0 = 8; | ||||
|                                 nth1 = 8; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q2_K: | ||||
|                             { | ||||
|                                 nth0 = 2; | ||||
|                                 nth1 = 32; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q3_K: | ||||
|                             { | ||||
|                                 nth0 = 2; | ||||
|                                 nth1 = 32; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q4_K: | ||||
|                             { | ||||
|                                 nth0 = 4; //1; | ||||
|                                 nth1 = 8; //32; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q5_K: | ||||
|                             { | ||||
|                                 nth0 = 2; | ||||
|                                 nth1 = 32; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_Q6_K: | ||||
|                             { | ||||
|                                 nth0 = 2; | ||||
|                                 nth1 = 32; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ2_XXS: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ2_XS: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ3_XXS: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ3_S: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ2_S: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ1_S: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ1_M: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ4_NL: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32].pipeline; | ||||
|                             } break; | ||||
|                         case GGML_TYPE_IQ4_XS: | ||||
|                             { | ||||
|                                 nth0 = 4; | ||||
|                                 nth1 = 16; | ||||
|                                 pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32].pipeline; | ||||
|                             } break; | ||||
|                         default: | ||||
|                             { | ||||
|                                 GGML_LOG_ERROR("Asserting on type %d\n", (int)src0t); | ||||
|                                 GGML_ABORT("not implemented"); | ||||
|                             } | ||||
|                     }; | ||||
|  | ||||
|                             ggml_metal_kargs_mul_mv args = { | ||||
|                                 /*.ne00 =*/ ne00, | ||||
|                                 /*.ne01 =*/ ne01, | ||||
|                                 /*.ne02 =*/ ne02, | ||||
|                                 /*.nb00 =*/ nb00, | ||||
|                                 /*.nb01 =*/ nb01, | ||||
|                                 /*.nb02 =*/ nb02, | ||||
|                                 /*.nb03 =*/ nb03, | ||||
|                                 /*.ne10 =*/ ne10, | ||||
|                                 /*.ne11 =*/ ne11, | ||||
|                                 /*.ne12 =*/ ne12, | ||||
|                                 /*.nb10 =*/ nb10, | ||||
|                                 /*.nb11 =*/ nb11, | ||||
|                                 /*.nb12 =*/ nb12, | ||||
|                                 /*.nb13 =*/ nb13, | ||||
|                                 /*.ne0  =*/ ne0, | ||||
|                                 /*.ne1  =*/ ne1, | ||||
|                                 /*.r2   =*/ r2, | ||||
|                                 /*.r3   =*/ r3, | ||||
|                             }; | ||||
|                     ggml_metal_kargs_mul_mv args = { | ||||
|                         /*.ne00 =*/ ne00, | ||||
|                         /*.ne01 =*/ ne01, | ||||
|                         /*.ne02 =*/ ne02, | ||||
|                         /*.nb00 =*/ nb00, | ||||
|                         /*.nb01 =*/ nb01, | ||||
|                         /*.nb02 =*/ nb02, | ||||
|                         /*.nb03 =*/ nb03, | ||||
|                         /*.ne10 =*/ ne10, | ||||
|                         /*.ne11 =*/ ne11, | ||||
|                         /*.ne12 =*/ ne12, | ||||
|                         /*.nb10 =*/ nb10, | ||||
|                         /*.nb11 =*/ nb11, | ||||
|                         /*.nb12 =*/ nb12, | ||||
|                         /*.nb13 =*/ nb13, | ||||
|                         /*.ne0  =*/ ne0, | ||||
|                         /*.ne1  =*/ ne1, | ||||
|                         /*.r2   =*/ r2, | ||||
|                         /*.r3   =*/ r3, | ||||
|                     }; | ||||
|  | ||||
|                             [encoder setComputePipelineState:pipeline]; | ||||
|                             [encoder setBytes:&args length:sizeof(args) atIndex:0]; | ||||
|                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; | ||||
|                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; | ||||
|                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3]; | ||||
|                     [encoder setComputePipelineState:pipeline]; | ||||
|                     [encoder setBytes:&args length:sizeof(args) atIndex:0]; | ||||
|                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; | ||||
|                     [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; | ||||
|                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:3]; | ||||
|  | ||||
|                             if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 || | ||||
|                                 src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K || | ||||
|                                 src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { | ||||
|                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                             } | ||||
|                             else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { | ||||
|                                 const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; | ||||
|                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; | ||||
|                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                             } | ||||
|                             else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { | ||||
|                                 const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; | ||||
|                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; | ||||
|                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                             } | ||||
|                             else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { | ||||
|                                 const int mem_size = 32*sizeof(float); | ||||
|                                 [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; | ||||
|                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                             } | ||||
|                             else if (src0t == GGML_TYPE_Q4_K) { | ||||
|                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                             } | ||||
|                             else if (src0t == GGML_TYPE_Q3_K) { | ||||
|                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                             } | ||||
|                             else if (src0t == GGML_TYPE_Q5_K) { | ||||
|                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                             } | ||||
|                             else if (src0t == GGML_TYPE_Q6_K) { | ||||
|                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                             } else { | ||||
|                                 const int64_t ny = (ne11 + nrows - 1)/nrows; | ||||
|                                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                             } | ||||
|                         } | ||||
|                     if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 || | ||||
|                         src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K || | ||||
|                         src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) { | ||||
|                         const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128; | ||||
|                         [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) { | ||||
|                         const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4; | ||||
|                         [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) { | ||||
|                         const int mem_size = 32*sizeof(float); | ||||
|                         [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_Q4_K) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_Q3_K) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_Q5_K) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                     else if (src0t == GGML_TYPE_Q6_K) { | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } else { | ||||
|                         const int64_t ny = (ne11 + nrows - 1)/nrows; | ||||
|                         [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; | ||||
|                     } | ||||
|                 } | ||||
|             } break; | ||||
|         case GGML_OP_MUL_MAT_ID: | ||||
|             { | ||||
|   | ||||
| @@ -5447,12 +5447,12 @@ kernel void kernel_mul_mm( | ||||
|     const int im = tgpig.z; | ||||
|  | ||||
|     // if this block is of 64x32 shape or smaller | ||||
|     short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; | ||||
|     short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; | ||||
|     const short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; | ||||
|     const short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; | ||||
|  | ||||
|     // a thread shouldn't load data outside of the matrix | ||||
|     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; | ||||
|     short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; | ||||
|     const short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; | ||||
|     const short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; | ||||
|  | ||||
|     simdgroup_T8x8     ma[4]; | ||||
|     simdgroup_float8x8 mb[2]; | ||||
| @@ -5467,20 +5467,23 @@ kernel void kernel_mul_mm( | ||||
|     const int i12 = im%args.ne12; | ||||
|     const int i13 = im/args.ne12; | ||||
|  | ||||
|     uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||||
|     short    offset1 = il/nl; | ||||
|     const uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||||
|     const short    offset1 = il/nl; | ||||
|  | ||||
|     device const block_q * x = (device const block_q *)(src0 | ||||
|         + args.nb01*(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1; | ||||
|  | ||||
|     device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1; | ||||
|     device const float   * y = (device const float   *)(src1 | ||||
|         + args.nb13*i13 | ||||
|         + args.nb12*i12 | ||||
|         + args.nb11*(r1 * BLOCK_SIZE_N + thread_col) | ||||
|         + args.nb11*(r1*BLOCK_SIZE_N + thread_col) | ||||
|         + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); | ||||
|  | ||||
|     for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { | ||||
|         // load data and store to threadgroup memory | ||||
|         T4x4 temp_a; | ||||
|         dequantize_func(x, il, temp_a); | ||||
|  | ||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|         #pragma unroll(16) | ||||
| @@ -5490,44 +5493,46 @@ kernel void kernel_mul_mm( | ||||
|             +                     (tiitg/THREAD_PER_ROW)%8  + (i&7)*8) = temp_a[i/4][i%4]; | ||||
|         } | ||||
|  | ||||
|         *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL)*8*32 + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y); | ||||
|         *(threadgroup float2x4 *)(sb + 32*8*(tiitg%THREAD_PER_COL) + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y); | ||||
|  | ||||
|         il = (il + 2 < nl) ? il + 2 : il % 2; | ||||
|         x  = (il < 2) ? x + (2+nl-1)/nl : x; | ||||
|         x  = (il < 2) ? x + (2 + nl - 1)/nl : x; | ||||
|         y += BLOCK_SIZE_K; | ||||
|  | ||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|         // load matrices from threadgroup memory and conduct outer products | ||||
|         threadgroup T     * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); | ||||
|         threadgroup float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); | ||||
|         threadgroup const T     * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2)); | ||||
|         threadgroup const float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2)); | ||||
|  | ||||
|         #pragma unroll(4) | ||||
|         for (short ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { | ||||
|         for (short ik = 0; ik < BLOCK_SIZE_K/8; ik++) { | ||||
|             #pragma unroll(4) | ||||
|             for (short i = 0; i < 4; i++) { | ||||
|                 simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); | ||||
|             } | ||||
|  | ||||
|             simdgroup_barrier(mem_flags::mem_none); | ||||
|  | ||||
|             #pragma unroll(2) | ||||
|             for (short i = 0; i < 2; i++) { | ||||
|                 simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); | ||||
|             } | ||||
|  | ||||
|             lsma += BLOCK_SIZE_M/SG_MAT_ROW * SG_MAT_SIZE; | ||||
|             lsmb += BLOCK_SIZE_N/SG_MAT_ROW * SG_MAT_SIZE; | ||||
|  | ||||
|             #pragma unroll(8) | ||||
|             for (short i = 0; i < 8; i++){ | ||||
|                 simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); | ||||
|             } | ||||
|  | ||||
|             lsma += (BLOCK_SIZE_M/SG_MAT_ROW)*SG_MAT_SIZE; | ||||
|             lsmb += (BLOCK_SIZE_N/SG_MAT_ROW)*SG_MAT_SIZE; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) { | ||||
|         device float * C = (device float *) dst + | ||||
|             (BLOCK_SIZE_M * r0 + 32 * (sgitg &  1)) + \ | ||||
|             (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; | ||||
|             (BLOCK_SIZE_M * r0 + 32*(sgitg &  1)) + \ | ||||
|             (BLOCK_SIZE_N * r1 + 16*(sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; | ||||
|  | ||||
|         for (short i = 0; i < 8; i++) { | ||||
|             simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0); | ||||
| @@ -5536,7 +5541,7 @@ kernel void kernel_mul_mm( | ||||
|         // block is smaller than 64x32, we should avoid writing data outside of the matrix | ||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|         threadgroup float * temp_str = ((threadgroup float *) shmem) \ | ||||
|                                       + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M; | ||||
|                                      + 32*(sgitg&1) + (16*(sgitg >> 1))*BLOCK_SIZE_M; | ||||
|         for (short i = 0; i < 8; i++) { | ||||
|             simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); | ||||
|         } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov