metal : reuse graphs

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-07-07 21:19:58 +03:00
parent 0d2038f90a
commit bf8b39015f

View File

@@ -821,13 +821,23 @@ struct ggml_backend_metal_context {
// the callback given to the thread pool
void (^encode_async)(size_t ith);
void (^encode_next)(void);
// n_cb command buffers + 1 used by the main thread
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
struct ggml_metal_command_buffer cmd_bufs_next[2];
// abort ggml_metal_graph_compute if callback returns true
ggml_abort_callback abort_callback;
void * abort_callback_data;
// reuse info
int i_next;
int n_nodes_max;
int n_nodes_prev;
struct ggml_tensor * cg_nodes;
};
// MSL code
@@ -1084,6 +1094,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
ctx->gf = nil;
ctx->encode_async = nil;
ctx->encode_next = nil;
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
ctx->cmd_bufs[i].obj = nil;
@@ -1091,6 +1102,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
ctx->cmd_bufs[i].mem_pool->device = device;
}
for (int i = 0; i < 2; ++i) {
ctx->cmd_bufs_next[i].obj = nil;
ctx->cmd_bufs_next[i].mem_pool = ggml_metal_mem_pool_init();
ctx->cmd_bufs_next[i].mem_pool->device = device;
}
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
if (@available(macOS 10.12, iOS 16.0, *)) {
GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
@@ -1521,6 +1539,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
}
ctx->i_next = 0;
ctx->n_nodes_max = 16384;
ctx->n_nodes_prev = -1;
ctx->cg_nodes = ggml_aligned_malloc(ctx->n_nodes_max * sizeof(struct ggml_tensor));
return ctx;
}
@@ -1532,6 +1557,7 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
}
Block_release(ctx->encode_async);
Block_release(ctx->encode_next);
[ctx->queue release];
@@ -1541,8 +1567,13 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
}
ggml_metal_mem_pool_free(ctx->cmd_bufs_next[0].mem_pool);
ggml_metal_mem_pool_free(ctx->cmd_bufs_next[1].mem_pool);
dispatch_release(ctx->d_queue);
ggml_aligned_free(ctx->cg_nodes, ctx->n_nodes_max * sizeof(struct ggml_tensor));
free(ctx);
}
@@ -5448,6 +5479,39 @@ static enum ggml_status ggml_metal_graph_compute(
struct ggml_backend_metal_context * ctx = backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
//const int64_t t_start = ggml_time_us();
/////////////////////////////////////////////////////
// hacky way to determine that the graph is the same as the previous one
//
bool can_reuse = true;
if (gf->n_nodes > ctx->n_nodes_max) {
can_reuse = false;
}
if (gf->n_nodes != ctx->n_nodes_prev) {
can_reuse = false;
}
if (can_reuse) {
for (int i = 0; i < gf->n_nodes; ++i) {
if (memcmp(gf->nodes[i], ctx->cg_nodes + i, sizeof(struct ggml_tensor)) != 0) {
can_reuse = false;
break;
}
}
}
if (!can_reuse) {
ctx->n_nodes_prev = gf->n_nodes;
for (int i = 0; i < gf->n_nodes; ++i) {
memcpy(ctx->cg_nodes + i, gf->nodes[i], sizeof(struct ggml_tensor));
}
}
//////////////////////////////////////////////////////
// number of nodes encoded by the main thread (empirically determined)
const int n_main = 128;
@@ -5492,78 +5556,126 @@ static enum ggml_status ggml_metal_graph_compute(
}
}
// the main thread commits the first few commands immediately
// cmd_buf[n_cb]
{
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
ctx->cmd_bufs[n_cb].obj = cmd_buf;
if (!can_reuse) {
// the main thread commits the first few commands immediately
// cmd_buf[n_cb]
{
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
ctx->cmd_bufs[n_cb].obj = cmd_buf;
[cmd_buf enqueue];
ctx->encode_async(n_cb);
}
// prepare the rest of the command buffers asynchronously
// cmd_buf[0.. n_cb)
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
// always enqueue the first two command buffers
// enqueue all of the command buffers if we don't need to abort
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[cmd_buf enqueue];
ctx->encode_async(n_cb);
}
}
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
// prepare the rest of the command buffers asynchronously
// cmd_buf[0.. n_cb)
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
// wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881)
{
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
[cmd_buf waitUntilCompleted];
// always enqueue the first two command buffers
// enqueue all of the command buffers if we don't need to abort
if (cb_idx < 2 || ctx->abort_callback == NULL) {
[cmd_buf enqueue];
}
}
MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
// encode the command buffer for the next iter while the GPU has already started
{
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
[cmd_buf retain];
if (ctx->cmd_bufs_next[ctx->i_next].obj != nil) {
[ctx->cmd_bufs_next[ctx->i_next].obj release];
}
ctx->cmd_bufs_next[ctx->i_next].obj = cmd_buf;
ctx->encode_next();
}
// wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881)
{
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
[cmd_buf waitUntilCompleted];
MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}
return GGML_STATUS_FAILED;
}
}
for (int i = 0; i < n_cb; ++i) {
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
[cmd_buf waitUntilCompleted];
MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}
return GGML_STATUS_FAILED;
}
return GGML_STATUS_FAILED;
}
}
for (int i = 0; i < n_cb; ++i) {
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
[cmd_buf waitUntilCompleted];
MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
if (!next_buffer) {
continue;
}
return GGML_STATUS_FAILED;
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
if (next_queued) {
continue;
}
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
return GGML_STATUS_ABORTED;
}
[next_buffer commit];
}
} else {
struct ggml_metal_command_buffer cmd_buf_cur = ctx->cmd_bufs_next[(ctx->i_next + 1)%2];
// directly submit the command buffer that we have prepared in the previous iteration
[ctx->cmd_bufs_next[(ctx->i_next + 1)%2].obj commit];
// encode the command buffer for the next iter
{
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
[cmd_buf retain];
if (ctx->cmd_bufs_next[ctx->i_next].obj != nil) {
[ctx->cmd_bufs_next[ctx->i_next].obj release];
}
ctx->cmd_bufs_next[ctx->i_next].obj = cmd_buf;
ctx->encode_next();
}
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
if (!next_buffer) {
continue;
}
// wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881)
{
id<MTLCommandBuffer> cmd_buf = cmd_buf_cur.obj;
[cmd_buf waitUntilCompleted];
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
if (next_queued) {
continue;
}
MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, ctx->i_next, status);
if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
}
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
return GGML_STATUS_ABORTED;
return GGML_STATUS_FAILED;
}
}
[next_buffer commit];
}
if (!should_capture && ctx->capture_started) {
@@ -5572,6 +5684,8 @@ static enum ggml_status ggml_metal_graph_compute(
}
}
//printf(" time = %.3f ms\n", (float)(ggml_time_us() - t_start)/1000.0f);
return GGML_STATUS_SUCCESS;
}
@@ -5919,6 +6033,10 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
Block_release(ctx->encode_async);
}
if (ctx->encode_next) {
Block_release(ctx->encode_next);
}
ctx->encode_async = Block_copy(^(size_t iter) {
const int cb_idx = iter;
const int n_cb_l = ctx->n_cb;
@@ -5967,6 +6085,40 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
[cmd_buf commit];
}
});
ctx->encode_next = Block_copy(^(void) {
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_next[ctx->i_next].obj;
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
int node_start = 0;
int node_end = ctx->gf->n_nodes;
const bool should_capture = ctx->capture_next_compute;
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs_next[ctx->i_next].mem_pool;
ggml_metal_mem_pool_reset(mem_pool);
for (int idx = node_start; idx < node_end; ++idx) {
if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
}
const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
if (should_capture) {
[encoder popDebugGroup];
}
if (!res) {
break;
}
}
[encoder endEncoding];
ctx->i_next = (ctx->i_next + 1) % 2;
});
}
static struct ggml_backend_i ggml_backend_metal_i = {