metal : fuse non-sequential nodes (#16102)

* metal : fuse non-sequential nodes

* cont : add comment

* cont : simplify bounds checks
This commit is contained in:
Georgi Gerganov
2025-09-28 09:34:05 +03:00
committed by GitHub
parent 1384abf8b8
commit 3b53634fe3
3 changed files with 161 additions and 135 deletions

View File

@@ -567,13 +567,13 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
ctx->debug_graph,
ctx->debug_fusion);
for (int idx = idx_start; idx < idx_end;) {
for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) {
const int res = ggml_metal_op_encode(ctx_op, idx);
if (res == 0) {
break;
}
idx += res;
idx += res - 1;
}
ggml_metal_op_free(ctx_op);

View File

@@ -24,22 +24,88 @@ static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
}
struct ggml_metal_op {
ggml_metal_op(
ggml_metal_device_t dev,
ggml_metal_cmd_buf_t cmd_buf,
ggml_cgraph * gf,
int idx_start,
int idx_end,
bool use_fusion,
bool use_concurrency,
bool use_capture,
int debug_graph,
int debug_fusion) {
this->dev = dev;
this->lib = ggml_metal_device_get_library(dev);
this->enc = ggml_metal_encoder_init(cmd_buf, use_concurrency);
this->mem_ranges = ggml_mem_ranges_init(debug_graph);
this->idx_start = idx_start;
this->idx_end = idx_end;
this->use_fusion = use_fusion;
this->use_concurrency = use_concurrency;
this->use_capture = use_capture;
this->debug_graph = debug_graph;
this->debug_fusion = debug_fusion;
this->gf = gf;
idxs.reserve(gf->n_nodes);
// filter empty nodes
// TODO: this can be removed when the allocator starts filtering them earlier
// https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830
for (int i = idx_start; i < idx_end; i++) {
if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) {
idxs.push_back(i);
}
}
}
~ggml_metal_op() {
ggml_metal_encoder_end_encoding(this->enc);
ggml_metal_encoder_free(this->enc);
ggml_mem_ranges_free(this->mem_ranges);
}
int n_nodes() const {
return idxs.size();
}
ggml_tensor * node(int i) const {
assert(i >= 0 && i < (int) idxs.size());
return ggml_graph_node(gf, idxs[i]);
}
bool can_fuse(int i0, const ggml_op * ops, int n_ops) const {
assert(use_fusion);
assert(i0 >= 0 && i0 < n_nodes());
if (i0 + n_ops > n_nodes()) {
return false;
}
return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops);
}
ggml_metal_device_t dev;
ggml_metal_library_t lib;
ggml_metal_encoder_t enc;
ggml_mem_ranges_t mem_ranges;
ggml_cgraph * gf;
int idx_start;
int idx_end;
bool use_fusion;
bool use_concurrency;
bool use_capture;
int debug_graph;
int debug_fusion;
private:
ggml_cgraph * gf;
int idx_start;
int idx_end;
// non-empty node indices
std::vector<int> idxs;
};
ggml_metal_op_t ggml_metal_op_init(
@@ -53,34 +119,29 @@ ggml_metal_op_t ggml_metal_op_init(
bool use_capture,
int debug_graph,
int debug_fusion) {
ggml_metal_op_t res = new ggml_metal_op();
*res = {
/*.dev =*/ dev,
/*.lib =*/ ggml_metal_device_get_library(dev),
/*.enc =*/ ggml_metal_encoder_init(cmd_buf, use_concurrency),
/*.mem_ranges =*/ ggml_mem_ranges_init(debug_graph),
/*.gf =*/ gf,
/*.idx_start =*/ idx_start,
/*.idx_end =*/ idx_end,
/*.use_fusion =*/ use_fusion,
/*.use_concurrency =*/ use_concurrency,
/*.use_capture =*/ use_capture,
/*.debug_graph =*/ debug_graph,
/*.debug_fusion =*/ debug_fusion,
};
ggml_metal_op_t res = new ggml_metal_op(
dev,
cmd_buf,
gf,
idx_start,
idx_end,
use_fusion,
use_concurrency,
use_capture,
debug_graph,
debug_fusion);
return res;
}
void ggml_metal_op_free(ggml_metal_op_t ctx) {
ggml_metal_encoder_end_encoding(ctx->enc);
ggml_metal_encoder_free(ctx->enc);
ggml_mem_ranges_free(ctx->mem_ranges);
delete ctx;
}
int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) {
return ctx->n_nodes();
}
static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) {
if (!ctx->mem_ranges) {
return true;
@@ -110,10 +171,7 @@ static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor
}
static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
struct ggml_cgraph * gf = ctx->gf;
struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
struct ggml_tensor * node = nodes[0];
struct ggml_tensor * node = ctx->node(idx);
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
@@ -129,6 +187,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
case GGML_OP_PERMUTE:
{
// noop -> next node
if (ctx->debug_graph > 0) {
GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)");
}
} return 1;
default:
{
@@ -352,7 +413,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
// update the mem ranges in the encoding context
for (int i = 0; i < n_fuse; ++i) {
if (!ggml_metal_op_concurrency_add(ctx, nodes[i])) {
if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) {
ggml_metal_op_concurrency_reset(ctx);
}
}
@@ -362,11 +423,11 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {
if (ctx->use_capture) {
ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ggml_graph_node(ctx->gf, idx)));
ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx)));
}
int res = ggml_metal_op_encode_impl(ctx, idx);
if (idx + res > ctx->idx_end) {
if (idx + res > ctx->n_nodes()) {
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
"https://github.com/ggml-org/llama.cpp/pull/14849");
}
@@ -379,8 +440,7 @@ int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -438,8 +498,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -483,8 +542,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -594,8 +652,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -634,8 +691,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -674,8 +730,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -703,8 +758,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -774,8 +828,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -838,8 +891,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -876,8 +928,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -939,8 +990,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -1030,8 +1080,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -1076,8 +1125,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -1170,8 +1218,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -1212,8 +1259,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -1286,8 +1332,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -1347,8 +1392,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -1589,8 +1633,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) {
}
int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -1783,8 +1826,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -1856,8 +1898,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
}
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -2176,16 +2217,11 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const int idx_end = ctx->idx_end;
const bool use_fusion = ctx->use_fusion;
const int debug_fusion = ctx->debug_fusion;
@@ -2258,22 +2294,25 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
// across splits. idx_end indicates the last node in the current split
for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
break;
}
if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
ggml_tensor * f0 = ctx->node(idx + n_fuse);
ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
if (f0 != f1->src[0]) {
break;
}
// b[0] === b[1] === ...
if (!ggml_are_same_layout(ops[n_fuse]->src[1], ops[n_fuse + 1]->src[1])) {
if (!ggml_are_same_layout(f0->src[1], f1->src[1])) {
break;
}
// only fuse ops if src1 is in the same Metal buffer
ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]);
if (bid_fuse.metal != bid_src1.metal) {
break;
}
@@ -2309,10 +2348,10 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
}
if (n_fuse > 1) {
bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
for (int i = 1; i < n_fuse; ++i) {
if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
ggml_metal_op_concurrency_reset(ctx);
break;
@@ -2344,8 +2383,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -2393,8 +2431,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -2445,20 +2482,15 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
const int idx_end = ctx->idx_end;
const bool use_fusion = ctx->use_fusion;
const int debug_fusion = ctx->debug_fusion;
ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
@@ -2499,38 +2531,41 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
fops[1] = GGML_OP_MUL;
fops[2] = GGML_OP_ADD;
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
break;
}
if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
ggml_tensor * f0 = ctx->node(idx + n_fuse);
ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
if (f0 != f1->src[0]) {
break;
}
if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
if (f1->src[1]->ne[0] != op->ne[0]) {
break;
}
if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
if (!ggml_is_contiguous_rows(f1->src[1])) {
break;
}
if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
if (f1->type != GGML_TYPE_F32) {
break;
}
//ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
//ctx->fuse_cnt[f1->op]++;
bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]);
args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
args.nef1[n_fuse + 1] = f1->src[1]->ne[1];
args.nef2[n_fuse + 1] = f1->src[1]->ne[2];
args.nef3[n_fuse + 1] = f1->src[1]->ne[3];
args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
args.nbf1[n_fuse + 1] = f1->src[1]->nb[1];
args.nbf2[n_fuse + 1] = f1->src[1]->nb[2];
args.nbf3[n_fuse + 1] = f1->src[1]->nb[3];
}
++n_fuse;
@@ -2546,10 +2581,10 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
}
if (n_fuse > 1) {
bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
for (int i = 1; i < n_fuse; ++i) {
if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
ggml_metal_op_concurrency_reset(ctx);
break;
@@ -2585,8 +2620,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -2681,8 +2715,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -2752,8 +2785,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -2798,8 +2830,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -2852,8 +2883,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -2897,8 +2927,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -2944,8 +2973,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -2985,8 +3013,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -3020,8 +3047,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -3060,8 +3086,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
@@ -3103,8 +3128,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
}
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
ggml_cgraph * gf = ctx->gf;
ggml_tensor * op = ggml_graph_node(gf, idx);
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;

View File

@@ -22,6 +22,8 @@ ggml_metal_op_t ggml_metal_op_init(
void ggml_metal_op_free(ggml_metal_op_t ctx);
int ggml_metal_op_n_nodes(ggml_metal_op_t ctx);
int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx);
//