mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : optimize ggml_mul_mat_id (faster Mixtral PP) (#4725)
* ggml : disable fast-math for Metal (cmake build only) ggml-ci * metal : fix Metal API debug warnings * cmake : add -fno-inline for Metal build (#4545) * metal : fix API debug warnings * metal : fix compile warnings * metal : use uint64_t for strides * cmake : rename option to LLAMA_METAL_SHADER_DEBUG * metal : fix mat-vec Q8_0 kernel for BS > 1 * metal : normalize mat-vec kernel signatures * cmake : respect LLAMA_QKK_64 option * metal : fix mat-vec Q4_K kernel for QK_K == 64 * metal : optimizing ggml_mul_mat_id (wip) * metal : minor fix * metal : opt mul_mm_id
This commit is contained in:
		
							
								
								
									
										31
									
								
								ggml-metal.m
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								ggml-metal.m
									
									
									
									
									
								
							| @@ -1657,6 +1657,10 @@ void ggml_metal_graph_compute( | ||||
|                                         } | ||||
|                                 }; | ||||
|  | ||||
|                                 if (ggml_is_quantized(src0t)) { | ||||
|                                     GGML_ASSERT(ne00 >= nth0*nth1); | ||||
|                                 } | ||||
|  | ||||
|                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||
|                                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; | ||||
|                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; | ||||
| @@ -1715,6 +1719,9 @@ void ggml_metal_graph_compute( | ||||
|                             // TODO: make this more general | ||||
|                             GGML_ASSERT(n_as <= 8); | ||||
|  | ||||
|                             // max size of the src1ids array in the kernel stack | ||||
|                             GGML_ASSERT(ne11 <= 512); | ||||
|  | ||||
|                             struct ggml_tensor * src2 = gf->nodes[i]->src[2]; | ||||
|  | ||||
|                             const int64_t  ne20 = src2 ? src2->ne[0] : 0; | ||||
| @@ -1732,9 +1739,6 @@ void ggml_metal_graph_compute( | ||||
|                             GGML_ASSERT(!ggml_is_transposed(src2)); | ||||
|                             GGML_ASSERT(!ggml_is_transposed(src1)); | ||||
|  | ||||
|                             GGML_ASSERT(ne20 % 32 == 0); | ||||
|                             // !!!!!!!!! TODO: this assert is probably required but not sure! | ||||
|                             //GGML_ASSERT(ne20 >= 64); | ||||
|                             GGML_ASSERT(src1t == GGML_TYPE_F32); | ||||
|  | ||||
|                             const uint r2 = ne12/ne22; | ||||
| @@ -1742,22 +1746,22 @@ void ggml_metal_graph_compute( | ||||
|  | ||||
|                             // find the break-even point where the matrix-matrix kernel becomes more efficient compared | ||||
|                             // to the matrix-vector kernel | ||||
|                             int ne11_mm_min = 1; | ||||
|                             int ne11_mm_min = n_as; | ||||
|  | ||||
|                             const int idx = ((int32_t *) dst->op_params)[0]; | ||||
|  | ||||
|                             // batch size | ||||
|                             GGML_ASSERT(ne01 == ne11); | ||||
|  | ||||
|                             const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory | ||||
|  | ||||
|                             // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs | ||||
|                             // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel | ||||
|                             // !!! | ||||
|                             // TODO: for now, always use mat-vec kernels until we figure out how to improve the | ||||
|                             //       indirect matrix multiplication | ||||
|                             // !!! | ||||
|                             if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) { | ||||
|                             if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && | ||||
|                                 ne20 % 32 == 0 && ne20 >= 64 && | ||||
|                                 ne11 > ne11_mm_min) { | ||||
|                                 switch (src2->type) { | ||||
|                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32];  break; | ||||
|                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32];  break; | ||||
| @@ -1787,7 +1791,7 @@ void ggml_metal_graph_compute( | ||||
|                                 [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:11]; | ||||
|                                 [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:12]; | ||||
|                                 [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13]; | ||||
|                                 [encoder setBytes:&_ne1    length:sizeof(_ne1) atIndex:14]; | ||||
|                                 [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14]; | ||||
|                                 [encoder setBytes:&nb1     length:sizeof(nb1)  atIndex:15]; | ||||
|                                 [encoder setBytes:&r2      length:sizeof(r2)   atIndex:16]; | ||||
|                                 [encoder setBytes:&r3      length:sizeof(r3)   atIndex:17]; | ||||
| @@ -1805,8 +1809,7 @@ void ggml_metal_graph_compute( | ||||
|  | ||||
|                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0]; | ||||
|  | ||||
|                                 // TODO: processing one row at a time (ne11 -> 1) is not efficient | ||||
|                                 [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; | ||||
|                                 [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; | ||||
|                             } else { | ||||
|                                 int nth0 = 32; | ||||
|                                 int nth1 = 1; | ||||
| @@ -1889,11 +1892,17 @@ void ggml_metal_graph_compute( | ||||
|                                         } break; | ||||
|                                     default: | ||||
|                                         { | ||||
|                                             GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); | ||||
|                                             GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); | ||||
|                                             GGML_ASSERT(false && "not implemented"); | ||||
|                                         } | ||||
|                                 }; | ||||
|  | ||||
|                                 if (ggml_is_quantized(src2t)) { | ||||
|                                     GGML_ASSERT(ne20 >= nth0*nth1); | ||||
|                                 } | ||||
|  | ||||
|                                 const int64_t _ne1 = 1; // kernels needs a reference in constant memory | ||||
|  | ||||
|                                 [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; | ||||
|                                 [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; | ||||
|                                 [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2]; | ||||
|   | ||||
							
								
								
									
										205
									
								
								ggml-metal.metal
									
									
									
									
									
								
							
							
						
						
									
										205
									
								
								ggml-metal.metal
									
									
									
									
									
								
							| @@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre | ||||
| #define N_SIMDGROUP 2  // number of SIMD groups in a thread group | ||||
| //Note: This is a template, but strictly speaking it only applies to | ||||
| //      quantizations where the block size is 32. It also does not | ||||
| //      giard against the number of rows not being divisible by | ||||
| //      guard against the number of rows not being divisible by | ||||
| //      N_DST, so this is another explicit assumption of the implementation. | ||||
| template<typename block_q_type, int nr, int nsg, int nw> | ||||
| void mul_vec_q_n_f32_impl( | ||||
| @@ -3973,6 +3973,131 @@ void kernel_mul_mm_impl(device const  uchar * src0, | ||||
|     } | ||||
| } | ||||
|  | ||||
| // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids | ||||
| template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> | ||||
| void kernel_mul_mm_id_impl( | ||||
|         device const  uchar * src0, | ||||
|         device const  uchar * src1, | ||||
|         thread        short * src1ids, | ||||
|         device        float * dst, | ||||
|         constant    int64_t & ne00, | ||||
|         constant    int64_t & ne02, | ||||
|         constant   uint64_t & nb01, | ||||
|         constant   uint64_t & nb02, | ||||
|         constant    int64_t & ne12, | ||||
|         constant   uint64_t & nb10, | ||||
|         constant   uint64_t & nb11, | ||||
|         constant   uint64_t & nb12, | ||||
|         constant    int64_t & ne0, | ||||
|                     int64_t   ne1, | ||||
|         constant       uint & r2, | ||||
|         constant       uint & r3, | ||||
|         threadgroup   uchar * shared_memory, | ||||
|         uint3                 tgpig[[threadgroup_position_in_grid]], | ||||
|         uint                  tiitg[[thread_index_in_threadgroup]], | ||||
|         uint                  sgitg[[simdgroup_index_in_threadgroup]]) { | ||||
|  | ||||
|     threadgroup half  * sa = (threadgroup half  *)(shared_memory); | ||||
|     threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); | ||||
|  | ||||
|     const uint r0 = tgpig.y; | ||||
|     const uint r1 = tgpig.x; | ||||
|     const uint im = tgpig.z; | ||||
|  | ||||
|     if (r1 * BLOCK_SIZE_N >= ne1) return; | ||||
|  | ||||
|     // if this block is of 64x32 shape or smaller | ||||
|     short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; | ||||
|     short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; | ||||
|  | ||||
|     // a thread shouldn't load data outside of the matrix | ||||
|     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; | ||||
|     short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; | ||||
|  | ||||
|     simdgroup_half8x8  ma[4]; | ||||
|     simdgroup_float8x8 mb[2]; | ||||
|     simdgroup_float8x8 c_res[8]; | ||||
|     for (int i = 0; i < 8; i++){ | ||||
|         c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f); | ||||
|     } | ||||
|  | ||||
|     short il = (tiitg % THREAD_PER_ROW); | ||||
|  | ||||
|     const uint i12 = im%ne12; | ||||
|     const uint i13 = im/ne12; | ||||
|  | ||||
|     uint   offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); | ||||
|     ushort offset1 = il/nl; | ||||
|  | ||||
|     device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; | ||||
|     device const float   * y = (device const float   *)(src1 | ||||
|         + nb12 * im | ||||
|         + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col] | ||||
|         + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); | ||||
|  | ||||
|     for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { | ||||
|         // load data and store to threadgroup memory | ||||
|         half4x4 temp_a; | ||||
|         dequantize_func(x, il, temp_a); | ||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|         for (int i = 0; i < 16; i++) { | ||||
|             *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ | ||||
|             +                     (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ | ||||
|             +                     (tiitg / THREAD_PER_ROW) % 8  + (i & 7) * 8) = temp_a[i/4][i%4]; | ||||
|         } | ||||
|  | ||||
|         *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); | ||||
|  | ||||
|         il = (il + 2 < nl) ? il + 2 : il % 2; | ||||
|         x  = (il < 2) ? x + (2+nl-1)/nl : x; | ||||
|         y += BLOCK_SIZE_K; | ||||
|  | ||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|         // load matrices from threadgroup memory and conduct outer products | ||||
|         threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); | ||||
|         threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); | ||||
|  | ||||
|         for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { | ||||
|             for (int i = 0; i < 4; i++) { | ||||
|                 simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); | ||||
|             } | ||||
|             simdgroup_barrier(mem_flags::mem_none); | ||||
|             for (int i = 0; i < 2; i++) { | ||||
|                 simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); | ||||
|             } | ||||
|  | ||||
|             lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; | ||||
|             lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; | ||||
|  | ||||
|             for (int i = 0; i < 8; i++){ | ||||
|                 simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     { | ||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|         threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ | ||||
|                                       + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; | ||||
|         for (int i = 0; i < 8; i++) { | ||||
|             simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); | ||||
|         } | ||||
|  | ||||
|         threadgroup_barrier(mem_flags::mem_threadgroup); | ||||
|  | ||||
|         device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0; | ||||
|         if (sgitg == 0) { | ||||
|             for (int i = 0; i < n_rows; i++) { | ||||
|                 for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { | ||||
|                     *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
|  | ||||
| template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)> | ||||
| kernel void kernel_mul_mm(device const  uchar * src0, | ||||
|                           device const  uchar * src1, | ||||
| @@ -4019,7 +4144,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_ | ||||
| kernel void kernel_mul_mm_id( | ||||
|         device const   uchar * ids, | ||||
|         device const   uchar * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne02, | ||||
| @@ -4048,18 +4173,28 @@ kernel void kernel_mul_mm_id( | ||||
|         uint3                  tgpig[[threadgroup_position_in_grid]], | ||||
|         uint                   tiitg[[thread_index_in_threadgroup]], | ||||
|         uint                   sgitg[[simdgroup_index_in_threadgroup]]) { | ||||
|     device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; | ||||
|     device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; | ||||
|  | ||||
|     const int64_t bid = tgpig.z/(ne12*ne13); | ||||
|     // expert id | ||||
|     const int32_t id = tgpig.z/(ne12*ne13); | ||||
|  | ||||
|     tgpig.z = tgpig.z%(ne12*ne13); | ||||
|  | ||||
|     const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx]; | ||||
|     // row indices of src1 for expert id | ||||
|     int64_t _ne1 = 0; | ||||
|     short src1ids[512]; | ||||
|  | ||||
|     kernel_mul_mm_impl<block_q, nl, dequantize_func>( | ||||
|         src0[id], | ||||
|         src1 + bid*nb11, | ||||
|         (device float *) (dst + bid*nb1), | ||||
|     for (int64_t i1 = 0; i1 < ne1; i1++) { | ||||
|         if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { | ||||
|             src1ids[_ne1++] = i1; | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     kernel_mul_mm_id_impl<block_q, nl, dequantize_func>( | ||||
|         src0s[id], | ||||
|         src1, | ||||
|         src1ids, | ||||
|         dst, | ||||
|         ne00, | ||||
|         ne02, | ||||
|         nb01, | ||||
| @@ -4069,7 +4204,7 @@ kernel void kernel_mul_mm_id( | ||||
|         nb11, | ||||
|         nb12, | ||||
|         ne0, | ||||
|         ne1, | ||||
|         _ne1, | ||||
|         r2, | ||||
|         r3, | ||||
|         shared_memory, | ||||
| @@ -4158,7 +4293,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<b | ||||
| typedef void (mat_mm_id_t)( | ||||
|         device const   uchar * ids, | ||||
|         device const   uchar * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne02, | ||||
| @@ -4207,7 +4342,7 @@ template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mu | ||||
| kernel void kernel_mul_mv_id_f32_f32( | ||||
|         device const    char * ids, | ||||
|         device const    char * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne01, | ||||
| @@ -4251,7 +4386,7 @@ kernel void kernel_mul_mv_id_f32_f32( | ||||
|     kernel_mul_mv_f32_f32_impl( | ||||
|         src0[id], | ||||
|         src1 + bid*nb11, | ||||
|         (device float *) (dst + bid*nb1), | ||||
|         dst  + bid*ne0, | ||||
|         ne00, | ||||
|         ne01, | ||||
|         ne02, | ||||
| @@ -4276,7 +4411,7 @@ kernel void kernel_mul_mv_id_f32_f32( | ||||
| kernel void kernel_mul_mv_id_f16_f32( | ||||
|         device const    char * ids, | ||||
|         device const    char * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne01, | ||||
| @@ -4320,7 +4455,7 @@ kernel void kernel_mul_mv_id_f16_f32( | ||||
|     kernel_mul_mv_f16_f32_impl( | ||||
|         src0[id], | ||||
|         src1 + bid*nb11, | ||||
|         (device float *) (dst + bid*nb1), | ||||
|         dst  + bid*ne0, | ||||
|         ne00, | ||||
|         ne01, | ||||
|         ne02, | ||||
| @@ -4345,7 +4480,7 @@ kernel void kernel_mul_mv_id_f16_f32( | ||||
| kernel void kernel_mul_mv_id_q8_0_f32( | ||||
|         device const    char * ids, | ||||
|         device const    char * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne01, | ||||
| @@ -4389,7 +4524,7 @@ kernel void kernel_mul_mv_id_q8_0_f32( | ||||
|     kernel_mul_mv_q8_0_f32_impl( | ||||
|         src0[id], | ||||
|         (device const float *) (src1 + bid*nb11), | ||||
|         (device       float *) ( dst + bid*nb1), | ||||
|         dst + bid*ne0, | ||||
|         ne00, | ||||
|         ne01, | ||||
|         ne02, | ||||
| @@ -4408,7 +4543,7 @@ kernel void kernel_mul_mv_id_q8_0_f32( | ||||
| kernel void kernel_mul_mv_id_q4_0_f32( | ||||
|         device const    char * ids, | ||||
|         device const    char * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne01, | ||||
| @@ -4452,7 +4587,7 @@ kernel void kernel_mul_mv_id_q4_0_f32( | ||||
|     mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>( | ||||
|         src0[id], | ||||
|         (device const float *) (src1 + bid*nb11), | ||||
|         (device       float *) ( dst + bid*nb1), | ||||
|         dst + bid*ne0, | ||||
|         ne00, | ||||
|         ne01, | ||||
|         ne02, | ||||
| @@ -4471,7 +4606,7 @@ kernel void kernel_mul_mv_id_q4_0_f32( | ||||
| kernel void kernel_mul_mv_id_q4_1_f32( | ||||
|         device const    char * ids, | ||||
|         device const    char * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne01, | ||||
| @@ -4515,7 +4650,7 @@ kernel void kernel_mul_mv_id_q4_1_f32( | ||||
|     mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>( | ||||
|         src0[id], | ||||
|         (device const float *) (src1 + bid*nb11), | ||||
|         (device       float *) ( dst + bid*nb1), | ||||
|         dst + bid*ne0, | ||||
|         ne00, | ||||
|         ne01, | ||||
|         ne02, | ||||
| @@ -4534,7 +4669,7 @@ kernel void kernel_mul_mv_id_q4_1_f32( | ||||
| kernel void kernel_mul_mv_id_q5_0_f32( | ||||
|         device const    char * ids, | ||||
|         device const    char * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne01, | ||||
| @@ -4578,7 +4713,7 @@ kernel void kernel_mul_mv_id_q5_0_f32( | ||||
|     mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>( | ||||
|         src0[id], | ||||
|         (device const float *) (src1 + bid*nb11), | ||||
|         (device       float *) ( dst + bid*nb1), | ||||
|         dst + bid*ne0, | ||||
|         ne00, | ||||
|         ne01, | ||||
|         ne02, | ||||
| @@ -4597,7 +4732,7 @@ kernel void kernel_mul_mv_id_q5_0_f32( | ||||
| kernel void kernel_mul_mv_id_q5_1_f32( | ||||
|         device const    char * ids, | ||||
|         device const    char * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne01, | ||||
| @@ -4641,7 +4776,7 @@ kernel void kernel_mul_mv_id_q5_1_f32( | ||||
|     mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>( | ||||
|         src0[id], | ||||
|         (device const float *) (src1 + bid*nb11), | ||||
|         (device       float *) ( dst + bid*nb1), | ||||
|         dst + bid*ne0, | ||||
|         ne00, | ||||
|         ne01, | ||||
|         ne02, | ||||
| @@ -4660,7 +4795,7 @@ kernel void kernel_mul_mv_id_q5_1_f32( | ||||
| kernel void kernel_mul_mv_id_q2_K_f32( | ||||
|         device const    char * ids, | ||||
|         device const    char * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne01, | ||||
| @@ -4704,7 +4839,7 @@ kernel void kernel_mul_mv_id_q2_K_f32( | ||||
|     kernel_mul_mv_q2_K_f32_impl( | ||||
|         src0[id], | ||||
|         (device const float *) (src1 + bid*nb11), | ||||
|         (device       float *) ( dst + bid*nb1), | ||||
|         dst + bid*ne0, | ||||
|         ne00, | ||||
|         ne01, | ||||
|         ne02, | ||||
| @@ -4723,7 +4858,7 @@ kernel void kernel_mul_mv_id_q2_K_f32( | ||||
| kernel void kernel_mul_mv_id_q3_K_f32( | ||||
|         device const    char * ids, | ||||
|         device const    char * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne01, | ||||
| @@ -4767,7 +4902,7 @@ kernel void kernel_mul_mv_id_q3_K_f32( | ||||
|     kernel_mul_mv_q3_K_f32_impl( | ||||
|         src0[id], | ||||
|         (device const float *) (src1 + bid*nb11), | ||||
|         (device       float *) ( dst + bid*nb1), | ||||
|         dst + bid*ne0, | ||||
|         ne00, | ||||
|         ne01, | ||||
|         ne02, | ||||
| @@ -4786,7 +4921,7 @@ kernel void kernel_mul_mv_id_q3_K_f32( | ||||
| kernel void kernel_mul_mv_id_q4_K_f32( | ||||
|         device const    char * ids, | ||||
|         device const    char * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne01, | ||||
| @@ -4830,7 +4965,7 @@ kernel void kernel_mul_mv_id_q4_K_f32( | ||||
|     kernel_mul_mv_q4_K_f32_impl( | ||||
|         src0[id], | ||||
|         (device const float *) (src1 + bid*nb11), | ||||
|         (device       float *) ( dst + bid*nb1), | ||||
|         dst + bid*ne0, | ||||
|         ne00, | ||||
|         ne01, | ||||
|         ne02, | ||||
| @@ -4849,7 +4984,7 @@ kernel void kernel_mul_mv_id_q4_K_f32( | ||||
| kernel void kernel_mul_mv_id_q5_K_f32( | ||||
|         device const    char * ids, | ||||
|         device const    char * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne01, | ||||
| @@ -4893,7 +5028,7 @@ kernel void kernel_mul_mv_id_q5_K_f32( | ||||
|     kernel_mul_mv_q5_K_f32_impl( | ||||
|         src0[id], | ||||
|         (device const float *) (src1 + bid*nb11), | ||||
|         (device       float *) ( dst + bid*nb1), | ||||
|         dst + bid*ne0, | ||||
|         ne00, | ||||
|         ne01, | ||||
|         ne02, | ||||
| @@ -4912,7 +5047,7 @@ kernel void kernel_mul_mv_id_q5_K_f32( | ||||
| kernel void kernel_mul_mv_id_q6_K_f32( | ||||
|         device const    char * ids, | ||||
|         device const    char * src1, | ||||
|         device         uchar * dst, | ||||
|         device         float * dst, | ||||
|         constant    uint64_t & nbi1, | ||||
|         constant     int64_t & ne00, | ||||
|         constant     int64_t & ne01, | ||||
| @@ -4956,7 +5091,7 @@ kernel void kernel_mul_mv_id_q6_K_f32( | ||||
|     kernel_mul_mv_q6_K_f32_impl( | ||||
|         src0[id], | ||||
|         (device const float *) (src1 + bid*nb11), | ||||
|         (device       float *) ( dst + bid*nb1), | ||||
|         dst + bid*ne0, | ||||
|         ne00, | ||||
|         ne01, | ||||
|         ne02, | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov