mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-28 08:31:25 +00:00
metal : improve F32, F16 and BF16 mat-vec multiplication (#16057)
* metal : improve F32, F16 and BF16 mat-vec multiplication ggml-ci * metal : make the NSG a function constant in mul_mv kernels ggml-ci
This commit is contained in:
@@ -34,6 +34,10 @@ ggml_metal_pipelines_t ggml_metal_pipelines_init(void) {
|
||||
}
|
||||
|
||||
void ggml_metal_pipelines_free(ggml_metal_pipelines_t ppls) {
|
||||
if (!ppls) {
|
||||
return;
|
||||
}
|
||||
|
||||
for (auto it = ppls->data.begin(); it != ppls->data.end(); ++it) {
|
||||
ggml_metal_pipeline_free(it->second);
|
||||
}
|
||||
@@ -467,37 +471,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
|
||||
// use custom matrix x vector kernel
|
||||
switch (tsrc0) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
nr1 = 4;
|
||||
if (ne00 == 4) {
|
||||
nr0 = 32;
|
||||
suffix = "_c4";
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
if (op->src[1]->type == GGML_TYPE_F32) {
|
||||
if (ne00 == 4) {
|
||||
nr0 = 32;
|
||||
nr1 = 4;
|
||||
suffix = "_c4";
|
||||
} else if (ne11 * ne12 < 4) {
|
||||
suffix = "_1row";
|
||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||
suffix = "_l4";
|
||||
nr1 = ne11;
|
||||
} else {
|
||||
nr1 = 4;
|
||||
}
|
||||
} else {
|
||||
if (ne00 == 4) {
|
||||
nsg = 1;
|
||||
nr0 = 32;
|
||||
nr1 = 4;
|
||||
suffix = "_c4";
|
||||
} else if (ne00 % 4 == 0) {
|
||||
nsg = N_SG_F;
|
||||
nr0 = N_R0_F;
|
||||
nr1 = 1;
|
||||
smem = 32*sizeof(float)*N_R0_F;
|
||||
suffix = "_4";
|
||||
} else {
|
||||
nsg = N_SG_F;
|
||||
nr0 = N_R0_F;
|
||||
nr1 = 1;
|
||||
smem = 32*sizeof(float)*N_R0_F;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
@@ -623,7 +615,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
ggml_metal_pipeline_set_nr0 (res, nr0);
|
||||
ggml_metal_pipeline_set_nr1 (res, nr1);
|
||||
@@ -689,25 +687,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
|
||||
const ggml_type tsrc0 = op->src[0]->type;
|
||||
const ggml_type tsrc1 = op->src[1]->type;
|
||||
|
||||
const char * suffix = "";
|
||||
|
||||
// use custom matrix x vector kernel
|
||||
switch (tsrc0) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
} break;
|
||||
case GGML_TYPE_BF16:
|
||||
{
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
nsg = 1;
|
||||
nr0 = 1;
|
||||
if (ne00 % 4 == 0) {
|
||||
nsg = N_SG_F;
|
||||
nr0 = N_R0_F;
|
||||
nr1 = 1;
|
||||
smem = 32*sizeof(float)*N_R0_F;
|
||||
suffix = "_4";
|
||||
} else {
|
||||
nsg = N_SG_F;
|
||||
nr0 = N_R0_F;
|
||||
nr1 = 1;
|
||||
smem = 32*sizeof(float)*N_R0_F;
|
||||
}
|
||||
} break;
|
||||
case GGML_TYPE_Q4_0:
|
||||
{
|
||||
@@ -824,7 +823,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
|
||||
}
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mv_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
|
||||
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
@@ -832,7 +831,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
ggml_metal_pipeline_set_nr0 (res, nr0);
|
||||
ggml_metal_pipeline_set_nr1 (res, nr1);
|
||||
|
||||
Reference in New Issue
Block a user