mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
metal : reuse graphs
ggml-ci
This commit is contained in:
@@ -821,13 +821,23 @@ struct ggml_backend_metal_context {
|
||||
|
||||
// the callback given to the thread pool
|
||||
void (^encode_async)(size_t ith);
|
||||
void (^encode_next)(void);
|
||||
|
||||
// n_cb command buffers + 1 used by the main thread
|
||||
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
||||
struct ggml_metal_command_buffer cmd_bufs_next[2];
|
||||
|
||||
// abort ggml_metal_graph_compute if callback returns true
|
||||
ggml_abort_callback abort_callback;
|
||||
void * abort_callback_data;
|
||||
|
||||
// reuse info
|
||||
int i_next;
|
||||
|
||||
int n_nodes_max;
|
||||
int n_nodes_prev;
|
||||
|
||||
struct ggml_tensor * cg_nodes;
|
||||
};
|
||||
|
||||
// MSL code
|
||||
@@ -1084,6 +1094,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
|
||||
ctx->gf = nil;
|
||||
ctx->encode_async = nil;
|
||||
ctx->encode_next = nil;
|
||||
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
||||
ctx->cmd_bufs[i].obj = nil;
|
||||
|
||||
@@ -1091,6 +1102,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
ctx->cmd_bufs[i].mem_pool->device = device;
|
||||
}
|
||||
|
||||
for (int i = 0; i < 2; ++i) {
|
||||
ctx->cmd_bufs_next[i].obj = nil;
|
||||
|
||||
ctx->cmd_bufs_next[i].mem_pool = ggml_metal_mem_pool_init();
|
||||
ctx->cmd_bufs_next[i].mem_pool->device = device;
|
||||
}
|
||||
|
||||
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
||||
if (@available(macOS 10.12, iOS 16.0, *)) {
|
||||
GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
|
||||
@@ -1521,6 +1539,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true);
|
||||
}
|
||||
|
||||
ctx->i_next = 0;
|
||||
|
||||
ctx->n_nodes_max = 16384;
|
||||
ctx->n_nodes_prev = -1;
|
||||
|
||||
ctx->cg_nodes = ggml_aligned_malloc(ctx->n_nodes_max * sizeof(struct ggml_tensor));
|
||||
|
||||
return ctx;
|
||||
}
|
||||
|
||||
@@ -1532,6 +1557,7 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
||||
}
|
||||
|
||||
Block_release(ctx->encode_async);
|
||||
Block_release(ctx->encode_next);
|
||||
|
||||
[ctx->queue release];
|
||||
|
||||
@@ -1541,8 +1567,13 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
||||
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
|
||||
}
|
||||
|
||||
ggml_metal_mem_pool_free(ctx->cmd_bufs_next[0].mem_pool);
|
||||
ggml_metal_mem_pool_free(ctx->cmd_bufs_next[1].mem_pool);
|
||||
|
||||
dispatch_release(ctx->d_queue);
|
||||
|
||||
ggml_aligned_free(ctx->cg_nodes, ctx->n_nodes_max * sizeof(struct ggml_tensor));
|
||||
|
||||
free(ctx);
|
||||
}
|
||||
|
||||
@@ -5448,6 +5479,39 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
struct ggml_backend_metal_context * ctx = backend->context;
|
||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
/////////////////////////////////////////////////////
|
||||
// hacky way to determine that the graph is the same as the previous one
|
||||
//
|
||||
bool can_reuse = true;
|
||||
|
||||
if (gf->n_nodes > ctx->n_nodes_max) {
|
||||
can_reuse = false;
|
||||
}
|
||||
|
||||
if (gf->n_nodes != ctx->n_nodes_prev) {
|
||||
can_reuse = false;
|
||||
}
|
||||
|
||||
if (can_reuse) {
|
||||
for (int i = 0; i < gf->n_nodes; ++i) {
|
||||
if (memcmp(gf->nodes[i], ctx->cg_nodes + i, sizeof(struct ggml_tensor)) != 0) {
|
||||
can_reuse = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!can_reuse) {
|
||||
ctx->n_nodes_prev = gf->n_nodes;
|
||||
|
||||
for (int i = 0; i < gf->n_nodes; ++i) {
|
||||
memcpy(ctx->cg_nodes + i, gf->nodes[i], sizeof(struct ggml_tensor));
|
||||
}
|
||||
}
|
||||
//////////////////////////////////////////////////////
|
||||
|
||||
// number of nodes encoded by the main thread (empirically determined)
|
||||
const int n_main = 128;
|
||||
|
||||
@@ -5492,78 +5556,126 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
}
|
||||
}
|
||||
|
||||
// the main thread commits the first few commands immediately
|
||||
// cmd_buf[n_cb]
|
||||
{
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
||||
if (!can_reuse) {
|
||||
// the main thread commits the first few commands immediately
|
||||
// cmd_buf[n_cb]
|
||||
{
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
||||
|
||||
[cmd_buf enqueue];
|
||||
ctx->encode_async(n_cb);
|
||||
}
|
||||
|
||||
// prepare the rest of the command buffers asynchronously
|
||||
// cmd_buf[0.. n_cb)
|
||||
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
|
||||
|
||||
// always enqueue the first two command buffers
|
||||
// enqueue all of the command buffers if we don't need to abort
|
||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||
[cmd_buf enqueue];
|
||||
ctx->encode_async(n_cb);
|
||||
}
|
||||
}
|
||||
|
||||
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
|
||||
// prepare the rest of the command buffers asynchronously
|
||||
// cmd_buf[0.. n_cb)
|
||||
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
|
||||
|
||||
// wait for completion and check status of each command buffer
|
||||
// needed to detect if the device ran out-of-memory for example (#1881)
|
||||
{
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
|
||||
[cmd_buf waitUntilCompleted];
|
||||
// always enqueue the first two command buffers
|
||||
// enqueue all of the command buffers if we don't need to abort
|
||||
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
||||
[cmd_buf enqueue];
|
||||
}
|
||||
}
|
||||
|
||||
MTLCommandBufferStatus status = [cmd_buf status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
|
||||
|
||||
// encode the command buffer for the next iter while the GPU has already started
|
||||
{
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
[cmd_buf retain];
|
||||
if (ctx->cmd_bufs_next[ctx->i_next].obj != nil) {
|
||||
[ctx->cmd_bufs_next[ctx->i_next].obj release];
|
||||
}
|
||||
ctx->cmd_bufs_next[ctx->i_next].obj = cmd_buf;
|
||||
|
||||
ctx->encode_next();
|
||||
}
|
||||
|
||||
// wait for completion and check status of each command buffer
|
||||
// needed to detect if the device ran out-of-memory for example (#1881)
|
||||
{
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
|
||||
[cmd_buf waitUntilCompleted];
|
||||
|
||||
MTLCommandBufferStatus status = [cmd_buf status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_cb; ++i) {
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
|
||||
[cmd_buf waitUntilCompleted];
|
||||
|
||||
MTLCommandBufferStatus status = [cmd_buf status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_cb; ++i) {
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
|
||||
[cmd_buf waitUntilCompleted];
|
||||
|
||||
MTLCommandBufferStatus status = [cmd_buf status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
|
||||
if (!next_buffer) {
|
||||
continue;
|
||||
}
|
||||
|
||||
return GGML_STATUS_FAILED;
|
||||
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
|
||||
if (next_queued) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
|
||||
GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
|
||||
return GGML_STATUS_ABORTED;
|
||||
}
|
||||
|
||||
[next_buffer commit];
|
||||
}
|
||||
} else {
|
||||
struct ggml_metal_command_buffer cmd_buf_cur = ctx->cmd_bufs_next[(ctx->i_next + 1)%2];
|
||||
|
||||
// directly submit the command buffer that we have prepared in the previous iteration
|
||||
[ctx->cmd_bufs_next[(ctx->i_next + 1)%2].obj commit];
|
||||
|
||||
// encode the command buffer for the next iter
|
||||
{
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
[cmd_buf retain];
|
||||
if (ctx->cmd_bufs_next[ctx->i_next].obj != nil) {
|
||||
[ctx->cmd_bufs_next[ctx->i_next].obj release];
|
||||
}
|
||||
ctx->cmd_bufs_next[ctx->i_next].obj = cmd_buf;
|
||||
|
||||
ctx->encode_next();
|
||||
}
|
||||
|
||||
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
|
||||
if (!next_buffer) {
|
||||
continue;
|
||||
}
|
||||
// wait for completion and check status of each command buffer
|
||||
// needed to detect if the device ran out-of-memory for example (#1881)
|
||||
{
|
||||
id<MTLCommandBuffer> cmd_buf = cmd_buf_cur.obj;
|
||||
[cmd_buf waitUntilCompleted];
|
||||
|
||||
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
|
||||
if (next_queued) {
|
||||
continue;
|
||||
}
|
||||
MTLCommandBufferStatus status = [cmd_buf status];
|
||||
if (status != MTLCommandBufferStatusCompleted) {
|
||||
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, ctx->i_next, status);
|
||||
if (status == MTLCommandBufferStatusError) {
|
||||
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
||||
}
|
||||
|
||||
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
|
||||
GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
|
||||
return GGML_STATUS_ABORTED;
|
||||
return GGML_STATUS_FAILED;
|
||||
}
|
||||
}
|
||||
|
||||
[next_buffer commit];
|
||||
}
|
||||
|
||||
if (!should_capture && ctx->capture_started) {
|
||||
@@ -5572,6 +5684,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
}
|
||||
}
|
||||
|
||||
//printf(" time = %.3f ms\n", (float)(ggml_time_us() - t_start)/1000.0f);
|
||||
|
||||
return GGML_STATUS_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -5919,6 +6033,10 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||
Block_release(ctx->encode_async);
|
||||
}
|
||||
|
||||
if (ctx->encode_next) {
|
||||
Block_release(ctx->encode_next);
|
||||
}
|
||||
|
||||
ctx->encode_async = Block_copy(^(size_t iter) {
|
||||
const int cb_idx = iter;
|
||||
const int n_cb_l = ctx->n_cb;
|
||||
@@ -5967,6 +6085,40 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||
[cmd_buf commit];
|
||||
}
|
||||
});
|
||||
|
||||
ctx->encode_next = Block_copy(^(void) {
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_next[ctx->i_next].obj;
|
||||
|
||||
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
|
||||
|
||||
int node_start = 0;
|
||||
int node_end = ctx->gf->n_nodes;
|
||||
|
||||
const bool should_capture = ctx->capture_next_compute;
|
||||
|
||||
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs_next[ctx->i_next].mem_pool;
|
||||
ggml_metal_mem_pool_reset(mem_pool);
|
||||
|
||||
for (int idx = node_start; idx < node_end; ++idx) {
|
||||
if (should_capture) {
|
||||
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
||||
}
|
||||
|
||||
const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
|
||||
|
||||
if (should_capture) {
|
||||
[encoder popDebugGroup];
|
||||
}
|
||||
|
||||
if (!res) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
[encoder endEncoding];
|
||||
|
||||
ctx->i_next = (ctx->i_next + 1) % 2;
|
||||
});
|
||||
}
|
||||
|
||||
static struct ggml_backend_i ggml_backend_metal_i = {
|
||||
|
||||
Reference in New Issue
Block a user