mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
metal : refactor kernel loading (#15964)
* metal : refactor bin kernels loading ggml-ci * metal : refactor rms kernel loading ggml-ci * ci : try to add memory leaks check ggml-ci * ci : try to enable memory leak detection for Mac * cont : seems to be working
This commit is contained in:
@@ -232,28 +232,6 @@ struct ggml_metal_kernel {
|
||||
@end
|
||||
|
||||
enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_ADD,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_2,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_3,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_4,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_5,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_6,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_7,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_FUSE_8,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8,
|
||||
GGML_METAL_KERNEL_TYPE_SUB,
|
||||
GGML_METAL_KERNEL_TYPE_SUB_ROW_C4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_ROW_C4,
|
||||
GGML_METAL_KERNEL_TYPE_DIV,
|
||||
GGML_METAL_KERNEL_TYPE_DIV_ROW_C4,
|
||||
GGML_METAL_KERNEL_TYPE_ADD_ID,
|
||||
GGML_METAL_KERNEL_TYPE_REPEAT_F32,
|
||||
GGML_METAL_KERNEL_TYPE_REPEAT_F16,
|
||||
@@ -319,9 +297,6 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
|
||||
GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
|
||||
GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL,
|
||||
GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD,
|
||||
GGML_METAL_KERNEL_TYPE_L2_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
||||
GGML_METAL_KERNEL_TYPE_NORM,
|
||||
@@ -1177,28 +1152,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
|
||||
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
||||
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_2, add_fuse_2, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_3, add_fuse_3, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_4, add_fuse_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_5, add_fuse_5, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_6, add_fuse_6, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_7, add_fuse_7, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_FUSE_8, add_fuse_8, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4, add_row_c4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2, add_row_c4_fuse_2, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3, add_row_c4_fuse_3, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4, add_row_c4_fuse_4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5, add_row_c4_fuse_5, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6, add_row_c4_fuse_6, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7, add_row_c4_fuse_7, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8, add_row_c4_fuse_8, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB, sub, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUB_ROW_C4, sub_row_c4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW_C4, mul_row_c4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW_C4, div_row_c4, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ID, add_id, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F32, repeat_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REPEAT_F16, repeat_f16, true);
|
||||
@@ -1264,9 +1217,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL, rms_norm_mul, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD, rms_norm_mul_add, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||
@@ -1722,6 +1672,73 @@ static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext_vec_re
|
||||
GGML_UNUSED(op);
|
||||
}
|
||||
|
||||
static id<MTLComputePipelineState> ggml_metal_get_pipeline_bin(
|
||||
ggml_backend_t backend, enum ggml_op op,
|
||||
int32_t n_fuse,
|
||||
bool row) {
|
||||
struct ggml_backend_metal_context * ctx = backend->context;
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
@autoreleasepool {
|
||||
const char * op_str = "undefined";
|
||||
switch (op) {
|
||||
case GGML_OP_ADD: op_str = "add"; break;
|
||||
case GGML_OP_SUB: op_str = "sub"; break;
|
||||
case GGML_OP_MUL: op_str = "mul"; break;
|
||||
case GGML_OP_DIV: op_str = "div"; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
|
||||
if (row) {
|
||||
snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
|
||||
} else {
|
||||
snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
|
||||
}
|
||||
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
|
||||
if (res) {
|
||||
// kernel found
|
||||
return res;
|
||||
}
|
||||
|
||||
return ggml_metal_compile_kernel(backend, base, name, nil);
|
||||
}
|
||||
}
|
||||
|
||||
static id<MTLComputePipelineState> ggml_metal_get_pipeline_rms_norm(
|
||||
ggml_backend_t backend, struct ggml_tensor * op,
|
||||
int32_t n_fuse) {
|
||||
struct ggml_backend_metal_context * ctx = backend->context;
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
@autoreleasepool {
|
||||
switch (n_fuse) {
|
||||
case 1: snprintf(base, 256, "kernel_rms_norm"); break;
|
||||
case 2: snprintf(base, 256, "kernel_rms_norm_mul"); break;
|
||||
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add"); break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
id<MTLComputePipelineState> res = ggml_metal_get_kernel(ctx, name);
|
||||
if (res) {
|
||||
// kernel found
|
||||
return res;
|
||||
}
|
||||
|
||||
return ggml_metal_compile_kernel(backend, base, name, nil);
|
||||
}
|
||||
|
||||
GGML_UNUSED(op);
|
||||
}
|
||||
|
||||
static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
||||
GGML_LOG_INFO("%s: deallocating\n", __func__);
|
||||
|
||||
@@ -2359,8 +2376,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
|
||||
|
||||
bool bcast_row = false;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
ggml_metal_kargs_bin args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
@@ -2441,55 +2456,19 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
|
||||
}
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||
|
||||
// src1 is a row
|
||||
GGML_ASSERT(ne11 == 1);
|
||||
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
switch (n_fuse) {
|
||||
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4 ].pipeline; break;
|
||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_2].pipeline; break;
|
||||
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_3].pipeline; break;
|
||||
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_4].pipeline; break;
|
||||
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_5].pipeline; break;
|
||||
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_6].pipeline; break;
|
||||
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_7].pipeline; break;
|
||||
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW_C4_FUSE_8].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW_C4].pipeline; break;
|
||||
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_ROW_C4].pipeline; break;
|
||||
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV_ROW_C4].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
pipeline = ggml_metal_get_pipeline_bin(backend, dst->op, n_fuse, true);
|
||||
|
||||
bcast_row = true;
|
||||
} else {
|
||||
switch (dst->op) {
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
switch (n_fuse) {
|
||||
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD ].pipeline; break;
|
||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_2].pipeline; break;
|
||||
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_3].pipeline; break;
|
||||
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_4].pipeline; break;
|
||||
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_5].pipeline; break;
|
||||
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_6].pipeline; break;
|
||||
case 7: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_7].pipeline; break;
|
||||
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_FUSE_8].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB].pipeline; break;
|
||||
case GGML_OP_MUL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL].pipeline; break;
|
||||
case GGML_OP_DIV: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIV].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
pipeline = ggml_metal_get_pipeline_bin(backend, dst->op, n_fuse, false);
|
||||
}
|
||||
|
||||
if (n_fuse > 1) {
|
||||
@@ -2650,8 +2629,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
|
||||
ggml_metal_encode_concurrency_reset(ctx_enc);
|
||||
}
|
||||
|
||||
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
|
||||
|
||||
ggml_metal_kargs_bin args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
@@ -2681,6 +2658,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
|
||||
/*.o1 =*/ { offs_src1},
|
||||
};
|
||||
|
||||
//const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
|
||||
const id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_bin(backend, GGML_OP_ADD, 1, false);
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
@@ -4659,14 +4639,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
|
||||
}
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> pipeline;
|
||||
|
||||
switch (n_fuse) {
|
||||
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM ].pipeline; break;
|
||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL ].pipeline; break;
|
||||
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM_MUL_ADD].pipeline; break;
|
||||
default: GGML_ABORT("unsupported n_fuse = %d\n", n_fuse);
|
||||
}
|
||||
const id<MTLComputePipelineState> pipeline = ggml_metal_get_pipeline_rms_norm(backend, node, n_fuse);
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
|
||||
@@ -928,7 +928,7 @@ kernel void kernel_add_fuse_impl(
|
||||
|
||||
typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t;
|
||||
|
||||
template [[host_name("kernel_add")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
|
||||
template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>;
|
||||
template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>;
|
||||
template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>;
|
||||
template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>;
|
||||
@@ -937,7 +937,7 @@ template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_
|
||||
template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>;
|
||||
template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>;
|
||||
|
||||
kernel void kernel_sub(
|
||||
kernel void kernel_sub_fuse_1(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
@@ -963,7 +963,7 @@ kernel void kernel_sub(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul(
|
||||
kernel void kernel_mul_fuse_1(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
@@ -996,7 +996,7 @@ kernel void kernel_mul(
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_div(
|
||||
kernel void kernel_div_fuse_1(
|
||||
constant ggml_metal_kargs_bin & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
@@ -1096,23 +1096,17 @@ kernel void kernel_add_row_c4_fuse_impl(
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
|
||||
const uint nb = args.ne00/4;
|
||||
const uint i = tpig % nb;
|
||||
|
||||
device const float4 * src0_row = (device const float4 *) (src0);
|
||||
device float4 * dst_row = (device float4 *) (dst);
|
||||
|
||||
device const float4 * src1_row[F];
|
||||
for (short j = 0; j < F; ++j) {
|
||||
src1_row[j] = (device const float4 *) (src1 + args.o1[j]);
|
||||
}
|
||||
|
||||
float4 res = src0_row[tpig];
|
||||
|
||||
#pragma unroll(F)
|
||||
for (short j = 0; j < F; ++j) {
|
||||
res += src1_row[j][i];
|
||||
res += ((device const float4 *) (src1 + args.o1[j]))[i];
|
||||
}
|
||||
|
||||
dst_row[tpig] = res;
|
||||
@@ -1120,7 +1114,7 @@ kernel void kernel_add_row_c4_fuse_impl(
|
||||
|
||||
typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t;
|
||||
|
||||
template [[host_name("kernel_add_row_c4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
|
||||
template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>;
|
||||
template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>;
|
||||
template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>;
|
||||
template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>;
|
||||
@@ -1160,7 +1154,7 @@ kernel void kernel_sub_row_c4_fuse_impl(
|
||||
|
||||
typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t;
|
||||
|
||||
template [[host_name("kernel_sub_row_c4")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
|
||||
template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>;
|
||||
|
||||
template <short F>
|
||||
kernel void kernel_mul_row_c4_fuse_impl(
|
||||
@@ -1193,7 +1187,7 @@ kernel void kernel_mul_row_c4_fuse_impl(
|
||||
|
||||
typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t;
|
||||
|
||||
template [[host_name("kernel_mul_row_c4")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
|
||||
template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>;
|
||||
|
||||
template <short F>
|
||||
kernel void kernel_div_row_c4_fuse_impl(
|
||||
@@ -1226,7 +1220,7 @@ kernel void kernel_div_row_c4_fuse_impl(
|
||||
|
||||
typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t;
|
||||
|
||||
template [[host_name("kernel_div_row_c4")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
|
||||
template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>;
|
||||
|
||||
kernel void kernel_scale(
|
||||
device const float * src0,
|
||||
|
||||
Reference in New Issue
Block a user