metal : allow ops to run concurrently (#15929)

* metal : run graphs ops concurrently

ggml-ci

* cont : add flags for debugging and disabling concurrency

ggml-ci

* cont : refactor and handle fusing

ggml-ci

* cont : simplify - no need to use GPU address

ggml-ci

* cont : prepare mem ranges for reuse + add ggml-metal-common.cpp

ggml-ci

* cont : avoid redundant keywords in cpp [no ci]

* metal : reorder graph for better concurrency

ggml-ci

* metal : fix race on mem pool buffers

ggml-ci

* cont : add env GGML_METAL_GRAPH_OPTIMIZE_DISABLE

ggml-ci

* cont : refactor, optimize, add comments

ggml-ci

* cont : refactor ggml-metal.m

ggml-ci

* minor : update logs [no ci]
This commit is contained in:
Georgi Gerganov
2025-09-13 13:54:28 +03:00
committed by GitHub
parent 84d7b2fca1
commit f161463a54
4 changed files with 719 additions and 38 deletions

View File

@@ -6,6 +6,7 @@ message(STATUS "Metal framework found")
ggml_add_backend_library(ggml-metal
ggml-metal.m
ggml-metal-common.cpp
)
target_link_libraries(ggml-metal PRIVATE

View File

@@ -0,0 +1,445 @@
#include "ggml-metal-common.h"
#include "ggml-impl.h"
#include <vector>
struct ggml_mem_range {
uint64_t pb; // buffer id
uint64_t p0; // begin
uint64_t p1; // end
ggml_mem_range_type pt;
};
struct ggml_mem_ranges {
std::vector<ggml_mem_range> ranges;
int debug = 0;
};
struct ggml_mem_ranges * ggml_mem_ranges_init(int debug) {
auto * res = new ggml_mem_ranges;
res->ranges.reserve(256);
res->debug = debug;
return res;
}
void ggml_mem_ranges_free(ggml_mem_ranges * mrs) {
delete mrs;
}
void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) {
mrs->ranges.clear();
}
static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mrp) {
mrs->ranges.push_back(mrp);
return true;
}
static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) {
// always use the base tensor
tensor = tensor->view_src ? tensor->view_src : tensor;
GGML_ASSERT(!tensor->view_src);
ggml_mem_range mrp;
if (tensor->buffer) {
// when the tensor is allocated, use the actual memory address range of the buffer
mrp = {
/*.pb =*/ (uint64_t) tensor->buffer,
/*.p0 =*/ (uint64_t) tensor->data,
/*.p1 =*/ (uint64_t) tensor->data + ggml_nbytes(tensor),
/*.pt =*/ pt,
};
} else {
// otherwise, the tensor ptr is used as an unique id of the memory ranges
// that the tensor will be using when it is allocated
mrp = {
/*.pb =*/ (uint64_t) tensor,
/*.p0 =*/ 0, //
/*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
/*.pt =*/ pt,
};
};
return mrp;
}
static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) {
return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);
}
static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) {
return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
}
static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor);
ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor);
if (mrs->debug > 2) {
GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1);
}
return ggml_mem_ranges_add(mrs, mrp);
}
static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor);
ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor);
if (mrs->debug > 2) {
GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1);
}
return ggml_mem_ranges_add(mrs, mrp);
}
bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (tensor->src[i]) {
ggml_mem_ranges_add_src(mrs, tensor->src[i]);
}
}
return ggml_mem_ranges_add_dst(mrs, tensor);
}
static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mrp) {
for (size_t i = 0; i < mrs->ranges.size(); i++) {
const auto & cmp = mrs->ranges[i];
if (mrp.pb != cmp.pb) {
continue;
}
if (mrp.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
continue;
}
if (mrp.p0 < cmp.p1 && mrp.p1 >= cmp.p0) {
if (mrs->debug > 2) {
GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
__func__,
mrp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
mrp.pb, mrp.p0, mrp.p1,
cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
cmp.pb, cmp.p0, cmp.p1);
}
return false;
}
}
return true;
}
static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor);
ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor);
const bool res = ggml_mem_ranges_check(mrs, mrp);
return res;
}
static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor);
ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor);
const bool res = ggml_mem_ranges_check(mrs, mrp);
return res;
}
bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (tensor->src[i]) {
if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
return false;
}
}
}
return ggml_mem_ranges_check_dst(mrs, tensor);
}
// TODO: move to ggml.h?
static bool is_empty(ggml_op op) {
switch (op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_TRANSPOSE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
return true;
default:
return false;
}
}
struct node_info {
ggml_tensor * node;
std::vector<ggml_tensor *> fused;
ggml_op op() const {
return node->op;
}
const ggml_tensor * dst() const {
return fused.empty() ? node : fused.back();
}
bool is_empty() const {
return ::is_empty(node->op);
}
void add_fused(ggml_tensor * t) {
fused.push_back(t);
}
};
static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
// helper to add node src and dst ranges
const auto & h_add = [](ggml_mem_ranges * mrs, const node_info & node) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node.node->src[i]) {
if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
return false;
}
}
}
for (const auto * fused : node.fused) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (fused->src[i]) {
if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) {
return false;
}
}
}
}
return ggml_mem_ranges_add_dst(mrs, node.dst());
};
// helper to check if a node can run concurrently with the existing set of nodes
const auto & h_check = [](const ggml_mem_ranges * mrs, const node_info & node) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node.node->src[i]) {
if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
return false;
}
}
}
for (const auto * fused : node.fused) {
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (fused->src[i]) {
if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) {
return false;
}
}
}
}
return ggml_mem_ranges_check_dst(mrs, node.dst());
};
// perform reorders only across these types of ops
// can be expanded when needed
// IMPORTANT: do not add ops such as GGML_OP_CPY or GGML_OP_SET_ROWS
// the dependencies from such ops are not always represented in the graph
const auto & h_safe = [](ggml_op op) {
switch (op) {
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
case GGML_OP_ROPE:
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_DIV:
case GGML_OP_GLU:
case GGML_OP_SCALE:
case GGML_OP_GET_ROWS:
return true;
default:
return is_empty(op);
}
};
const int n = nodes.size();
std::vector<int> res;
res.reserve(n);
std::vector<bool> used(n, false);
ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0);
ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0);
for (int i0 = 0; i0 < n; i0++) {
if (used[i0]) {
continue;
}
const auto & node0 = nodes[i0];
// the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0)
// but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0
//
// note: we can always add empty nodes to the concurrent set as they don't read nor write anything
if (!node0.is_empty() && !h_check(mrs0, node0)) {
// this will hold the set of memory ranges from the nodes that haven't been processed yet
// if a node is not concurrent with this set, we cannot reorder it
ggml_mem_ranges_reset(mrs1);
// initialize it with the current node
h_add(mrs1, node0);
// that many nodes forward to search for a concurrent node
constexpr int N_FORWARD = 8;
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
if (used[i1]) {
continue;
}
const auto & node1 = nodes[i1];
// disallow reordering of certain ops
if (!h_safe(node1.op())) {
break;
}
const bool is_empty = node1.is_empty();
// to add a concurrent node, it has to be:
// + empty or concurrent with all nodes in the existing concurrent set (mrs0)
// + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
// add the node to the existing concurrent set (i.e. reorder it for early execution)
h_add(mrs0, node1);
res.push_back(i1);
// mark as used, so we skip re-processing it later
used[i1] = true;
} else {
// expand the set of nodes that haven't been processed yet
h_add(mrs1, node1);
}
}
// finalize the concurrent set and begin a new one
ggml_mem_ranges_reset(mrs0);
}
// expand the concurrent set with the current node
{
h_add(mrs0, node0);
res.push_back(i0);
}
}
ggml_mem_ranges_free(mrs0);
ggml_mem_ranges_free(mrs1);
return res;
}
void ggml_metal_graph_optimize(ggml_cgraph * gf) {
constexpr int MAX_FUSE = 16;
const int n = gf->n_nodes;
enum ggml_op ops[MAX_FUSE];
std::vector<node_info> nodes;
nodes.reserve(gf->n_nodes);
// fuse nodes:
// we don't want to make reorders that break fusing, so we first pack all fusable tensors
// and perform the reorder over the fused nodes. after the reorder is done, we unfuse
for (int i = 0; i < n; i++) {
node_info node = {
/*.node =*/ gf->nodes[i],
/*.fused =*/ {},
};
// fuse only ops that start with these operations
// can be expanded when needed
if (node.op() == GGML_OP_ADD ||
node.op() == GGML_OP_RMS_NORM) {
ops[0] = node.op();
int f = i + 1;
while (f < n && f < i + MAX_FUSE) {
// conservatively allow fusing only these ops
// can be expanded when needed
if (gf->nodes[f]->op != GGML_OP_ADD &&
gf->nodes[f]->op != GGML_OP_MUL &&
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
break;
}
ops[f - i] = gf->nodes[f]->op;
f++;
}
f -= i;
for (; f > 1; f--) {
if (ggml_can_fuse(gf, i, ops, f)) {
break;
}
}
// add the fused tensors into the node info so we can unfuse them later
for (int k = 1; k < f; k++) {
++i;
// the .dst() becomes the last fused tensor
node.add_fused(gf->nodes[i]);
}
}
nodes.push_back(std::move(node));
}
// reorder to improve concurrency
#if 1
const auto order = ggml_metal_graph_optimize_reorder(nodes);
#else
std::vector<int> order(nodes.size());
for (size_t i = 0; i < nodes.size(); i++) {
order[i] = i;
}
#endif
// unfuse
{
int j = 0;
for (const auto i : order) {
const auto & node = nodes[i];
gf->nodes[j++] = node.node;
for (auto * fused : node.fused) {
gf->nodes[j++] = fused;
}
}
}
}

View File

@@ -0,0 +1,52 @@
// helper functions for ggml-metal that are too difficult to implement in Objective-C
#pragma once
#include <stdbool.h>
#ifdef __cplusplus
extern "C" {
#endif
struct ggml_tensor;
struct ggml_cgraph;
enum ggml_mem_range_type {
MEM_RANGE_TYPE_SRC = 0,
MEM_RANGE_TYPE_DST = 1,
};
// a helper object that can be used for reordering operations to improve concurrency
//
// the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they
// don't write to a memory that is being read by another task or written to by another task in the set
//
// with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task
// can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the
// tasks already in the set)
//
struct ggml_mem_ranges;
struct ggml_mem_ranges * ggml_mem_ranges_init(int debug);
void ggml_mem_ranges_free(struct ggml_mem_ranges * mrs);
// remove all ranges from the set
void ggml_mem_ranges_reset(struct ggml_mem_ranges * mrs);
// add src or dst ranges to track
bool ggml_mem_ranges_add(struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
// return false if:
// - new src range overlaps with any existing dst range
// - new dst range overlaps with any existing range (src or dst)
bool ggml_mem_ranges_check(const struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
// reorder the nodes in the graph to improve concurrency, while respecting fusion
//
// note: this implementation is generic and not specific to metal
// if it proves to work well, we can start using it for other backends in the future
void ggml_metal_graph_optimize(struct ggml_cgraph * gf);
#ifdef __cplusplus
}
#endif

View File

@@ -3,6 +3,7 @@
#import "ggml-impl.h"
#import "ggml-backend-impl.h"
#import "ggml-metal-impl.h"
#import "ggml-metal-common.h"
#import <Foundation/Foundation.h>
@@ -61,8 +62,11 @@ static struct ggml_backend_metal_device_context {
bool has_bfloat;
bool use_bfloat;
bool use_fusion;
bool use_concurrency;
bool use_shared_buffers;
bool use_graph_optimize;
int debug_graph;
int debug_fusion;
// how many times a given op was fused
@@ -83,7 +87,10 @@ static struct ggml_backend_metal_device_context {
/*.has_bfloat =*/ false,
/*.use_bfloat =*/ false,
/*.use_fusion =*/ true,
/*.use_concurrency =*/ true,
/*.use_shared_buffers =*/ true,
/*.use_graph_optimize =*/ true,
/*.debug_graph =*/ 0,
/*.debug_fusion =*/ 0,
/*.fuse_cnt =*/ { 0 },
/*.max_size =*/ 0,
@@ -124,7 +131,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
#else
ctx->use_bfloat = false;
#endif
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
ctx->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
{
const char * val = getenv("GGML_METAL_GRAPH_DEBUG");
ctx->debug_graph = val ? atoi(val) : 0;
}
{
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
@@ -137,6 +151,12 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
ctx->use_shared_buffers = false;
}
ctx->use_graph_optimize = true;
if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) {
ctx->use_graph_optimize = false;
}
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
ctx->max_size = ctx->mtl_device.maxBufferLength;
@@ -628,7 +648,7 @@ static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
@end
//
// ggml_metal_mem_pool
// ggml_metal_mem_pool [TAG_MEM_POOL_REMOVE]
//
struct ggml_metal_mem_pool {
@@ -791,6 +811,9 @@ struct ggml_metal_command_buffer {
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
struct ggml_metal_mem_pool * mem_pool;
// used to enable concurrent execution of ops in the command buffers
struct ggml_mem_ranges * mem_ranges;
};
struct ggml_backend_metal_context {
@@ -1091,7 +1114,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
GGML_LOG_INFO("%s: use fusion = %s\n", __func__, ctx_dev->use_fusion ? "true" : "false");
GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, ctx_dev->use_concurrency ? "true" : "false");
GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, ctx_dev->use_shared_buffers ? "true" : "false");
GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, ctx_dev->use_graph_optimize ? "true" : "false");
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
ctx->capture_next_compute = false;
@@ -1105,6 +1130,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
ctx->cmd_bufs[i].mem_pool->device = device;
if (ctx_dev->use_concurrency) {
ctx->cmd_bufs[i].mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph);
}
}
ctx->cmd_bufs_ext = [[NSMutableArray alloc] init];
@@ -1715,6 +1744,10 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
}
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
if (ctx->cmd_bufs[i].mem_ranges) {
ggml_mem_ranges_free(ctx->cmd_bufs[i].mem_ranges);
}
}
[ctx->cmd_bufs_ext removeAllObjects];
@@ -2071,12 +2104,51 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
}
}
static int ggml_metal_encode_node(
ggml_backend_t backend,
int idx,
int idx_end,
id<MTLComputeCommandEncoder> encoder,
struct ggml_metal_mem_pool * mem_pool) {
struct ggml_metal_encode_context {
ggml_backend_t backend;
id<MTLComputeCommandEncoder> encoder;
struct ggml_metal_mem_pool * mem_pool;
struct ggml_mem_ranges * mem_ranges;
};
static bool ggml_metal_encode_concurrency_reset(struct ggml_metal_encode_context * ctx) {
if (!ctx->mem_ranges) {
return true;
}
[ctx->encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
ggml_mem_ranges_reset(ctx->mem_ranges);
return true;
}
static bool ggml_metal_encode_concurrency_check(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
if (!ctx->mem_ranges) {
return false;
}
return ggml_mem_ranges_check(ctx->mem_ranges, node);
}
static bool ggml_metal_encode_concurrency_add(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
if (!ctx->mem_ranges) {
return true;
}
return ggml_mem_ranges_add(ctx->mem_ranges, node);
}
static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
ggml_backend_t backend = ctx_enc->backend;
id<MTLComputeCommandEncoder> encoder = ctx_enc->encoder;
struct ggml_metal_mem_pool * mem_pool = ctx_enc->mem_pool;
struct ggml_backend_metal_context * ctx = backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
@@ -2159,38 +2231,71 @@ static int ggml_metal_encode_node(
const uint64_t nb2 = dst ? dst->nb[2] : 0;
const uint64_t nb3 = dst ? dst->nb[3] : 0;
size_t offs_src[GGML_MAX_SRC];
id<MTLBuffer> id_src[GGML_MAX_SRC];
enum ggml_type srct[GGML_MAX_SRC];
for (int i = 0; i < GGML_MAX_SRC; i++) {
offs_src[i] = 0;
id_src[i] = node->src[i] ? ggml_metal_get_buffer(node->src[i], &offs_src[i]) : nil;
srct[i] = node->src[i] ? node->src[i]->type : GGML_TYPE_COUNT;
}
// TODO: tmp shorthands - remove
size_t offs_src0 = offs_src[0];
size_t offs_src1 = offs_src[1];
size_t offs_src2 = offs_src[2];
id<MTLBuffer> id_src0 = id_src[0];
id<MTLBuffer> id_src1 = id_src[1];
id<MTLBuffer> id_src2 = id_src[2];
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
size_t offs_src0 = 0;
size_t offs_src1 = 0;
size_t offs_src2 = 0;
size_t offs_dst = 0;
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
int n_fuse = 1;
#if 0
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
// check if the current node can run concurrently with other nodes before it
// the condition is that:
// - the current node cannot write to any previous src or dst ranges
// - the current node cannot read from any previous dst ranges
//
// if the condition is not satisfied, we put a memory barrier and clear all ranges
// otherwise, we add the new ranges to the encoding context and process the node concurrently
//
{
const bool is_concurrent = ggml_metal_encode_concurrency_check(ctx_enc, node);
if (!is_concurrent) {
ggml_metal_encode_concurrency_reset(ctx_enc);
}
if (ctx_dev->debug_graph > 0) {
GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(dst->op), is_concurrent ? "(concurrent)" : "");
}
if (ctx_dev->debug_graph > 1) {
if (src0) {
GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
ggml_is_contiguous(src0), src0->name);
}
if (src1) {
GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
ggml_is_contiguous(src1), src1->name);
}
if (dst) {
GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
GGML_LOG_DEBUG("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
dst->name);
}
#endif
}
}
id<MTLDevice> device = ctx_dev->mtl_device;
@@ -2389,6 +2494,14 @@ static int ggml_metal_encode_node(
if (n_fuse > 1) {
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
for (int i = 1; i < n_fuse; ++i) {
if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
ggml_metal_encode_concurrency_reset(ctx_enc);
break;
}
}
}
[encoder setComputePipelineState:pipeline];
@@ -2533,6 +2646,8 @@ static int ggml_metal_encode_node(
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
ggml_metal_encode_concurrency_reset(ctx_enc);
}
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
@@ -3997,6 +4112,12 @@ static int ggml_metal_encode_node(
default: break;
}
// TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool
// reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
// so we add this extra barrier to prevent the race.
// the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
ggml_metal_encode_concurrency_reset(ctx_enc);
// tokens per expert
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
@@ -4057,6 +4178,9 @@ static int ggml_metal_encode_node(
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
}
// this barrier is always needed because the next kernel has to wait for the id maps to be computed
ggml_metal_encode_concurrency_reset(ctx_enc);
{
id<MTLComputePipelineState> pipeline = nil;
@@ -4525,6 +4649,14 @@ static int ggml_metal_encode_node(
if (n_fuse > 1) {
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
for (int i = 1; i < n_fuse; ++i) {
if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
ggml_metal_encode_concurrency_reset(ctx_enc);
break;
}
}
}
id<MTLComputePipelineState> pipeline;
@@ -4668,7 +4800,6 @@ static int ggml_metal_encode_node(
} break;
case GGML_OP_ROPE:
{
// make sure we have one or more position id(ne10) per token(ne02)
GGML_ASSERT(ne10 % ne02 == 0);
GGML_ASSERT(ne10 >= ne02);
@@ -5427,6 +5558,10 @@ static int ggml_metal_encode_node(
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
// using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
// still, we assume that concurrent FA won't happen before we do the refactor
//ggml_metal_encode_concurrency_reset(ctx_enc);
const int32_t nrows = ne1*ne2*ne3;
// temp buffer for writing the results from each workgroup
@@ -5447,6 +5582,8 @@ static int ggml_metal_encode_node(
[encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
ggml_metal_encode_concurrency_reset(ctx_enc);
// reduce the results from the workgroups
{
ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
@@ -5709,6 +5846,19 @@ static int ggml_metal_encode_node(
}
}
if (ctx_dev->debug_graph > 0) {
if (n_fuse > 1) {
GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse);
}
}
// update the mem ranges in the encoding context
for (int i = 0; i < n_fuse; ++i) {
if (!ggml_metal_encode_concurrency_add(ctx_enc, nodes[i])) {
ggml_metal_encode_concurrency_reset(ctx_enc);
}
}
return n_fuse;
}
@@ -5719,7 +5869,7 @@ static enum ggml_status ggml_metal_graph_compute(
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
// number of nodes encoded by the main thread (empirically determined)
const int n_main = 128;
const int n_main = 64;
// number of threads in addition to the main thread
const int n_cb = ctx->n_cb;
@@ -5774,6 +5924,7 @@ static enum ggml_status ggml_metal_graph_compute(
// cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed
// TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences
// https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
// [TAG_MEM_POOL_REMOVE]
//id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
[cmd_buf retain];
@@ -6547,6 +6698,18 @@ static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend,
return ggml_metal_graph_compute(backend, cgraph);
}
static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
//const int64_t t_start = ggml_time_us();
if (ctx_dev->use_graph_optimize) {
ggml_metal_graph_optimize(cgraph);
}
//printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0);
}
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
GGML_ASSERT(ggml_backend_is_metal(backend));
@@ -6575,10 +6738,23 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
struct ggml_mem_ranges * mem_ranges = ctx->cmd_bufs[cb_idx].mem_ranges;
ggml_metal_mem_pool_reset(mem_pool);
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
if (mem_ranges) {
ggml_mem_ranges_reset(mem_ranges);
}
id<MTLComputeCommandEncoder> encoder;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
if (ctx_dev->use_concurrency) {
encoder = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
} else {
encoder = [cmd_buf computeCommandEncoder];
}
int node_start = 0;
int node_end = n_nodes_0;
@@ -6590,12 +6766,19 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
const bool should_capture = ctx->capture_next_compute;
struct ggml_metal_encode_context ctx_enc = {
/*.backend =*/ backend,
/*.encoder =*/ encoder,
/*.mem_pool =*/ mem_pool,
/*.mem_ranges =*/ mem_ranges,
};
for (int idx = node_start; idx < node_end;) {
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
}
const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
const int res = ggml_metal_encode_node(&ctx_enc, idx, node_end);
if (idx + res > node_end) {
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
"https://github.com/ggml-org/llama.cpp/pull/14849");
@@ -6638,7 +6821,7 @@ static struct ggml_backend_i ggml_backend_metal_i = {
// https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
/* .optimize_graph = */ ggml_backend_metal_graph_optimize,
};
static ggml_guid_t ggml_backend_metal_guid(void) {