mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-13 10:57:15 +00:00
CUDA: fuse adds, fuse add with rms norm (#15631)
* CUDA: fused add with rms_norm_mul * Non-broadcast fuse works * Add fused adds * format * Remove n_fuse from template params * Address review comments * Move template inside binbcast
This commit is contained in:
@@ -2821,9 +2821,14 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
|
||||
if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
|
||||
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
|
||||
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
|
||||
const ggml_tensor *add = nullptr;
|
||||
|
||||
if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
|
||||
add = cgraph->nodes[node_idx+1];
|
||||
}
|
||||
|
||||
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
|
||||
@@ -2835,6 +2840,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (add && (add->src[0]->type != GGML_TYPE_F32 ||
|
||||
add->src[1]->type != GGML_TYPE_F32 ||
|
||||
add->type != GGML_TYPE_F32) ) {
|
||||
return false;
|
||||
}
|
||||
|
||||
//if rms norm is the B operand, then we don't handle broadcast
|
||||
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
|
||||
return false;
|
||||
@@ -2845,6 +2856,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
|
||||
return false;
|
||||
}
|
||||
|
||||
if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -2891,7 +2906,46 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
|
||||
|
||||
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
|
||||
if (!disable_fusion) {
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
|
||||
|
||||
if (node->op == GGML_OP_ADD) {
|
||||
int n_fuse = 0;
|
||||
ggml_op ops[8];
|
||||
std::fill(ops, ops + 8, GGML_OP_ADD);
|
||||
|
||||
for (; n_fuse <= 6; ++n_fuse){
|
||||
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
|
||||
break;
|
||||
}
|
||||
if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
|
||||
break;
|
||||
}
|
||||
if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
n_fuse++;
|
||||
|
||||
if (n_fuse > 1) {
|
||||
for (int j = 0; j < n_fuse - 1; ++j) {
|
||||
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
|
||||
}
|
||||
cgraph->nodes[i + n_fuse - 1]->data = node->data;
|
||||
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
|
||||
i += n_fuse - 1;
|
||||
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
|
||||
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
|
||||
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
|
||||
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
|
||||
i++;
|
||||
continue;
|
||||
|
||||
Reference in New Issue
Block a user