diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 40fc315e82..de3dcb6075 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -821,13 +821,23 @@ struct ggml_backend_metal_context { // the callback given to the thread pool void (^encode_async)(size_t ith); + void (^encode_next)(void); // n_cb command buffers + 1 used by the main thread struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; + struct ggml_metal_command_buffer cmd_bufs_next[2]; // abort ggml_metal_graph_compute if callback returns true ggml_abort_callback abort_callback; void * abort_callback_data; + + // reuse info + int i_next; + + int n_nodes_max; + int n_nodes_prev; + + struct ggml_tensor * cg_nodes; }; // MSL code @@ -1084,6 +1094,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ctx->gf = nil; ctx->encode_async = nil; + ctx->encode_next = nil; for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { ctx->cmd_bufs[i].obj = nil; @@ -1091,6 +1102,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ctx->cmd_bufs[i].mem_pool->device = device; } + for (int i = 0; i < 2; ++i) { + ctx->cmd_bufs_next[i].obj = nil; + + ctx->cmd_bufs_next[i].mem_pool = ggml_metal_mem_pool_init(); + ctx->cmd_bufs_next[i].mem_pool->device = device; + } + #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6); @@ -1521,6 +1539,13 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32, pool_2d_max_f32, true); } + ctx->i_next = 0; + + ctx->n_nodes_max = 16384; + ctx->n_nodes_prev = -1; + + ctx->cg_nodes = ggml_aligned_malloc(ctx->n_nodes_max * sizeof(struct ggml_tensor)); + return ctx; } @@ -1532,6 +1557,7 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { } Block_release(ctx->encode_async); + Block_release(ctx->encode_next); [ctx->queue release]; @@ -1541,8 +1567,13 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool); } + ggml_metal_mem_pool_free(ctx->cmd_bufs_next[0].mem_pool); + ggml_metal_mem_pool_free(ctx->cmd_bufs_next[1].mem_pool); + dispatch_release(ctx->d_queue); + ggml_aligned_free(ctx->cg_nodes, ctx->n_nodes_max * sizeof(struct ggml_tensor)); + free(ctx); } @@ -5448,6 +5479,39 @@ static enum ggml_status ggml_metal_graph_compute( struct ggml_backend_metal_context * ctx = backend->context; struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; + //const int64_t t_start = ggml_time_us(); + + ///////////////////////////////////////////////////// + // hacky way to determine that the graph is the same as the previous one + // + bool can_reuse = true; + + if (gf->n_nodes > ctx->n_nodes_max) { + can_reuse = false; + } + + if (gf->n_nodes != ctx->n_nodes_prev) { + can_reuse = false; + } + + if (can_reuse) { + for (int i = 0; i < gf->n_nodes; ++i) { + if (memcmp(gf->nodes[i], ctx->cg_nodes + i, sizeof(struct ggml_tensor)) != 0) { + can_reuse = false; + break; + } + } + } + + if (!can_reuse) { + ctx->n_nodes_prev = gf->n_nodes; + + for (int i = 0; i < gf->n_nodes; ++i) { + memcpy(ctx->cg_nodes + i, gf->nodes[i], sizeof(struct ggml_tensor)); + } + } + ////////////////////////////////////////////////////// + // number of nodes encoded by the main thread (empirically determined) const int n_main = 128; @@ -5492,78 +5556,126 @@ static enum ggml_status ggml_metal_graph_compute( } } - // the main thread commits the first few commands immediately - // cmd_buf[n_cb] - { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; - ctx->cmd_bufs[n_cb].obj = cmd_buf; + if (!can_reuse) { + // the main thread commits the first few commands immediately + // cmd_buf[n_cb] + { + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->cmd_bufs[n_cb].obj = cmd_buf; - [cmd_buf enqueue]; - ctx->encode_async(n_cb); - } - - // prepare the rest of the command buffers asynchronously - // cmd_buf[0.. n_cb) - for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { - id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; - ctx->cmd_bufs[cb_idx].obj = cmd_buf; - - // always enqueue the first two command buffers - // enqueue all of the command buffers if we don't need to abort - if (cb_idx < 2 || ctx->abort_callback == NULL) { [cmd_buf enqueue]; + ctx->encode_async(n_cb); } - } - dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); + // prepare the rest of the command buffers asynchronously + // cmd_buf[0.. n_cb) + for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + ctx->cmd_bufs[cb_idx].obj = cmd_buf; - // wait for completion and check status of each command buffer - // needed to detect if the device ran out-of-memory for example (#1881) - { - id cmd_buf = ctx->cmd_bufs[n_cb].obj; - [cmd_buf waitUntilCompleted]; + // always enqueue the first two command buffers + // enqueue all of the command buffers if we don't need to abort + if (cb_idx < 2 || ctx->abort_callback == NULL) { + [cmd_buf enqueue]; + } + } - MTLCommandBufferStatus status = [cmd_buf status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); - if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async); + + // encode the command buffer for the next iter while the GPU has already started + { + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + [cmd_buf retain]; + if (ctx->cmd_bufs_next[ctx->i_next].obj != nil) { + [ctx->cmd_bufs_next[ctx->i_next].obj release]; + } + ctx->cmd_bufs_next[ctx->i_next].obj = cmd_buf; + + ctx->encode_next(); + } + + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id cmd_buf = ctx->cmd_bufs[n_cb].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; + } + } + + for (int i = 0; i < n_cb; ++i) { + id cmd_buf = ctx->cmd_bufs[i].obj; + [cmd_buf waitUntilCompleted]; + + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } + + return GGML_STATUS_FAILED; } - return GGML_STATUS_FAILED; - } - } - - for (int i = 0; i < n_cb; ++i) { - id cmd_buf = ctx->cmd_bufs[i].obj; - [cmd_buf waitUntilCompleted]; - - MTLCommandBufferStatus status = [cmd_buf status]; - if (status != MTLCommandBufferStatusCompleted) { - GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); - if (status == MTLCommandBufferStatusError) { - GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); + if (!next_buffer) { + continue; } - return GGML_STATUS_FAILED; + const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); + if (next_queued) { + continue; + } + + if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { + GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); + return GGML_STATUS_ABORTED; + } + + [next_buffer commit]; + } + } else { + struct ggml_metal_command_buffer cmd_buf_cur = ctx->cmd_bufs_next[(ctx->i_next + 1)%2]; + + // directly submit the command buffer that we have prepared in the previous iteration + [ctx->cmd_bufs_next[(ctx->i_next + 1)%2].obj commit]; + + // encode the command buffer for the next iter + { + id cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; + [cmd_buf retain]; + if (ctx->cmd_bufs_next[ctx->i_next].obj != nil) { + [ctx->cmd_bufs_next[ctx->i_next].obj release]; + } + ctx->cmd_bufs_next[ctx->i_next].obj = cmd_buf; + + ctx->encode_next(); } - id next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil); - if (!next_buffer) { - continue; - } + // wait for completion and check status of each command buffer + // needed to detect if the device ran out-of-memory for example (#1881) + { + id cmd_buf = cmd_buf_cur.obj; + [cmd_buf waitUntilCompleted]; - const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued); - if (next_queued) { - continue; - } + MTLCommandBufferStatus status = [cmd_buf status]; + if (status != MTLCommandBufferStatusCompleted) { + GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, ctx->i_next, status); + if (status == MTLCommandBufferStatusError) { + GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]); + } - if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) { - GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i); - return GGML_STATUS_ABORTED; + return GGML_STATUS_FAILED; + } } - - [next_buffer commit]; } if (!should_capture && ctx->capture_started) { @@ -5572,6 +5684,8 @@ static enum ggml_status ggml_metal_graph_compute( } } + //printf(" time = %.3f ms\n", (float)(ggml_time_us() - t_start)/1000.0f); + return GGML_STATUS_SUCCESS; } @@ -5919,6 +6033,10 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { Block_release(ctx->encode_async); } + if (ctx->encode_next) { + Block_release(ctx->encode_next); + } + ctx->encode_async = Block_copy(^(size_t iter) { const int cb_idx = iter; const int n_cb_l = ctx->n_cb; @@ -5967,6 +6085,40 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { [cmd_buf commit]; } }); + + ctx->encode_next = Block_copy(^(void) { + id cmd_buf = ctx->cmd_bufs_next[ctx->i_next].obj; + + id encoder = [cmd_buf computeCommandEncoder]; + + int node_start = 0; + int node_end = ctx->gf->n_nodes; + + const bool should_capture = ctx->capture_next_compute; + + struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs_next[ctx->i_next].mem_pool; + ggml_metal_mem_pool_reset(mem_pool); + + for (int idx = node_start; idx < node_end; ++idx) { + if (should_capture) { + [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; + } + + const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool); + + if (should_capture) { + [encoder popDebugGroup]; + } + + if (!res) { + break; + } + } + + [encoder endEncoding]; + + ctx->i_next = (ctx->i_next + 1) % 2; + }); } static struct ggml_backend_i ggml_backend_metal_i = {