mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-17 11:37:10 +00:00
601 lines
21 KiB
Objective-C
601 lines
21 KiB
Objective-C
#import "ggml-metal-context.h"
|
|
|
|
#import "ggml-impl.h"
|
|
#import "ggml-backend-impl.h"
|
|
|
|
#import "ggml-metal-impl.h"
|
|
#import "ggml-metal-common.h"
|
|
#import "ggml-metal-ops.h"
|
|
|
|
#import <Foundation/Foundation.h>
|
|
|
|
#import <Metal/Metal.h>
|
|
|
|
#undef MIN
|
|
#undef MAX
|
|
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
|
|
|
// max number of MTLCommandBuffer used to submit a graph for processing
|
|
#define GGML_METAL_MAX_COMMAND_BUFFERS 8
|
|
|
|
struct ggml_metal_command_buffer {
|
|
id<MTLCommandBuffer> obj;
|
|
};
|
|
|
|
struct ggml_metal {
|
|
id<MTLDevice> device;
|
|
id<MTLCommandQueue> queue; // currently a pointer to the device queue, but might become separate queue [TAG_QUEUE_PER_BACKEND]
|
|
|
|
ggml_metal_device_t dev;
|
|
ggml_metal_library_t lib;
|
|
|
|
dispatch_queue_t d_queue;
|
|
|
|
// additional, inference-time compiled pipelines
|
|
ggml_metal_pipelines_t pipelines_ext;
|
|
|
|
bool use_bfloat;
|
|
bool use_fusion;
|
|
bool use_concurrency;
|
|
bool use_graph_optimize;
|
|
|
|
int debug_graph;
|
|
int debug_fusion;
|
|
|
|
// how many times a given op was fused
|
|
uint64_t fuse_cnt[GGML_OP_COUNT];
|
|
|
|
// capture state
|
|
bool capture_next_compute;
|
|
bool capture_started;
|
|
|
|
id<MTLCaptureScope> capture_scope;
|
|
|
|
// command buffer state
|
|
int n_cb; // number of extra threads used to submit the command buffers
|
|
int n_nodes_0; // number of nodes submitted by the main thread
|
|
int n_nodes_1; // remaining number of nodes submitted by the n_cb threads
|
|
int n_nodes_per_cb;
|
|
|
|
struct ggml_cgraph * gf;
|
|
|
|
// the callback given to the thread pool
|
|
void (^encode_async)(size_t ith);
|
|
|
|
// n_cb command buffers + 1 used by the main thread
|
|
struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
|
|
|
|
// extra command buffers for things like getting, setting and copying tensors
|
|
NSMutableArray * cmd_bufs_ext;
|
|
|
|
// the last command buffer queued into the Metal queue with operations relevant to the current Metal backend
|
|
id<MTLCommandBuffer> cmd_buf_last;
|
|
|
|
// abort ggml_metal_graph_compute if callback returns true
|
|
ggml_abort_callback abort_callback;
|
|
void * abort_callback_data;
|
|
};
|
|
|
|
ggml_metal_t ggml_metal_init(ggml_metal_device_t dev) {
|
|
GGML_LOG_INFO("%s: allocating\n", __func__);
|
|
|
|
#if TARGET_OS_OSX && !GGML_METAL_NDEBUG
|
|
// Show all the Metal device instances in the system
|
|
NSArray * devices = MTLCopyAllDevices();
|
|
for (id<MTLDevice> device in devices) {
|
|
GGML_LOG_INFO("%s: found device: %s\n", __func__, [[device name] UTF8String]);
|
|
}
|
|
[devices release]; // since it was created by a *Copy* C method
|
|
#endif
|
|
|
|
// init context
|
|
ggml_metal_t res = calloc(1, sizeof(struct ggml_metal));
|
|
|
|
res->device = ggml_metal_device_get_obj(dev);
|
|
|
|
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[res->device name] UTF8String]);
|
|
|
|
// TODO: would it be better to have one queue for the backend and one queue for the device?
|
|
// the graph encoders and async ops would use the backend queue while the sync ops would use the device queue?
|
|
//res->queue = [device newCommandQueue]; [TAG_QUEUE_PER_BACKEND]
|
|
res->queue = ggml_metal_device_get_queue(dev);
|
|
if (res->queue == nil) {
|
|
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
|
|
return NULL;
|
|
}
|
|
|
|
res->dev = dev;
|
|
res->lib = ggml_metal_device_get_library(dev);
|
|
if (res->lib == NULL) {
|
|
GGML_LOG_WARN("%s: the device does not have a precompiled Metal library - this is unexpected\n", __func__);
|
|
GGML_LOG_WARN("%s: will try to compile it on the fly\n", __func__);
|
|
|
|
res->lib = ggml_metal_library_init(dev);
|
|
if (res->lib == NULL) {
|
|
GGML_LOG_ERROR("%s: error: failed to initialize the Metal library\n", __func__);
|
|
|
|
free(res);
|
|
|
|
return NULL;
|
|
}
|
|
}
|
|
|
|
const struct ggml_metal_device_props * props_dev = ggml_metal_device_get_props(dev);
|
|
|
|
res->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
|
|
|
|
res->use_bfloat = props_dev->has_bfloat;
|
|
res->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
|
res->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
|
|
|
|
{
|
|
const char * val = getenv("GGML_METAL_GRAPH_DEBUG");
|
|
res->debug_graph = val ? atoi(val) : 0;
|
|
}
|
|
|
|
{
|
|
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
|
res->debug_fusion = val ? atoi(val) : 0;
|
|
}
|
|
|
|
res->use_graph_optimize = true;
|
|
|
|
if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) {
|
|
res->use_graph_optimize = false;
|
|
}
|
|
|
|
memset(res->fuse_cnt, 0, sizeof(res->fuse_cnt));
|
|
|
|
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, res->use_bfloat ? "true" : "false");
|
|
GGML_LOG_INFO("%s: use fusion = %s\n", __func__, res->use_fusion ? "true" : "false");
|
|
GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, res->use_concurrency ? "true" : "false");
|
|
GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, res->use_graph_optimize ? "true" : "false");
|
|
|
|
res->capture_next_compute = false;
|
|
res->capture_started = false;
|
|
res->capture_scope = nil;
|
|
|
|
res->gf = nil;
|
|
res->encode_async = nil;
|
|
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
|
res->cmd_bufs[i].obj = nil;
|
|
}
|
|
|
|
res->cmd_bufs_ext = [[NSMutableArray alloc] init];
|
|
|
|
res->cmd_buf_last = nil;
|
|
|
|
res->pipelines_ext = ggml_metal_pipelines_init();
|
|
|
|
return res;
|
|
}
|
|
|
|
void ggml_metal_free(ggml_metal_t ctx) {
|
|
GGML_LOG_INFO("%s: deallocating\n", __func__);
|
|
|
|
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
|
|
if (ctx->cmd_bufs[i].obj) {
|
|
[ctx->cmd_bufs[i].obj release];
|
|
}
|
|
}
|
|
|
|
for (int i = 0; i < (int) ctx->cmd_bufs_ext.count; ++i) {
|
|
if (ctx->cmd_bufs_ext[i]) {
|
|
[ctx->cmd_bufs_ext[i] release];
|
|
}
|
|
}
|
|
|
|
[ctx->cmd_bufs_ext removeAllObjects];
|
|
[ctx->cmd_bufs_ext release];
|
|
|
|
if (ctx->pipelines_ext) {
|
|
ggml_metal_pipelines_free(ctx->pipelines_ext);
|
|
ctx->pipelines_ext = nil;
|
|
}
|
|
|
|
if (ctx->debug_fusion > 0) {
|
|
GGML_LOG_DEBUG("%s: fusion stats:\n", __func__);
|
|
for (int i = 0; i < GGML_OP_COUNT; i++) {
|
|
if (ctx->fuse_cnt[i] == 0) {
|
|
continue;
|
|
}
|
|
|
|
// note: cannot use ggml_log here
|
|
GGML_LOG_DEBUG("%s: - %s: %" PRIu64 "\n", __func__, ggml_op_name((enum ggml_op) i), ctx->fuse_cnt[i]);
|
|
}
|
|
}
|
|
|
|
Block_release(ctx->encode_async);
|
|
|
|
//[ctx->queue release]; // [TAG_QUEUE_PER_BACKEND]
|
|
|
|
dispatch_release(ctx->d_queue);
|
|
|
|
free(ctx);
|
|
}
|
|
|
|
void ggml_metal_synchronize(ggml_metal_t ctx) {
|
|
// wait for any backend operations to finish
|
|
if (ctx->cmd_buf_last) {
|
|
[ctx->cmd_buf_last waitUntilCompleted];
|
|
ctx->cmd_buf_last = nil;
|
|
}
|
|
|
|
// check status of all command buffers
|
|
{
|
|
const int n_cb = ctx->n_cb;
|
|
|
|
for (int cb_idx = 0; cb_idx <= n_cb; ++cb_idx) {
|
|
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
|
if (!cmd_buf) {
|
|
continue;
|
|
}
|
|
|
|
MTLCommandBufferStatus status = [cmd_buf status];
|
|
if (status != MTLCommandBufferStatusCompleted) {
|
|
GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, cb_idx, (int) status);
|
|
if (status == MTLCommandBufferStatusError) {
|
|
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
|
}
|
|
GGML_ABORT("fatal error");
|
|
}
|
|
}
|
|
}
|
|
|
|
// release any completed extra command buffers
|
|
if (ctx->cmd_bufs_ext.count > 0) {
|
|
for (size_t i = 0; i < ctx->cmd_bufs_ext.count; ++i) {
|
|
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs_ext[i];
|
|
|
|
MTLCommandBufferStatus status = [cmd_buf status];
|
|
if (status != MTLCommandBufferStatusCompleted) {
|
|
GGML_LOG_ERROR("%s: error: command buffer %d failed with status %d\n", __func__, (int) i, (int) status);
|
|
if (status == MTLCommandBufferStatusError) {
|
|
GGML_LOG_ERROR("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
|
}
|
|
GGML_ABORT("fatal error");
|
|
}
|
|
|
|
[cmd_buf release];
|
|
}
|
|
|
|
[ctx->cmd_bufs_ext removeAllObjects];
|
|
}
|
|
}
|
|
|
|
static struct ggml_metal_buffer_id ggml_metal_get_buffer_id(const struct ggml_tensor * t) {
|
|
if (!t) {
|
|
return (struct ggml_metal_buffer_id) { nil, 0 };
|
|
}
|
|
|
|
ggml_backend_buffer_t buffer = t->view_src ? t->view_src->buffer : t->buffer;
|
|
|
|
return ggml_metal_buffer_get_id(buffer->context, t);
|
|
}
|
|
|
|
void ggml_metal_set_tensor_async(ggml_metal_t ctx, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
|
|
@autoreleasepool {
|
|
// wrap the source data into a Metal buffer
|
|
id<MTLBuffer> buf_src = [ctx->device newBufferWithBytes:data
|
|
length:size
|
|
options:MTLResourceStorageModeShared];
|
|
|
|
GGML_ASSERT(buf_src);
|
|
|
|
struct ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(tensor);
|
|
if (bid_dst.metal == nil) {
|
|
GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
|
|
}
|
|
|
|
bid_dst.offs += offset;
|
|
|
|
// queue the copy operation into the queue of the Metal context
|
|
// this will be queued at the end, after any currently ongoing GPU operations
|
|
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
|
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
|
|
|
[encoder copyFromBuffer:buf_src
|
|
sourceOffset:0
|
|
toBuffer:bid_dst.metal
|
|
destinationOffset:bid_dst.offs
|
|
size:size];
|
|
|
|
[encoder endEncoding];
|
|
[cmd_buf commit];
|
|
|
|
// do not wait here for completion
|
|
//[cmd_buf waitUntilCompleted];
|
|
|
|
// instead, remember a reference to the command buffer and wait for it later if needed
|
|
[ctx->cmd_bufs_ext addObject:cmd_buf];
|
|
ctx->cmd_buf_last = cmd_buf;
|
|
|
|
[cmd_buf retain];
|
|
}
|
|
}
|
|
|
|
void ggml_metal_get_tensor_async(ggml_metal_t ctx, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
|
|
@autoreleasepool {
|
|
id<MTLBuffer> buf_dst = [ctx->device newBufferWithBytesNoCopy:data
|
|
length:size
|
|
options:MTLResourceStorageModeShared
|
|
deallocator:nil];
|
|
|
|
GGML_ASSERT(buf_dst);
|
|
|
|
struct ggml_metal_buffer_id bid_src = ggml_metal_get_buffer_id(tensor);
|
|
if (bid_src.metal == nil) {
|
|
GGML_ABORT("%s: failed to find buffer for tensor '%s'\n", __func__, tensor->name);
|
|
}
|
|
|
|
bid_src.offs += offset;
|
|
|
|
// queue the copy operation into the queue of the Metal context
|
|
// this will be queued at the end, after any currently ongoing GPU operations
|
|
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
|
id<MTLBlitCommandEncoder> encoder = [cmd_buf blitCommandEncoder];
|
|
|
|
[encoder copyFromBuffer:bid_src.metal
|
|
sourceOffset:bid_src.offs
|
|
toBuffer:buf_dst
|
|
destinationOffset:0
|
|
size:size];
|
|
|
|
[encoder endEncoding];
|
|
[cmd_buf commit];
|
|
|
|
// do not wait here for completion
|
|
//[cmd_buf waitUntilCompleted];
|
|
|
|
// instead, remember a reference to the command buffer and wait for it later if needed
|
|
[ctx->cmd_bufs_ext addObject:cmd_buf];
|
|
ctx->cmd_buf_last = cmd_buf;
|
|
|
|
[cmd_buf retain];
|
|
}
|
|
}
|
|
|
|
enum ggml_status ggml_metal_graph_compute(ggml_metal_t ctx, struct ggml_cgraph * gf) {
|
|
// number of nodes encoded by the main thread (empirically determined)
|
|
const int n_main = 64;
|
|
|
|
// number of threads in addition to the main thread
|
|
const int n_cb = ctx->n_cb;
|
|
|
|
// submit the ggml compute graph to the GPU by creating command buffers and encoding the ops in them
|
|
// the first n_nodes_0 are encoded and submitted for processing directly by the calling thread
|
|
// while these nodes are processing, we start n_cb threads to enqueue the rest of the nodes
|
|
// each thread creates it's own command buffer and enqueues the ops in parallel
|
|
//
|
|
// tests on M1 Pro and M2 Ultra using LLaMA models, show that optimal values for n_cb are 1 or 2
|
|
|
|
@autoreleasepool {
|
|
ctx->gf = gf;
|
|
|
|
ctx->n_nodes_0 = MIN(n_main, gf->n_nodes);
|
|
ctx->n_nodes_1 = gf->n_nodes - ctx->n_nodes_0;
|
|
|
|
ctx->n_nodes_per_cb = (ctx->n_nodes_1 + ctx->n_cb - 1) / ctx->n_cb;
|
|
|
|
const bool use_capture = ctx->capture_next_compute;
|
|
if (use_capture) {
|
|
ctx->capture_next_compute = false;
|
|
|
|
// make sure all previous computations have finished before starting the capture
|
|
if (ctx->cmd_buf_last) {
|
|
[ctx->cmd_buf_last waitUntilCompleted];
|
|
ctx->cmd_buf_last = nil;
|
|
}
|
|
|
|
if (!ctx->capture_started) {
|
|
// create capture scope
|
|
ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
|
|
|
|
MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
|
|
descriptor.captureObject = ctx->capture_scope;
|
|
descriptor.destination = MTLCaptureDestinationGPUTraceDocument;
|
|
descriptor.outputURL = [NSURL fileURLWithPath:[NSString stringWithFormat:@"/tmp/perf-metal.gputrace"]];
|
|
|
|
NSError * error = nil;
|
|
if (![[MTLCaptureManager sharedCaptureManager] startCaptureWithDescriptor:descriptor error:&error]) {
|
|
GGML_LOG_ERROR("%s: error: unable to start capture '%s'\n", __func__, [[error localizedDescription] UTF8String]);
|
|
} else {
|
|
[ctx->capture_scope beginScope];
|
|
ctx->capture_started = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
// the main thread commits the first few commands immediately
|
|
// cmd_buf[n_cb]
|
|
{
|
|
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
|
[cmd_buf retain];
|
|
|
|
if (ctx->cmd_bufs[n_cb].obj) {
|
|
[ctx->cmd_bufs[n_cb].obj release];
|
|
}
|
|
ctx->cmd_bufs[n_cb].obj = cmd_buf;
|
|
|
|
[cmd_buf enqueue];
|
|
|
|
ctx->encode_async(n_cb);
|
|
}
|
|
|
|
// remember the command buffer for the next iteration
|
|
ctx->cmd_buf_last = ctx->cmd_bufs[n_cb].obj;
|
|
|
|
// prepare the rest of the command buffers asynchronously (optional)
|
|
// cmd_buf[0.. n_cb)
|
|
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
|
|
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
|
[cmd_buf retain];
|
|
|
|
if (ctx->cmd_bufs[cb_idx].obj) {
|
|
[ctx->cmd_bufs[cb_idx].obj release];
|
|
}
|
|
ctx->cmd_bufs[cb_idx].obj = cmd_buf;
|
|
|
|
// always enqueue the first two command buffers
|
|
// enqueue all of the command buffers if we don't need to abort
|
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
|
[cmd_buf enqueue];
|
|
|
|
// update the pointer to the last queued command buffer
|
|
// this is needed to implement synchronize()
|
|
ctx->cmd_buf_last = cmd_buf;
|
|
}
|
|
}
|
|
|
|
dispatch_apply(n_cb, ctx->d_queue, ctx->encode_async);
|
|
|
|
// for debugging: block until graph is computed
|
|
//[ctx->cmd_buf_last waitUntilCompleted];
|
|
|
|
// enter here only when capturing in order to wait for all computation to finish
|
|
// otherwise, we leave the graph to compute asynchronously
|
|
if (!use_capture && ctx->capture_started) {
|
|
// wait for completion and check status of each command buffer
|
|
// needed to detect if the device ran out-of-memory for example (#1881)
|
|
{
|
|
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
|
|
[cmd_buf waitUntilCompleted];
|
|
|
|
MTLCommandBufferStatus status = [cmd_buf status];
|
|
if (status != MTLCommandBufferStatusCompleted) {
|
|
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
|
|
if (status == MTLCommandBufferStatusError) {
|
|
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
|
}
|
|
|
|
return GGML_STATUS_FAILED;
|
|
}
|
|
}
|
|
|
|
for (int i = 0; i < n_cb; ++i) {
|
|
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
|
|
[cmd_buf waitUntilCompleted];
|
|
|
|
MTLCommandBufferStatus status = [cmd_buf status];
|
|
if (status != MTLCommandBufferStatusCompleted) {
|
|
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
|
|
if (status == MTLCommandBufferStatusError) {
|
|
GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
|
|
}
|
|
|
|
return GGML_STATUS_FAILED;
|
|
}
|
|
|
|
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
|
|
if (!next_buffer) {
|
|
continue;
|
|
}
|
|
|
|
const bool next_queued = ([next_buffer status] != MTLCommandBufferStatusNotEnqueued);
|
|
if (next_queued) {
|
|
continue;
|
|
}
|
|
|
|
if (ctx->abort_callback && ctx->abort_callback(ctx->abort_callback_data)) {
|
|
GGML_LOG_INFO("%s: command buffer %d aborted", __func__, i);
|
|
return GGML_STATUS_ABORTED;
|
|
}
|
|
|
|
[next_buffer commit];
|
|
}
|
|
|
|
[ctx->capture_scope endScope];
|
|
[[MTLCaptureManager sharedCaptureManager] stopCapture];
|
|
}
|
|
}
|
|
|
|
return GGML_STATUS_SUCCESS;
|
|
}
|
|
|
|
void ggml_metal_graph_optimize(ggml_metal_t ctx, struct ggml_cgraph * gf) {
|
|
//const int64_t t_start = ggml_time_us();
|
|
|
|
if (ctx->use_graph_optimize) {
|
|
ggml_graph_optimize(gf);
|
|
}
|
|
|
|
//printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0);
|
|
}
|
|
|
|
void ggml_metal_set_n_cb(ggml_metal_t ctx, int n_cb) {
|
|
if (ctx->n_cb != n_cb) {
|
|
ctx->n_cb = MIN(n_cb, GGML_METAL_MAX_COMMAND_BUFFERS);
|
|
|
|
if (ctx->n_cb > 2) {
|
|
GGML_LOG_WARN("%s: n_cb = %d, using n_cb > 2 is not recommended and can degrade the performance in some cases\n", __func__, n_cb);
|
|
}
|
|
}
|
|
|
|
if (ctx->encode_async) {
|
|
Block_release(ctx->encode_async);
|
|
}
|
|
|
|
ctx->encode_async = Block_copy(^(size_t iter) {
|
|
const int cb_idx = iter;
|
|
const int n_cb_l = ctx->n_cb;
|
|
|
|
const int n_nodes_0 = ctx->n_nodes_0;
|
|
const int n_nodes_1 = ctx->n_nodes_1;
|
|
|
|
const int n_nodes_per_cb = ctx->n_nodes_per_cb;
|
|
|
|
int idx_start = 0;
|
|
int idx_end = n_nodes_0;
|
|
|
|
if (cb_idx < n_cb_l) {
|
|
idx_start = n_nodes_0 + ( (cb_idx + 0) * n_nodes_per_cb);
|
|
idx_end = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
|
|
}
|
|
|
|
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
|
|
|
ggml_metal_op_t ctx_op = ggml_metal_op_init(
|
|
ctx->dev,
|
|
cmd_buf,
|
|
ctx->gf,
|
|
idx_start,
|
|
idx_end,
|
|
ctx->use_fusion,
|
|
ctx->use_concurrency,
|
|
ctx->capture_next_compute,
|
|
ctx->debug_graph,
|
|
ctx->debug_fusion);
|
|
|
|
for (int idx = 0; idx < ggml_metal_op_n_nodes(ctx_op); ++idx) {
|
|
const int res = ggml_metal_op_encode(ctx_op, idx);
|
|
if (res == 0) {
|
|
break;
|
|
}
|
|
|
|
idx += res - 1;
|
|
}
|
|
|
|
ggml_metal_op_free(ctx_op);
|
|
|
|
if (cb_idx < 2 || ctx->abort_callback == NULL) {
|
|
[cmd_buf commit];
|
|
}
|
|
});
|
|
}
|
|
|
|
void ggml_metal_set_abort_callback(ggml_metal_t ctx, ggml_abort_callback abort_callback, void * user_data) {
|
|
ctx->abort_callback = abort_callback;
|
|
ctx->abort_callback_data = user_data;
|
|
}
|
|
|
|
bool ggml_metal_supports_family(ggml_metal_t ctx, int family) {
|
|
GGML_ASSERT(ctx->device != nil);
|
|
|
|
return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
|
|
}
|
|
|
|
void ggml_metal_capture_next_compute(ggml_metal_t ctx) {
|
|
ctx->capture_next_compute = true;
|
|
}
|