mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : improve F32, F16 and BF16 mat-vec multiplication
ggml-ci
This commit is contained in:
		| @@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) { | |||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) { | void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) { | ||||||
|  |     if (!ppls) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) { |     for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) { | ||||||
|         ggml_metal_pipeline_free(it->second); |         ggml_metal_pipeline_free(it->second); | ||||||
|     } |     } | ||||||
| @@ -467,37 +471,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_ | |||||||
|     // use custom matrix x vector kernel |     // use custom matrix x vector kernel | ||||||
|     switch (tsrc0) { |     switch (tsrc0) { | ||||||
|         case GGML_TYPE_F32: |         case GGML_TYPE_F32: | ||||||
|             { |  | ||||||
|                 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); |  | ||||||
|  |  | ||||||
|                 nsg = 1; |  | ||||||
|                 nr0 = 1; |  | ||||||
|                 nr1 = 4; |  | ||||||
|                 if (ne00 == 4) { |  | ||||||
|                     nr0 = 32; |  | ||||||
|                     suffix = "_c4"; |  | ||||||
|                 } |  | ||||||
|             } break; |  | ||||||
|         case GGML_TYPE_F16: |         case GGML_TYPE_F16: | ||||||
|         case GGML_TYPE_BF16: |         case GGML_TYPE_BF16: | ||||||
|             { |             { | ||||||
|                 nsg = 1; |                 if (ne00 == 4) { | ||||||
|                 nr0 = 1; |                     nsg = 1; | ||||||
|                 if (op->src[1]->type == GGML_TYPE_F32) { |                     nr0 = 32; | ||||||
|                     if (ne00 == 4) { |  | ||||||
|                         nr0 = 32; |  | ||||||
|                         nr1 = 4; |  | ||||||
|                         suffix = "_c4"; |  | ||||||
|                     } else if (ne11 * ne12 < 4) { |  | ||||||
|                         suffix = "_1row"; |  | ||||||
|                     } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { |  | ||||||
|                         suffix = "_l4"; |  | ||||||
|                         nr1 = ne11; |  | ||||||
|                     } else { |  | ||||||
|                         nr1 = 4; |  | ||||||
|                     } |  | ||||||
|                 } else { |  | ||||||
|                     nr1 = 4; |                     nr1 = 4; | ||||||
|  |                     suffix = "_c4"; | ||||||
|  |                 } else if (ne00 % 4 == 0) { | ||||||
|  |                     nsg = N_SG_F; | ||||||
|  |                     nr0 = N_R0_F; | ||||||
|  |                     nr1 = 1; | ||||||
|  |                     smem = 32*sizeof(float)*N_R0_F; | ||||||
|  |                     suffix = "_4"; | ||||||
|  |                 } else { | ||||||
|  |                     nsg = N_SG_F; | ||||||
|  |                     nr0 = N_R0_F; | ||||||
|  |                     nr1 = 1; | ||||||
|  |                     smem = 32*sizeof(float)*N_R0_F; | ||||||
|                 } |                 } | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
| @@ -689,25 +681,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra | |||||||
|     const ggml_type tsrc0 = op->src[0]->type; |     const ggml_type tsrc0 = op->src[0]->type; | ||||||
|     const ggml_type tsrc1 = op->src[1]->type; |     const ggml_type tsrc1 = op->src[1]->type; | ||||||
|  |  | ||||||
|  |     const char * suffix = ""; | ||||||
|  |  | ||||||
|         // use custom matrix x vector kernel |         // use custom matrix x vector kernel | ||||||
|     switch (tsrc0) { |     switch (tsrc0) { | ||||||
|         case GGML_TYPE_F32: |         case GGML_TYPE_F32: | ||||||
|             { |  | ||||||
|                 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); |  | ||||||
|                 nsg = 1; |  | ||||||
|                 nr0 = 1; |  | ||||||
|             } break; |  | ||||||
|         case GGML_TYPE_F16: |         case GGML_TYPE_F16: | ||||||
|             { |  | ||||||
|                 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); |  | ||||||
|                 nsg = 1; |  | ||||||
|                 nr0 = 1; |  | ||||||
|             } break; |  | ||||||
|         case GGML_TYPE_BF16: |         case GGML_TYPE_BF16: | ||||||
|             { |             { | ||||||
|                 GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); |                 if (ne00 % 4 == 0) { | ||||||
|                 nsg = 1; |                     nsg = N_SG_F; | ||||||
|                 nr0 = 1; |                     nr0 = N_R0_F; | ||||||
|  |                     nr1 = 1; | ||||||
|  |                     smem = 32*sizeof(float)*N_R0_F; | ||||||
|  |                     suffix = "_4"; | ||||||
|  |                 } else { | ||||||
|  |                     nsg = N_SG_F; | ||||||
|  |                     nr0 = N_R0_F; | ||||||
|  |                     nr1 = 1; | ||||||
|  |                     smem = 32*sizeof(float)*N_R0_F; | ||||||
|  |                 } | ||||||
|             } break; |             } break; | ||||||
|         case GGML_TYPE_Q4_0: |         case GGML_TYPE_Q4_0: | ||||||
|             { |             { | ||||||
| @@ -824,7 +817,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra | |||||||
|             } |             } | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     snprintf(base, 256, "kernel_mul_mv_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1)); |     snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix); | ||||||
|     snprintf(name, 256, "%s", base); |     snprintf(name, 256, "%s", base); | ||||||
|  |  | ||||||
|     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); |     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); | ||||||
|   | |||||||
| @@ -8,6 +8,9 @@ | |||||||
| // | // | ||||||
| // TODO: for optimal performance, become function of the device and work size | // TODO: for optimal performance, become function of the device and work size | ||||||
|  |  | ||||||
|  | #define N_R0_F 2 | ||||||
|  | #define N_SG_F 4 | ||||||
|  |  | ||||||
| #define N_R0_Q4_0 4 | #define N_R0_Q4_0 4 | ||||||
| #define N_SG_Q4_0 2 | #define N_SG_Q4_0 2 | ||||||
|  |  | ||||||
|   | |||||||
| @@ -1564,7 +1564,10 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { | |||||||
|  |  | ||||||
|         ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); |         ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||||||
|  |  | ||||||
|         if (op->src[0]->type == GGML_TYPE_Q8_0) { |         if (op->src[0]->type == GGML_TYPE_F32 || | ||||||
|  |             op->src[0]->type == GGML_TYPE_F16 || | ||||||
|  |             op->src[0]->type == GGML_TYPE_BF16 || | ||||||
|  |             op->src[0]->type == GGML_TYPE_Q8_0) { | ||||||
|             ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1); |             ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1); | ||||||
|         } else { |         } else { | ||||||
|             ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1); |             ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1); | ||||||
| @@ -1772,7 +1775,10 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { | |||||||
|  |  | ||||||
|         ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); |         ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); | ||||||
|  |  | ||||||
|         if (op->src[0]->type == GGML_TYPE_Q8_0) { |         if (op->src[0]->type == GGML_TYPE_F32 || | ||||||
|  |             op->src[0]->type == GGML_TYPE_F16 || | ||||||
|  |             op->src[0]->type == GGML_TYPE_BF16 || | ||||||
|  |             op->src[0]->type == GGML_TYPE_Q8_0) { | ||||||
|             ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1); |             ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1); | ||||||
|         } else { |         } else { | ||||||
|             ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1); |             ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1); | ||||||
|   | |||||||
| @@ -3404,104 +3404,211 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4 | |||||||
| template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; | template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>; | ||||||
| template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; | template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>; | ||||||
|  |  | ||||||
| #define N_MV_T_T 4 | template<typename T0, typename T1, short NR0, short NSG, short NW, typename args_t> | ||||||
|  | void kernel_mul_mv_t_t_impl( | ||||||
| template<typename T0, typename T04, typename T1, typename T14, typename args_t> |  | ||||||
| void kernel_mul_mv_impl( |  | ||||||
|         args_t args, |         args_t args, | ||||||
|         device const char * src0, |         device const char * src0, | ||||||
|         device const char * src1, |         device const char * src1, | ||||||
|         device       char * dst, |         device       char * dst, | ||||||
|  |         threadgroup  char * shmem, | ||||||
|         uint3  tgpig, |         uint3  tgpig, | ||||||
|         ushort tiisg) { |         ushort tiisg, | ||||||
|     const int r0 = tgpig.x; |         ushort sgitg) { | ||||||
|     const int rb = tgpig.y*N_MV_T_T; |     constexpr short NB = 32; | ||||||
|  |     constexpr short NF = 8; | ||||||
|  |  | ||||||
|  |     const int nb = args.ne00/NB; | ||||||
|  |  | ||||||
|  |     const int r0 = tgpig.x*NR0; | ||||||
|  |     const int r1 = tgpig.y; | ||||||
|     const int im = tgpig.z; |     const int im = tgpig.z; | ||||||
|  |  | ||||||
|     const uint i12 = im%args.ne12; |     const uint i12 = im%args.ne12; | ||||||
|     const uint i13 = im/args.ne12; |     const uint i13 = im/args.ne12; | ||||||
|  |  | ||||||
|     const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; |   //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||||||
|  |     const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13; | ||||||
|  |  | ||||||
|     device const T0 * x = (device const T0 *) (src0 + offset0); |   //device const T0 * x = (device const T0 *) (src0 + offset0); | ||||||
|  |     device const T1 * y = (device const T1 *) (src1 + offset1); | ||||||
|  |  | ||||||
|     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; |     // pointers to src0 rows | ||||||
|  |     device const T0 * ax [NR0]; | ||||||
|  |     FOR_UNROLL (short row = 0; row < NR0; ++row) { | ||||||
|  |         const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||||||
|  |  | ||||||
|     if (args.ne00 < 128) { |         ax[row] = (device const T0 *) ((device char *) src0 + offset0); | ||||||
|         for (int row = 0; row < N_MV_T_T; ++row) { |     } | ||||||
|             int r1 = rb + row; |  | ||||||
|             if (r1 >= args.ne11) { |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13; |     float sumf[NR0] = { 0.f }; | ||||||
|  |  | ||||||
|             device const T1 * y = (device const T1 *) (src1 + offset1); |     const short ix = tiisg/(NW/NF); | ||||||
|  |     const short il = tiisg%(NW/NF); | ||||||
|  |  | ||||||
|             float sumf = 0; |     const int ib0 = sgitg*NF + ix; | ||||||
|             for (int i = tiisg; i < args.ne00; i += 32) { |  | ||||||
|                 sumf += (T0) x[i] * (T1) y[i]; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             float sum_all = simd_sum(sumf); |     T1 yl[NF]; | ||||||
|             if (tiisg == 0) { |  | ||||||
|                 dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; |     device const T1 * yb = y + (ib0*NB + il*NF); | ||||||
|             } |  | ||||||
|  |     for (int ib = ib0; ib < nb; ib += NSG*NF) { | ||||||
|  |         for (short i = 0; i < NF; ++i) { | ||||||
|  |             yl[i] = yb[i]; | ||||||
|         } |         } | ||||||
|     } else { |  | ||||||
|         device const T04 * x4 = (device const T04 *) x; |         for (short row = 0; row < NR0; row++) { | ||||||
|         for (int row = 0; row < N_MV_T_T; ++row) { |             device const T0 * xb = ax[row] + (ib*NB + il*NF); | ||||||
|             int r1 = rb + row; |  | ||||||
|             if (r1 >= args.ne11) { |             float sumq = 0.f; | ||||||
|                 break; |             FOR_UNROLL (short i = 0; i < NF; ++i) { | ||||||
|  |                 sumq += xb[i] * yl[i]; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13; |             sumf[row] += sumq; | ||||||
|  |         } | ||||||
|  |  | ||||||
|             device const T1  * y  = (device const T1  *) (src1 + offset1); |         yb += NSG*NF*NW; | ||||||
|             device const T14 * y4 = (device const T14 *) y; |     } | ||||||
|  |  | ||||||
|             float sumf = 0; |     for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) { | ||||||
|             for (int i = tiisg; i < args.ne00/4; i += 32) { |         for (short row = 0; row < NR0; row++) { | ||||||
|                 sumf += dot((float4) x4[i], (float4) y4[i]); |             sumf[row] += ax[row][i] * y[i]; | ||||||
|             } |  | ||||||
|  |  | ||||||
|             float sum_all = simd_sum(sumf); |  | ||||||
|             if (tiisg == 0) { |  | ||||||
|                 for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); |  | ||||||
|                 dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; |  | ||||||
|             } |  | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||||||
|  |  | ||||||
|  |     helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); | ||||||
| } | } | ||||||
|  |  | ||||||
| template<typename T0, typename T04, typename T1, typename T14> | template<typename T0, typename T1, short NR0, short NSG, short NW> | ||||||
| kernel void kernel_mul_mv( | kernel void kernel_mul_mv_t_t( | ||||||
|         constant ggml_metal_kargs_mul_mv & args, |         constant ggml_metal_kargs_mul_mv & args, | ||||||
|         device const char * src0, |         device const char * src0, | ||||||
|         device const char * src1, |         device const char * src1, | ||||||
|         device       char * dst, |         device       char * dst, | ||||||
|  |         threadgroup  char * shmem [[threadgroup(0)]], | ||||||
|         uint3  tgpig[[threadgroup_position_in_grid]], |         uint3  tgpig[[threadgroup_position_in_grid]], | ||||||
|         ushort tiisg[[thread_index_in_simdgroup]]) { |         ushort tiisg[[thread_index_in_simdgroup]], | ||||||
|     kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>( |         ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||||||
|         args, |     kernel_mul_mv_t_t_impl<T0, T1, NR0, NSG, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||||||
|         src0, |  | ||||||
|         src1, |  | ||||||
|         dst, |  | ||||||
|         tgpig, |  | ||||||
|         tiisg); |  | ||||||
| } | } | ||||||
|  |  | ||||||
| typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t; | typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F, N_SG_F, N_SIMDWIDTH>) mul_mv_t_t; | ||||||
|  |  | ||||||
| template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t kernel_mul_mv<float,  float4,  float,  float4>; | template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F, N_SG_F, N_SIMDWIDTH>; | ||||||
| template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   float,  float4>; | template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  float, N_R0_F, N_SG_F, N_SIMDWIDTH>; | ||||||
| template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   half,   half4>; | template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t_t kernel_mul_mv_t_t<half,  half,  N_R0_F, N_SG_F, N_SIMDWIDTH>; | ||||||
| #if defined(GGML_METAL_HAS_BF16) | #if defined(GGML_METAL_HAS_BF16) | ||||||
| template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float,  float4>; | template [[host_name("kernel_mul_mv_bf16_f32")]]  kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float,  N_R0_F, N_SG_F, N_SIMDWIDTH>; | ||||||
| template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>; | template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F, N_SG_F, N_SIMDWIDTH>; | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|  | template<typename T0, typename T04, typename T1, typename T14, short NR0, short NSG, short NW, typename args_t> | ||||||
|  | void kernel_mul_mv_t_t_4_impl( | ||||||
|  |         args_t args, | ||||||
|  |         device const char * src0, | ||||||
|  |         device const char * src1, | ||||||
|  |         device       char * dst, | ||||||
|  |         threadgroup  char * shmem, | ||||||
|  |         uint3  tgpig, | ||||||
|  |         ushort tiisg, | ||||||
|  |         ushort sgitg) { | ||||||
|  |     constexpr short NB  = 32; | ||||||
|  |     constexpr short NF  = 16; | ||||||
|  |     constexpr short NF4 = NF/4; | ||||||
|  |  | ||||||
|  |     const int nb = args.ne00/NB; | ||||||
|  |  | ||||||
|  |     const int r0 = tgpig.x*NR0; | ||||||
|  |     const int r1 = tgpig.y; | ||||||
|  |     const int im = tgpig.z; | ||||||
|  |  | ||||||
|  |     const uint i12 = im%args.ne12; | ||||||
|  |     const uint i13 = im/args.ne12; | ||||||
|  |  | ||||||
|  |   //const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||||||
|  |     const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13; | ||||||
|  |  | ||||||
|  |     device const T1  * y  = (device const T1  *) (src1 + offset1); | ||||||
|  |     device const T14 * y4 = (device const T14 *) (src1 + offset1); | ||||||
|  |  | ||||||
|  |     // pointers to src0 rows | ||||||
|  |     device const T0  * ax [NR0]; | ||||||
|  |     device const T04 * ax4[NR0]; | ||||||
|  |     FOR_UNROLL (short row = 0; row < NR0; ++row) { | ||||||
|  |         const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; | ||||||
|  |  | ||||||
|  |         ax [row] = (device const T0  *) ((device char *) src0 + offset0); | ||||||
|  |         ax4[row] = (device const T04 *) ((device char *) src0 + offset0); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     float sumf[NR0] = { 0.f }; | ||||||
|  |  | ||||||
|  |     const short ix = tiisg/(NW/NF); | ||||||
|  |     const short il = tiisg%(NW/NF); | ||||||
|  |  | ||||||
|  |     const int ib0 = sgitg*NF + ix; | ||||||
|  |  | ||||||
|  |     T14 yl4[NF4]; | ||||||
|  |  | ||||||
|  |     device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4; | ||||||
|  |  | ||||||
|  |     for (int ib = ib0; ib < nb; ib += NSG*NF) { | ||||||
|  |         for (short i = 0; i < NF4; ++i) { | ||||||
|  |             yl4[i] = yb4[i]; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         for (short row = 0; row < NR0; row++) { | ||||||
|  |             device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4; | ||||||
|  |  | ||||||
|  |             float sumq = 0.f; | ||||||
|  |             FOR_UNROLL (short i = 0; i < NF4; ++i) { | ||||||
|  |                 sumq += dot(float4(xb4[i]), float4(yl4[i])); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             sumf[row] += sumq; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         yb4 += NSG*NF*NW/4; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) { | ||||||
|  |         for (short row = 0; row < NR0; row++) { | ||||||
|  |             sumf[row] += ax[row][i] * y[i]; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; | ||||||
|  |  | ||||||
|  |     helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template<typename T0, typename T04, typename T1, typename T14, short NR0, short NSG, short NW> | ||||||
|  | kernel void kernel_mul_mv_t_t_4( | ||||||
|  |         constant ggml_metal_kargs_mul_mv & args, | ||||||
|  |         device const char * src0, | ||||||
|  |         device const char * src1, | ||||||
|  |         device       char * dst, | ||||||
|  |         threadgroup  char * shmem [[threadgroup(0)]], | ||||||
|  |         uint3  tgpig[[threadgroup_position_in_grid]], | ||||||
|  |         ushort tiisg[[thread_index_in_simdgroup]], | ||||||
|  |         ushort sgitg[[simdgroup_index_in_threadgroup]]) { | ||||||
|  |     kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, NSG, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SG_F, N_SIMDWIDTH>) mul_mv_t_t_4; | ||||||
|  |  | ||||||
|  | template [[host_name("kernel_mul_mv_f32_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F, N_SG_F, N_SIMDWIDTH>; | ||||||
|  | template [[host_name("kernel_mul_mv_f16_f32_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  float, float4, N_R0_F, N_SG_F, N_SIMDWIDTH>; | ||||||
|  | template [[host_name("kernel_mul_mv_f16_f16_4")]]   kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half,  half4,  half,  half4,  N_R0_F, N_SG_F, N_SIMDWIDTH>; | ||||||
|  | #if defined(GGML_METAL_HAS_BF16) | ||||||
|  | template [[host_name("kernel_mul_mv_bf16_f32_4")]]  kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float,  float4,  N_R0_F, N_SG_F, N_SIMDWIDTH>; | ||||||
|  | template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F, N_SG_F, N_SIMDWIDTH>; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | #define N_MV_T_T 4 | ||||||
|  |  | ||||||
| template<typename T04, typename T14, typename args_t> | template<typename T04, typename T14, typename args_t> | ||||||
| void kernel_mul_mv_c4_impl( | void kernel_mul_mv_c4_impl( | ||||||
|         args_t args, |         args_t args, | ||||||
| @@ -3562,112 +3669,10 @@ typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t; | |||||||
|  |  | ||||||
| template [[host_name("kernel_mul_mv_f32_f32_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<float4,  float4>; | template [[host_name("kernel_mul_mv_f32_f32_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<float4,  float4>; | ||||||
| template [[host_name("kernel_mul_mv_f16_f32_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<half4,   float4>; | template [[host_name("kernel_mul_mv_f16_f32_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<half4,   float4>; | ||||||
|  | template [[host_name("kernel_mul_mv_f16_f16_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<half4,   half4>; | ||||||
| #if defined(GGML_METAL_HAS_BF16) | #if defined(GGML_METAL_HAS_BF16) | ||||||
| template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>; | template [[host_name("kernel_mul_mv_bf16_f32_c4")]]  kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>; | ||||||
| #endif | template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, bfloat4>; | ||||||
|  |  | ||||||
| template<typename T, typename T4> |  | ||||||
| kernel void kernel_mul_mv_1row( |  | ||||||
|         constant ggml_metal_kargs_mul_mv & args, |  | ||||||
|         device const char * src0, |  | ||||||
|         device const char * src1, |  | ||||||
|         device       char * dst, |  | ||||||
|         uint3  tgpig[[threadgroup_position_in_grid]], |  | ||||||
|         ushort tiisg[[thread_index_in_simdgroup]]) { |  | ||||||
|  |  | ||||||
|     const int r0 = tgpig.x; |  | ||||||
|     const int r1 = tgpig.y; |  | ||||||
|     const int im = tgpig.z; |  | ||||||
|  |  | ||||||
|     const uint i12 = im%args.ne12; |  | ||||||
|     const uint i13 = im/args.ne12; |  | ||||||
|  |  | ||||||
|     const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; |  | ||||||
|     const uint64_t offset1 = r1*args.nb11 + (i12        )*args.nb12 + (i13        )*args.nb13; |  | ||||||
|  |  | ||||||
|     device const T     * x = (device const T     *) (src0 + offset0); |  | ||||||
|     device const float * y = (device const float *) (src1 + offset1); |  | ||||||
|  |  | ||||||
|     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; |  | ||||||
|  |  | ||||||
|     float sumf = 0; |  | ||||||
|     if (args.ne00 < 128) { |  | ||||||
|         for (int i = tiisg; i < args.ne00; i += 32) { |  | ||||||
|             sumf += (float) x[i] * (float) y[i]; |  | ||||||
|         } |  | ||||||
|         float sum_all = simd_sum(sumf); |  | ||||||
|         if (tiisg == 0) { |  | ||||||
|             dst_f32[r0] = sum_all; |  | ||||||
|         } |  | ||||||
|     } else { |  | ||||||
|         device const T4     * x4 = (device const T4     *) x; |  | ||||||
|         device const float4 * y4 = (device const float4 *) y; |  | ||||||
|  |  | ||||||
|         for (int i = tiisg; i < args.ne00/4; i += 32) { |  | ||||||
|             sumf += dot((float4) x4[i], y4[i]); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         float sum_all = simd_sum(sumf); |  | ||||||
|  |  | ||||||
|         if (tiisg == 0) { |  | ||||||
|             for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]); |  | ||||||
|             dst_f32[r0] = sum_all; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t; |  | ||||||
|  |  | ||||||
| template [[host_name("kernel_mul_mv_f16_f32_1row")]]  kernel mul_mv_1row_t kernel_mul_mv_1row<half,   half4>; |  | ||||||
| #if defined(GGML_METAL_HAS_BF16) |  | ||||||
| template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>; |  | ||||||
| #endif |  | ||||||
|  |  | ||||||
| // Assumes row size (ne00) is a multiple of 4 |  | ||||||
| template<typename T, typename T4> |  | ||||||
| kernel void kernel_mul_mv_l4( |  | ||||||
|         constant ggml_metal_kargs_mul_mv & args, |  | ||||||
|         device const char * src0, |  | ||||||
|         device const char * src1, |  | ||||||
|         device       char * dst, |  | ||||||
|         uint3  tgpig[[threadgroup_position_in_grid]], |  | ||||||
|         ushort tiisg[[thread_index_in_simdgroup]]) { |  | ||||||
|  |  | ||||||
|     const int nrows = args.ne11; |  | ||||||
|     const int r0 = tgpig.x; |  | ||||||
|     const int im = tgpig.z; |  | ||||||
|  |  | ||||||
|     const uint i12 = im%args.ne12; |  | ||||||
|     const uint i13 = im/args.ne12; |  | ||||||
|  |  | ||||||
|     const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; |  | ||||||
|  |  | ||||||
|     device const T4 * x4 = (device const T4 *) (src0 + offset0); |  | ||||||
|  |  | ||||||
|     device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; |  | ||||||
|  |  | ||||||
|     for (int r1 = 0; r1 < nrows; ++r1) { |  | ||||||
|         const uint64_t offset1 = r1*args.nb11 + (i12   )*args.nb12 + (i13   )*args.nb13; |  | ||||||
|  |  | ||||||
|         device const float4 * y4 = (device const float4 *) (src1 + offset1); |  | ||||||
|  |  | ||||||
|         float sumf = 0; |  | ||||||
|         for (int i = tiisg; i < args.ne00/4; i += 32) { |  | ||||||
|             sumf += dot((float4) x4[i], y4[i]); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         float sum_all = simd_sum(sumf); |  | ||||||
|         if (tiisg == 0) { |  | ||||||
|             dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t; |  | ||||||
|  |  | ||||||
| template [[host_name("kernel_mul_mv_f16_f32_l4")]]  kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>; |  | ||||||
| #if defined(GGML_METAL_HAS_BF16) |  | ||||||
| template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>; |  | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| static float rope_yarn_ramp(const float low, const float high, const int i0) { | static float rope_yarn_ramp(const float low, const float high, const int i0) { | ||||||
| @@ -8314,7 +8319,7 @@ void mmv_fn( | |||||||
|     impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); |     impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); | ||||||
| } | } | ||||||
|  |  | ||||||
| typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t; | typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, N_SG_F, N_SIMDWIDTH, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t; | ||||||
|  |  | ||||||
| template<mul_mv_impl_fn_t impl_fn> | template<mul_mv_impl_fn_t impl_fn> | ||||||
| kernel void kernel_mul_mv_id( | kernel void kernel_mul_mv_id( | ||||||
| @@ -8379,13 +8384,21 @@ kernel void kernel_mul_mv_id( | |||||||
|         sgitg); |         sgitg); | ||||||
| } | } | ||||||
|  |  | ||||||
| typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t; | typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SG_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_t; | ||||||
|  |  | ||||||
| template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>; | typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SG_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_4_t; | ||||||
| template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>; |  | ||||||
|  | template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SG_F, N_SIMDWIDTH>>>; | ||||||
|  | template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half,  float, N_R0_F, N_SG_F, N_SIMDWIDTH>>>; | ||||||
| #if defined(GGML_METAL_HAS_BF16) | #if defined(GGML_METAL_HAS_BF16) | ||||||
| template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>; | template [[host_name("kernel_mul_mv_id_bf16_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F, N_SG_F, N_SIMDWIDTH>>>; | ||||||
| #endif | #endif | ||||||
|  | template [[host_name("kernel_mul_mv_id_f32_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SG_F, N_SIMDWIDTH>>>; | ||||||
|  | template [[host_name("kernel_mul_mv_id_f16_f32_4")]]   kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half,  half4,  float, float4, N_R0_F, N_SG_F, N_SIMDWIDTH>>>; | ||||||
|  | #if defined(GGML_METAL_HAS_BF16) | ||||||
|  | template [[host_name("kernel_mul_mv_id_bf16_f32_4")]]  kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F, N_SG_F, N_SIMDWIDTH>>>; | ||||||
|  | #endif | ||||||
|  |  | ||||||
| template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH>>>; | template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH>>>; | ||||||
|  |  | ||||||
| template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH>>>; | template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH>>>; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov