mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
metal : allow ops to run concurrently (#15929)
* metal : run graphs ops concurrently ggml-ci * cont : add flags for debugging and disabling concurrency ggml-ci * cont : refactor and handle fusing ggml-ci * cont : simplify - no need to use GPU address ggml-ci * cont : prepare mem ranges for reuse + add ggml-metal-common.cpp ggml-ci * cont : avoid redundant keywords in cpp [no ci] * metal : reorder graph for better concurrency ggml-ci * metal : fix race on mem pool buffers ggml-ci * cont : add env GGML_METAL_GRAPH_OPTIMIZE_DISABLE ggml-ci * cont : refactor, optimize, add comments ggml-ci * cont : refactor ggml-metal.m ggml-ci * minor : update logs [no ci]
This commit is contained in:
@@ -6,6 +6,7 @@ message(STATUS "Metal framework found")
|
||||
|
||||
ggml_add_backend_library(ggml-metal
|
||||
ggml-metal.m
|
||||
ggml-metal-common.cpp
|
||||
)
|
||||
|
||||
target_link_libraries(ggml-metal PRIVATE
|
||||
|
||||
445
ggml/src/ggml-metal/ggml-metal-common.cpp
Normal file
445
ggml/src/ggml-metal/ggml-metal-common.cpp
Normal file
@@ -0,0 +1,445 @@
|
||||
#include "ggml-metal-common.h"
|
||||
|
||||
#include "ggml-impl.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
struct ggml_mem_range {
|
||||
uint64_t pb; // buffer id
|
||||
|
||||
uint64_t p0; // begin
|
||||
uint64_t p1; // end
|
||||
|
||||
ggml_mem_range_type pt;
|
||||
};
|
||||
|
||||
struct ggml_mem_ranges {
|
||||
std::vector<ggml_mem_range> ranges;
|
||||
|
||||
int debug = 0;
|
||||
};
|
||||
|
||||
struct ggml_mem_ranges * ggml_mem_ranges_init(int debug) {
|
||||
auto * res = new ggml_mem_ranges;
|
||||
|
||||
res->ranges.reserve(256);
|
||||
res->debug = debug;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void ggml_mem_ranges_free(ggml_mem_ranges * mrs) {
|
||||
delete mrs;
|
||||
}
|
||||
|
||||
void ggml_mem_ranges_reset(ggml_mem_ranges * mrs) {
|
||||
mrs->ranges.clear();
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, ggml_mem_range mrp) {
|
||||
mrs->ranges.push_back(mrp);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static ggml_mem_range ggml_mem_range_from_tensor(const ggml_tensor * tensor, ggml_mem_range_type pt) {
|
||||
// always use the base tensor
|
||||
tensor = tensor->view_src ? tensor->view_src : tensor;
|
||||
|
||||
GGML_ASSERT(!tensor->view_src);
|
||||
|
||||
ggml_mem_range mrp;
|
||||
|
||||
if (tensor->buffer) {
|
||||
// when the tensor is allocated, use the actual memory address range of the buffer
|
||||
mrp = {
|
||||
/*.pb =*/ (uint64_t) tensor->buffer,
|
||||
/*.p0 =*/ (uint64_t) tensor->data,
|
||||
/*.p1 =*/ (uint64_t) tensor->data + ggml_nbytes(tensor),
|
||||
/*.pt =*/ pt,
|
||||
};
|
||||
} else {
|
||||
// otherwise, the tensor ptr is used as an unique id of the memory ranges
|
||||
// that the tensor will be using when it is allocated
|
||||
mrp = {
|
||||
/*.pb =*/ (uint64_t) tensor,
|
||||
/*.p0 =*/ 0, //
|
||||
/*.p1 =*/ 1024, // [0, 1024) is a dummy range, not used
|
||||
/*.pt =*/ pt,
|
||||
};
|
||||
};
|
||||
|
||||
return mrp;
|
||||
}
|
||||
|
||||
static ggml_mem_range ggml_mem_range_from_tensor_src(const ggml_tensor * tensor) {
|
||||
return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_SRC);
|
||||
}
|
||||
|
||||
static ggml_mem_range ggml_mem_range_from_tensor_dst(const ggml_tensor * tensor) {
|
||||
return ggml_mem_range_from_tensor(tensor, MEM_RANGE_TYPE_DST);
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_add_src(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor);
|
||||
|
||||
ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor);
|
||||
|
||||
if (mrs->debug > 2) {
|
||||
GGML_LOG_DEBUG("%s: add src range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1);
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_add(mrs, mrp);
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_add_dst(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor);
|
||||
|
||||
ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor);
|
||||
|
||||
if (mrs->debug > 2) {
|
||||
GGML_LOG_DEBUG("%s: add dst range buf=%lld, [%lld, %lld)\n", __func__, mrp.pb, mrp.p0, mrp.p1);
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_add(mrs, mrp);
|
||||
}
|
||||
|
||||
bool ggml_mem_ranges_add(ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (tensor->src[i]) {
|
||||
ggml_mem_ranges_add_src(mrs, tensor->src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_add_dst(mrs, tensor);
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, ggml_mem_range mrp) {
|
||||
for (size_t i = 0; i < mrs->ranges.size(); i++) {
|
||||
const auto & cmp = mrs->ranges[i];
|
||||
|
||||
if (mrp.pb != cmp.pb) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mrp.pt == MEM_RANGE_TYPE_SRC && cmp.pt == MEM_RANGE_TYPE_SRC) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (mrp.p0 < cmp.p1 && mrp.p1 >= cmp.p0) {
|
||||
if (mrs->debug > 2) {
|
||||
GGML_LOG_DEBUG("%s: the %s range buf=%lld, [%lld, %lld) overlaps with a previous %s range buf=%lld, [%lld, %lld)\n",
|
||||
__func__,
|
||||
mrp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
|
||||
mrp.pb, mrp.p0, mrp.p1,
|
||||
cmp.pt == MEM_RANGE_TYPE_SRC ? "src" : "dst",
|
||||
cmp.pb, cmp.p0, cmp.p1);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_check_src(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor);
|
||||
|
||||
ggml_mem_range mrp = ggml_mem_range_from_tensor_src(tensor);
|
||||
|
||||
const bool res = ggml_mem_ranges_check(mrs, mrp);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
static bool ggml_mem_ranges_check_dst(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||
GGML_ASSERT(tensor);
|
||||
|
||||
ggml_mem_range mrp = ggml_mem_range_from_tensor_dst(tensor);
|
||||
|
||||
const bool res = ggml_mem_ranges_check(mrs, mrp);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
bool ggml_mem_ranges_check(const ggml_mem_ranges * mrs, const ggml_tensor * tensor) {
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
if (tensor->src[i]) {
|
||||
if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_check_dst(mrs, tensor);
|
||||
}
|
||||
|
||||
// TODO: move to ggml.h?
|
||||
static bool is_empty(ggml_op op) {
|
||||
switch (op) {
|
||||
case GGML_OP_NONE:
|
||||
case GGML_OP_RESHAPE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_VIEW:
|
||||
case GGML_OP_PERMUTE:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
struct node_info {
|
||||
ggml_tensor * node;
|
||||
|
||||
std::vector<ggml_tensor *> fused;
|
||||
|
||||
ggml_op op() const {
|
||||
return node->op;
|
||||
}
|
||||
|
||||
const ggml_tensor * dst() const {
|
||||
return fused.empty() ? node : fused.back();
|
||||
}
|
||||
|
||||
bool is_empty() const {
|
||||
return ::is_empty(node->op);
|
||||
}
|
||||
|
||||
void add_fused(ggml_tensor * t) {
|
||||
fused.push_back(t);
|
||||
}
|
||||
};
|
||||
|
||||
static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node_info> & nodes) {
|
||||
// helper to add node src and dst ranges
|
||||
const auto & h_add = [](ggml_mem_ranges * mrs, const node_info & node) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node.node->src[i]) {
|
||||
if (!ggml_mem_ranges_add_src(mrs, node.node->src[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto * fused : node.fused) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (fused->src[i]) {
|
||||
if (!ggml_mem_ranges_add_src(mrs, fused->src[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_add_dst(mrs, node.dst());
|
||||
};
|
||||
|
||||
// helper to check if a node can run concurrently with the existing set of nodes
|
||||
const auto & h_check = [](const ggml_mem_ranges * mrs, const node_info & node) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node.node->src[i]) {
|
||||
if (!ggml_mem_ranges_check_src(mrs, node.node->src[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (const auto * fused : node.fused) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (fused->src[i]) {
|
||||
if (!ggml_mem_ranges_check_src(mrs, fused->src[i])) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_check_dst(mrs, node.dst());
|
||||
};
|
||||
|
||||
// perform reorders only across these types of ops
|
||||
// can be expanded when needed
|
||||
// IMPORTANT: do not add ops such as GGML_OP_CPY or GGML_OP_SET_ROWS
|
||||
// the dependencies from such ops are not always represented in the graph
|
||||
const auto & h_safe = [](ggml_op op) {
|
||||
switch (op) {
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
case GGML_OP_ROPE:
|
||||
case GGML_OP_NORM:
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
case GGML_OP_SUM_ROWS:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_DIV:
|
||||
case GGML_OP_GLU:
|
||||
case GGML_OP_SCALE:
|
||||
case GGML_OP_GET_ROWS:
|
||||
return true;
|
||||
default:
|
||||
return is_empty(op);
|
||||
}
|
||||
};
|
||||
|
||||
const int n = nodes.size();
|
||||
|
||||
std::vector<int> res;
|
||||
res.reserve(n);
|
||||
|
||||
std::vector<bool> used(n, false);
|
||||
|
||||
ggml_mem_ranges * mrs0 = ggml_mem_ranges_init(0);
|
||||
ggml_mem_ranges * mrs1 = ggml_mem_ranges_init(0);
|
||||
|
||||
for (int i0 = 0; i0 < n; i0++) {
|
||||
if (used[i0]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto & node0 = nodes[i0];
|
||||
|
||||
// the node is not concurrent with the existing concurrent set, so we have to "put a barrier" (i.e reset mrs0)
|
||||
// but before we do that, look forward for some other nodes that can be added to the concurrent set mrs0
|
||||
//
|
||||
// note: we can always add empty nodes to the concurrent set as they don't read nor write anything
|
||||
if (!node0.is_empty() && !h_check(mrs0, node0)) {
|
||||
// this will hold the set of memory ranges from the nodes that haven't been processed yet
|
||||
// if a node is not concurrent with this set, we cannot reorder it
|
||||
ggml_mem_ranges_reset(mrs1);
|
||||
|
||||
// initialize it with the current node
|
||||
h_add(mrs1, node0);
|
||||
|
||||
// that many nodes forward to search for a concurrent node
|
||||
constexpr int N_FORWARD = 8;
|
||||
|
||||
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
|
||||
if (used[i1]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto & node1 = nodes[i1];
|
||||
|
||||
// disallow reordering of certain ops
|
||||
if (!h_safe(node1.op())) {
|
||||
break;
|
||||
}
|
||||
|
||||
const bool is_empty = node1.is_empty();
|
||||
|
||||
// to add a concurrent node, it has to be:
|
||||
// + empty or concurrent with all nodes in the existing concurrent set (mrs0)
|
||||
// + concurrent with all nodes prior to it that haven't been processed yet (mrs1)
|
||||
if ((is_empty || h_check(mrs0, node1)) && h_check(mrs1, node1)) {
|
||||
// add the node to the existing concurrent set (i.e. reorder it for early execution)
|
||||
h_add(mrs0, node1);
|
||||
res.push_back(i1);
|
||||
|
||||
// mark as used, so we skip re-processing it later
|
||||
used[i1] = true;
|
||||
} else {
|
||||
// expand the set of nodes that haven't been processed yet
|
||||
h_add(mrs1, node1);
|
||||
}
|
||||
}
|
||||
|
||||
// finalize the concurrent set and begin a new one
|
||||
ggml_mem_ranges_reset(mrs0);
|
||||
}
|
||||
|
||||
// expand the concurrent set with the current node
|
||||
{
|
||||
h_add(mrs0, node0);
|
||||
res.push_back(i0);
|
||||
}
|
||||
}
|
||||
|
||||
ggml_mem_ranges_free(mrs0);
|
||||
ggml_mem_ranges_free(mrs1);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
void ggml_metal_graph_optimize(ggml_cgraph * gf) {
|
||||
constexpr int MAX_FUSE = 16;
|
||||
|
||||
const int n = gf->n_nodes;
|
||||
|
||||
enum ggml_op ops[MAX_FUSE];
|
||||
|
||||
std::vector<node_info> nodes;
|
||||
nodes.reserve(gf->n_nodes);
|
||||
|
||||
// fuse nodes:
|
||||
// we don't want to make reorders that break fusing, so we first pack all fusable tensors
|
||||
// and perform the reorder over the fused nodes. after the reorder is done, we unfuse
|
||||
for (int i = 0; i < n; i++) {
|
||||
node_info node = {
|
||||
/*.node =*/ gf->nodes[i],
|
||||
/*.fused =*/ {},
|
||||
};
|
||||
|
||||
// fuse only ops that start with these operations
|
||||
// can be expanded when needed
|
||||
if (node.op() == GGML_OP_ADD ||
|
||||
node.op() == GGML_OP_RMS_NORM) {
|
||||
ops[0] = node.op();
|
||||
|
||||
int f = i + 1;
|
||||
while (f < n && f < i + MAX_FUSE) {
|
||||
// conservatively allow fusing only these ops
|
||||
// can be expanded when needed
|
||||
if (gf->nodes[f]->op != GGML_OP_ADD &&
|
||||
gf->nodes[f]->op != GGML_OP_MUL &&
|
||||
gf->nodes[f]->op != GGML_OP_RMS_NORM) {
|
||||
break;
|
||||
}
|
||||
ops[f - i] = gf->nodes[f]->op;
|
||||
f++;
|
||||
}
|
||||
|
||||
f -= i;
|
||||
for (; f > 1; f--) {
|
||||
if (ggml_can_fuse(gf, i, ops, f)) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// add the fused tensors into the node info so we can unfuse them later
|
||||
for (int k = 1; k < f; k++) {
|
||||
++i;
|
||||
|
||||
// the .dst() becomes the last fused tensor
|
||||
node.add_fused(gf->nodes[i]);
|
||||
}
|
||||
}
|
||||
|
||||
nodes.push_back(std::move(node));
|
||||
}
|
||||
|
||||
// reorder to improve concurrency
|
||||
#if 1
|
||||
const auto order = ggml_metal_graph_optimize_reorder(nodes);
|
||||
#else
|
||||
std::vector<int> order(nodes.size());
|
||||
for (size_t i = 0; i < nodes.size(); i++) {
|
||||
order[i] = i;
|
||||
}
|
||||
#endif
|
||||
|
||||
// unfuse
|
||||
{
|
||||
int j = 0;
|
||||
for (const auto i : order) {
|
||||
const auto & node = nodes[i];
|
||||
|
||||
gf->nodes[j++] = node.node;
|
||||
|
||||
for (auto * fused : node.fused) {
|
||||
gf->nodes[j++] = fused;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
52
ggml/src/ggml-metal/ggml-metal-common.h
Normal file
52
ggml/src/ggml-metal/ggml-metal-common.h
Normal file
@@ -0,0 +1,52 @@
|
||||
// helper functions for ggml-metal that are too difficult to implement in Objective-C
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct ggml_tensor;
|
||||
struct ggml_cgraph;
|
||||
|
||||
enum ggml_mem_range_type {
|
||||
MEM_RANGE_TYPE_SRC = 0,
|
||||
MEM_RANGE_TYPE_DST = 1,
|
||||
};
|
||||
|
||||
// a helper object that can be used for reordering operations to improve concurrency
|
||||
//
|
||||
// the fundamental idea is that a set of tasks (either ggml ops, or something else) can run concurrently if they
|
||||
// don't write to a memory that is being read by another task or written to by another task in the set
|
||||
//
|
||||
// with this structure, we can add tasks to the set, setting memory constraints. we can also check if a new task
|
||||
// can be added to the set without violating the constraints (i.e. if it can be executed concurrently with the
|
||||
// tasks already in the set)
|
||||
//
|
||||
struct ggml_mem_ranges;
|
||||
|
||||
struct ggml_mem_ranges * ggml_mem_ranges_init(int debug);
|
||||
void ggml_mem_ranges_free(struct ggml_mem_ranges * mrs);
|
||||
|
||||
// remove all ranges from the set
|
||||
void ggml_mem_ranges_reset(struct ggml_mem_ranges * mrs);
|
||||
|
||||
// add src or dst ranges to track
|
||||
bool ggml_mem_ranges_add(struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
|
||||
|
||||
// return false if:
|
||||
// - new src range overlaps with any existing dst range
|
||||
// - new dst range overlaps with any existing range (src or dst)
|
||||
bool ggml_mem_ranges_check(const struct ggml_mem_ranges * mrs, const struct ggml_tensor * tensor);
|
||||
|
||||
// reorder the nodes in the graph to improve concurrency, while respecting fusion
|
||||
//
|
||||
// note: this implementation is generic and not specific to metal
|
||||
// if it proves to work well, we can start using it for other backends in the future
|
||||
void ggml_metal_graph_optimize(struct ggml_cgraph * gf);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
@@ -3,6 +3,7 @@
|
||||
#import "ggml-impl.h"
|
||||
#import "ggml-backend-impl.h"
|
||||
#import "ggml-metal-impl.h"
|
||||
#import "ggml-metal-common.h"
|
||||
|
||||
#import <Foundation/Foundation.h>
|
||||
|
||||
@@ -61,8 +62,11 @@ static struct ggml_backend_metal_device_context {
|
||||
bool has_bfloat;
|
||||
bool use_bfloat;
|
||||
bool use_fusion;
|
||||
bool use_concurrency;
|
||||
bool use_shared_buffers;
|
||||
bool use_graph_optimize;
|
||||
|
||||
int debug_graph;
|
||||
int debug_fusion;
|
||||
|
||||
// how many times a given op was fused
|
||||
@@ -83,7 +87,10 @@ static struct ggml_backend_metal_device_context {
|
||||
/*.has_bfloat =*/ false,
|
||||
/*.use_bfloat =*/ false,
|
||||
/*.use_fusion =*/ true,
|
||||
/*.use_concurrency =*/ true,
|
||||
/*.use_shared_buffers =*/ true,
|
||||
/*.use_graph_optimize =*/ true,
|
||||
/*.debug_graph =*/ 0,
|
||||
/*.debug_fusion =*/ 0,
|
||||
/*.fuse_cnt =*/ { 0 },
|
||||
/*.max_size =*/ 0,
|
||||
@@ -124,7 +131,14 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
||||
#else
|
||||
ctx->use_bfloat = false;
|
||||
#endif
|
||||
|
||||
ctx->use_fusion = getenv("GGML_METAL_FUSION_DISABLE") == nil;
|
||||
ctx->use_concurrency = getenv("GGML_METAL_CONCURRENCY_DISABLE") == nil;
|
||||
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_GRAPH_DEBUG");
|
||||
ctx->debug_graph = val ? atoi(val) : 0;
|
||||
}
|
||||
|
||||
{
|
||||
const char * val = getenv("GGML_METAL_FUSION_DEBUG");
|
||||
@@ -137,6 +151,12 @@ static id<MTLDevice> ggml_backend_metal_device_acq(struct ggml_backend_metal_dev
|
||||
ctx->use_shared_buffers = false;
|
||||
}
|
||||
|
||||
ctx->use_graph_optimize = true;
|
||||
|
||||
if (getenv("GGML_METAL_GRAPH_OPTIMIZE_DISABLE") != NULL) {
|
||||
ctx->use_graph_optimize = false;
|
||||
}
|
||||
|
||||
memset(ctx->fuse_cnt, 0, sizeof(ctx->fuse_cnt));
|
||||
|
||||
ctx->max_size = ctx->mtl_device.maxBufferLength;
|
||||
@@ -628,7 +648,7 @@ static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
|
||||
@end
|
||||
|
||||
//
|
||||
// ggml_metal_mem_pool
|
||||
// ggml_metal_mem_pool [TAG_MEM_POOL_REMOVE]
|
||||
//
|
||||
|
||||
struct ggml_metal_mem_pool {
|
||||
@@ -791,6 +811,9 @@ struct ggml_metal_command_buffer {
|
||||
|
||||
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
|
||||
struct ggml_metal_mem_pool * mem_pool;
|
||||
|
||||
// used to enable concurrent execution of ops in the command buffers
|
||||
struct ggml_mem_ranges * mem_ranges;
|
||||
};
|
||||
|
||||
struct ggml_backend_metal_context {
|
||||
@@ -1091,7 +1114,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_LOG_INFO("%s: has bfloat = %s\n", __func__, ctx_dev->has_bfloat ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use bfloat = %s\n", __func__, ctx_dev->use_bfloat ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use fusion = %s\n", __func__, ctx_dev->use_fusion ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use concurrency = %s\n", __func__, ctx_dev->use_concurrency ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use shared buffers = %s\n", __func__, ctx_dev->use_shared_buffers ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: use graph optimize = %s\n", __func__, ctx_dev->use_graph_optimize ? "true" : "false");
|
||||
GGML_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
|
||||
|
||||
ctx->capture_next_compute = false;
|
||||
@@ -1105,6 +1130,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
|
||||
ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
|
||||
ctx->cmd_bufs[i].mem_pool->device = device;
|
||||
|
||||
if (ctx_dev->use_concurrency) {
|
||||
ctx->cmd_bufs[i].mem_ranges = ggml_mem_ranges_init(ctx_dev->debug_graph);
|
||||
}
|
||||
}
|
||||
|
||||
ctx->cmd_bufs_ext = [[NSMutableArray alloc] init];
|
||||
@@ -1715,6 +1744,10 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
|
||||
}
|
||||
|
||||
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
|
||||
|
||||
if (ctx->cmd_bufs[i].mem_ranges) {
|
||||
ggml_mem_ranges_free(ctx->cmd_bufs[i].mem_ranges);
|
||||
}
|
||||
}
|
||||
|
||||
[ctx->cmd_bufs_ext removeAllObjects];
|
||||
@@ -2071,12 +2104,51 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
}
|
||||
}
|
||||
|
||||
static int ggml_metal_encode_node(
|
||||
ggml_backend_t backend,
|
||||
int idx,
|
||||
int idx_end,
|
||||
id<MTLComputeCommandEncoder> encoder,
|
||||
struct ggml_metal_mem_pool * mem_pool) {
|
||||
struct ggml_metal_encode_context {
|
||||
ggml_backend_t backend;
|
||||
|
||||
id<MTLComputeCommandEncoder> encoder;
|
||||
|
||||
struct ggml_metal_mem_pool * mem_pool;
|
||||
|
||||
struct ggml_mem_ranges * mem_ranges;
|
||||
};
|
||||
|
||||
static bool ggml_metal_encode_concurrency_reset(struct ggml_metal_encode_context * ctx) {
|
||||
if (!ctx->mem_ranges) {
|
||||
return true;
|
||||
}
|
||||
|
||||
[ctx->encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
|
||||
|
||||
ggml_mem_ranges_reset(ctx->mem_ranges);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool ggml_metal_encode_concurrency_check(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
|
||||
if (!ctx->mem_ranges) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_check(ctx->mem_ranges, node);
|
||||
}
|
||||
|
||||
static bool ggml_metal_encode_concurrency_add(struct ggml_metal_encode_context * ctx, const struct ggml_tensor * node) {
|
||||
if (!ctx->mem_ranges) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return ggml_mem_ranges_add(ctx->mem_ranges, node);
|
||||
}
|
||||
|
||||
static int ggml_metal_encode_node(struct ggml_metal_encode_context * ctx_enc, int idx, int idx_end) {
|
||||
ggml_backend_t backend = ctx_enc->backend;
|
||||
|
||||
id<MTLComputeCommandEncoder> encoder = ctx_enc->encoder;
|
||||
|
||||
struct ggml_metal_mem_pool * mem_pool = ctx_enc->mem_pool;
|
||||
|
||||
struct ggml_backend_metal_context * ctx = backend->context;
|
||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||
|
||||
@@ -2159,38 +2231,71 @@ static int ggml_metal_encode_node(
|
||||
const uint64_t nb2 = dst ? dst->nb[2] : 0;
|
||||
const uint64_t nb3 = dst ? dst->nb[3] : 0;
|
||||
|
||||
size_t offs_src[GGML_MAX_SRC];
|
||||
|
||||
id<MTLBuffer> id_src[GGML_MAX_SRC];
|
||||
|
||||
enum ggml_type srct[GGML_MAX_SRC];
|
||||
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
offs_src[i] = 0;
|
||||
id_src[i] = node->src[i] ? ggml_metal_get_buffer(node->src[i], &offs_src[i]) : nil;
|
||||
srct[i] = node->src[i] ? node->src[i]->type : GGML_TYPE_COUNT;
|
||||
}
|
||||
|
||||
// TODO: tmp shorthands - remove
|
||||
size_t offs_src0 = offs_src[0];
|
||||
size_t offs_src1 = offs_src[1];
|
||||
size_t offs_src2 = offs_src[2];
|
||||
|
||||
id<MTLBuffer> id_src0 = id_src[0];
|
||||
id<MTLBuffer> id_src1 = id_src[1];
|
||||
id<MTLBuffer> id_src2 = id_src[2];
|
||||
|
||||
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
|
||||
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
|
||||
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT;
|
||||
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
|
||||
|
||||
size_t offs_src0 = 0;
|
||||
size_t offs_src1 = 0;
|
||||
size_t offs_src2 = 0;
|
||||
size_t offs_dst = 0;
|
||||
|
||||
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
|
||||
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
|
||||
id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
|
||||
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(dst, &offs_dst) : nil;
|
||||
|
||||
int n_fuse = 1;
|
||||
|
||||
#if 0
|
||||
GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
|
||||
// check if the current node can run concurrently with other nodes before it
|
||||
// the condition is that:
|
||||
// - the current node cannot write to any previous src or dst ranges
|
||||
// - the current node cannot read from any previous dst ranges
|
||||
//
|
||||
// if the condition is not satisfied, we put a memory barrier and clear all ranges
|
||||
// otherwise, we add the new ranges to the encoding context and process the node concurrently
|
||||
//
|
||||
{
|
||||
const bool is_concurrent = ggml_metal_encode_concurrency_check(ctx_enc, node);
|
||||
|
||||
if (!is_concurrent) {
|
||||
ggml_metal_encode_concurrency_reset(ctx_enc);
|
||||
}
|
||||
|
||||
if (ctx_dev->debug_graph > 0) {
|
||||
GGML_LOG_DEBUG("%s: node[%5d] - %-12s %s\n", __func__, idx, ggml_op_name(dst->op), is_concurrent ? "(concurrent)" : "");
|
||||
}
|
||||
if (ctx_dev->debug_graph > 1) {
|
||||
if (src0) {
|
||||
GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
|
||||
GGML_LOG_DEBUG("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
|
||||
ggml_is_contiguous(src0), src0->name);
|
||||
}
|
||||
if (src1) {
|
||||
GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
||||
GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
|
||||
ggml_is_contiguous(src1), src1->name);
|
||||
}
|
||||
if (dst) {
|
||||
GGML_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
||||
GGML_LOG_DEBUG("%s: dst - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
|
||||
dst->name);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
id<MTLDevice> device = ctx_dev->mtl_device;
|
||||
|
||||
@@ -2389,6 +2494,14 @@ static int ggml_metal_encode_node(
|
||||
|
||||
if (n_fuse > 1) {
|
||||
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
||||
|
||||
for (int i = 1; i < n_fuse; ++i) {
|
||||
if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
|
||||
ggml_metal_encode_concurrency_reset(ctx_enc);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
@@ -2533,6 +2646,8 @@ static int ggml_metal_encode_node(
|
||||
const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
|
||||
ggml_metal_encode_concurrency_reset(ctx_enc);
|
||||
}
|
||||
|
||||
const id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline;
|
||||
@@ -3997,6 +4112,12 @@ static int ggml_metal_encode_node(
|
||||
default: break;
|
||||
}
|
||||
|
||||
// TODO: using mem pool allocations with enabled concurrency is not safe because the mem pool
|
||||
// reuses buffers. this can result in 2 concurrent MUL_MAT_ID ops using the same mem pool buffer.
|
||||
// so we add this extra barrier to prevent the race.
|
||||
// the correct solution is to remove mem pools and then remove this barrier [TAG_MEM_POOL_REMOVE]
|
||||
ggml_metal_encode_concurrency_reset(ctx_enc);
|
||||
|
||||
// tokens per expert
|
||||
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
|
||||
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
||||
@@ -4057,6 +4178,9 @@ static int ggml_metal_encode_node(
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
|
||||
}
|
||||
|
||||
// this barrier is always needed because the next kernel has to wait for the id maps to be computed
|
||||
ggml_metal_encode_concurrency_reset(ctx_enc);
|
||||
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
@@ -4525,6 +4649,14 @@ static int ggml_metal_encode_node(
|
||||
|
||||
if (n_fuse > 1) {
|
||||
id_dst = ggml_metal_get_buffer(nodes[n_fuse - 1], &offs_dst);
|
||||
|
||||
for (int i = 1; i < n_fuse; ++i) {
|
||||
if (!ggml_metal_encode_concurrency_check(ctx_enc, nodes[i])) {
|
||||
ggml_metal_encode_concurrency_reset(ctx_enc);
|
||||
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> pipeline;
|
||||
@@ -4668,7 +4800,6 @@ static int ggml_metal_encode_node(
|
||||
} break;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
|
||||
// make sure we have one or more position id(ne10) per token(ne02)
|
||||
GGML_ASSERT(ne10 % ne02 == 0);
|
||||
GGML_ASSERT(ne10 >= ne02);
|
||||
@@ -5427,6 +5558,10 @@ static int ggml_metal_encode_node(
|
||||
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
||||
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
|
||||
|
||||
// using mem pool allocations with enabled concurrency is not safe [TAG_MEM_POOL_REMOVE]
|
||||
// still, we assume that concurrent FA won't happen before we do the refactor
|
||||
//ggml_metal_encode_concurrency_reset(ctx_enc);
|
||||
|
||||
const int32_t nrows = ne1*ne2*ne3;
|
||||
|
||||
// temp buffer for writing the results from each workgroup
|
||||
@@ -5447,6 +5582,8 @@ static int ggml_metal_encode_node(
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
|
||||
ggml_metal_encode_concurrency_reset(ctx_enc);
|
||||
|
||||
// reduce the results from the workgroups
|
||||
{
|
||||
ggml_metal_kargs_flash_attn_ext_vec_reduce args0 = {
|
||||
@@ -5709,6 +5846,19 @@ static int ggml_metal_encode_node(
|
||||
}
|
||||
}
|
||||
|
||||
if (ctx_dev->debug_graph > 0) {
|
||||
if (n_fuse > 1) {
|
||||
GGML_LOG_DEBUG("%s: fuse %d ops\n", __func__, n_fuse);
|
||||
}
|
||||
}
|
||||
|
||||
// update the mem ranges in the encoding context
|
||||
for (int i = 0; i < n_fuse; ++i) {
|
||||
if (!ggml_metal_encode_concurrency_add(ctx_enc, nodes[i])) {
|
||||
ggml_metal_encode_concurrency_reset(ctx_enc);
|
||||
}
|
||||
}
|
||||
|
||||
return n_fuse;
|
||||
}
|
||||
|
||||
@@ -5719,7 +5869,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||
|
||||
// number of nodes encoded by the main thread (empirically determined)
|
||||
const int n_main = 128;
|
||||
const int n_main = 64;
|
||||
|
||||
// number of threads in addition to the main thread
|
||||
const int n_cb = ctx->n_cb;
|
||||
@@ -5774,6 +5924,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||
// cannot use commandBufferWithUnretainedReferences because the buffers from the memory pool can get destroyed
|
||||
// TODO: when the memory pools are removed, we can again use commandBufferWithUnretainedReferences
|
||||
// https://github.com/ggml-org/llama.cpp/pull/15832#discussion_r2334215009
|
||||
// [TAG_MEM_POOL_REMOVE]
|
||||
//id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
|
||||
id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBuffer];
|
||||
[cmd_buf retain];
|
||||
@@ -6547,6 +6698,18 @@ static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend,
|
||||
return ggml_metal_graph_compute(backend, cgraph);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
|
||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||
|
||||
//const int64_t t_start = ggml_time_us();
|
||||
|
||||
if (ctx_dev->use_graph_optimize) {
|
||||
ggml_metal_graph_optimize(cgraph);
|
||||
}
|
||||
|
||||
//printf("%s: graph optimize took %.3f ms\n", __func__, (ggml_time_us() - t_start) / 1000.0);
|
||||
}
|
||||
|
||||
static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||
GGML_ASSERT(ggml_backend_is_metal(backend));
|
||||
|
||||
@@ -6575,10 +6738,23 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||
|
||||
id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
|
||||
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
|
||||
struct ggml_mem_ranges * mem_ranges = ctx->cmd_bufs[cb_idx].mem_ranges;
|
||||
|
||||
ggml_metal_mem_pool_reset(mem_pool);
|
||||
|
||||
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
|
||||
if (mem_ranges) {
|
||||
ggml_mem_ranges_reset(mem_ranges);
|
||||
}
|
||||
|
||||
id<MTLComputeCommandEncoder> encoder;
|
||||
|
||||
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
|
||||
|
||||
if (ctx_dev->use_concurrency) {
|
||||
encoder = [cmd_buf computeCommandEncoderWithDispatchType: MTLDispatchTypeConcurrent];
|
||||
} else {
|
||||
encoder = [cmd_buf computeCommandEncoder];
|
||||
}
|
||||
|
||||
int node_start = 0;
|
||||
int node_end = n_nodes_0;
|
||||
@@ -6590,12 +6766,19 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
|
||||
|
||||
const bool should_capture = ctx->capture_next_compute;
|
||||
|
||||
struct ggml_metal_encode_context ctx_enc = {
|
||||
/*.backend =*/ backend,
|
||||
/*.encoder =*/ encoder,
|
||||
/*.mem_pool =*/ mem_pool,
|
||||
/*.mem_ranges =*/ mem_ranges,
|
||||
};
|
||||
|
||||
for (int idx = node_start; idx < node_end;) {
|
||||
if (should_capture) {
|
||||
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
|
||||
}
|
||||
|
||||
const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool);
|
||||
const int res = ggml_metal_encode_node(&ctx_enc, idx, node_end);
|
||||
if (idx + res > node_end) {
|
||||
GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s",
|
||||
"https://github.com/ggml-org/llama.cpp/pull/14849");
|
||||
@@ -6638,7 +6821,7 @@ static struct ggml_backend_i ggml_backend_metal_i = {
|
||||
// https://developer.apple.com/documentation/metal/mtlcommandbuffer#Synchronizing-Passes-with-Events
|
||||
/* .event_record = */ NULL,
|
||||
/* .event_wait = */ NULL,
|
||||
/* .optimize_graph = */ NULL,
|
||||
/* .optimize_graph = */ ggml_backend_metal_graph_optimize,
|
||||
};
|
||||
|
||||
static ggml_guid_t ggml_backend_metal_guid(void) {
|
||||
|
||||
Reference in New Issue
Block a user