mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	metal : optimize ggml_mul_mat_id (faster Mixtral PP) (#4725)
* ggml : disable fast-math for Metal (cmake build only) ggml-ci * metal : fix Metal API debug warnings * cmake : add -fno-inline for Metal build (#4545) * metal : fix API debug warnings * metal : fix compile warnings * metal : use uint64_t for strides * cmake : rename option to LLAMA_METAL_SHADER_DEBUG * metal : fix mat-vec Q8_0 kernel for BS > 1 * metal : normalize mat-vec kernel signatures * cmake : respect LLAMA_QKK_64 option * metal : fix mat-vec Q4_K kernel for QK_K == 64 * metal : optimizing ggml_mul_mat_id (wip) * metal : minor fix * metal : opt mul_mm_id
This commit is contained in:
		
							
								
								
									
										31
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -1657,6 +1657,10 @@ void ggml_metal_graph_compute( | |||||||
|                                         } |                                         } | ||||||
|                                 }; |                                 }; | ||||||
|  |  | ||||||
|  |                                 if (ggml_is_quantized(src0t)) { | ||||||
|  |                                     GGML_ASSERT(ne00 >= nth0*nth1); | ||||||
|  |                                 } | ||||||
|  |  | ||||||
|                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||||
|                                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |                                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; | ||||||
|                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; |                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; | ||||||
| @@ -1715,6 +1719,9 @@ void ggml_metal_graph_compute( | |||||||
|                             // TODO: make this more general |                             // TODO: make this more general | ||||||
|                             GGML_ASSERT(n_as <= 8); |                             GGML_ASSERT(n_as <= 8); | ||||||
|  |  | ||||||
|  |                             // max size of the src1ids array in the kernel stack | ||||||
|  |                             GGML_ASSERT(ne11 <= 512); | ||||||
|  |  | ||||||
|                             struct ggml_tensor * src2 = gf->nodes[i]->src[2]; |                             struct ggml_tensor * src2 = gf->nodes[i]->src[2]; | ||||||
|  |  | ||||||
|                             const int64_t  ne20 = src2 ? src2->ne[0] : 0; |                             const int64_t  ne20 = src2 ? src2->ne[0] : 0; | ||||||
| @@ -1732,9 +1739,6 @@ void ggml_metal_graph_compute( | |||||||
|                             GGML_ASSERT(!ggml_is_transposed(src2)); |                             GGML_ASSERT(!ggml_is_transposed(src2)); | ||||||
|                             GGML_ASSERT(!ggml_is_transposed(src1)); |                             GGML_ASSERT(!ggml_is_transposed(src1)); | ||||||
|  |  | ||||||
|                             GGML_ASSERT(ne20 % 32 == 0); |  | ||||||
|                             // !!!!!!!!! TODO: this assert is probably required but not sure! |  | ||||||
|                             //GGML_ASSERT(ne20 >= 64); |  | ||||||
|                             GGML_ASSERT(src1t == GGML_TYPE_F32); |                             GGML_ASSERT(src1t == GGML_TYPE_F32); | ||||||
|  |  | ||||||
|                             const uint r2 = ne12/ne22; |                             const uint r2 = ne12/ne22; | ||||||
| @@ -1742,22 +1746,22 @@ void ggml_metal_graph_compute( | |||||||
|  |  | ||||||
|                             // find the break-even point where the matrix-matrix kernel becomes more efficient compared |                             // find the break-even point where the matrix-matrix kernel becomes more efficient compared | ||||||
|                             // to the matrix-vector kernel |                             // to the matrix-vector kernel | ||||||
|                             int ne11_mm_min = 1; |                             int ne11_mm_min = n_as; | ||||||
|  |  | ||||||
|                             const int idx = ((int32_t *) dst->op_params)[0]; |                             const int idx = ((int32_t *) dst->op_params)[0]; | ||||||
|  |  | ||||||
|                             // batch size |                             // batch size | ||||||
|                             GGML_ASSERT(ne01 == ne11); |                             GGML_ASSERT(ne01 == ne11); | ||||||
|  |  | ||||||
|                             const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory |  | ||||||
|  |  | ||||||
|                             // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs |                             // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs | ||||||
|                             // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel |                             // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel | ||||||
|                             // !!! |                             // !!! | ||||||
|                             // TODO: for now, always use mat-vec kernels until we figure out how to improve the |                             // TODO: for now, always use mat-vec kernels until we figure out how to improve the | ||||||
|                             //       indirect matrix multiplication |                             //       indirect matrix multiplication | ||||||
|                             // !!! |                             // !!! | ||||||
|                             if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) { |                             if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && | ||||||
|  |                                 ne20 % 32 == 0 && ne20 >= 64 && | ||||||
|  |                                 ne11 > ne11_mm_min) { | ||||||
|                                 switch (src2->type) { |                                 switch (src2->type) { | ||||||
|                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32];  break; |                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32];  break; | ||||||
|                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32];  break; |                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32];  break; | ||||||
| @@ -1787,7 +1791,7 @@ void ggml_metal_graph_compute( | |||||||
|                                 [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:11]; |                                 [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:11]; | ||||||
|                                 [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:12]; |                                 [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:12]; | ||||||
|                                 [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13]; |                                 [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13]; | ||||||
|                                 [encoder setBytes:&_ne1    length:sizeof(_ne1) atIndex:14]; |                                 [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14]; | ||||||
|                                 [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:15]; |                                 [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:15]; | ||||||
|                                 [encoder setBytes:&r2      length:sizeof(r2)   atIndex:16]; |                                 [encoder setBytes:&r2      length:sizeof(r2)   atIndex:16]; | ||||||
|                                 [encoder setBytes:&r3      length:sizeof(r3)   atIndex:17]; |                                 [encoder setBytes:&r3      length:sizeof(r3)   atIndex:17]; | ||||||
| @@ -1805,8 +1809,7 @@ void ggml_metal_graph_compute( | |||||||
|  |  | ||||||
|                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0]; |                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0]; | ||||||
|  |  | ||||||
|                                 // TODO: processing one row at a time (ne11 -> 1) is not efficient |                                 [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; | ||||||
|                                 [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; |  | ||||||
|                             } else { |                             } else { | ||||||
|                                 int nth0 = 32; |                                 int nth0 = 32; | ||||||
|                                 int nth1 = 1; |                                 int nth1 = 1; | ||||||
| @@ -1889,11 +1892,17 @@ void ggml_metal_graph_compute( | |||||||
|                                         } break; |                                         } break; | ||||||
|                                     default: |                                     default: | ||||||
|                                         { |                                         { | ||||||
|                                             GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); |                                             GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); | ||||||
|                                             GGML_ASSERT(false && "not implemented"); |                                             GGML_ASSERT(false && "not implemented"); | ||||||
|                                         } |                                         } | ||||||
|                                 }; |                                 }; | ||||||
|  |  | ||||||
|  |                                 if (ggml_is_quantized(src2t)) { | ||||||
|  |                                     GGML_ASSERT(ne20 >= nth0*nth1); | ||||||
|  |                                 } | ||||||
|  |  | ||||||
|  |                                 const int64_t _ne1 = 1; // kernels needs a reference in constant memory | ||||||
|  |  | ||||||
|                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||||
|                                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |                                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; | ||||||
|                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; |                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; | ||||||
|   | |||||||
							
								
								
									
										205
									
								
								ggml-metal.metal
									
									
									
									
									
								
							
							
						
						
									
										205
									
								
								ggml-metal.metal
									
									
									
									
									
								
							| @@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre | |||||||
| #define N_SIMDGROUP 2  // number of SIMD groups in a thread group | #define N_SIMDGROUP 2  // number of SIMD groups in a thread group | ||||||
| //Note: This is a template, but strictly speaking it only applies to | //Note: This is a template, but strictly speaking it only applies to | ||||||
| //      quantizations where the block size is 32. It also does not | //      quantizations where the block size is 32. It also does not | ||||||
| //      giard against the number of rows not being divisible by | //      guard against the number of rows not being divisible by | ||||||
| //      N_DST, so this is another explicit assumption of the implementation. | //      N_DST, so this is another explicit assumption of the implementation. | ||||||
| template<typename block_q_type, int nr, int nsg, int nw> | template<typename block_q_type, int nr, int nsg, int nw> | ||||||
| void mul_vec_q_n_f32_impl( | void mul_vec_q_n_f32_impl( | ||||||
| @@ -3973,6 +3973,131 @@ void kernel_mul_mm_impl(device const  uchar * src0, | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids | ||||||
|  | template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> | ||||||
|  | void kernel_mul_mm_id_impl( | ||||||
|  |         device const  uchar * src0, | ||||||
|  |         device const  uchar * src1, | ||||||
|  |         thread        short * src1ids, | ||||||
|  |         device        float * dst, | ||||||
|  |         constant    int64_t & ne00, | ||||||
|  |         constant    int64_t & ne02, | ||||||
|  |         constant   uint64_t & nb01, | ||||||
|  |         constant   uint64_t & nb02, | ||||||
|  |         constant    int64_t & ne12, | ||||||
|  |         constant   uint64_t & nb10, | ||||||
|  |         constant   uint64_t & nb11, | ||||||
|  |         constant   uint64_t & nb12, | ||||||
|  |         constant    int64_t & ne0, | ||||||
|  |                     int64_t   ne1, | ||||||
|  |         constant       uint & r2, | ||||||
|  |         constant       uint & r3, | ||||||
|  |         threadgroup   uchar * shared_memory, | ||||||
|  |         uint3                 tgpig[[threadgroup_position_in_grid]], | ||||||
|  |         uint                  tiitg[[thread_index_in_threadgroup]], | ||||||
|  |         uint                  sgitg[[simdgroup_index_in_threadgroup]]) { | ||||||
|  |  | ||||||
|  |     threadgroup half  * sa = (threadgroup half  *)(shared_memory); | ||||||
|  |     threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); | ||||||
|  |  | ||||||
|  |     const uint r0 = tgpig.y; | ||||||
|  |     const uint r1 = tgpig.x; | ||||||
|  |     const uint im = tgpig.z; | ||||||
|  |  | ||||||
|  |     if (r1 * BLOCK_SIZE_N >= ne1) return; | ||||||
|  |  | ||||||
|  |     // if this block is of 64x32 shape or smaller | ||||||
|  |     short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; | ||||||
|  |     short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; | ||||||
|  |  | ||||||
|  |     // a thread shouldn't load data outside of the matrix | ||||||
|  |     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; | ||||||
|  |     short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; | ||||||
|  |  | ||||||
|  |     simdgroup_half8x8  ma[4]; | ||||||
|  |     simdgroup_float8x8 mb[2]; | ||||||
|  |     simdgroup_float8x8 c_res[8]; | ||||||
|  |     for (int i = 0; i < 8; i++){ | ||||||
|  |         c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     short il = (tiitg % THREAD_PER_ROW); | ||||||
|  |  | ||||||
|  |     const uint i12 = im%ne12; | ||||||
|  |     const uint i13 = im/ne12; | ||||||
|  |  | ||||||
|  |     uint   offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); | ||||||
|  |     ushort offset1 = il/nl; | ||||||
|  |  | ||||||
|  |     device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; | ||||||
|  |     device const float   * y = (device const float   *)(src1 | ||||||
|  |         + nb12 * im | ||||||
|  |         + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col] | ||||||
|  |         + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); | ||||||
|  |  | ||||||
|  |     for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { | ||||||
|  |         // load data and store to threadgroup memory | ||||||
|  |         half4x4 temp_a; | ||||||
|  |         dequantize_func(x, il, temp_a); | ||||||
|  |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|  |         for (int i = 0; i < 16; i++) { | ||||||
|  |             *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ | ||||||
|  |             +                     (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ | ||||||
|  |             +                     (tiitg / THREAD_PER_ROW) % 8  + (i & 7) * 8) = temp_a[i/4][i%4]; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); | ||||||
|  |  | ||||||
|  |         il = (il + 2 < nl) ? il + 2 : il % 2; | ||||||
|  |         x  = (il < 2) ? x + (2+nl-1)/nl : x; | ||||||
|  |         y += BLOCK_SIZE_K; | ||||||
|  |  | ||||||
|  |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|  |         // load matrices from threadgroup memory and conduct outer products | ||||||
|  |         threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); | ||||||
|  |         threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); | ||||||
|  |  | ||||||
|  |         for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { | ||||||
|  |             for (int i = 0; i < 4; i++) { | ||||||
|  |                 simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); | ||||||
|  |             } | ||||||
|  |             simdgroup_barrier(mem_flags::mem_none); | ||||||
|  |             for (int i = 0; i < 2; i++) { | ||||||
|  |                 simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; | ||||||
|  |             lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; | ||||||
|  |  | ||||||
|  |             for (int i = 0; i < 8; i++){ | ||||||
|  |                 simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     { | ||||||
|  |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |         threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ | ||||||
|  |                                       + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; | ||||||
|  |         for (int i = 0; i < 8; i++) { | ||||||
|  |             simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||||
|  |  | ||||||
|  |         device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0; | ||||||
|  |         if (sgitg == 0) { | ||||||
|  |             for (int i = 0; i < n_rows; i++) { | ||||||
|  |                 for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { | ||||||
|  |                     *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
| template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> | template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> | ||||||
| kernel void kernel_mul_mm(device const  uchar * src0, | kernel void kernel_mul_mm(device const  uchar * src0, | ||||||
|                           device const  uchar * src1, |                           device const  uchar * src1, | ||||||
| @@ -4019,7 +4144,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_ | |||||||
| kernel void kernel_mul_mm_id( | kernel void kernel_mul_mm_id( | ||||||
|         device const   uchar * ids, |         device const   uchar * ids, | ||||||
|         device const   uchar * src1, |         device const   uchar * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne02, |         constant     int64_t & ne02, | ||||||
| @@ -4048,18 +4173,28 @@ kernel void kernel_mul_mm_id( | |||||||
|         uint3                  tgpig[[threadgroup_position_in_grid]], |         uint3                  tgpig[[threadgroup_position_in_grid]], | ||||||
|         uint                   tiitg[[thread_index_in_threadgroup]], |         uint                   tiitg[[thread_index_in_threadgroup]], | ||||||
|         uint                   sgitg[[simdgroup_index_in_threadgroup]]) { |         uint                   sgitg[[simdgroup_index_in_threadgroup]]) { | ||||||
|     device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; |     device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; | ||||||
|  |  | ||||||
|     const int64_t bid = tgpig.z/(ne12*ne13); |     // expert id | ||||||
|  |     const int32_t id = tgpig.z/(ne12*ne13); | ||||||
|  |  | ||||||
|     tgpig.z = tgpig.z%(ne12*ne13); |     tgpig.z = tgpig.z%(ne12*ne13); | ||||||
|  |  | ||||||
|     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; |     // row indices of src1 for expert id | ||||||
|  |     int64_t _ne1 = 0; | ||||||
|  |     short src1ids[512]; | ||||||
|  |  | ||||||
|     kernel_mul_mm_impl<block_q, nl, dequantize_func>( |     for (int64_t i1 = 0; i1 < ne1; i1++) { | ||||||
|         src0[id], |         if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { | ||||||
|         src1 + bid*nb11, |             src1ids[_ne1++] = i1; | ||||||
|         (device float *) (dst + bid*nb1), |         } | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     kernel_mul_mm_id_impl<block_q, nl, dequantize_func>( | ||||||
|  |         src0s[id], | ||||||
|  |         src1, | ||||||
|  |         src1ids, | ||||||
|  |         dst, | ||||||
|         ne00, |         ne00, | ||||||
|         ne02, |         ne02, | ||||||
|         nb01, |         nb01, | ||||||
| @@ -4069,7 +4204,7 @@ kernel void kernel_mul_mm_id( | |||||||
|         nb11, |         nb11, | ||||||
|         nb12, |         nb12, | ||||||
|         ne0, |         ne0, | ||||||
|         ne1, |         _ne1, | ||||||
|         r2, |         r2, | ||||||
|         r3, |         r3, | ||||||
|         shared_memory, |         shared_memory, | ||||||
| @@ -4158,7 +4293,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b | |||||||
| typedef void (mat_mm_id_t)( | typedef void (mat_mm_id_t)( | ||||||
|         device const   uchar * ids, |         device const   uchar * ids, | ||||||
|         device const   uchar * src1, |         device const   uchar * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne02, |         constant     int64_t & ne02, | ||||||
| @@ -4207,7 +4342,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu | |||||||
| kernel void kernel_mul_mv_id_f32_f32( | kernel void kernel_mul_mv_id_f32_f32( | ||||||
|         device const    char * ids, |         device const    char * ids, | ||||||
|         device const    char * src1, |         device const    char * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne01, |         constant     int64_t & ne01, | ||||||
| @@ -4251,7 +4386,7 @@ kernel void kernel_mul_mv_id_f32_f32( | |||||||
|     kernel_mul_mv_f32_f32_impl( |     kernel_mul_mv_f32_f32_impl( | ||||||
|         src0[id], |         src0[id], | ||||||
|         src1 + bid*nb11, |         src1 + bid*nb11, | ||||||
|         (device float *) (dst + bid*nb1), |         dst  + bid*ne0, | ||||||
|         ne00, |         ne00, | ||||||
|         ne01, |         ne01, | ||||||
|         ne02, |         ne02, | ||||||
| @@ -4276,7 +4411,7 @@ kernel void kernel_mul_mv_id_f32_f32( | |||||||
| kernel void kernel_mul_mv_id_f16_f32( | kernel void kernel_mul_mv_id_f16_f32( | ||||||
|         device const    char * ids, |         device const    char * ids, | ||||||
|         device const    char * src1, |         device const    char * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne01, |         constant     int64_t & ne01, | ||||||
| @@ -4320,7 +4455,7 @@ kernel void kernel_mul_mv_id_f16_f32( | |||||||
|     kernel_mul_mv_f16_f32_impl( |     kernel_mul_mv_f16_f32_impl( | ||||||
|         src0[id], |         src0[id], | ||||||
|         src1 + bid*nb11, |         src1 + bid*nb11, | ||||||
|         (device float *) (dst + bid*nb1), |         dst  + bid*ne0, | ||||||
|         ne00, |         ne00, | ||||||
|         ne01, |         ne01, | ||||||
|         ne02, |         ne02, | ||||||
| @@ -4345,7 +4480,7 @@ kernel void kernel_mul_mv_id_f16_f32( | |||||||
| kernel void kernel_mul_mv_id_q8_0_f32( | kernel void kernel_mul_mv_id_q8_0_f32( | ||||||
|         device const    char * ids, |         device const    char * ids, | ||||||
|         device const    char * src1, |         device const    char * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne01, |         constant     int64_t & ne01, | ||||||
| @@ -4389,7 +4524,7 @@ kernel void kernel_mul_mv_id_q8_0_f32( | |||||||
|     kernel_mul_mv_q8_0_f32_impl( |     kernel_mul_mv_q8_0_f32_impl( | ||||||
|         src0[id], |         src0[id], | ||||||
|         (device const float *) (src1 + bid*nb11), |         (device const float *) (src1 + bid*nb11), | ||||||
|         (device       float *) ( dst + bid*nb1), |         dst + bid*ne0, | ||||||
|         ne00, |         ne00, | ||||||
|         ne01, |         ne01, | ||||||
|         ne02, |         ne02, | ||||||
| @@ -4408,7 +4543,7 @@ kernel void kernel_mul_mv_id_q8_0_f32( | |||||||
| kernel void kernel_mul_mv_id_q4_0_f32( | kernel void kernel_mul_mv_id_q4_0_f32( | ||||||
|         device const    char * ids, |         device const    char * ids, | ||||||
|         device const    char * src1, |         device const    char * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne01, |         constant     int64_t & ne01, | ||||||
| @@ -4452,7 +4587,7 @@ kernel void kernel_mul_mv_id_q4_0_f32( | |||||||
|     mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>( |     mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>( | ||||||
|         src0[id], |         src0[id], | ||||||
|         (device const float *) (src1 + bid*nb11), |         (device const float *) (src1 + bid*nb11), | ||||||
|         (device       float *) ( dst + bid*nb1), |         dst + bid*ne0, | ||||||
|         ne00, |         ne00, | ||||||
|         ne01, |         ne01, | ||||||
|         ne02, |         ne02, | ||||||
| @@ -4471,7 +4606,7 @@ kernel void kernel_mul_mv_id_q4_0_f32( | |||||||
| kernel void kernel_mul_mv_id_q4_1_f32( | kernel void kernel_mul_mv_id_q4_1_f32( | ||||||
|         device const    char * ids, |         device const    char * ids, | ||||||
|         device const    char * src1, |         device const    char * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne01, |         constant     int64_t & ne01, | ||||||
| @@ -4515,7 +4650,7 @@ kernel void kernel_mul_mv_id_q4_1_f32( | |||||||
|     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>( |     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>( | ||||||
|         src0[id], |         src0[id], | ||||||
|         (device const float *) (src1 + bid*nb11), |         (device const float *) (src1 + bid*nb11), | ||||||
|         (device       float *) ( dst + bid*nb1), |         dst + bid*ne0, | ||||||
|         ne00, |         ne00, | ||||||
|         ne01, |         ne01, | ||||||
|         ne02, |         ne02, | ||||||
| @@ -4534,7 +4669,7 @@ kernel void kernel_mul_mv_id_q4_1_f32( | |||||||
| kernel void kernel_mul_mv_id_q5_0_f32( | kernel void kernel_mul_mv_id_q5_0_f32( | ||||||
|         device const    char * ids, |         device const    char * ids, | ||||||
|         device const    char * src1, |         device const    char * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne01, |         constant     int64_t & ne01, | ||||||
| @@ -4578,7 +4713,7 @@ kernel void kernel_mul_mv_id_q5_0_f32( | |||||||
|     mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>( |     mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>( | ||||||
|         src0[id], |         src0[id], | ||||||
|         (device const float *) (src1 + bid*nb11), |         (device const float *) (src1 + bid*nb11), | ||||||
|         (device       float *) ( dst + bid*nb1), |         dst + bid*ne0, | ||||||
|         ne00, |         ne00, | ||||||
|         ne01, |         ne01, | ||||||
|         ne02, |         ne02, | ||||||
| @@ -4597,7 +4732,7 @@ kernel void kernel_mul_mv_id_q5_0_f32( | |||||||
| kernel void kernel_mul_mv_id_q5_1_f32( | kernel void kernel_mul_mv_id_q5_1_f32( | ||||||
|         device const    char * ids, |         device const    char * ids, | ||||||
|         device const    char * src1, |         device const    char * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne01, |         constant     int64_t & ne01, | ||||||
| @@ -4641,7 +4776,7 @@ kernel void kernel_mul_mv_id_q5_1_f32( | |||||||
|     mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>( |     mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>( | ||||||
|         src0[id], |         src0[id], | ||||||
|         (device const float *) (src1 + bid*nb11), |         (device const float *) (src1 + bid*nb11), | ||||||
|         (device       float *) ( dst + bid*nb1), |         dst + bid*ne0, | ||||||
|         ne00, |         ne00, | ||||||
|         ne01, |         ne01, | ||||||
|         ne02, |         ne02, | ||||||
| @@ -4660,7 +4795,7 @@ kernel void kernel_mul_mv_id_q5_1_f32( | |||||||
| kernel void kernel_mul_mv_id_q2_K_f32( | kernel void kernel_mul_mv_id_q2_K_f32( | ||||||
|         device const    char * ids, |         device const    char * ids, | ||||||
|         device const    char * src1, |         device const    char * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne01, |         constant     int64_t & ne01, | ||||||
| @@ -4704,7 +4839,7 @@ kernel void kernel_mul_mv_id_q2_K_f32( | |||||||
|     kernel_mul_mv_q2_K_f32_impl( |     kernel_mul_mv_q2_K_f32_impl( | ||||||
|         src0[id], |         src0[id], | ||||||
|         (device const float *) (src1 + bid*nb11), |         (device const float *) (src1 + bid*nb11), | ||||||
|         (device       float *) ( dst + bid*nb1), |         dst + bid*ne0, | ||||||
|         ne00, |         ne00, | ||||||
|         ne01, |         ne01, | ||||||
|         ne02, |         ne02, | ||||||
| @@ -4723,7 +4858,7 @@ kernel void kernel_mul_mv_id_q2_K_f32( | |||||||
| kernel void kernel_mul_mv_id_q3_K_f32( | kernel void kernel_mul_mv_id_q3_K_f32( | ||||||
|         device const    char * ids, |         device const    char * ids, | ||||||
|         device const    char * src1, |         device const    char * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne01, |         constant     int64_t & ne01, | ||||||
| @@ -4767,7 +4902,7 @@ kernel void kernel_mul_mv_id_q3_K_f32( | |||||||
|     kernel_mul_mv_q3_K_f32_impl( |     kernel_mul_mv_q3_K_f32_impl( | ||||||
|         src0[id], |         src0[id], | ||||||
|         (device const float *) (src1 + bid*nb11), |         (device const float *) (src1 + bid*nb11), | ||||||
|         (device       float *) ( dst + bid*nb1), |         dst + bid*ne0, | ||||||
|         ne00, |         ne00, | ||||||
|         ne01, |         ne01, | ||||||
|         ne02, |         ne02, | ||||||
| @@ -4786,7 +4921,7 @@ kernel void kernel_mul_mv_id_q3_K_f32( | |||||||
| kernel void kernel_mul_mv_id_q4_K_f32( | kernel void kernel_mul_mv_id_q4_K_f32( | ||||||
|         device const    char * ids, |         device const    char * ids, | ||||||
|         device const    char * src1, |         device const    char * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne01, |         constant     int64_t & ne01, | ||||||
| @@ -4830,7 +4965,7 @@ kernel void kernel_mul_mv_id_q4_K_f32( | |||||||
|     kernel_mul_mv_q4_K_f32_impl( |     kernel_mul_mv_q4_K_f32_impl( | ||||||
|         src0[id], |         src0[id], | ||||||
|         (device const float *) (src1 + bid*nb11), |         (device const float *) (src1 + bid*nb11), | ||||||
|         (device       float *) ( dst + bid*nb1), |         dst + bid*ne0, | ||||||
|         ne00, |         ne00, | ||||||
|         ne01, |         ne01, | ||||||
|         ne02, |         ne02, | ||||||
| @@ -4849,7 +4984,7 @@ kernel void kernel_mul_mv_id_q4_K_f32( | |||||||
| kernel void kernel_mul_mv_id_q5_K_f32( | kernel void kernel_mul_mv_id_q5_K_f32( | ||||||
|         device const    char * ids, |         device const    char * ids, | ||||||
|         device const    char * src1, |         device const    char * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne01, |         constant     int64_t & ne01, | ||||||
| @@ -4893,7 +5028,7 @@ kernel void kernel_mul_mv_id_q5_K_f32( | |||||||
|     kernel_mul_mv_q5_K_f32_impl( |     kernel_mul_mv_q5_K_f32_impl( | ||||||
|         src0[id], |         src0[id], | ||||||
|         (device const float *) (src1 + bid*nb11), |         (device const float *) (src1 + bid*nb11), | ||||||
|         (device       float *) ( dst + bid*nb1), |         dst + bid*ne0, | ||||||
|         ne00, |         ne00, | ||||||
|         ne01, |         ne01, | ||||||
|         ne02, |         ne02, | ||||||
| @@ -4912,7 +5047,7 @@ kernel void kernel_mul_mv_id_q5_K_f32( | |||||||
| kernel void kernel_mul_mv_id_q6_K_f32( | kernel void kernel_mul_mv_id_q6_K_f32( | ||||||
|         device const    char * ids, |         device const    char * ids, | ||||||
|         device const    char * src1, |         device const    char * src1, | ||||||
|         device         uchar * dst, |         device         float * dst, | ||||||
|         constant    uint64_t & nbi1, |         constant    uint64_t & nbi1, | ||||||
|         constant     int64_t & ne00, |         constant     int64_t & ne00, | ||||||
|         constant     int64_t & ne01, |         constant     int64_t & ne01, | ||||||
| @@ -4956,7 +5091,7 @@ kernel void kernel_mul_mv_id_q6_K_f32( | |||||||
|     kernel_mul_mv_q6_K_f32_impl( |     kernel_mul_mv_q6_K_f32_impl( | ||||||
|         src0[id], |         src0[id], | ||||||
|         (device const float *) (src1 + bid*nb11), |         (device const float *) (src1 + bid*nb11), | ||||||
|         (device       float *) ( dst + bid*nb1), |         dst + bid*ne0, | ||||||
|         ne00, |         ne00, | ||||||
|         ne01, |         ne01, | ||||||
|         ne02, |         ne02, | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov