metal : remove memory pools (#15966)

* metal : remove mem pool usage

ggml-ci

* metal : remove mem pool implementation

ggml-ci

* metal : take into account the actual allocated memory of the tensor

ggml-ci

* cont : use ggml_backend_buft_get_alloc_size

ggml-ci

* cont : improve, comments

ggml-ci

* cont : add functions for the extra tensor sizes

* metal : add comments

ggml-ci

* metal : implement .get_alloc_size for the rest of the buffer types

ggml-ci

* metal : remove ggml_metal_heap

ggml-ci
This commit is contained in:
Georgi Gerganov
2025-09-14 22:02:32 +03:00
committed by GitHub
parent 0fa154e350
commit 9dcd200d57
2 changed files with 144 additions and 402 deletions

View File

@@ -1,9 +1,12 @@
#include "ggml-metal-common.h" #include "ggml-metal-common.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include "ggml-backend-impl.h"
#include <vector> #include <vector>
// represents a memory range (i.e. an interval from a starting address p0 to an ending address p1 in a given buffer pb)
// the type indicates whether it is a source range (i.e. ops read data from it) or a destination range (i.e. ops write data to it)
struct ggml_mem_range { struct ggml_mem_range {
uint64_t pb; // buffer id uint64_t pb; // buffer id
@@ -36,8 +39,8 @@ void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) {
mrs->ranges.clear(); mrs->ranges.clear();
} }
static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mrp) { static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mr) {
mrs->ranges.push_back(mrp); mrs->ranges.push_back(mr);
return true; return true;
} }
@@ -48,20 +51,24 @@ static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggm
GGML_ASSERT(!tensor->view_src); GGML_ASSERT(!tensor->view_src);
ggml_mem_range mrp; ggml_mem_range mr;
if (tensor->buffer) { if (tensor->buffer) {
// when the tensor is allocated, use the actual memory address range of the buffer // when the tensor is allocated, use the actual memory address range in the buffer
mrp = { //
// take the actual allocated size with ggml_backend_buft_get_alloc_size()
// this can be larger than the tensor size if the buffer type allocates extra memory
// ref: https://github.com/ggml-org/llama.cpp/pull/15966
mr = {
/*.pb =*/ (uint64_t) tensor->buffer, /*.pb =*/ (uint64_t) tensor->buffer,
/*.p0 =*/ (uint64_t) tensor->data, /*.p0 =*/ (uint64_t) tensor->data,
/*.p1 =*/ (uint64_t) tensor->data + ggml_nbytes(tensor), /*.p1 =*/ (uint64_t) tensor->data + ggml_backend_buft_get_alloc_size(tensor->buffer->buft, tensor),
/*.pt =*/ pt, /*.pt =*/ pt,
}; };
} else { } else {
// otherwise, the tensor ptr is used as an unique id of the memory ranges // otherwise, the pointer address is used as an unique id of the memory ranges
// that the tensor will be using when it is allocated // that the tensor will be using when it is allocated
mrp = { mr = {
/*.pb =*/ (uint64_t) tensor, /*.pb =*/ (uint64_t) tensor,
/*.p0 =*/ 0, // /*.p0 =*/ 0, //
/*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used /*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
@@ -69,7 +76,7 @@ static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggm
}; };
}; };
return mrp; return mr;
} }
static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) { static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) {
@@ -83,25 +90,25 @@ static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor)
static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) { static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor); GGML_ASSERT(tensor);
ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor); ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
if (mrs->debug > 2) { if (mrs->debug > 2) {
GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1); GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
} }
return ggml_mem_ranges_add(mrs, mrp); return ggml_mem_ranges_add(mrs, mr);
} }
static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) { static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor); GGML_ASSERT(tensor);
ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor); ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
if (mrs->debug > 2) { if (mrs->debug > 2) {
GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1); GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mr.pb, mr.p0, mr.p1);
} }
return ggml_mem_ranges_add(mrs, mrp); return ggml_mem_ranges_add(mrs, mr);
} }
bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) { bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
@@ -114,24 +121,26 @@ bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
return ggml_mem_ranges_add_dst(mrs, tensor); return ggml_mem_ranges_add_dst(mrs, tensor);
} }
static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mrp) { static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr) {
for (size_t i = 0; i < mrs->ranges.size(); i++) { for (size_t i = 0; i < mrs->ranges.size(); i++) {
const auto & cmp = mrs->ranges[i]; const auto & cmp = mrs->ranges[i];
if (mrp.pb != cmp.pb) { // two memory ranges cannot intersect if they are in different buffers
if (mr.pb != cmp.pb) {
continue; continue;
} }
if (mrp.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) { // intersecting source ranges are allowed
if (mr.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
continue; continue;
} }
if (mrp.p0 < cmp.p1 && mrp.p1 >= cmp.p0) { if (mr.p0 < cmp.p1 && mr.p1 >= cmp.p0) {
if (mrs->debug > 2) { if (mrs->debug > 2) {
GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n", GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
__func__, __func__,
mrp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst", mr.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
mrp.pb, mrp.p0, mrp.p1, mr.pb, mr.p0, mr.p1,
cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst", cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
cmp.pb, cmp.p0, cmp.p1); cmp.pb, cmp.p0, cmp.p1);
} }
@@ -146,9 +155,9 @@ static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mr
static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) { static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor); GGML_ASSERT(tensor);
ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor); ggml_mem_range mr = ggml_mem_range_from_tensor_src(tensor);
const bool res = ggml_mem_ranges_check(mrs, mrp); const bool res = ggml_mem_ranges_check(mrs, mr);
return res; return res;
} }
@@ -156,9 +165,9 @@ static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_te
static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) { static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
GGML_ASSERT(tensor); GGML_ASSERT(tensor);
ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor); ggml_mem_range mr = ggml_mem_range_from_tensor_dst(tensor);
const bool res = ggml_mem_ranges_check(mrs, mrp); const bool res = ggml_mem_ranges_check(mrs, mr);
return res; return res;
} }
@@ -222,6 +231,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
} }
} }
// keep track of the sources of the fused nodes as well
for (const auto * fused : node.fused) { for (const auto * fused : node.fused) {
for (int i = 0; i < GGML_MAX_SRC; i++) { for (int i = 0; i < GGML_MAX_SRC; i++) {
if (fused->src[i]) { if (fused->src[i]) {
@@ -290,7 +300,10 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
std::vector<bool> used(n, false); std::vector<bool> used(n, false);
// the memory ranges for the set of currently concurrent nodes
ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0); ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0);
// the memory ranges for the set of nodes that haven't been processed yet, when looking forward for a node to reorder
ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0); ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0);
for (int i0 = 0; i0 < n; i0++) { for (int i0 = 0; i0 < n; i0++) {
@@ -329,7 +342,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
const bool is_empty = node1.is_empty(); const bool is_empty = node1.is_empty();
// to add a concurrent node, it has to be: // to reorder a node and add it to the concurrent set, it has to be:
// + empty or concurrent with all nodes in the existing concurrent set (mrs0) // + empty or concurrent with all nodes in the existing concurrent set (mrs0)
// + concurrent with all nodes prior to it that haven't been processed yet (mrs1) // + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) { if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
@@ -419,8 +432,8 @@ void ggml_metal_graph_optimize(ggml_cgraph * gf) {
nodes.push_back(std::move(node)); nodes.push_back(std::move(node));
} }
// reorder to improve concurrency
#if 1 #if 1
// reorder to improve concurrency
const auto order = ggml_metal_graph_optimize_reorder(nodes); const auto order = ggml_metal_graph_optimize_reorder(nodes);
#else #else
std::vector<int> order(nodes.size()); std::vector<int> order(nodes.size());

View File

@@ -532,261 +532,9 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COUNT GGML_METAL_KERNEL_TYPE_COUNT
}; };
//
// ggml_metal_heap
//
struct ggml_metal_heap {
// number of times the heap was unused
int n_unused;
// total number of buffer allocations in this heap across all computes
int64_t n_alloc;
// current offset in the heap - we reset this after each node in order to reuse the memory
size_t offs;
// the currently allocated MTLBuffer objects in this heap
id<MTLHeap> obj;
NSMutableArray * bufs;
};
static struct ggml_metal_heap * ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
desc.storageMode = MTLStorageModePrivate;
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
desc.type = MTLHeapTypePlacement;
desc.size = size;
heap->n_unused = 0;
heap->n_alloc = 0;
heap->obj = [device newHeapWithDescriptor:desc];
if (!heap->obj) {
GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
free(heap);
return false;
}
[desc release];
heap->bufs = [[NSMutableArray alloc] init];
return heap;
}
static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
heap->offs = 0;
// count how many graph computes the heap ended up being unused
if ([heap->bufs count] > 0) {
heap->n_unused = 0;
} else {
heap->n_unused++;
}
for (id<MTLBuffer> buf in heap->bufs) {
[buf release];
}
[heap->bufs removeAllObjects];
// tell the OS that it can reuse this memory if needed
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
[heap->obj setPurgeableState:MTLPurgeableStateVolatile];
}
static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
if (heap == nil) {
return;
}
ggml_metal_heap_reset(heap);
[heap->obj release];
[heap->bufs release];
free(heap);
}
@interface ggml_metal_heap_ptr : NSObject
@property (nonatomic, assign) struct ggml_metal_heap * data;
@end
@implementation ggml_metal_heap_ptr
@end
//
// ggml_metal_mem_pool [TAG_MEM_POOL_REMOVE]
//
struct ggml_metal_mem_pool {
id<MTLDevice> device;
int n_heaps; // total number of heaps ever created (including those that were removed)
NSMutableArray * heaps;
NSMutableArray * heaps_to_remove;
};
static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) {
struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool));
mem_pool->n_heaps = 0;
mem_pool->heaps = [[NSMutableArray alloc] init];
mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
return mem_pool;
}
static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) {
GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
size_t size_all = 0;
size_t size_cur = 0;
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
if ([ptr.data->bufs count] > 0) {
size_cur += [ptr.data->obj size];
}
size_all += [ptr.data->obj size];
ggml_metal_heap_free(ptr.data);
[ptr release];
}
[mem_pool->heaps release];
[mem_pool->heaps_to_remove release];
if (size_all > 0) {
GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
}
free(mem_pool);
}
static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
struct ggml_metal_heap * heap = ptr.data;
ggml_metal_heap_reset(heap);
// if the heap hasn't been used for a while, remove it
if (heap->n_unused >= 128) {
[mem_pool->heaps_to_remove addObject:@(i)];
}
}
if (mem_pool->heaps_to_remove.count > 0) {
// remove in reverse order
for (NSUInteger i = [mem_pool->heaps_to_remove count] - 1; ; --i) {
NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
struct ggml_metal_heap * heap = ptr.data;
ggml_metal_heap_free(heap);
[mem_pool->heaps removeObjectAtIndex:index];
[ptr release];
if (i == 0) {
break;
}
}
[mem_pool->heaps_to_remove removeAllObjects];
}
}
static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
ptr.data->offs = 0;
}
}
static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
const size_t alignment = 256;
const size_t size_aligned = GGML_PAD(size, alignment);
// try one of the existing heaps
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
struct ggml_metal_heap * heap = ptr.data;
if (heap->offs + size_aligned <= [heap->obj size]) {
// if this is the first buffer in the heap for the current command buffer, tell the OS that
// it cannot free the memory used by the heap
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
if ([heap->bufs count] == 0) {
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
}
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
if (buf == nil) {
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
return nil;
}
heap->n_alloc++;
heap->offs += size_aligned;
[heap->bufs addObject:buf];
return buf;
}
}
// create a new heap that can fit this buffer
ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new];
struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned);
if (heap == NULL) {
GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
return NULL;
}
//GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
heap_ptr.data = heap;
ggml_metal_heap_reset(heap);
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
if (buf == nil) {
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
return NULL;
}
heap->n_alloc++;
heap->offs += size_aligned;
[heap->bufs addObject:buf];
[mem_pool->heaps addObject:heap_ptr];
mem_pool->n_heaps++;
return buf;
}
struct ggml_metal_command_buffer { struct ggml_metal_command_buffer {
id<MTLCommandBuffer> obj; id<MTLCommandBuffer> obj;
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
struct ggml_metal_mem_pool * mem_pool;
// used to enable concurrent execution of ops in the command buffers // used to enable concurrent execution of ops in the command buffers
struct ggml_mem_ranges * mem_ranges; struct ggml_mem_ranges * mem_ranges;
}; };
@@ -1103,9 +851,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
ctx->cmd_bufs[i].obj = nil; ctx->cmd_bufs[i].obj = nil;
ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
ctx->cmd_bufs[i].mem_pool->device = device;
if (ctx_dev->use_concurrency) { if (ctx_dev->use_concurrency) {
ctx->cmd_bufs[i].mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph); ctx->cmd_bufs[i].mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph);
} }
@@ -1510,6 +1255,52 @@ static id<MTLComputePipelineState> ggml_metal_compile_kernel(ggml_backend_t back
return res; return res;
} }
// tokens per expert
static size_t ggml_metal_mul_mat_id_extra_tpe(const struct ggml_tensor * op) {
assert(op->op == GGML_OP_MUL_MAT_ID);
const int64_t ne02 = op->src[0]->ne[2]; // n_expert
return ggml_type_size(GGML_TYPE_I32)*ne02;
}
// id map [n_tokens, n_expert]
static size_t ggml_metal_mul_mat_id_extra_ids(const struct ggml_tensor * op) {
assert(op->op == GGML_OP_MUL_MAT_ID);
const int64_t ne02 = op->src[0]->ne[2]; // n_expert
const int64_t ne21 = op->src[2]->ne[1]; // n_token
return ggml_type_size(GGML_TYPE_I32)*ne02*ne21;
}
// return true if we should use the FA vector kernel for this op
static bool ggml_metal_flash_attn_ext_use_vec(const struct ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
const int64_t ne00 = op->src[0]->ne[0]; // head size
const int64_t ne01 = op->src[0]->ne[1]; // batch size
// use vec kernel if the batch size is small and if the head size is supported
return (ne01 < 20) && (ne00 % 32 == 0);
}
static size_t ggml_metal_flash_attn_ext_extra_tmp(const struct ggml_tensor * op) {
assert(op->op == GGML_OP_FLASH_ATTN_EXT);
const int64_t nwg = 32;
const int64_t ne01 = op->src[0]->ne[1];
const int64_t ne02 = op->src[0]->ne[2];
const int64_t ne03 = op->src[0]->ne[3];
const int64_t ne20 = op->src[2]->ne[0];
// temp buffer for writing the results from each workgroup
// - ne20: the size of the Value head
// - + 2: the S and M values for each intermediate result
return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2));
}
static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext( static id<MTLComputePipelineState> ggml_metal_get_pipeline_flash_attn_ext(
ggml_backend_t backend, struct ggml_tensor * op, ggml_backend_t backend, struct ggml_tensor * op,
bool has_mask, bool has_mask,
@@ -1760,8 +1551,6 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
[ctx->cmd_bufs[i].obj release]; [ctx->cmd_bufs[i].obj release];
} }
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
if (ctx->cmd_bufs[i].mem_ranges) { if (ctx->cmd_bufs[i].mem_ranges) {
ggml_mem_ranges_free(ctx->cmd_bufs[i].mem_ranges); ggml_mem_ranges_free(ctx->cmd_bufs[i].mem_ranges);
} }
@@ -2127,8 +1916,6 @@ struct ggml_metal_encode_context {
id<MTLComputeCommandEncoder> encoder; id<MTLComputeCommandEncoder> encoder;
struct ggml_metal_mem_pool * mem_pool;
struct ggml_mem_ranges * mem_ranges; struct ggml_mem_ranges * mem_ranges;
}; };
@@ -2165,8 +1952,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
id<MTLComputeCommandEncoder> encoder = ctx_enc->encoder; id<MTLComputeCommandEncoder> encoder = ctx_enc->encoder;
struct ggml_metal_mem_pool * mem_pool = ctx_enc->mem_pool;
struct ggml_backend_metal_context * ctx = backend->context; struct ggml_backend_metal_context * ctx = backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
@@ -2207,8 +1992,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
GGML_ABORT("unsupported op"); GGML_ABORT("unsupported op");
} }
ggml_metal_mem_pool_clear(mem_pool);
const int64_t ne00 = src0 ? src0->ne[0] : 0; const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0; const int64_t ne01 = src0 ? src0->ne[1] : 0;
const int64_t ne02 = src0 ? src0->ne[2] : 0; const int64_t ne02 = src0 ? src0->ne[2] : 0;
@@ -2522,7 +2305,6 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
/*.nb02 =*/ nb02, /*.nb02 =*/ nb02,
/*.nb11 =*/ nb11, /*.nb11 =*/ nb11,
/*.nb21 =*/ nb21, /*.nb21 =*/ nb21,
}; };
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
@@ -3167,54 +2949,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
// use this branch to test the ggml_metal_mem_pool functionality
#if 0
// cpy to tmp buffer in MTLHeap
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
if (!h_src0) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
return 0;
}
offs_src0 = 0;
ggml_metal_kargs_cpy args_cpy = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne00,
/*.ne1 =*/ ne01,
/*.ne2 =*/ ne02,
/*.ne3 =*/ ne03,
/*.nb0 =*/ nb00,
/*.nb1 =*/ nb01,
/*.nb2 =*/ nb02,
/*.nb3 =*/ nb03,
};
if (src0->type == GGML_TYPE_F16) {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
} else {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
}
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:h_src0 offset:0 atIndex:2];
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
#else
id<MTLBuffer> h_src0 = id_src0; id<MTLBuffer> h_src0 = id_src0;
#endif
// softmax // softmax
ggml_metal_kargs_soft_max args = { ggml_metal_kargs_soft_max args = {
@@ -4093,28 +3829,9 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
default: break; default: break;
} }
// TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool // extra buffers for intermediate id mapping
// reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer. size_t offs_tpe = offs_dst + ggml_nbytes(dst);
// so we add this extra barrier to prevent the race. size_t offs_ids = offs_tpe + ggml_metal_mul_mat_id_extra_tpe(dst);
// the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
ggml_metal_encode_concurrency_reset(ctx_enc);
// tokens per expert
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
if (!h_tpe) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tpe);
return 0;
}
// id map
// [n_tokens, n_expert]
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
if (!h_ids) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
return 0;
}
{ {
ggml_metal_kargs_mul_mm_id_map0 args = { ggml_metal_kargs_mul_mm_id_map0 args = {
@@ -4152,8 +3869,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0]; [encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
[encoder setBuffer: h_tpe offset:0 atIndex:2]; [encoder setBuffer:id_dst offset:offs_tpe atIndex:2];
[encoder setBuffer: h_ids offset:0 atIndex:3]; [encoder setBuffer:id_dst offset:offs_ids atIndex:3];
[encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
@@ -4215,8 +3932,8 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
[encoder setBytes:&args length:sizeof(args) atIndex:0]; [encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer: h_tpe offset:0 atIndex:3]; [encoder setBuffer:id_dst offset:offs_tpe atIndex:3];
[encoder setBuffer: h_ids offset:0 atIndex:4]; [encoder setBuffer:id_dst offset:offs_ids atIndex:4];
[encoder setBuffer:id_dst offset:offs_dst atIndex:5]; [encoder setBuffer:id_dst offset:offs_dst atIndex:5];
[encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder setThreadgroupMemoryLength:8192 atIndex:0];
@@ -5306,8 +5023,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
GGML_ASSERT(ne01 < 65536); GGML_ASSERT(ne01 < 65536);
// use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size if (!ggml_metal_flash_attn_ext_use_vec(dst)) {
if (ne01 >= 20 || (ne00 % 32 != 0)) {
// half8x8 kernel // half8x8 kernel
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !! const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !!
@@ -5532,34 +5248,20 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31)); GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
// using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE] // write the results from each workgroup into a temp buffer
// still, we assume that concurrent FA won't happen before we do the refactor const size_t offs_tmp = offs_dst + ggml_nbytes(dst);
//ggml_metal_encode_concurrency_reset(ctx_enc); [encoder setBuffer:id_dst offset:offs_tmp atIndex:6];
const int32_t nrows = ne1*ne2*ne3;
// temp buffer for writing the results from each workgroup
// - ne20: the size of the head vector
// - + 2: the S and M values for each intermediate result
const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
if (!h_tmp) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
return 0;
}
//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
[encoder setBuffer:h_tmp offset:0 atIndex:6];
[encoder setThreadgroupMemoryLength:smem atIndex:0]; [encoder setThreadgroupMemoryLength:smem atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
// sync the 2 kernels
ggml_metal_encode_concurrency_reset(ctx_enc); ggml_metal_encode_concurrency_reset(ctx_enc);
// reduce the results from the workgroups // reduce the results from the workgroups
{ {
const int32_t nrows = ne1*ne2*ne3;
ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = { ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
nrows, nrows,
}; };
@@ -5568,7 +5270,7 @@ static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, in
[encoder setComputePipelineState:pipeline0]; [encoder setComputePipelineState:pipeline0];
[encoder setBytes:&args0 length:sizeof(args0) atIndex:0]; [encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
[encoder setBuffer:h_tmp offset:0 atIndex:1]; [encoder setBuffer:id_dst offset:offs_tmp atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
//printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20); //printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
@@ -5895,12 +5597,7 @@ static enum ggml_status ggml_metal_graph_compute(
// the main thread commits the first few commands immediately // the main thread commits the first few commands immediately
// cmd_buf[n_cb] // cmd_buf[n_cb]
{ {
// cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
// TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences
// https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
// [TAG_MEM_POOL_REMOVE]
//id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
[cmd_buf retain]; [cmd_buf retain];
if (ctx->cmd_bufs[n_cb].obj) { if (ctx->cmd_bufs[n_cb].obj) {
@@ -5919,8 +5616,7 @@ static enum ggml_status ggml_metal_graph_compute(
// prepare the rest of the command buffers asynchronously (optional) // prepare the rest of the command buffers asynchronously (optional)
// cmd_buf[0.. n_cb) // cmd_buf[0.. n_cb)
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
//id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences]; id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
[cmd_buf retain]; [cmd_buf retain];
if (ctx->cmd_bufs[cb_idx].obj) { if (ctx->cmd_bufs[cb_idx].obj) {
@@ -6377,6 +6073,31 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
return ggml_backend_buffer_init(buft, buf_i, ctx, size); return ggml_backend_buffer_init(buft, buf_i, ctx, size);
} }
static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
size_t res = ggml_nbytes(tensor);
// some operations require additional memory for fleeting data:
switch (tensor->op) {
case GGML_OP_MUL_MAT_ID:
{
res += ggml_metal_mul_mat_id_extra_tpe(tensor);
res += ggml_metal_mul_mat_id_extra_ids(tensor);
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
if (ggml_metal_flash_attn_ext_use_vec(tensor)) {
res += ggml_metal_flash_attn_ext_extra_tmp(tensor);
}
} break;
default:
break;
}
return res;
GGML_UNUSED(buft);
}
// default (shared) buffer type // default (shared) buffer type
static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) { static const char * ggml_backend_metal_buffer_type_shared_get_name(ggml_backend_buffer_type_t buft) {
@@ -6401,6 +6122,10 @@ static size_t ggml_backend_metal_buffer_type_shared_get_max_size(ggml_backend_bu
return max_size; return max_size;
} }
static size_t ggml_backend_metal_buffer_type_shared_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
}
static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) { static bool ggml_backend_metal_buffer_type_shared_is_host(ggml_backend_buffer_type_t buft) {
return false; return false;
@@ -6414,7 +6139,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_shared(void) {
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer, /* .alloc_buffer = */ ggml_backend_metal_buffer_type_shared_alloc_buffer,
/* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment, /* .get_alignment = */ ggml_backend_metal_buffer_type_shared_get_alignment,
/* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size, /* .get_max_size = */ ggml_backend_metal_buffer_type_shared_get_max_size,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes /* .get_alloc_size = */ ggml_backend_metal_buffer_type_shared_get_alloc_size,
/* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host, /* .is_host = */ ggml_backend_metal_buffer_type_shared_is_host,
}, },
/* .device = */ &g_ggml_backend_metal_device, /* .device = */ &g_ggml_backend_metal_device,
@@ -6448,6 +6173,10 @@ static size_t ggml_backend_metal_buffer_type_private_get_max_size(ggml_backend_b
return max_size; return max_size;
} }
static size_t ggml_backend_metal_buffer_type_private_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
}
static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) { static bool ggml_backend_metal_buffer_type_private_is_host(ggml_backend_buffer_type_t buft) {
return false; return false;
@@ -6461,7 +6190,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_private(void) {
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer, /* .alloc_buffer = */ ggml_backend_metal_buffer_type_private_alloc_buffer,
/* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment, /* .get_alignment = */ ggml_backend_metal_buffer_type_private_get_alignment,
/* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size, /* .get_max_size = */ ggml_backend_metal_buffer_type_private_get_max_size,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes /* .get_alloc_size = */ ggml_backend_metal_buffer_type_private_get_alloc_size,
/* .is_host = */ ggml_backend_metal_buffer_type_private_is_host, /* .is_host = */ ggml_backend_metal_buffer_type_private_is_host,
}, },
/* .device = */ &g_ggml_backend_metal_device, /* .device = */ &g_ggml_backend_metal_device,
@@ -6496,6 +6225,10 @@ static size_t ggml_backend_metal_buffer_type_mapped_get_max_size(ggml_backend_bu
return max_size; return max_size;
} }
static size_t ggml_backend_metal_buffer_type_mapped_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
return ggml_backend_metal_buffer_type_get_alloc_size(buft, tensor);
}
static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) { static bool ggml_backend_metal_buffer_type_mapped_is_host(ggml_backend_buffer_type_t buft) {
return false; return false;
@@ -6511,7 +6244,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_buffer_type_mapped(void) {
/* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer, /* .alloc_buffer = */ ggml_backend_metal_buffer_type_mapped_alloc_buffer,
/* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment, /* .get_alignment = */ ggml_backend_metal_buffer_type_mapped_get_alignment,
/* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size, /* .get_max_size = */ ggml_backend_metal_buffer_type_mapped_get_max_size,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes /* .get_alloc_size = */ ggml_backend_metal_buffer_type_mapped_get_alloc_size,
/* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host, /* .is_host = */ ggml_backend_metal_buffer_type_mapped_is_host,
}, },
/* .device = */ &g_ggml_backend_metal_device, /* .device = */ &g_ggml_backend_metal_device,
@@ -6711,11 +6444,8 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
const int n_nodes_per_cb = ctx->n_nodes_per_cb; const int n_nodes_per_cb = ctx->n_nodes_per_cb;
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj; id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
struct ggml_mem_ranges * mem_ranges = ctx->cmd_bufs[cb_idx].mem_ranges; struct ggml_mem_ranges * mem_ranges = ctx->cmd_bufs[cb_idx].mem_ranges;
ggml_metal_mem_pool_reset(mem_pool);
if (mem_ranges) { if (mem_ranges) {
ggml_mem_ranges_reset(mem_ranges); ggml_mem_ranges_reset(mem_ranges);
} }
@@ -6743,7 +6473,6 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
struct ggml_metal_encode_context ctx_enc = { struct ggml_metal_encode_context ctx_enc = {
/*.backend =*/ backend, /*.backend =*/ backend,
/*.encoder =*/ encoder, /*.encoder =*/ encoder,
/*.mem_pool =*/ mem_pool,
/*.mem_ranges =*/ mem_ranges, /*.mem_ranges =*/ mem_ranges,
}; };