mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	| @@ -193,16 +193,16 @@ enum ggml_metal_kernel_type { | |||||||
|   //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,     // https://github.com/ggerganov/llama.cpp/issues/7261 |   //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,     // https://github.com/ggerganov/llama.cpp/issues/7261 | ||||||
|     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, |     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, | ||||||
|   //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 |   //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 | ||||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_F16, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_F32, |     GGML_METAL_KERNEL_TYPE_CPY_F32_F32, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CPY_F32_F16, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CPY_F16_F16, | ||||||
|  |     GGML_METAL_KERNEL_TYPE_CPY_F16_F32, | ||||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, |     GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, | ||||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, |     GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, | ||||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, |     GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, | ||||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, |     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, | ||||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, |     GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, | ||||||
|     GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, |     GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, | ||||||
|     GGML_METAL_KERNEL_TYPE_CPY_F16_F16, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_CPY_F16_F32, |  | ||||||
|     GGML_METAL_KERNEL_TYPE_CONCAT, |     GGML_METAL_KERNEL_TYPE_CONCAT, | ||||||
|     GGML_METAL_KERNEL_TYPE_SQR, |     GGML_METAL_KERNEL_TYPE_SQR, | ||||||
|     GGML_METAL_KERNEL_TYPE_SUM_ROWS, |     GGML_METAL_KERNEL_TYPE_SUM_ROWS, | ||||||
| @@ -651,14 +651,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { | |||||||
|       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    ctx->support_simdgroup_reduction); |       //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    ctx->support_simdgroup_reduction); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true); | ||||||
|  |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,                   cpy_f16_f32,                    true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,                  cpy_f32_q8_0,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,                  cpy_f32_q4_0,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,                  cpy_f32_q4_0,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,                  cpy_f32_q4_1,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,                  cpy_f32_q4_1,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,                  cpy_f32_q5_0,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,                  cpy_f32_q5_1,                   true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,                cpy_f32_iq4_nl,                 true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32,                   cpy_f16_f32,                    true); |  | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT,                        concat,                         true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR,                           sqr,                            true); | ||||||
|         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true); |         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true); | ||||||
| @@ -810,8 +810,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const | |||||||
|                 switch (op->src[0]->type) { |                 switch (op->src[0]->type) { | ||||||
|                     case GGML_TYPE_F32: |                     case GGML_TYPE_F32: | ||||||
|                         switch (op->type) { |                         switch (op->type) { | ||||||
|                            case GGML_TYPE_F16: |  | ||||||
|                            case GGML_TYPE_F32: |                            case GGML_TYPE_F32: | ||||||
|  |                            case GGML_TYPE_F16: | ||||||
|                            case GGML_TYPE_Q8_0: |                            case GGML_TYPE_Q8_0: | ||||||
|                            case GGML_TYPE_Q4_0: |                            case GGML_TYPE_Q4_0: | ||||||
|                            case GGML_TYPE_Q4_1: |                            case GGML_TYPE_Q4_1: | ||||||
| @@ -824,8 +824,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const | |||||||
|                         } |                         } | ||||||
|                     case GGML_TYPE_F16: |                     case GGML_TYPE_F16: | ||||||
|                         switch (op->type) { |                         switch (op->type) { | ||||||
|                            case GGML_TYPE_F16: |  | ||||||
|                            case GGML_TYPE_F32: |                            case GGML_TYPE_F32: | ||||||
|  |                            case GGML_TYPE_F16: | ||||||
|                                 return true; |                                 return true; | ||||||
|                            default: |                            default: | ||||||
|                                 return false; |                                 return false; | ||||||
| @@ -837,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const | |||||||
|         case GGML_OP_DIAG_MASK_INF: |         case GGML_OP_DIAG_MASK_INF: | ||||||
|         case GGML_OP_GET_ROWS: |         case GGML_OP_GET_ROWS: | ||||||
|             { |             { | ||||||
|                 return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1; |                 return op->ne[3] == 1; | ||||||
|             } |             } | ||||||
|         default: |         default: | ||||||
|             return false; |             return false; | ||||||
| @@ -2775,8 +2775,8 @@ static enum ggml_status ggml_metal_graph_compute( | |||||||
|                                     GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); |                                     GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0); | ||||||
|  |  | ||||||
|                                     switch (dstt) { |                                     switch (dstt) { | ||||||
|                                         case GGML_TYPE_F16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline;  break; |  | ||||||
|                                         case GGML_TYPE_F32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; |                                         case GGML_TYPE_F32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break; | ||||||
|  |                                         case GGML_TYPE_F16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break; | ||||||
|                                         case GGML_TYPE_Q8_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; |                                         case GGML_TYPE_Q8_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break; | ||||||
|                                         case GGML_TYPE_Q4_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; |                                         case GGML_TYPE_Q4_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break; | ||||||
|                                         case GGML_TYPE_Q4_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; |                                         case GGML_TYPE_Q4_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break; | ||||||
| @@ -2789,8 +2789,8 @@ static enum ggml_status ggml_metal_graph_compute( | |||||||
|                             case GGML_TYPE_F16: |                             case GGML_TYPE_F16: | ||||||
|                                 { |                                 { | ||||||
|                                     switch (dstt) { |                                     switch (dstt) { | ||||||
|                                         case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; |  | ||||||
|                                         case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; |                                         case GGML_TYPE_F32:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break; | ||||||
|  |                                         case GGML_TYPE_F16:  pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break; | ||||||
|                                         default: GGML_ASSERT(false && "not implemented"); |                                         default: GGML_ASSERT(false && "not implemented"); | ||||||
|                                     }; |                                     }; | ||||||
|                                 } break; |                                 } break; | ||||||
|   | |||||||
| @@ -1219,9 +1219,10 @@ kernel void kernel_mul_mv_q8_0_f32( | |||||||
|     kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); |     kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); | ||||||
| } | } | ||||||
|  |  | ||||||
| #define N_F32_F32 4 | #define N_MV_T_T 4 | ||||||
|  |  | ||||||
| void kernel_mul_mv_f32_f32_impl( | template<typename T0, typename T04, typename T1, typename T14> | ||||||
|  | void kernel_mul_mv_impl( | ||||||
|         device const  char * src0, |         device const  char * src0, | ||||||
|         device const  char * src1, |         device const  char * src1, | ||||||
|         device       float * dst, |         device       float * dst, | ||||||
| @@ -1243,9 +1244,8 @@ void kernel_mul_mv_f32_f32_impl( | |||||||
|                    uint      r3, |                    uint      r3, | ||||||
|                    uint3     tgpig, |                    uint3     tgpig, | ||||||
|                    uint      tiisg) { |                    uint      tiisg) { | ||||||
|  |  | ||||||
|     const int64_t r0 = tgpig.x; |     const int64_t r0 = tgpig.x; | ||||||
|     const int64_t rb = tgpig.y*N_F32_F32; |     const int64_t rb = tgpig.y*N_MV_T_T; | ||||||
|     const int64_t im = tgpig.z; |     const int64_t im = tgpig.z; | ||||||
|  |  | ||||||
|     const uint i12 = im%ne12; |     const uint i12 = im%ne12; | ||||||
| @@ -1253,20 +1253,20 @@ void kernel_mul_mv_f32_f32_impl( | |||||||
|  |  | ||||||
|     const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; |     const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; | ||||||
|  |  | ||||||
|     device const float * x = (device const float *) (src0 + offset0); |     device const T0 * x = (device const T0 *) (src0 + offset0); | ||||||
|  |  | ||||||
|     if (ne00 < 128) { |     if (ne00 < 128) { | ||||||
|         for (int row = 0; row < N_F32_F32; ++row) { |         for (int row = 0; row < N_MV_T_T; ++row) { | ||||||
|             int r1 = rb + row; |             int r1 = rb + row; | ||||||
|             if (r1 >= ne11) { |             if (r1 >= ne11) { | ||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); |             device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12); | ||||||
|  |  | ||||||
|             float sumf = 0; |             float sumf = 0; | ||||||
|             for (int i = tiisg; i < ne00; i += 32) { |             for (int i = tiisg; i < ne00; i += 32) { | ||||||
|                 sumf += (float) x[i] * (float) y[i]; |                 sumf += (T0) x[i] * (T1) y[i]; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             float all_sum = simd_sum(sumf); |             float all_sum = simd_sum(sumf); | ||||||
| @@ -1275,32 +1275,32 @@ void kernel_mul_mv_f32_f32_impl( | |||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } else { |     } else { | ||||||
|         device const float4 * x4 = (device const float4 *)x; |         device const T04 * x4 = (device const T04 *) x; | ||||||
|         for (int row = 0; row < N_F32_F32; ++row) { |         for (int row = 0; row < N_MV_T_T; ++row) { | ||||||
|             int r1 = rb + row; |             int r1 = rb + row; | ||||||
|             if (r1 >= ne11) { |             if (r1 >= ne11) { | ||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12); |             device const T1  * y  = (device const T1  *) (src1 + r1*nb11 + im*nb12); | ||||||
|             device const float4 * y4 = (device const float4 *) y; |             device const T14 * y4 = (device const T14 *) y; | ||||||
|  |  | ||||||
|             float sumf = 0; |             float sumf = 0; | ||||||
|             for (int i = tiisg; i < ne00/4; i += 32) { |             for (int i = tiisg; i < ne00/4; i += 32) { | ||||||
|                 for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; |                 for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             float all_sum = simd_sum(sumf); |             float all_sum = simd_sum(sumf); | ||||||
|             if (tiisg == 0) { |             if (tiisg == 0) { | ||||||
|                 for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; |                 for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); | ||||||
|                 dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; |                 dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| [[host_name("kernel_mul_mv_f32_f32")]] | template<typename T0, typename T04, typename T1, typename T14> | ||||||
| kernel void kernel_mul_mv_f32_f32( | kernel void kernel_mul_mv( | ||||||
|         device const  char * src0, |         device const  char * src0, | ||||||
|         device const  char * src1, |         device const  char * src1, | ||||||
|         device       float * dst, |         device       float * dst, | ||||||
| @@ -1322,90 +1322,38 @@ kernel void kernel_mul_mv_f32_f32( | |||||||
|         constant   uint    & r3, |         constant   uint    & r3, | ||||||
|         uint3 tgpig[[threadgroup_position_in_grid]], |         uint3 tgpig[[threadgroup_position_in_grid]], | ||||||
|         uint  tiisg[[thread_index_in_simdgroup]]) { |         uint  tiisg[[thread_index_in_simdgroup]]) { | ||||||
|     kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); |     kernel_mul_mv_impl<T0, T04, T1, T14>( | ||||||
|  |         src0, | ||||||
|  |         src1, | ||||||
|  |         dst, | ||||||
|  |         ne00, | ||||||
|  |         ne01, | ||||||
|  |         ne02, | ||||||
|  |         nb00, | ||||||
|  |         nb01, | ||||||
|  |         nb02, | ||||||
|  |         ne10, | ||||||
|  |         ne11, | ||||||
|  |         ne12, | ||||||
|  |         nb10, | ||||||
|  |         nb11, | ||||||
|  |         nb12, | ||||||
|  |         ne0, | ||||||
|  |         ne1, | ||||||
|  |         r2, | ||||||
|  |         r3, | ||||||
|  |         tgpig, | ||||||
|  |         tiisg); | ||||||
| } | } | ||||||
|  |  | ||||||
| #define N_F16_F16 4 | typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t; | ||||||
|  |  | ||||||
| kernel void kernel_mul_mv_f16_f16( | template [[host_name("kernel_mul_mv_f32_f32")]]   kernel mul_mv_t kernel_mul_mv<float,  float4,  float,  float4>; | ||||||
|         device const  char * src0, | template [[host_name("kernel_mul_mv_f16_f32")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   float,  float4>; | ||||||
|         device const  char * src1, | template [[host_name("kernel_mul_mv_f16_f16")]]   kernel mul_mv_t kernel_mul_mv<half,   half4,   half,   half4>; | ||||||
|         device       float * dst, |  | ||||||
|         constant   int64_t & ne00, |  | ||||||
|         constant   int64_t & ne01, |  | ||||||
|         constant   int64_t & ne02, |  | ||||||
|         constant  uint64_t & nb00, |  | ||||||
|         constant  uint64_t & nb01, |  | ||||||
|         constant  uint64_t & nb02, |  | ||||||
|         constant   int64_t & ne10, |  | ||||||
|         constant   int64_t & ne11, |  | ||||||
|         constant   int64_t & ne12, |  | ||||||
|         constant  uint64_t & nb10, |  | ||||||
|         constant  uint64_t & nb11, |  | ||||||
|         constant  uint64_t & nb12, |  | ||||||
|         constant   int64_t & ne0, |  | ||||||
|         constant   int64_t & ne1, |  | ||||||
|         constant   uint    & r2, |  | ||||||
|         constant   uint    & r3, |  | ||||||
|         uint3 tgpig[[threadgroup_position_in_grid]], |  | ||||||
|         uint  tiisg[[thread_index_in_simdgroup]]) { |  | ||||||
|  |  | ||||||
|     const int64_t r0 = tgpig.x; | template<typename T, typename T4> | ||||||
|     const int64_t rb = tgpig.y*N_F16_F16; | kernel void kernel_mul_mv_1row( | ||||||
|     const int64_t im = tgpig.z; |  | ||||||
|  |  | ||||||
|     const uint i12 = im%ne12; |  | ||||||
|     const uint i13 = im/ne12; |  | ||||||
|  |  | ||||||
|     const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; |  | ||||||
|  |  | ||||||
|     device const half * x = (device const half *) (src0 + offset0); |  | ||||||
|  |  | ||||||
|     if (ne00 < 128) { |  | ||||||
|         for (int row = 0; row < N_F16_F16; ++row) { |  | ||||||
|             int r1 = rb + row; |  | ||||||
|             if (r1 >= ne11) { |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12); |  | ||||||
|  |  | ||||||
|             float sumf = 0; |  | ||||||
|             for (int i = tiisg; i < ne00; i += 32) { |  | ||||||
|                 sumf += (half) x[i] * (half) y[i]; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             float all_sum = simd_sum(sumf); |  | ||||||
|             if (tiisg == 0) { |  | ||||||
|                 dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } else { |  | ||||||
|         device const half4 * x4 = (device const half4 *)x; |  | ||||||
|         for (int row = 0; row < N_F16_F16; ++row) { |  | ||||||
|             int r1 = rb + row; |  | ||||||
|             if (r1 >= ne11) { |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             device const half  * y  = (device const half  *) (src1 + r1*nb11 + im*nb12); |  | ||||||
|             device const half4 * y4 = (device const half4 *) y; |  | ||||||
|  |  | ||||||
|             float sumf = 0; |  | ||||||
|             for (int i = tiisg; i < ne00/4; i += 32) { |  | ||||||
|                 for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k]; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             float all_sum = simd_sum(sumf); |  | ||||||
|             if (tiisg == 0) { |  | ||||||
|                 for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i]; |  | ||||||
|                 dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| void kernel_mul_mv_f16_f32_1row_impl( |  | ||||||
|         device const  char * src0, |         device const  char * src0, | ||||||
|         device const  char * src1, |         device const  char * src1, | ||||||
|         device       float * dst, |         device       float * dst, | ||||||
| @@ -1437,7 +1385,7 @@ void kernel_mul_mv_f16_f32_1row_impl( | |||||||
|  |  | ||||||
|     const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; |     const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; | ||||||
|  |  | ||||||
|     device const half  * x = (device const half  *) (src0 + offset0); |     device const T     * x = (device const T     *) (src0 + offset0); | ||||||
|     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); |     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); | ||||||
|  |  | ||||||
|     float sumf = 0; |     float sumf = 0; | ||||||
| @@ -1450,153 +1398,29 @@ void kernel_mul_mv_f16_f32_1row_impl( | |||||||
|             dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; |             dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; | ||||||
|         } |         } | ||||||
|     } else { |     } else { | ||||||
|         device const half4  * x4 = (device const half4  *) x; |         device const T4     * x4 = (device const T4     *) x; | ||||||
|         device const float4 * y4 = (device const float4 *) y; |  | ||||||
|         for (int i = tiisg; i < ne00/4; i += 32) { |  | ||||||
|             for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k]; |  | ||||||
|         } |  | ||||||
|         float all_sum = simd_sum(sumf); |  | ||||||
|         if (tiisg == 0) { |  | ||||||
|             for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; |  | ||||||
|             dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| [[host_name("kernel_mul_mv_f16_f32_1row")]] |  | ||||||
| kernel void kernel_mul_mv_f16_f32_1row( |  | ||||||
|         device const  char * src0, |  | ||||||
|         device const  char * src1, |  | ||||||
|         device       float * dst, |  | ||||||
|         constant   int64_t & ne00, |  | ||||||
|         constant   int64_t & ne01, |  | ||||||
|         constant   int64_t & ne02, |  | ||||||
|         constant  uint64_t & nb00, |  | ||||||
|         constant  uint64_t & nb01, |  | ||||||
|         constant  uint64_t & nb02, |  | ||||||
|         constant   int64_t & ne10, |  | ||||||
|         constant   int64_t & ne11, |  | ||||||
|         constant   int64_t & ne12, |  | ||||||
|         constant  uint64_t & nb10, |  | ||||||
|         constant  uint64_t & nb11, |  | ||||||
|         constant  uint64_t & nb12, |  | ||||||
|         constant   int64_t & ne0, |  | ||||||
|         constant   int64_t & ne1, |  | ||||||
|         constant   uint    & r2, |  | ||||||
|         constant   uint    & r3, |  | ||||||
|         uint3 tgpig[[threadgroup_position_in_grid]], |  | ||||||
|         uint  tiisg[[thread_index_in_simdgroup]]) { |  | ||||||
|     kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| #define N_F16_F32 4 |  | ||||||
|  |  | ||||||
| void kernel_mul_mv_f16_f32_impl( |  | ||||||
|         device const  char * src0, |  | ||||||
|         device const  char * src1, |  | ||||||
|         device       float * dst, |  | ||||||
|                    int64_t   ne00, |  | ||||||
|                    int64_t   ne01, |  | ||||||
|                    int64_t   ne02, |  | ||||||
|                   uint64_t   nb00, |  | ||||||
|                   uint64_t   nb01, |  | ||||||
|                   uint64_t   nb02, |  | ||||||
|                    int64_t   ne10, |  | ||||||
|                    int64_t   ne11, |  | ||||||
|                    int64_t   ne12, |  | ||||||
|                   uint64_t   nb10, |  | ||||||
|                   uint64_t   nb11, |  | ||||||
|                   uint64_t   nb12, |  | ||||||
|                    int64_t   ne0, |  | ||||||
|                    int64_t   ne1, |  | ||||||
|                    uint      r2, |  | ||||||
|                    uint      r3, |  | ||||||
|                    uint3     tgpig, |  | ||||||
|                    uint      tiisg) { |  | ||||||
|  |  | ||||||
|     const int64_t r0 = tgpig.x; |  | ||||||
|     const int64_t rb = tgpig.y*N_F16_F32; |  | ||||||
|     const int64_t im = tgpig.z; |  | ||||||
|  |  | ||||||
|     const uint i12 = im%ne12; |  | ||||||
|     const uint i13 = im/ne12; |  | ||||||
|  |  | ||||||
|     const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; |  | ||||||
|  |  | ||||||
|     device const half * x = (device const half *) (src0 + offset0); |  | ||||||
|  |  | ||||||
|     if (ne00 < 128) { |  | ||||||
|         for (int row = 0; row < N_F16_F32; ++row) { |  | ||||||
|             int r1 = rb + row; |  | ||||||
|             if (r1 >= ne11) { |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); |  | ||||||
|  |  | ||||||
|             float sumf = 0; |  | ||||||
|             for (int i = tiisg; i < ne00; i += 32) { |  | ||||||
|                 sumf += (float) x[i] * (float) y[i]; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             float all_sum = simd_sum(sumf); |  | ||||||
|             if (tiisg == 0) { |  | ||||||
|                 dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } else { |  | ||||||
|         device const half4 * x4 = (device const half4 *)x; |  | ||||||
|         for (int row = 0; row < N_F16_F32; ++row) { |  | ||||||
|             int r1 = rb + row; |  | ||||||
|             if (r1 >= ne11) { |  | ||||||
|                 break; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12); |  | ||||||
|         device const float4 * y4 = (device const float4 *) y; |         device const float4 * y4 = (device const float4 *) y; | ||||||
|  |  | ||||||
|             float sumf = 0; |  | ||||||
|         for (int i = tiisg; i < ne00/4; i += 32) { |         for (int i = tiisg; i < ne00/4; i += 32) { | ||||||
|                 for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; |             for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         float all_sum = simd_sum(sumf); |         float all_sum = simd_sum(sumf); | ||||||
|  |  | ||||||
|         if (tiisg == 0) { |         if (tiisg == 0) { | ||||||
|                 for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i]; |             for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); | ||||||
|             dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; |             dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     } |  | ||||||
| } | } | ||||||
|  |  | ||||||
| [[host_name("kernel_mul_mv_f16_f32")]] | typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t; | ||||||
| kernel void kernel_mul_mv_f16_f32( |  | ||||||
|         device const  char * src0, | template [[host_name("kernel_mul_mv_f16_f32_1row")]]  kernel mul_mv_1row_t kernel_mul_mv_1row<half,   half4>; | ||||||
|         device const  char * src1, |  | ||||||
|         device       float * dst, |  | ||||||
|         constant   int64_t & ne00, |  | ||||||
|         constant   int64_t & ne01, |  | ||||||
|         constant   int64_t & ne02, |  | ||||||
|         constant  uint64_t & nb00, |  | ||||||
|         constant  uint64_t & nb01, |  | ||||||
|         constant  uint64_t & nb02, |  | ||||||
|         constant   int64_t & ne10, |  | ||||||
|         constant   int64_t & ne11, |  | ||||||
|         constant   int64_t & ne12, |  | ||||||
|         constant  uint64_t & nb10, |  | ||||||
|         constant  uint64_t & nb11, |  | ||||||
|         constant  uint64_t & nb12, |  | ||||||
|         constant   int64_t & ne0, |  | ||||||
|         constant   int64_t & ne1, |  | ||||||
|         constant   uint    & r2, |  | ||||||
|         constant   uint    & r3, |  | ||||||
|         uint3 tgpig[[threadgroup_position_in_grid]], |  | ||||||
|         uint tiisg[[thread_index_in_simdgroup]]) { |  | ||||||
|     kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Assumes row size (ne00) is a multiple of 4 | // Assumes row size (ne00) is a multiple of 4 | ||||||
| kernel void kernel_mul_mv_f16_f32_l4( | template<typename T, typename T4> | ||||||
|  | kernel void kernel_mul_mv_l4( | ||||||
|         device const  char * src0, |         device const  char * src0, | ||||||
|         device const  char * src1, |         device const  char * src1, | ||||||
|         device       float * dst, |         device       float * dst, | ||||||
| @@ -1628,14 +1452,14 @@ kernel void kernel_mul_mv_f16_f32_l4( | |||||||
|  |  | ||||||
|     const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; |     const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; | ||||||
|  |  | ||||||
|     device const half4 * x4 = (device const half4 *) (src0 + offset0); |     device const T4 * x4 = (device const T4 *) (src0 + offset0); | ||||||
|  |  | ||||||
|     for (int r1 = 0; r1 < nrows; ++r1) { |     for (int r1 = 0; r1 < nrows; ++r1) { | ||||||
|         device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); |         device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); | ||||||
|  |  | ||||||
|         float sumf = 0; |         float sumf = 0; | ||||||
|         for (int i = tiisg; i < ne00/4; i += 32) { |         for (int i = tiisg; i < ne00/4; i += 32) { | ||||||
|             for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k]; |             for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         float all_sum = simd_sum(sumf); |         float all_sum = simd_sum(sumf); | ||||||
| @@ -1645,6 +1469,10 @@ kernel void kernel_mul_mv_f16_f32_l4( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t; | ||||||
|  |  | ||||||
|  | template [[host_name("kernel_mul_mv_f16_f32_l4")]]  kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>; | ||||||
|  |  | ||||||
| static float rope_yarn_ramp(const float low, const float high, const int i0) { | static float rope_yarn_ramp(const float low, const float high, const int i0) { | ||||||
|     const float y = (i0 / 2 - low) / max(0.001f, high - low); |     const float y = (i0 / 2 - low) / max(0.001f, high - low); | ||||||
|     return 1.0f - min(1.0f, max(0.0f, y)); |     return 1.0f - min(1.0f, max(0.0f, y)); | ||||||
| @@ -2765,9 +2593,10 @@ kernel void kernel_flash_attn_ext_vec_f16( | |||||||
| template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; | template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>; | ||||||
| //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; | //template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>; | ||||||
|  |  | ||||||
| kernel void kernel_cpy_f16_f16( | template<typename T0, typename T1> | ||||||
|         device  const half * src0, | kernel void kernel_cpy( | ||||||
|         device        half * dst, |         device  const void * src0, | ||||||
|  |         device        void * dst, | ||||||
|         constant   int64_t & ne00, |         constant   int64_t & ne00, | ||||||
|         constant   int64_t & ne01, |         constant   int64_t & ne01, | ||||||
|         constant   int64_t & ne02, |         constant   int64_t & ne02, | ||||||
| @@ -2798,138 +2627,20 @@ kernel void kernel_cpy_f16_f16( | |||||||
|     const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; |     const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; | ||||||
|     const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); |     const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); | ||||||
|  |  | ||||||
|     device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); |     device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); | ||||||
|  |  | ||||||
|     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { |     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { | ||||||
|         device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); |         device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); | ||||||
|         dst_data[i00] = src[0]; |         dst_data[i00] = (T1) src[0]; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| kernel void kernel_cpy_f16_f32( | typedef decltype(kernel_cpy<float, float>) kernel_cpy_t; | ||||||
|         device  const half * src0, |  | ||||||
|         device       float * dst, |  | ||||||
|         constant   int64_t & ne00, |  | ||||||
|         constant   int64_t & ne01, |  | ||||||
|         constant   int64_t & ne02, |  | ||||||
|         constant   int64_t & ne03, |  | ||||||
|         constant  uint64_t & nb00, |  | ||||||
|         constant  uint64_t & nb01, |  | ||||||
|         constant  uint64_t & nb02, |  | ||||||
|         constant  uint64_t & nb03, |  | ||||||
|         constant   int64_t & ne0, |  | ||||||
|         constant   int64_t & ne1, |  | ||||||
|         constant   int64_t & ne2, |  | ||||||
|         constant   int64_t & ne3, |  | ||||||
|         constant  uint64_t & nb0, |  | ||||||
|         constant  uint64_t & nb1, |  | ||||||
|         constant  uint64_t & nb2, |  | ||||||
|         constant  uint64_t & nb3, |  | ||||||
|         uint3 tgpig[[threadgroup_position_in_grid]], |  | ||||||
|         uint3 tpitg[[thread_position_in_threadgroup]], |  | ||||||
|         uint3   ntg[[threads_per_threadgroup]]) { |  | ||||||
|     const int64_t i03 = tgpig[2]; |  | ||||||
|     const int64_t i02 = tgpig[1]; |  | ||||||
|     const int64_t i01 = tgpig[0]; |  | ||||||
|  |  | ||||||
|     const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; | template [[host_name("kernel_cpy_f32_f32")]]  kernel kernel_cpy_t kernel_cpy<float,  float>; | ||||||
|  | template [[host_name("kernel_cpy_f32_f16")]]  kernel kernel_cpy_t kernel_cpy<float,  half>; | ||||||
|     const int64_t i3 = n / (ne2*ne1*ne0); | template [[host_name("kernel_cpy_f16_f16")]]  kernel kernel_cpy_t kernel_cpy<half,   half>; | ||||||
|     const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); | template [[host_name("kernel_cpy_f16_f32")]]  kernel kernel_cpy_t kernel_cpy<half,   float>; | ||||||
|     const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; |  | ||||||
|     const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); |  | ||||||
|  |  | ||||||
|     device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); |  | ||||||
|  |  | ||||||
|     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { |  | ||||||
|         device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); |  | ||||||
|         dst_data[i00] = src[0]; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| kernel void kernel_cpy_f32_f16( |  | ||||||
|         device const float * src0, |  | ||||||
|         device        half * dst, |  | ||||||
|         constant   int64_t & ne00, |  | ||||||
|         constant   int64_t & ne01, |  | ||||||
|         constant   int64_t & ne02, |  | ||||||
|         constant   int64_t & ne03, |  | ||||||
|         constant  uint64_t & nb00, |  | ||||||
|         constant  uint64_t & nb01, |  | ||||||
|         constant  uint64_t & nb02, |  | ||||||
|         constant  uint64_t & nb03, |  | ||||||
|         constant   int64_t & ne0, |  | ||||||
|         constant   int64_t & ne1, |  | ||||||
|         constant   int64_t & ne2, |  | ||||||
|         constant   int64_t & ne3, |  | ||||||
|         constant  uint64_t & nb0, |  | ||||||
|         constant  uint64_t & nb1, |  | ||||||
|         constant  uint64_t & nb2, |  | ||||||
|         constant  uint64_t & nb3, |  | ||||||
|         uint3 tgpig[[threadgroup_position_in_grid]], |  | ||||||
|         uint3 tpitg[[thread_position_in_threadgroup]], |  | ||||||
|         uint3   ntg[[threads_per_threadgroup]]) { |  | ||||||
|     const int64_t i03 = tgpig[2]; |  | ||||||
|     const int64_t i02 = tgpig[1]; |  | ||||||
|     const int64_t i01 = tgpig[0]; |  | ||||||
|  |  | ||||||
|     const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; |  | ||||||
|  |  | ||||||
|     const int64_t i3 = n / (ne2*ne1*ne0); |  | ||||||
|     const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); |  | ||||||
|     const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; |  | ||||||
|     const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); |  | ||||||
|  |  | ||||||
|     device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); |  | ||||||
|  |  | ||||||
|     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { |  | ||||||
|         device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); |  | ||||||
|  |  | ||||||
|         dst_data[i00] = src[0]; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| kernel void kernel_cpy_f32_f32( |  | ||||||
|         device const float * src0, |  | ||||||
|         device       float * dst, |  | ||||||
|         constant   int64_t & ne00, |  | ||||||
|         constant   int64_t & ne01, |  | ||||||
|         constant   int64_t & ne02, |  | ||||||
|         constant   int64_t & ne03, |  | ||||||
|         constant  uint64_t & nb00, |  | ||||||
|         constant  uint64_t & nb01, |  | ||||||
|         constant  uint64_t & nb02, |  | ||||||
|         constant  uint64_t & nb03, |  | ||||||
|         constant   int64_t & ne0, |  | ||||||
|         constant   int64_t & ne1, |  | ||||||
|         constant   int64_t & ne2, |  | ||||||
|         constant   int64_t & ne3, |  | ||||||
|         constant  uint64_t & nb0, |  | ||||||
|         constant  uint64_t & nb1, |  | ||||||
|         constant  uint64_t & nb2, |  | ||||||
|         constant  uint64_t & nb3, |  | ||||||
|         uint3 tgpig[[threadgroup_position_in_grid]], |  | ||||||
|         uint3 tpitg[[thread_position_in_threadgroup]], |  | ||||||
|         uint3   ntg[[threads_per_threadgroup]]) { |  | ||||||
|     const int64_t i03 = tgpig[2]; |  | ||||||
|     const int64_t i02 = tgpig[1]; |  | ||||||
|     const int64_t i01 = tgpig[0]; |  | ||||||
|  |  | ||||||
|     const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; |  | ||||||
|  |  | ||||||
|     const int64_t i3 = n / (ne2*ne1*ne0); |  | ||||||
|     const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); |  | ||||||
|     const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; |  | ||||||
|     const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); |  | ||||||
|  |  | ||||||
|     device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); |  | ||||||
|  |  | ||||||
|     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { |  | ||||||
|         device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); |  | ||||||
|  |  | ||||||
|         dst_data[i00] = src[0]; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| kernel void kernel_cpy_f32_q8_0( | kernel void kernel_cpy_f32_q8_0( | ||||||
|         device const float * src0, |         device const float * src0, | ||||||
| @@ -5730,9 +5441,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 | |||||||
| } | } | ||||||
|  |  | ||||||
| template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)> | template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)> | ||||||
| kernel void kernel_get_rows( | kernel void kernel_get_rows_q( | ||||||
|         device const  void * src0, |         device const  void * src0, | ||||||
|         device const  char * src1, |         device const  void * src1, | ||||||
|         device       float * dst, |         device       float * dst, | ||||||
|         constant   int64_t & ne00, |         constant   int64_t & ne00, | ||||||
|         constant  uint64_t & nb01, |         constant  uint64_t & nb01, | ||||||
| @@ -5745,27 +5456,24 @@ kernel void kernel_get_rows( | |||||||
|         uint3                tgpig[[threadgroup_position_in_grid]], |         uint3                tgpig[[threadgroup_position_in_grid]], | ||||||
|         uint                 tiitg[[thread_index_in_threadgroup]], |         uint                 tiitg[[thread_index_in_threadgroup]], | ||||||
|         uint3                tptg [[threads_per_threadgroup]]) { |         uint3                tptg [[threads_per_threadgroup]]) { | ||||||
|     //const int64_t i = tgpig; |  | ||||||
|     //const int64_t r = ((device int32_t *) src1)[i]; |  | ||||||
|  |  | ||||||
|     const int64_t i10 = tgpig.x; |     const int64_t i10 = tgpig.x; | ||||||
|     const int64_t i11 = tgpig.y; |     const int64_t i11 = tgpig.y; | ||||||
|  |  | ||||||
|     const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; |     const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; | ||||||
|  |  | ||||||
|     const int64_t i02 = i11; |     const int64_t i02 = i11; | ||||||
|  |  | ||||||
|     for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { |     for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) { | ||||||
|         float4x4 temp; |         float4x4 temp; | ||||||
|         dequantize_func( |         dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); | ||||||
|             ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp); |  | ||||||
|         *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; |         *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| kernel void kernel_get_rows_f32( | template<typename T> | ||||||
|  | kernel void kernel_get_rows_f( | ||||||
|         device const  void * src0, |         device const  void * src0, | ||||||
|         device const  char * src1, |         device const  void * src1, | ||||||
|         device       float * dst, |         device       float * dst, | ||||||
|         constant   int64_t & ne00, |         constant   int64_t & ne00, | ||||||
|         constant  uint64_t & nb01, |         constant  uint64_t & nb01, | ||||||
| @@ -5781,47 +5489,19 @@ kernel void kernel_get_rows_f32( | |||||||
|     const int64_t i10 = tgpig.x; |     const int64_t i10 = tgpig.x; | ||||||
|     const int64_t i11 = tgpig.y; |     const int64_t i11 = tgpig.y; | ||||||
|  |  | ||||||
|     const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; |     const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; | ||||||
|  |  | ||||||
|     const int64_t i02 = i11; |     const int64_t i02 = i11; | ||||||
|  |  | ||||||
|     for (int ind = tiitg; ind < ne00; ind += tptg.x) { |     for (int ind = tiitg; ind < ne00; ind += tptg.x) { | ||||||
|         ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = |         ((      device float *) ((      device char *)  dst + i11*nb2  + i10*nb1))[ind] = | ||||||
|             ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; |         ((const device T     *) ((const device char *) src0 + i02*nb02 +  r*nb01))[ind]; | ||||||
|     } |  | ||||||
| } |  | ||||||
|  |  | ||||||
| kernel void kernel_get_rows_f16( |  | ||||||
|         device const  void * src0, |  | ||||||
|         device const  char * src1, |  | ||||||
|         device       float * dst, |  | ||||||
|         constant   int64_t & ne00, |  | ||||||
|         constant  uint64_t & nb01, |  | ||||||
|         constant  uint64_t & nb02, |  | ||||||
|         constant   int64_t & ne10, |  | ||||||
|         constant  uint64_t & nb10, |  | ||||||
|         constant  uint64_t & nb11, |  | ||||||
|         constant  uint64_t & nb1, |  | ||||||
|         constant  uint64_t & nb2, |  | ||||||
|         uint3                tgpig[[threadgroup_position_in_grid]], |  | ||||||
|         uint                 tiitg[[thread_index_in_threadgroup]], |  | ||||||
|         uint3                tptg [[threads_per_threadgroup]]) { |  | ||||||
|     const int64_t i10 = tgpig.x; |  | ||||||
|     const int64_t i11 = tgpig.y; |  | ||||||
|  |  | ||||||
|     const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; |  | ||||||
|  |  | ||||||
|     const int64_t i02 = i11; |  | ||||||
|  |  | ||||||
|     for (int ind = tiitg; ind < ne00; ind += tptg.x) { |  | ||||||
|         ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = |  | ||||||
|             ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; |  | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| kernel void kernel_get_rows_i32( | kernel void kernel_get_rows_i32( | ||||||
|         device const  void * src0, |         device const  void * src0, | ||||||
|         device const  char * src1, |         device const  void * src1, | ||||||
|         device     int32_t * dst, |         device     int32_t * dst, | ||||||
|         constant   int64_t & ne00, |         constant   int64_t & ne00, | ||||||
|         constant  uint64_t & nb01, |         constant  uint64_t & nb01, | ||||||
| @@ -5837,13 +5517,13 @@ kernel void kernel_get_rows_i32( | |||||||
|     const int64_t i10 = tgpig.x; |     const int64_t i10 = tgpig.x; | ||||||
|     const int64_t i11 = tgpig.y; |     const int64_t i11 = tgpig.y; | ||||||
|  |  | ||||||
|     const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0]; |     const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0]; | ||||||
|  |  | ||||||
|     const int64_t i02 = i11; |     const int64_t i02 = i11; | ||||||
|  |  | ||||||
|     for (int ind = tiitg; ind < ne00; ind += tptg.x) { |     for (int ind = tiitg; ind < ne00; ind += tptg.x) { | ||||||
|         ((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] = |         ((      device int32_t *) ((      device char *) dst  + i11*nb2 + i10*nb1))[ind] = | ||||||
|             ((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind]; |         ((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind]; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -5860,8 +5540,8 @@ kernel void kernel_get_rows_i32( | |||||||
| #define SG_MAT_ROW 8 | #define SG_MAT_ROW 8 | ||||||
|  |  | ||||||
| // each block_q contains 16*nl weights | // each block_q contains 16*nl weights | ||||||
| template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> | template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)> | ||||||
| void kernel_mul_mm_impl(device const  uchar * src0, | kernel void kernel_mul_mm(device const  uchar * src0, | ||||||
|                           device const  uchar * src1, |                           device const  uchar * src1, | ||||||
|                           device        float * dst, |                           device        float * dst, | ||||||
|                           constant    int64_t & ne00, |                           constant    int64_t & ne00, | ||||||
| @@ -5881,7 +5561,7 @@ void kernel_mul_mm_impl(device const  uchar * src0, | |||||||
|                           uint                  tiitg[[thread_index_in_threadgroup]], |                           uint                  tiitg[[thread_index_in_threadgroup]], | ||||||
|                           uint                  sgitg[[simdgroup_index_in_threadgroup]]) { |                           uint                  sgitg[[simdgroup_index_in_threadgroup]]) { | ||||||
|  |  | ||||||
|     threadgroup half  * sa = (threadgroup half  *)(shared_memory); |     threadgroup T     * sa = (threadgroup T     *)(shared_memory); | ||||||
|     threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); |     threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); | ||||||
|  |  | ||||||
|     const uint r0 = tgpig.y; |     const uint r0 = tgpig.y; | ||||||
| @@ -5896,7 +5576,7 @@ void kernel_mul_mm_impl(device const  uchar * src0, | |||||||
|     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; |     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; | ||||||
|     short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; |     short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; | ||||||
|  |  | ||||||
|     simdgroup_half8x8  ma[4]; |     simdgroup_T8x8     ma[4]; | ||||||
|     simdgroup_float8x8 mb[2]; |     simdgroup_float8x8 mb[2]; | ||||||
|     simdgroup_float8x8 c_res[8]; |     simdgroup_float8x8 c_res[8]; | ||||||
|     for (int i = 0; i < 8; i++){ |     for (int i = 0; i < 8; i++){ | ||||||
| @@ -5919,7 +5599,7 @@ void kernel_mul_mm_impl(device const  uchar * src0, | |||||||
|  |  | ||||||
|     for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { |     for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { | ||||||
|         // load data and store to threadgroup memory |         // load data and store to threadgroup memory | ||||||
|         half4x4 temp_a; |         T4x4 temp_a; | ||||||
|         dequantize_func(x, il, temp_a); |         dequantize_func(x, il, temp_a); | ||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
| @@ -5939,7 +5619,7 @@ void kernel_mul_mm_impl(device const  uchar * src0, | |||||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|         // load matrices from threadgroup memory and conduct outer products |         // load matrices from threadgroup memory and conduct outer products | ||||||
|         threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); |         threadgroup T     * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); | ||||||
|         threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); |         threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); | ||||||
|  |  | ||||||
|         #pragma unroll(4) |         #pragma unroll(4) | ||||||
| @@ -6115,48 +5795,6 @@ void kernel_mul_mm_id_impl( | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> |  | ||||||
| kernel void kernel_mul_mm(device const  uchar * src0, |  | ||||||
|                           device const  uchar * src1, |  | ||||||
|                           device        float * dst, |  | ||||||
|                           constant    int64_t & ne00, |  | ||||||
|                           constant    int64_t & ne02, |  | ||||||
|                           constant   uint64_t & nb01, |  | ||||||
|                           constant   uint64_t & nb02, |  | ||||||
|                           constant    int64_t & ne12, |  | ||||||
|                           constant   uint64_t & nb10, |  | ||||||
|                           constant   uint64_t & nb11, |  | ||||||
|                           constant   uint64_t & nb12, |  | ||||||
|                           constant    int64_t & ne0, |  | ||||||
|                           constant    int64_t & ne1, |  | ||||||
|                           constant       uint & r2, |  | ||||||
|                           constant       uint & r3, |  | ||||||
|                           threadgroup   uchar * shared_memory [[threadgroup(0)]], |  | ||||||
|                           uint3                 tgpig[[threadgroup_position_in_grid]], |  | ||||||
|                           uint                  tiitg[[thread_index_in_threadgroup]], |  | ||||||
|                           uint                  sgitg[[simdgroup_index_in_threadgroup]]) { |  | ||||||
|     kernel_mul_mm_impl<block_q, nl, dequantize_func>( |  | ||||||
|         src0, |  | ||||||
|         src1, |  | ||||||
|         dst, |  | ||||||
|         ne00, |  | ||||||
|         ne02, |  | ||||||
|         nb01, |  | ||||||
|         nb02, |  | ||||||
|         ne12, |  | ||||||
|         nb10, |  | ||||||
|         nb11, |  | ||||||
|         nb12, |  | ||||||
|         ne0, |  | ||||||
|         ne1, |  | ||||||
|         r2, |  | ||||||
|         r3, |  | ||||||
|         shared_memory, |  | ||||||
|         tgpig, |  | ||||||
|         tiitg, |  | ||||||
|         sgitg); |  | ||||||
| } |  | ||||||
|  |  | ||||||
| template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> | template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> | ||||||
| kernel void kernel_mul_mm_id( | kernel void kernel_mul_mm_id( | ||||||
|         device const   uchar * src0s, |         device const   uchar * src0s, | ||||||
| @@ -6237,69 +5875,60 @@ kernel void kernel_mul_mm_id( | |||||||
| // get rows | // get rows | ||||||
| // | // | ||||||
|  |  | ||||||
| typedef void (get_rows_t)( | typedef decltype(kernel_get_rows_f<float>) get_rows_f_t; | ||||||
|         device const void * src0, |  | ||||||
|         device const char * src1, |  | ||||||
|         device      float * dst, |  | ||||||
|         constant  int64_t & ne00, |  | ||||||
|         constant uint64_t & nb01, |  | ||||||
|         constant uint64_t & nb02, |  | ||||||
|         constant  int64_t & ne10, |  | ||||||
|         constant uint64_t & nb10, |  | ||||||
|         constant uint64_t & nb11, |  | ||||||
|         constant uint64_t & nb1, |  | ||||||
|         constant uint64_t & nb2, |  | ||||||
|         uint3, uint, uint3); |  | ||||||
|  |  | ||||||
| //template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>; | template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_f_t kernel_get_rows_f<float>; | ||||||
| //template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>; | template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_f_t kernel_get_rows_f<half>; | ||||||
| template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>; |  | ||||||
| template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>; | typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t; | ||||||
| template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>; |  | ||||||
| template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>; | template [[host_name("kernel_get_rows_q4_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_0,    2, dequantize_q4_0>; | ||||||
| template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>; | template [[host_name("kernel_get_rows_q4_1")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_1,    2, dequantize_q4_1>; | ||||||
| template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>; | template [[host_name("kernel_get_rows_q5_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_0,    2, dequantize_q5_0>; | ||||||
| template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>; | template [[host_name("kernel_get_rows_q5_1")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_1,    2, dequantize_q5_1>; | ||||||
| template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>; | template [[host_name("kernel_get_rows_q8_0")]]    kernel get_rows_q_t kernel_get_rows_q<block_q8_0,    2, dequantize_q8_0>; | ||||||
| template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>; | template [[host_name("kernel_get_rows_q2_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q2_K,    QK_NL, dequantize_q2_K>; | ||||||
| template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>; | template [[host_name("kernel_get_rows_q3_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q3_K,    QK_NL, dequantize_q3_K>; | ||||||
| template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; | template [[host_name("kernel_get_rows_q4_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q4_K,    QK_NL, dequantize_q4_K>; | ||||||
| template [[host_name("kernel_get_rows_iq2_xs")]]  kernel get_rows_t kernel_get_rows<block_iq2_xs,  QK_NL, dequantize_iq2_xs>; | template [[host_name("kernel_get_rows_q5_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q5_K,    QK_NL, dequantize_q5_K>; | ||||||
| template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; | template [[host_name("kernel_get_rows_q6_K")]]    kernel get_rows_q_t kernel_get_rows_q<block_q6_K,    QK_NL, dequantize_q6_K>; | ||||||
| template [[host_name("kernel_get_rows_iq3_s")]]   kernel get_rows_t kernel_get_rows<block_iq3_s,   QK_NL, dequantize_iq3_s>; | template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; | ||||||
| template [[host_name("kernel_get_rows_iq2_s")]]   kernel get_rows_t kernel_get_rows<block_iq2_s,   QK_NL, dequantize_iq2_s>; | template [[host_name("kernel_get_rows_iq2_xs")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs,  QK_NL, dequantize_iq2_xs>; | ||||||
| template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_t kernel_get_rows<block_iq1_s,   QK_NL, dequantize_iq1_s>; | template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; | ||||||
| template [[host_name("kernel_get_rows_iq1_m")]]   kernel get_rows_t kernel_get_rows<block_iq1_m,   QK_NL, dequantize_iq1_m>; | template [[host_name("kernel_get_rows_iq3_s")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq3_s,   QK_NL, dequantize_iq3_s>; | ||||||
| template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_t kernel_get_rows<block_iq4_nl,  2,     dequantize_iq4_nl>; | template [[host_name("kernel_get_rows_iq2_s")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq2_s,   QK_NL, dequantize_iq2_s>; | ||||||
| template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_t kernel_get_rows<block_iq4_xs,  QK_NL, dequantize_iq4_xs>; | template [[host_name("kernel_get_rows_iq1_s")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq1_s,   QK_NL, dequantize_iq1_s>; | ||||||
|  | template [[host_name("kernel_get_rows_iq1_m")]]   kernel get_rows_q_t kernel_get_rows_q<block_iq1_m,   QK_NL, dequantize_iq1_m>; | ||||||
|  | template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl,  2,     dequantize_iq4_nl>; | ||||||
|  | template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs,  QK_NL, dequantize_iq4_xs>; | ||||||
|  |  | ||||||
| // | // | ||||||
| // matrix-matrix multiplication | // matrix-matrix multiplication | ||||||
| // | // | ||||||
|  |  | ||||||
| typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t; | typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t; | ||||||
|  |  | ||||||
| template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mat_mm_t kernel_mul_mm<float4x4,      1,     dequantize_f32>; | template [[host_name("kernel_mul_mm_f32_f32")]]     kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   float4x4,      1,     dequantize_f32>; | ||||||
| template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mat_mm_t kernel_mul_mm<half4x4,       1,     dequantize_f16>; | template [[host_name("kernel_mul_mm_f16_f32")]]     kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   half4x4,       1,     dequantize_f16>; | ||||||
| template [[host_name("kernel_mul_mm_q4_0_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q4_0,    2,     dequantize_q4_0>; | template [[host_name("kernel_mul_mm_q4_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_0,    2,     dequantize_q4_0>; | ||||||
| template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q4_1,    2,     dequantize_q4_1>; | template [[host_name("kernel_mul_mm_q4_1_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_1,    2,     dequantize_q4_1>; | ||||||
| template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q5_0,    2,     dequantize_q5_0>; | template [[host_name("kernel_mul_mm_q5_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_0,    2,     dequantize_q5_0>; | ||||||
| template [[host_name("kernel_mul_mm_q5_1_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q5_1,    2,     dequantize_q5_1>; | template [[host_name("kernel_mul_mm_q5_1_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_1,    2,     dequantize_q5_1>; | ||||||
| template [[host_name("kernel_mul_mm_q8_0_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q8_0,    2,     dequantize_q8_0>; | template [[host_name("kernel_mul_mm_q8_0_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q8_0,    2,     dequantize_q8_0>; | ||||||
| template [[host_name("kernel_mul_mm_q2_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q2_K,    QK_NL, dequantize_q2_K>; | template [[host_name("kernel_mul_mm_q2_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q2_K,    QK_NL, dequantize_q2_K>; | ||||||
| template [[host_name("kernel_mul_mm_q3_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q3_K,    QK_NL, dequantize_q3_K>; | template [[host_name("kernel_mul_mm_q3_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q3_K,    QK_NL, dequantize_q3_K>; | ||||||
| template [[host_name("kernel_mul_mm_q4_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q4_K,    QK_NL, dequantize_q4_K>; | template [[host_name("kernel_mul_mm_q4_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q4_K,    QK_NL, dequantize_q4_K>; | ||||||
| template [[host_name("kernel_mul_mm_q5_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q5_K,    QK_NL, dequantize_q5_K>; | template [[host_name("kernel_mul_mm_q5_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q5_K,    QK_NL, dequantize_q5_K>; | ||||||
| template [[host_name("kernel_mul_mm_q6_K_f32")]]    kernel mat_mm_t kernel_mul_mm<block_q6_K,    QK_NL, dequantize_q6_K>; | template [[host_name("kernel_mul_mm_q6_K_f32")]]    kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_q6_K,    QK_NL, dequantize_q6_K>; | ||||||
| template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; | template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_xxs, QK_NL, dequantize_iq2_xxs>; | ||||||
| template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq2_xs,  QK_NL, dequantize_iq2_xs>; | template [[host_name("kernel_mul_mm_iq2_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_xs,  QK_NL, dequantize_iq2_xs>; | ||||||
| template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; | template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq3_xxs, QK_NL, dequantize_iq3_xxs>; | ||||||
| template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq3_s,   QK_NL, dequantize_iq3_s>; | template [[host_name("kernel_mul_mm_iq3_s_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq3_s,   QK_NL, dequantize_iq3_s>; | ||||||
| template [[host_name("kernel_mul_mm_iq2_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq2_s,   QK_NL, dequantize_iq2_s>; | template [[host_name("kernel_mul_mm_iq2_s_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq2_s,   QK_NL, dequantize_iq2_s>; | ||||||
| template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq1_s,   QK_NL, dequantize_iq1_s>; | template [[host_name("kernel_mul_mm_iq1_s_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq1_s,   QK_NL, dequantize_iq1_s>; | ||||||
| template [[host_name("kernel_mul_mm_iq1_m_f32")]]   kernel mat_mm_t kernel_mul_mm<block_iq1_m,   QK_NL, dequantize_iq1_m>; | template [[host_name("kernel_mul_mm_iq1_m_f32")]]   kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq1_m,   QK_NL, dequantize_iq1_m>; | ||||||
| template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq4_nl,  2,     dequantize_iq4_nl>; | template [[host_name("kernel_mul_mm_iq4_nl_f32")]]  kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq4_nl,  2,     dequantize_iq4_nl>; | ||||||
| template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<block_iq4_xs,  QK_NL, dequantize_iq4_xs>; | template [[host_name("kernel_mul_mm_iq4_xs_f32")]]  kernel mat_mm_t kernel_mul_mm<half,   half4x4,   simdgroup_half8x8,   block_iq4_xs,  QK_NL, dequantize_iq4_xs>; | ||||||
|  |  | ||||||
| // | // | ||||||
| // indirect matrix-matrix multiplication | // indirect matrix-matrix multiplication | ||||||
| @@ -6436,7 +6065,7 @@ void mmv_fn( | |||||||
|     impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); |     impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); | ||||||
| } | } | ||||||
|  |  | ||||||
| typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t; | typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t; | ||||||
|  |  | ||||||
| template<mul_mv_impl_fn_t impl_fn> | template<mul_mv_impl_fn_t impl_fn> | ||||||
| kernel void kernel_mul_mv_id( | kernel void kernel_mul_mv_id( | ||||||
| @@ -6514,10 +6143,10 @@ kernel void kernel_mul_mv_id( | |||||||
|         sgitg); |         sgitg); | ||||||
| } | } | ||||||
|  |  | ||||||
| typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t; | typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t; | ||||||
|  |  | ||||||
| template [[host_name("kernel_mul_mv_id_f32_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>; | template [[host_name("kernel_mul_mv_id_f32_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>; | ||||||
| template [[host_name("kernel_mul_mv_id_f16_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>; | template [[host_name("kernel_mul_mv_id_f16_f32")]]     kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>; | ||||||
| template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>; | template [[host_name("kernel_mul_mv_id_q8_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>; | ||||||
| template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; | template [[host_name("kernel_mul_mv_id_q4_0_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; | ||||||
| template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; | template [[host_name("kernel_mul_mv_id_q4_1_f32")]]    kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov