mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
metal : improve F32, F16 and BF16 mat-vec multiplication
ggml-ci
This commit is contained in:
@@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
|
||||
}
|
||||
|
||||
void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
|
||||
if (!ppls) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
|
||||
ggml_metal_pipeline_free(it->second);
|
||||
}
|
||||
@@ -467,37 +471,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
|
||||
// use custom matrix x vector kernel
|
||||
switch (tsrc0) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
nr1 = 4;
|
||||
if (ne00 == 4) {
|
||||
nr0 = 32;
|
||||
suffix = "_c4";
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
if (op->src[1]->type == GGML_TYPE_F32) {
|
||||
if (ne00 == 4) {
|
||||
nr0 = 32;
|
||||
nr1 = 4;
|
||||
suffix = "_c4";
|
||||
} else if (ne11 * ne12 < 4) {
|
||||
suffix = "_1row";
|
||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
suffix = "_l4";
|
||||
nr1 = ne11;
|
||||
} else {
|
||||
nr1 = 4;
|
||||
}
|
||||
} else {
|
||||
if (ne00 == 4) {
|
||||
nsg = 1;
|
||||
nr0 = 32;
|
||||
nr1 = 4;
|
||||
suffix = "_c4";
|
||||
} else if (ne00 % 4 == 0) {
|
||||
nsg = N_SG_F;
|
||||
nr0 = N_R0_F;
|
||||
nr1 = 1;
|
||||
smem = 32*sizeof(float)*N_R0_F;
|
||||
suffix = "_4";
|
||||
} else {
|
||||
nsg = N_SG_F;
|
||||
nr0 = N_R0_F;
|
||||
nr1 = 1;
|
||||
smem = 32*sizeof(float)*N_R0_F;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
@@ -689,25 +681,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
|
||||
const ggml_type tsrc0 = op->src[0]->type;
|
||||
const ggml_type tsrc1 = op->src[1]->type;
|
||||
|
||||
const char * suffix = "";
|
||||
|
||||
// use custom matrix x vector kernel
|
||||
switch (tsrc0) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
} break;
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
if (ne00 % 4 == 0) {
|
||||
nsg = N_SG_F;
|
||||
nr0 = N_R0_F;
|
||||
nr1 = 1;
|
||||
smem = 32*sizeof(float)*N_R0_F;
|
||||
suffix = "_4";
|
||||
} else {
|
||||
nsg = N_SG_F;
|
||||
nr0 = N_R0_F;
|
||||
nr1 = 1;
|
||||
smem = 32*sizeof(float)*N_R0_F;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
@@ -824,7 +817,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
|
||||
}
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mv_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
|
||||
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
|
||||
@@ -8,6 +8,9 @@
|
||||
//
|
||||
// TODO: for optimal performance, become function of the device and work size
|
||||
|
||||
#define N_R0_F 2
|
||||
#define N_SG_F 4
|
||||
|
||||
#define N_R0_Q4_0 4
|
||||
#define N_SG_Q4_0 2
|
||||
|
||||
|
||||
@@ -1564,7 +1564,10 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
if (op->src[0]->type == GGML_TYPE_Q8_0) {
|
||||
if (op->src[0]->type == GGML_TYPE_F32 ||
|
||||
op->src[0]->type == GGML_TYPE_F16 ||
|
||||
op->src[0]->type == GGML_TYPE_BF16 ||
|
||||
op->src[0]->type == GGML_TYPE_Q8_0) {
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0 - 1)/(nr0)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
|
||||
} else {
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, ((ne01 + nr0*nsg - 1)/(nr0*nsg)), ((ne11 + nr1 - 1)/nr1), ne12*ne13, 32, nsg, 1);
|
||||
@@ -1772,7 +1775,10 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||
|
||||
if (op->src[0]->type == GGML_TYPE_Q8_0) {
|
||||
if (op->src[0]->type == GGML_TYPE_F32 ||
|
||||
op->src[0]->type == GGML_TYPE_F16 ||
|
||||
op->src[0]->type == GGML_TYPE_BF16 ||
|
||||
op->src[0]->type == GGML_TYPE_Q8_0) {
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0 - 1)/(nr0), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
|
||||
} else {
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nr0*nsg - 1)/(nr0*nsg), (_ne1 + nr1 - 1)/nr1, ne123, 32, nsg, 1);
|
||||
|
||||
@@ -3404,104 +3404,211 @@ template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_3")]] kernel mul_mv_ext_q4x4
|
||||
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_4")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<4, block_q6_K, 256, dequantize_q6_K>;
|
||||
template [[host_name("kernel_mul_mv_ext_q6_K_f32_r1_5")]] kernel mul_mv_ext_q4x4_f32_t kernel_mul_mv_ext_q4x4_f32_disp<5, block_q6_K, 256, dequantize_q6_K>;
|
||||
|
||||
#define N_MV_T_T 4
|
||||
|
||||
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
|
||||
void kernel_mul_mv_impl(
|
||||
template<typename T0, typename T1, short NR0, short NSG, short NW, typename args_t>
|
||||
void kernel_mul_mv_t_t_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
ushort tiisg) {
|
||||
const int r0 = tgpig.x;
|
||||
const int rb = tgpig.y*N_MV_T_T;
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
constexpr short NB = 32;
|
||||
constexpr short NF = 8;
|
||||
|
||||
const int nb = args.ne00/NB;
|
||||
|
||||
const int r0 = tgpig.x*NR0;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
|
||||
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
//const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const T0 * x = (device const T0 *) (src0 + offset0);
|
||||
//device const T0 * x = (device const T0 *) (src0 + offset0);
|
||||
device const T1 * y = (device const T1 *) (src1 + offset1);
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
|
||||
// pointers to src0 rows
|
||||
device const T0 * ax [NR0];
|
||||
FOR_UNROLL (short row = 0; row < NR0; ++row) {
|
||||
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
|
||||
if (args.ne00 < 128) {
|
||||
for (int row = 0; row < N_MV_T_T; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= args.ne11) {
|
||||
break;
|
||||
}
|
||||
ax[row] = (device const T0 *) ((device char *) src0 + offset0);
|
||||
}
|
||||
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
float sumf[NR0] = { 0.f };
|
||||
|
||||
device const T1 * y = (device const T1 *) (src1 + offset1);
|
||||
const short ix = tiisg/(NW/NF);
|
||||
const short il = tiisg%(NW/NF);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < args.ne00; i += 32) {
|
||||
sumf += (T0) x[i] * (T1) y[i];
|
||||
}
|
||||
const int ib0 = sgitg*NF + ix;
|
||||
|
||||
float sum_all = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
|
||||
}
|
||||
T1 yl[NF];
|
||||
|
||||
device const T1 * yb = y + (ib0*NB + il*NF);
|
||||
|
||||
for (int ib = ib0; ib < nb; ib += NSG*NF) {
|
||||
for (short i = 0; i < NF; ++i) {
|
||||
yl[i] = yb[i];
|
||||
}
|
||||
} else {
|
||||
device const T04 * x4 = (device const T04 *) x;
|
||||
for (int row = 0; row < N_MV_T_T; ++row) {
|
||||
int r1 = rb + row;
|
||||
if (r1 >= args.ne11) {
|
||||
break;
|
||||
|
||||
for (short row = 0; row < NR0; row++) {
|
||||
device const T0 * xb = ax[row] + (ib*NB + il*NF);
|
||||
|
||||
float sumq = 0.f;
|
||||
FOR_UNROLL (short i = 0; i < NF; ++i) {
|
||||
sumq += xb[i] * yl[i];
|
||||
}
|
||||
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
sumf[row] += sumq;
|
||||
}
|
||||
|
||||
device const T1 * y = (device const T1 *) (src1 + offset1);
|
||||
device const T14 * y4 = (device const T14 *) y;
|
||||
yb += NSG*NF*NW;
|
||||
}
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < args.ne00/4; i += 32) {
|
||||
sumf += dot((float4) x4[i], (float4) y4[i]);
|
||||
}
|
||||
|
||||
float sum_all = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
|
||||
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
|
||||
}
|
||||
for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
|
||||
for (short row = 0; row < NR0; row++) {
|
||||
sumf[row] += ax[row][i] * y[i];
|
||||
}
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||
}
|
||||
|
||||
template<typename T0, typename T04, typename T1, typename T14>
|
||||
kernel void kernel_mul_mv(
|
||||
template<typename T0, typename T1, short NR0, short NSG, short NW>
|
||||
kernel void kernel_mul_mv_t_t(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]]) {
|
||||
kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
|
||||
args,
|
||||
src0,
|
||||
src1,
|
||||
dst,
|
||||
tgpig,
|
||||
tiisg);
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_t_t_impl<T0, T1, NR0, NSG, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
|
||||
typedef decltype(kernel_mul_mv_t_t<half, half, N_R0_F, N_SG_F, N_SIMDWIDTH>) mul_mv_t_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
|
||||
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<float, float, N_R0_F, N_SG_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, float, N_R0_F, N_SG_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<half, half, N_R0_F, N_SG_F, N_SIMDWIDTH>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, float, float4>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<bfloat, bfloat4, bfloat, bfloat4>;
|
||||
template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, float, N_R0_F, N_SG_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t<bfloat, bfloat, N_R0_F, N_SG_F, N_SIMDWIDTH>;
|
||||
#endif
|
||||
|
||||
template<typename T0, typename T04, typename T1, typename T14, short NR0, short NSG, short NW, typename args_t>
|
||||
void kernel_mul_mv_t_t_4_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem,
|
||||
uint3 tgpig,
|
||||
ushort tiisg,
|
||||
ushort sgitg) {
|
||||
constexpr short NB = 32;
|
||||
constexpr short NF = 16;
|
||||
constexpr short NF4 = NF/4;
|
||||
|
||||
const int nb = args.ne00/NB;
|
||||
|
||||
const int r0 = tgpig.x*NR0;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
|
||||
//const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const T1 * y = (device const T1 *) (src1 + offset1);
|
||||
device const T14 * y4 = (device const T14 *) (src1 + offset1);
|
||||
|
||||
// pointers to src0 rows
|
||||
device const T0 * ax [NR0];
|
||||
device const T04 * ax4[NR0];
|
||||
FOR_UNROLL (short row = 0; row < NR0; ++row) {
|
||||
const uint64_t offset0 = (r0 + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
|
||||
ax [row] = (device const T0 *) ((device char *) src0 + offset0);
|
||||
ax4[row] = (device const T04 *) ((device char *) src0 + offset0);
|
||||
}
|
||||
|
||||
float sumf[NR0] = { 0.f };
|
||||
|
||||
const short ix = tiisg/(NW/NF);
|
||||
const short il = tiisg%(NW/NF);
|
||||
|
||||
const int ib0 = sgitg*NF + ix;
|
||||
|
||||
T14 yl4[NF4];
|
||||
|
||||
device const T14 * yb4 = y4 + (ib0*NB + il*NF)/4;
|
||||
|
||||
for (int ib = ib0; ib < nb; ib += NSG*NF) {
|
||||
for (short i = 0; i < NF4; ++i) {
|
||||
yl4[i] = yb4[i];
|
||||
}
|
||||
|
||||
for (short row = 0; row < NR0; row++) {
|
||||
device const T04 * xb4 = ax4[row] + (ib*NB + il*NF)/4;
|
||||
|
||||
float sumq = 0.f;
|
||||
FOR_UNROLL (short i = 0; i < NF4; ++i) {
|
||||
sumq += dot(float4(xb4[i]), float4(yl4[i]));
|
||||
}
|
||||
|
||||
sumf[row] += sumq;
|
||||
}
|
||||
|
||||
yb4 += NSG*NF*NW/4;
|
||||
}
|
||||
|
||||
for (int i = nb*NB + sgitg*NW + tiisg; i < args.ne00; i += NW*NSG) {
|
||||
for (short row = 0; row < NR0; row++) {
|
||||
sumf[row] += ax[row][i] * y[i];
|
||||
}
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
helper_mv_reduce_and_write<NR0, NW>(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem);
|
||||
}
|
||||
|
||||
template<typename T0, typename T04, typename T1, typename T14, short NR0, short NSG, short NW>
|
||||
kernel void kernel_mul_mv_t_t_4(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup char * shmem [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_t_t_4_impl<T0, T04, T1, T14, NR0, NSG, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SG_F, N_SIMDWIDTH>) mul_mv_t_t_4;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<float, float4, float, float4, N_R0_F, N_SG_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, float, float4, N_R0_F, N_SG_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<half, half4, half, half4, N_R0_F, N_SG_F, N_SIMDWIDTH>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, float, float4, N_R0_F, N_SG_F, N_SIMDWIDTH>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4<bfloat, bfloat4, bfloat, bfloat4, N_R0_F, N_SG_F, N_SIMDWIDTH>;
|
||||
#endif
|
||||
|
||||
#define N_MV_T_T 4
|
||||
|
||||
template<typename T04, typename T14, typename args_t>
|
||||
void kernel_mul_mv_c4_impl(
|
||||
args_t args,
|
||||
@@ -3562,112 +3669,10 @@ typedef decltype(kernel_mul_mv_c4<half4, half4>) mul_mv_c4_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<float4, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, float4>;
|
||||
template [[host_name("kernel_mul_mv_f16_f16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<half4, half4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
|
||||
#endif
|
||||
|
||||
template<typename T, typename T4>
|
||||
kernel void kernel_mul_mv_1row(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
|
||||
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const T * x = (device const T *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
|
||||
|
||||
float sumf = 0;
|
||||
if (args.ne00 < 128) {
|
||||
for (int i = tiisg; i < args.ne00; i += 32) {
|
||||
sumf += (float) x[i] * (float) y[i];
|
||||
}
|
||||
float sum_all = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[r0] = sum_all;
|
||||
}
|
||||
} else {
|
||||
device const T4 * x4 = (device const T4 *) x;
|
||||
device const float4 * y4 = (device const float4 *) y;
|
||||
|
||||
for (int i = tiisg; i < args.ne00/4; i += 32) {
|
||||
sumf += dot((float4) x4[i], y4[i]);
|
||||
}
|
||||
|
||||
float sum_all = simd_sum(sumf);
|
||||
|
||||
if (tiisg == 0) {
|
||||
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) sum_all += (float) (x[i] * y[i]);
|
||||
dst_f32[r0] = sum_all;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<bfloat, bfloat4>;
|
||||
#endif
|
||||
|
||||
// Assumes row size (ne00) is a multiple of 4
|
||||
template<typename T, typename T4>
|
||||
kernel void kernel_mul_mv_l4(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]]) {
|
||||
|
||||
const int nrows = args.ne11;
|
||||
const int r0 = tgpig.x;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const uint i12 = im%args.ne12;
|
||||
const uint i13 = im/args.ne12;
|
||||
|
||||
const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
|
||||
device const T4 * x4 = (device const T4 *) (src0 + offset0);
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
|
||||
|
||||
for (int r1 = 0; r1 < nrows; ++r1) {
|
||||
const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const float4 * y4 = (device const float4 *) (src1 + offset1);
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = tiisg; i < args.ne00/4; i += 32) {
|
||||
sumf += dot((float4) x4[i], y4[i]);
|
||||
}
|
||||
|
||||
float sum_all = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst_f32[(uint64_t)r1*args.ne0 + r0] = sum_all;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_bf16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<bfloat, bfloat4>;
|
||||
template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, float4>;
|
||||
template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4<bfloat4, bfloat4>;
|
||||
#endif
|
||||
|
||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||
@@ -8314,7 +8319,7 @@ void mmv_fn(
|
||||
impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
|
||||
typedef decltype(mmv_fn<kernel_mul_mv_t_t_impl<half, half, N_R0_F, N_SG_F, N_SIMDWIDTH, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
|
||||
|
||||
template<mul_mv_impl_fn_t impl_fn>
|
||||
kernel void kernel_mul_mv_id(
|
||||
@@ -8379,13 +8384,21 @@ kernel void kernel_mul_mv_id(
|
||||
sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SG_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SG_F, N_SIMDWIDTH>>>) kernel_mul_mv_id_4_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<float, float, N_R0_F, N_SG_F, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<half, float, N_R0_F, N_SG_F, N_SIMDWIDTH>>>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<bfloat, bfloat4, float, float4>>>;
|
||||
template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_impl<bfloat, float, N_R0_F, N_SG_F, N_SIMDWIDTH>>>;
|
||||
#endif
|
||||
template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<float, float4, float, float4, N_R0_F, N_SG_F, N_SIMDWIDTH>>>;
|
||||
template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<half, half4, float, float4, N_R0_F, N_SG_F, N_SIMDWIDTH>>>;
|
||||
#if defined(GGML_METAL_HAS_BF16)
|
||||
template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_t_t_4_impl<bfloat, bfloat4, float, float4, N_R0_F, N_SG_F, N_SIMDWIDTH>>>;
|
||||
#endif
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl<N_R0_Q8_0, N_SG_Q8_0, N_SIMDWIDTH>>>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_R0_Q4_0, N_SG_Q4_0, N_SIMDWIDTH>>>;
|
||||
|
||||
Reference in New Issue
Block a user