mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
metal : fuse non-sequential nodes (#16102)
* metal : fuse non-sequential nodes * cont : add comment * cont : simplify bounds checks
This commit is contained in:
@@ -567,13 +567,13 @@ void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
|
||||
ctx->debug_graph,
|
||||
ctx->debug_fusion);
|
||||
|
||||
for (int idx = idx_start; idx < idx_end;) {
|
||||
for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) {
|
||||
const int res = ggml_metal_op_encode(ctx_op, idx);
|
||||
if (res == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
idx += res;
|
||||
idx += res - 1;
|
||||
}
|
||||
|
||||
ggml_metal_op_free(ctx_op);
|
||||
|
||||
@@ -24,22 +24,88 @@ static ggml_metal_buffer_id ggml_metal_get_buffer_id(const ggml_tensor * t) {
|
||||
}
|
||||
|
||||
struct ggml_metal_op {
|
||||
ggml_metal_op(
|
||||
ggml_metal_device_t dev,
|
||||
ggml_metal_cmd_buf_t cmd_buf,
|
||||
ggml_cgraph * gf,
|
||||
int idx_start,
|
||||
int idx_end,
|
||||
bool use_fusion,
|
||||
bool use_concurrency,
|
||||
bool use_capture,
|
||||
int debug_graph,
|
||||
int debug_fusion) {
|
||||
this->dev = dev;
|
||||
this->lib = ggml_metal_device_get_library(dev);
|
||||
this->enc = ggml_metal_encoder_init(cmd_buf, use_concurrency);
|
||||
this->mem_ranges = ggml_mem_ranges_init(debug_graph);
|
||||
this->idx_start = idx_start;
|
||||
this->idx_end = idx_end;
|
||||
this->use_fusion = use_fusion;
|
||||
this->use_concurrency = use_concurrency;
|
||||
this->use_capture = use_capture;
|
||||
this->debug_graph = debug_graph;
|
||||
this->debug_fusion = debug_fusion;
|
||||
this->gf = gf;
|
||||
|
||||
idxs.reserve(gf->n_nodes);
|
||||
|
||||
// filter empty nodes
|
||||
// TODO: this can be removed when the allocator starts filtering them earlier
|
||||
// https://github.com/ggml-org/llama.cpp/pull/16130#issuecomment-3327905830
|
||||
for (int i = idx_start; i < idx_end; i++) {
|
||||
if (!ggml_op_is_empty(gf->nodes[i]->op) && !ggml_is_empty(gf->nodes[i])) {
|
||||
idxs.push_back(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
~ggml_metal_op() {
|
||||
ggml_metal_encoder_end_encoding(this->enc);
|
||||
ggml_metal_encoder_free(this->enc);
|
||||
ggml_mem_ranges_free(this->mem_ranges);
|
||||
}
|
||||
|
||||
int n_nodes() const {
|
||||
return idxs.size();
|
||||
}
|
||||
|
||||
ggml_tensor * node(int i) const {
|
||||
assert(i >= 0 && i < (int) idxs.size());
|
||||
return ggml_graph_node(gf, idxs[i]);
|
||||
}
|
||||
|
||||
bool can_fuse(int i0, const ggml_op * ops, int n_ops) const {
|
||||
assert(use_fusion);
|
||||
assert(i0 >= 0 && i0 < n_nodes());
|
||||
|
||||
if (i0 + n_ops > n_nodes()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return ggml_can_fuse_ext(gf, idxs.data() + i0, ops, n_ops);
|
||||
}
|
||||
|
||||
ggml_metal_device_t dev;
|
||||
ggml_metal_library_t lib;
|
||||
ggml_metal_encoder_t enc;
|
||||
ggml_mem_ranges_t mem_ranges;
|
||||
|
||||
ggml_cgraph * gf;
|
||||
|
||||
int idx_start;
|
||||
int idx_end;
|
||||
|
||||
bool use_fusion;
|
||||
bool use_concurrency;
|
||||
bool use_capture;
|
||||
|
||||
int debug_graph;
|
||||
int debug_fusion;
|
||||
|
||||
private:
|
||||
ggml_cgraph * gf;
|
||||
|
||||
int idx_start;
|
||||
int idx_end;
|
||||
|
||||
// non-empty node indices
|
||||
std::vector<int> idxs;
|
||||
};
|
||||
|
||||
ggml_metal_op_t ggml_metal_op_init(
|
||||
@@ -53,34 +119,29 @@ ggml_metal_op_t ggml_metal_op_init(
|
||||
bool use_capture,
|
||||
int debug_graph,
|
||||
int debug_fusion) {
|
||||
ggml_metal_op_t res = new ggml_metal_op();
|
||||
|
||||
*res = {
|
||||
/*.dev =*/ dev,
|
||||
/*.lib =*/ ggml_metal_device_get_library(dev),
|
||||
/*.enc =*/ ggml_metal_encoder_init(cmd_buf, use_concurrency),
|
||||
/*.mem_ranges =*/ ggml_mem_ranges_init(debug_graph),
|
||||
/*.gf =*/ gf,
|
||||
/*.idx_start =*/ idx_start,
|
||||
/*.idx_end =*/ idx_end,
|
||||
/*.use_fusion =*/ use_fusion,
|
||||
/*.use_concurrency =*/ use_concurrency,
|
||||
/*.use_capture =*/ use_capture,
|
||||
/*.debug_graph =*/ debug_graph,
|
||||
/*.debug_fusion =*/ debug_fusion,
|
||||
};
|
||||
ggml_metal_op_t res = new ggml_metal_op(
|
||||
dev,
|
||||
cmd_buf,
|
||||
gf,
|
||||
idx_start,
|
||||
idx_end,
|
||||
use_fusion,
|
||||
use_concurrency,
|
||||
use_capture,
|
||||
debug_graph,
|
||||
debug_fusion);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void ggml_metal_op_free(ggml_metal_op_t ctx) {
|
||||
ggml_metal_encoder_end_encoding(ctx->enc);
|
||||
ggml_metal_encoder_free(ctx->enc);
|
||||
ggml_mem_ranges_free(ctx->mem_ranges);
|
||||
|
||||
delete ctx;
|
||||
}
|
||||
|
||||
int ggml_metal_op_n_nodes(ggml_metal_op_t ctx) {
|
||||
return ctx->n_nodes();
|
||||
}
|
||||
|
||||
static bool ggml_metal_op_concurrency_reset(ggml_metal_op_t ctx) {
|
||||
if (!ctx->mem_ranges) {
|
||||
return true;
|
||||
@@ -110,10 +171,7 @@ static bool ggml_metal_op_concurrency_add(ggml_metal_op_t ctx, const ggml_tensor
|
||||
}
|
||||
|
||||
static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
struct ggml_cgraph * gf = ctx->gf;
|
||||
|
||||
struct ggml_tensor ** nodes = ggml_graph_nodes(gf) + idx;
|
||||
struct ggml_tensor * node = nodes[0];
|
||||
struct ggml_tensor * node = ctx->node(idx);
|
||||
|
||||
//GGML_LOG_INFO("%s: encoding node %3d, op = %8s\n", __func__, idx, ggml_op_name(node->op));
|
||||
|
||||
@@ -129,6 +187,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
case GGML_OP_PERMUTE:
|
||||
{
|
||||
// noop -> next node
|
||||
if (ctx->debug_graph > 0) {
|
||||
GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(node->op), "(noop)");
|
||||
}
|
||||
} return 1;
|
||||
default:
|
||||
{
|
||||
@@ -352,7 +413,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
// update the mem ranges in the encoding context
|
||||
for (int i = 0; i < n_fuse; ++i) {
|
||||
if (!ggml_metal_op_concurrency_add(ctx, nodes[i])) {
|
||||
if (!ggml_metal_op_concurrency_add(ctx, ctx->node(idx + i))) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
}
|
||||
}
|
||||
@@ -362,11 +423,11 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {
|
||||
if (ctx->use_capture) {
|
||||
ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ggml_graph_node(ctx->gf, idx)));
|
||||
ggml_metal_encoder_debug_group_push(ctx->enc, ggml_op_desc(ctx->node(idx)));
|
||||
}
|
||||
|
||||
int res = ggml_metal_op_encode_impl(ctx, idx);
|
||||
if (idx + res > ctx->idx_end) {
|
||||
if (idx + res > ctx->n_nodes()) {
|
||||
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/14849");
|
||||
}
|
||||
@@ -379,8 +440,7 @@ int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -438,8 +498,7 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -483,8 +542,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -594,8 +652,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -634,8 +691,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -674,8 +730,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -703,8 +758,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -774,8 +828,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -838,8 +891,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -876,8 +928,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -939,8 +990,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -1030,8 +1080,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -1076,8 +1125,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -1170,8 +1218,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -1212,8 +1259,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -1286,8 +1332,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -1347,8 +1392,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -1589,8 +1633,7 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const ggml_tensor * op) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -1783,8 +1826,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_add_id(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -1856,8 +1898,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -2176,16 +2217,11 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
|
||||
ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
const int idx_end = ctx->idx_end;
|
||||
|
||||
const bool use_fusion = ctx->use_fusion;
|
||||
|
||||
const int debug_fusion = ctx->debug_fusion;
|
||||
@@ -2258,22 +2294,25 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
// note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing ops
|
||||
// across splits. idx_end indicates the last node in the current split
|
||||
for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
||||
if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
|
||||
for (n_fuse = 0; n_fuse <= 6; ++n_fuse) {
|
||||
if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
|
||||
ggml_tensor * f0 = ctx->node(idx + n_fuse);
|
||||
ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
|
||||
|
||||
if (f0 != f1->src[0]) {
|
||||
break;
|
||||
}
|
||||
|
||||
// b[0] === b[1] === ...
|
||||
if (!ggml_are_same_layout(ops[n_fuse]->src[1], ops[n_fuse + 1]->src[1])) {
|
||||
if (!ggml_are_same_layout(f0->src[1], f1->src[1])) {
|
||||
break;
|
||||
}
|
||||
|
||||
// only fuse ops if src1 is in the same Metal buffer
|
||||
ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
|
||||
ggml_metal_buffer_id bid_fuse = ggml_metal_get_buffer_id(f1->src[1]);
|
||||
if (bid_fuse.metal != bid_src1.metal) {
|
||||
break;
|
||||
}
|
||||
@@ -2309,10 +2348,10 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
if (n_fuse > 1) {
|
||||
bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
|
||||
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
|
||||
|
||||
for (int i = 1; i < n_fuse; ++i) {
|
||||
if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
|
||||
if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
break;
|
||||
@@ -2344,8 +2383,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -2393,8 +2431,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -2445,20 +2482,15 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
const int idx_end = ctx->idx_end;
|
||||
|
||||
const bool use_fusion = ctx->use_fusion;
|
||||
|
||||
const int debug_fusion = ctx->debug_fusion;
|
||||
|
||||
ggml_tensor ** ops = ggml_graph_nodes(gf) + idx;
|
||||
|
||||
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
|
||||
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
|
||||
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
|
||||
@@ -2499,38 +2531,41 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||
fops[1] = GGML_OP_MUL;
|
||||
fops[2] = GGML_OP_ADD;
|
||||
|
||||
for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) {
|
||||
if (!ggml_can_fuse(gf, idx + n_fuse, fops + n_fuse, 2)) {
|
||||
for (n_fuse = 0; n_fuse <= 1; ++n_fuse) {
|
||||
if (!ctx->can_fuse(idx + n_fuse, fops + n_fuse, 2)) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (ops[n_fuse] != ops[n_fuse + 1]->src[0]) {
|
||||
ggml_tensor * f0 = ctx->node(idx + n_fuse);
|
||||
ggml_tensor * f1 = ctx->node(idx + n_fuse + 1);
|
||||
|
||||
if (f0 != f1->src[0]) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (ops[n_fuse + 1]->src[1]->ne[0] != op->ne[0]) {
|
||||
if (f1->src[1]->ne[0] != op->ne[0]) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (!ggml_is_contiguous_rows(ops[n_fuse + 1]->src[1])) {
|
||||
if (!ggml_is_contiguous_rows(f1->src[1])) {
|
||||
break;
|
||||
}
|
||||
|
||||
if (ops[n_fuse + 1]->type != GGML_TYPE_F32) {
|
||||
if (f1->type != GGML_TYPE_F32) {
|
||||
break;
|
||||
}
|
||||
|
||||
//ctx->fuse_cnt[ops[n_fuse + 1]->op]++;
|
||||
//ctx->fuse_cnt[f1->op]++;
|
||||
|
||||
bid_fuse[n_fuse] = ggml_metal_get_buffer_id(ops[n_fuse + 1]->src[1]);
|
||||
bid_fuse[n_fuse] = ggml_metal_get_buffer_id(f1->src[1]);
|
||||
|
||||
args.nef1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[1];
|
||||
args.nef2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[2];
|
||||
args.nef3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->ne[3];
|
||||
args.nef1[n_fuse + 1] = f1->src[1]->ne[1];
|
||||
args.nef2[n_fuse + 1] = f1->src[1]->ne[2];
|
||||
args.nef3[n_fuse + 1] = f1->src[1]->ne[3];
|
||||
|
||||
args.nbf1[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[1];
|
||||
args.nbf2[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[2];
|
||||
args.nbf3[n_fuse + 1] = ops[n_fuse + 1]->src[1]->nb[3];
|
||||
args.nbf1[n_fuse + 1] = f1->src[1]->nb[1];
|
||||
args.nbf2[n_fuse + 1] = f1->src[1]->nb[2];
|
||||
args.nbf3[n_fuse + 1] = f1->src[1]->nb[3];
|
||||
}
|
||||
|
||||
++n_fuse;
|
||||
@@ -2546,10 +2581,10 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
if (n_fuse > 1) {
|
||||
bid_dst = ggml_metal_get_buffer_id(ops[n_fuse - 1]);
|
||||
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
|
||||
|
||||
for (int i = 1; i < n_fuse; ++i) {
|
||||
if (!ggml_metal_op_concurrency_check(ctx, ops[i])) {
|
||||
if (!ggml_metal_op_concurrency_check(ctx, ctx->node(idx + i))) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
break;
|
||||
@@ -2585,8 +2620,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -2681,8 +2715,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -2752,8 +2785,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -2798,8 +2830,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -2852,8 +2883,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -2897,8 +2927,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -2944,8 +2973,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -2985,8 +3013,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -3020,8 +3047,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -3060,8 +3086,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
@@ -3103,8 +3128,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
|
||||
}
|
||||
|
||||
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_cgraph * gf = ctx->gf;
|
||||
ggml_tensor * op = ggml_graph_node(gf, idx);
|
||||
ggml_tensor * op = ctx->node(idx);
|
||||
|
||||
ggml_metal_library_t lib = ctx->lib;
|
||||
ggml_metal_encoder_t enc = ctx->enc;
|
||||
|
||||
@@ -22,6 +22,8 @@ ggml_metal_op_t ggml_metal_op_init(
|
||||
|
||||
void ggml_metal_op_free(ggml_metal_op_t ctx);
|
||||
|
||||
int ggml_metal_op_n_nodes(ggml_metal_op_t ctx);
|
||||
|
||||
int ggml_metal_op_encode(ggml_metal_op_t ctx, int idx);
|
||||
|
||||
//
|
||||
|
||||
Reference in New Issue
Block a user