|
|
|
|
@@ -2843,7 +2843,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|
|
|
|
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<short NR0, short NW>
|
|
|
|
|
template<short NR0>
|
|
|
|
|
static inline void helper_mv_reduce_and_write(
|
|
|
|
|
device float * dst_f32,
|
|
|
|
|
float sumf[NR0],
|
|
|
|
|
@@ -2852,6 +2852,8 @@ static inline void helper_mv_reduce_and_write(
|
|
|
|
|
ushort tiisg,
|
|
|
|
|
ushort sgitg,
|
|
|
|
|
threadgroup char * shmem) {
|
|
|
|
|
constexpr short NW = N_SIMDWIDTH;
|
|
|
|
|
|
|
|
|
|
threadgroup float * shmem_f32[NR0];
|
|
|
|
|
|
|
|
|
|
for (short row = 0; row < NR0; ++row) {
|
|
|
|
|
@@ -2883,9 +2885,10 @@ static inline void helper_mv_reduce_and_write(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
|
|
|
|
|
constant short FC_mul_mv_nsg [[function_constant(FC_MUL_MV + 0)]];
|
|
|
|
|
constant short FC_mul_mv_nxpsg [[function_constant(FC_MUL_MV + 1)]];
|
|
|
|
|
|
|
|
|
|
template<typename block_q_type, short NR0, short NW, typename args_t>
|
|
|
|
|
template<typename block_q_type, short NR0, typename args_t>
|
|
|
|
|
void mul_vec_q_n_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -2897,6 +2900,7 @@ void mul_vec_q_n_f32_impl(
|
|
|
|
|
ushort sgitg) {
|
|
|
|
|
const short NSG = FC_mul_mv_nsg;
|
|
|
|
|
|
|
|
|
|
constexpr short NW = N_SIMDWIDTH;
|
|
|
|
|
constexpr short NQ = 16;
|
|
|
|
|
|
|
|
|
|
const int nb = args.ne00/QK4_0;
|
|
|
|
|
@@ -2961,7 +2965,7 @@ void mul_vec_q_n_f32_impl(
|
|
|
|
|
|
|
|
|
|
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
|
|
|
|
|
|
|
|
|
//helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
|
|
|
|
//helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
|
|
|
|
|
|
|
|
|
for (int row = 0; row < NR0; ++row) {
|
|
|
|
|
const float tot = simd_sum(sumf[row]);
|
|
|
|
|
@@ -2981,7 +2985,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel void kernel_mul_mv_q4_1_f32(
|
|
|
|
|
@@ -2993,7 +2997,7 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel void kernel_mul_mv_q5_0_f32(
|
|
|
|
|
@@ -3005,7 +3009,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel void kernel_mul_mv_q5_1_f32(
|
|
|
|
|
@@ -3017,10 +3021,10 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<short NR0, short NW, typename args_t>
|
|
|
|
|
template<short NR0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_q8_0_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -3032,6 +3036,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
|
|
|
ushort sgitg) {
|
|
|
|
|
const short NSG = FC_mul_mv_nsg;
|
|
|
|
|
|
|
|
|
|
constexpr short NW = N_SIMDWIDTH;
|
|
|
|
|
constexpr short NQ = 8;
|
|
|
|
|
|
|
|
|
|
const int nb = args.ne00/QK8_0;
|
|
|
|
|
@@ -3090,7 +3095,7 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|
|
|
|
|
|
|
|
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
|
|
|
|
|
|
|
|
|
helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
|
|
|
|
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
[[host_name("kernel_mul_mv_q8_0_f32")]]
|
|
|
|
|
@@ -3103,12 +3108,12 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// mat-vec kernel processing in chunks of float4
|
|
|
|
|
// chpb - chunks per quantization block
|
|
|
|
|
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
|
|
|
|
|
template<short r1ptg, typename q_t, short chpb, void (*deq_t4)(device const q_t *, short, thread float4 &) >
|
|
|
|
|
void kernel_mul_mv_ext_q4_f32_impl(
|
|
|
|
|
constant ggml_metal_kargs_mul_mv_ext & args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -3117,6 +3122,9 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
const short NSG = FC_mul_mv_nsg;
|
|
|
|
|
const short nxpsg = FC_mul_mv_nxpsg;
|
|
|
|
|
|
|
|
|
|
const short chpt = 4; // chunks per thread
|
|
|
|
|
|
|
|
|
|
//const short nxpsg = (32);
|
|
|
|
|
@@ -3125,7 +3133,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
|
|
|
|
const short tx = tiisg%nxpsg;
|
|
|
|
|
const short ty = tiisg/nxpsg;
|
|
|
|
|
|
|
|
|
|
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
|
|
|
|
|
const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
|
|
|
|
|
const int i11 = tgpig.y*r1ptg;
|
|
|
|
|
const int i1m = tgpig.z;
|
|
|
|
|
|
|
|
|
|
@@ -3208,7 +3216,7 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// mat-vec kernel processing in chunks of float4x4
|
|
|
|
|
template<short nxpsg, short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
|
|
|
|
|
template<short r1ptg, typename q_t, short chpb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &) >
|
|
|
|
|
void kernel_mul_mv_ext_q4x4_f32_impl(
|
|
|
|
|
constant ggml_metal_kargs_mul_mv_ext & args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -3217,6 +3225,9 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
const short NSG = FC_mul_mv_nsg;
|
|
|
|
|
const short nxpsg = FC_mul_mv_nxpsg;
|
|
|
|
|
|
|
|
|
|
const short chpt = 1;
|
|
|
|
|
|
|
|
|
|
//const short nxpsg = (32);
|
|
|
|
|
@@ -3225,7 +3236,7 @@ void kernel_mul_mv_ext_q4x4_f32_impl(
|
|
|
|
|
const short tx = tiisg%nxpsg;
|
|
|
|
|
const short ty = tiisg/nxpsg;
|
|
|
|
|
|
|
|
|
|
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
|
|
|
|
|
const int i01 = tgpig.x*(nypsg*NSG) + nypsg*sgitg + ty;
|
|
|
|
|
const int i11 = tgpig.y*r1ptg;
|
|
|
|
|
const int i1m = tgpig.z;
|
|
|
|
|
|
|
|
|
|
@@ -3322,12 +3333,7 @@ kernel void kernel_mul_mv_ext_q4_f32_disp(
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
switch (args.nxpsg) {
|
|
|
|
|
case 4: kernel_mul_mv_ext_q4_f32_impl<4, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
|
|
|
|
case 8: kernel_mul_mv_ext_q4_f32_impl<8, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
|
|
|
|
case 16: kernel_mul_mv_ext_q4_f32_impl<16, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
|
|
|
|
case 32: kernel_mul_mv_ext_q4_f32_impl<32, r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
|
|
|
|
}
|
|
|
|
|
kernel_mul_mv_ext_q4_f32_impl<r1ptg, q_t, epb/4, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<short r1ptg, typename q_t, short epb, void (*deq_t4x4)(device const q_t *, short, thread float4x4 &)>
|
|
|
|
|
@@ -3339,12 +3345,7 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
switch (args.nxpsg) {
|
|
|
|
|
case 4: kernel_mul_mv_ext_q4x4_f32_impl<4, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
|
|
|
|
case 8: kernel_mul_mv_ext_q4x4_f32_impl<8, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
|
|
|
|
case 16: kernel_mul_mv_ext_q4x4_f32_impl<16, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
|
|
|
|
case 32: kernel_mul_mv_ext_q4x4_f32_impl<32, r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
|
|
|
|
}
|
|
|
|
|
kernel_mul_mv_ext_q4x4_f32_impl<r1ptg, q_t, epb/16, deq_t4x4>(args, src0, src1, dst, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
|
|
|
|
|
@@ -3410,7 +3411,7 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
|
|
|
|
|
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
|
|
|
|
|
|
|
|
|
|
template<typename T0, typename T1, short NR0, short NW, typename args_t>
|
|
|
|
|
template<typename T0, typename T1, short NR0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_t_t_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -3422,6 +3423,7 @@ void kernel_mul_mv_t_t_impl(
|
|
|
|
|
ushort sgitg) {
|
|
|
|
|
const short NSG = FC_mul_mv_nsg;
|
|
|
|
|
|
|
|
|
|
constexpr short NW = N_SIMDWIDTH;
|
|
|
|
|
constexpr short NB = 32;
|
|
|
|
|
constexpr short NF = 8;
|
|
|
|
|
|
|
|
|
|
@@ -3486,10 +3488,10 @@ void kernel_mul_mv_t_t_impl(
|
|
|
|
|
|
|
|
|
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
|
|
|
|
|
|
|
|
|
helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
|
|
|
|
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename T0, typename T1, short NR0, short NW>
|
|
|
|
|
template<typename T0, typename T1, short NR0>
|
|
|
|
|
kernel void kernel_mul_mv_t_t(
|
|
|
|
|
constant ggml_metal_kargs_mul_mv & args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -3499,20 +3501,20 @@ kernel void kernel_mul_mv_t_t(
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
kernel_mul_mv_t_t_impl<T0, T1, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_t_t_impl<T0, T1, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t;
|
|
|
|
|
typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F>) mul_mv_t_t;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F, N_SIMDWIDTH>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float, N_R0_F, N_SIMDWIDTH>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half, N_R0_F, N_SIMDWIDTH>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float, N_R0_F>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half, N_R0_F>;
|
|
|
|
|
#if defined(GGML_METAL_HAS_BF16)
|
|
|
|
|
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float, N_R0_F, N_SIMDWIDTH>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F, N_SIMDWIDTH>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float, N_R0_F>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F>;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW, typename args_t>
|
|
|
|
|
template<typename T0, typename T04, typename T1, typename T14, short NR0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_t_t_4_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -3524,6 +3526,7 @@ void kernel_mul_mv_t_t_4_impl(
|
|
|
|
|
ushort sgitg) {
|
|
|
|
|
const short NSG = FC_mul_mv_nsg;
|
|
|
|
|
|
|
|
|
|
constexpr short NW = N_SIMDWIDTH;
|
|
|
|
|
constexpr short NB = 32;
|
|
|
|
|
constexpr short NF = 16;
|
|
|
|
|
constexpr short NF4 = NF/4;
|
|
|
|
|
@@ -3591,10 +3594,10 @@ void kernel_mul_mv_t_t_4_impl(
|
|
|
|
|
|
|
|
|
|
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
|
|
|
|
|
|
|
|
|
helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
|
|
|
|
helper_mv_reduce_and_write<NR0>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename T0, typename T04, typename T1, typename T14, short NR0, short NW>
|
|
|
|
|
template<typename T0, typename T04, typename T1, typename T14, short NR0>
|
|
|
|
|
kernel void kernel_mul_mv_t_t_4(
|
|
|
|
|
constant ggml_metal_kargs_mul_mv & args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -3604,17 +3607,17 @@ kernel void kernel_mul_mv_t_t_4(
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SIMDWIDTH>) mul_mv_t_t_4;
|
|
|
|
|
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>) mul_mv_t_t_4;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4, N_R0_F, N_SIMDWIDTH>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SIMDWIDTH>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4, N_R0_F>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F>;
|
|
|
|
|
#if defined(GGML_METAL_HAS_BF16)
|
|
|
|
|
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4, N_R0_F, N_SIMDWIDTH>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F, N_SIMDWIDTH>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4, N_R0_F>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F>;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define N_MV_T_T 4
|
|
|
|
|
@@ -5966,7 +5969,7 @@ kernel void kernel_concat(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_q2_K_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -6068,10 +6071,10 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_q2_K_f32_impl<N_R0_Q2_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_q3_K_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -6233,10 +6236,10 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_q3_K_f32_impl<N_R0_Q3_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_q4_K_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -6248,9 +6251,9 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
|
|
|
ushort sgitg) {
|
|
|
|
|
const short NSG = FC_mul_mv_nsg;
|
|
|
|
|
|
|
|
|
|
const uint16_t kmask1 = 0x3f3f;
|
|
|
|
|
const uint16_t kmask2 = 0x0f0f;
|
|
|
|
|
const uint16_t kmask3 = 0xc0c0;
|
|
|
|
|
constexpr uint16_t kmask1 = 0x3f3f;
|
|
|
|
|
constexpr uint16_t kmask2 = 0x0f0f;
|
|
|
|
|
constexpr uint16_t kmask3 = 0xc0c0;
|
|
|
|
|
|
|
|
|
|
const short ix = tiisg/8; // 0...3
|
|
|
|
|
const short it = tiisg%8; // 0...7
|
|
|
|
|
@@ -6309,7 +6312,7 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
|
|
|
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
|
|
|
|
|
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
|
|
|
|
|
|
|
|
|
|
for (short i = 0; i < 4; ++i) {
|
|
|
|
|
FOR_UNROLL (short i = 0; i < 4; ++i) {
|
|
|
|
|
acc1[0] += yl[2*i + 0] * (q1[i] & 0x000F);
|
|
|
|
|
acc1[1] += yl[2*i + 1] * (q1[i] & 0x0F00);
|
|
|
|
|
acc1[2] += yl[2*i + 8] * (q1[i] & 0x00F0);
|
|
|
|
|
@@ -6320,14 +6323,11 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|
|
|
|
acc2[3] += yh[2*i + 9] * (q2[i] & 0xF000);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
float dall = dh[0];
|
|
|
|
|
float dmin = dh[1];
|
|
|
|
|
|
|
|
|
|
sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
|
|
|
|
|
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
|
|
|
|
|
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
|
|
|
|
|
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
|
|
|
|
|
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
|
|
|
|
sumf[row] += dh[0] * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] +
|
|
|
|
|
(acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f +
|
|
|
|
|
(acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] +
|
|
|
|
|
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
|
|
|
|
|
dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
|
|
|
|
|
|
|
|
|
q1 += args.nb01/2;
|
|
|
|
|
sc += args.nb01/2;
|
|
|
|
|
@@ -6357,10 +6357,10 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_q4_K_f32_impl<N_R0_Q4_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_q5_K_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -6393,9 +6393,9 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
|
|
|
|
|
|
|
|
float yl[16], yh[16];
|
|
|
|
|
|
|
|
|
|
const uint16_t kmask1 = 0x3f3f;
|
|
|
|
|
const uint16_t kmask2 = 0x0f0f;
|
|
|
|
|
const uint16_t kmask3 = 0xc0c0;
|
|
|
|
|
constexpr uint16_t kmask1 = 0x3f3f;
|
|
|
|
|
constexpr uint16_t kmask2 = 0x0f0f;
|
|
|
|
|
constexpr uint16_t kmask3 = 0xc0c0;
|
|
|
|
|
|
|
|
|
|
const short tid = tiisg/4;
|
|
|
|
|
const short ix = tiisg%4;
|
|
|
|
|
@@ -6441,7 +6441,7 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
|
|
|
|
|
|
|
|
float4 acc1 = {0.f};
|
|
|
|
|
float4 acc2 = {0.f};
|
|
|
|
|
for (short l = 0; l < 8; ++l) {
|
|
|
|
|
FOR_UNROLL (short l = 0; l < 8; ++l) {
|
|
|
|
|
uint8_t h = qh[l];
|
|
|
|
|
acc1[0] += yl[l+0] * (q1[l] & 0x0F);
|
|
|
|
|
acc1[1] += yl[l+8] * (q1[l] & 0xF0);
|
|
|
|
|
@@ -6452,13 +6452,12 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|
|
|
|
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
|
|
|
|
|
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
|
|
|
|
|
}
|
|
|
|
|
const float dall = dh[0];
|
|
|
|
|
const float dmin = dh[1];
|
|
|
|
|
sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
|
|
|
|
|
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
|
|
|
|
|
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
|
|
|
|
|
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
|
|
|
|
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
|
|
|
|
|
|
|
|
|
sumf[row] += dh[0] * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
|
|
|
|
|
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
|
|
|
|
|
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
|
|
|
|
|
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
|
|
|
|
|
dh[1] * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
|
|
|
|
|
|
|
|
|
|
q1 += args.nb01;
|
|
|
|
|
qh += args.nb01;
|
|
|
|
|
@@ -6489,10 +6488,10 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_q5_K_f32_impl<N_R0_Q5_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_q6_K_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -6504,10 +6503,10 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
|
|
|
ushort sgitg) {
|
|
|
|
|
const short NSG = FC_mul_mv_nsg;
|
|
|
|
|
|
|
|
|
|
const uint8_t kmask1 = 0x03;
|
|
|
|
|
const uint8_t kmask2 = 0x0C;
|
|
|
|
|
const uint8_t kmask3 = 0x30;
|
|
|
|
|
const uint8_t kmask4 = 0xC0;
|
|
|
|
|
constexpr uint8_t kmask1 = 0x03;
|
|
|
|
|
constexpr uint8_t kmask2 = 0x0C;
|
|
|
|
|
constexpr uint8_t kmask3 = 0x30;
|
|
|
|
|
constexpr uint8_t kmask4 = 0xC0;
|
|
|
|
|
|
|
|
|
|
const int nb = args.ne00/QK_K;
|
|
|
|
|
|
|
|
|
|
@@ -6558,18 +6557,16 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (short row = 0; row < nr0; ++row) {
|
|
|
|
|
const float dall = dh[0];
|
|
|
|
|
|
|
|
|
|
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
|
|
|
|
|
|
|
|
|
for (short l = 0; l < 4; ++l) {
|
|
|
|
|
FOR_UNROLL (short l = 0; l < 4; ++l) {
|
|
|
|
|
sums[0] += yl[4*l + 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
|
|
|
|
sums[1] += yl[4*l + 1] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
|
|
|
|
sums[2] += yl[4*l + 2] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
|
|
|
|
sums[3] += yl[4*l + 3] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
sumf[row] += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
|
|
|
|
sumf[row] += dh[0] * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
|
|
|
|
|
|
|
|
|
q1 += args.nb01;
|
|
|
|
|
q2 += args.nb01;
|
|
|
|
|
@@ -6599,12 +6596,12 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_q6_K_f32_impl<N_R0_Q6_K, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ======================= "True" 2-bit
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_iq2_xxs_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -6709,10 +6706,10 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_iq2_xs_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -6828,10 +6825,10 @@ kernel void kernel_mul_mv_iq2_xs_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_iq2_xs_f32_impl<N_R0_IQ2_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_iq3_xxs_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -6940,10 +6937,10 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_iq3_s_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -7052,10 +7049,10 @@ kernel void kernel_mul_mv_iq3_s_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_iq3_s_f32_impl<N_R0_IQ3_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_iq2_s_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -7165,10 +7162,10 @@ kernel void kernel_mul_mv_iq2_s_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_iq2_s_f32_impl<N_R0_IQ2_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_iq1_s_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -7264,10 +7261,10 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_iq1_s_f32_impl<N_R0_IQ1_S, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_iq1_m_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -7373,10 +7370,10 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_iq1_m_f32_impl<N_R0_IQ1_M, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_iq4_nl_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -7480,10 +7477,10 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_iq4_nl_f32_impl<N_R0_IQ4_NL, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_iq4_xs_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -7587,10 +7584,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_iq4_xs_f32_impl<N_R0_IQ4_XS, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<int nr0, int nw, typename args_t>
|
|
|
|
|
template<int nr0, typename args_t>
|
|
|
|
|
void kernel_mul_mv_mxfp4_f32_impl(
|
|
|
|
|
args_t args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
@@ -7677,7 +7674,7 @@ kernel void kernel_mul_mv_mxfp4_f32(
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
|
|
|
|
|
|
|
|
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
|
|
|
@@ -8353,7 +8350,7 @@ void mmv_fn(
|
|
|
|
|
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, N_SIMDWIDTH, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
|
|
|
|
|
typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
|
|
|
|
|
|
|
|
|
|
template<mul_mv_impl_fn_t impl_fn>
|
|
|
|
|
kernel void kernel_mul_mv_id(
|
|
|
|
|
@@ -8418,44 +8415,44 @@ kernel void kernel_mul_mv_id(
|
|
|
|
|
sgitg);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_t;
|
|
|
|
|
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>) kernel_mul_mv_id_t;
|
|
|
|
|
|
|
|
|
|
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_4_t;
|
|
|
|
|
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>) kernel_mul_mv_id_4_t;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half, float, N_R0_F, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half, float, N_R0_F>>>;
|
|
|
|
|
#if defined(GGML_METAL_HAS_BF16)
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F>>>;
|
|
|
|
|
#endif
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half, half4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half, half4, float, float4, N_R0_F>>>;
|
|
|
|
|
#if defined(GGML_METAL_HAS_BF16)
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F>>>;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0>>>;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_R0_Q4_1>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_R0_Q5_0>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_R0_Q5_1>>>;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_mxfp4_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_mxfp4_f32_impl<N_R0_MXFP4>>>;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS, N_SIMDWIDTH>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl <N_R0_Q2_K>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl <N_R0_Q3_K>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl <N_R0_Q4_K>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl <N_R0_Q5_K>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl <N_R0_Q6_K>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl <N_R0_IQ1_S>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl <N_R0_IQ1_M>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl<N_R0_IQ2_XXS>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq2_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xs_f32_impl <N_R0_IQ2_XS>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq3_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_xxs_f32_impl<N_R0_IQ3_XXS>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq3_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq3_s_f32_impl <N_R0_IQ3_S>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq2_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_s_f32_impl <N_R0_IQ2_S>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_nl_f32_impl <N_R0_IQ4_NL>>>;
|
|
|
|
|
template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq4_xs_f32_impl <N_R0_IQ4_XS>>>;
|
|
|
|
|
|
|
|
|
|
kernel void kernel_pool_2d_max_f32(
|
|
|
|
|
constant ggml_metal_kargs_pool_2d & args,
|
|
|
|
|
|