CUDA: fuse adds, fuse add with rms norm (#15631)

* CUDA: fused add with rms_norm_mul

* Non-broadcast fuse works

* Add fused adds

* format

* Remove n_fuse from template params

* Address review comments

* Move template inside binbcast
This commit is contained in:
Aman Gupta
2025-08-29 11:35:58 +08:00
committed by GitHub
parent e8d99dd0b6
commit 009b709d6e
5 changed files with 501 additions and 190 deletions

View File

@@ -2821,9 +2821,14 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
const ggml_tensor *add = nullptr;
if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
add = cgraph->nodes[node_idx+1];
}
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
@@ -2835,6 +2840,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
if (add && (add->src[0]->type != GGML_TYPE_F32 ||
add->src[1]->type != GGML_TYPE_F32 ||
add->type != GGML_TYPE_F32) ) {
return false;
}
//if rms norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
return false;
@@ -2845,6 +2856,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false;
}
if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
return false;
}
return true;
}
@@ -2891,7 +2906,46 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
if (node->op == GGML_OP_ADD) {
int n_fuse = 0;
ggml_op ops[8];
std::fill(ops, ops + 8, GGML_OP_ADD);
for (; n_fuse <= 6; ++n_fuse){
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
break;
}
if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
break;
}
if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
break;
}
}
n_fuse++;
if (n_fuse > 1) {
for (int j = 0; j < n_fuse - 1; ++j) {
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
}
cgraph->nodes[i + n_fuse - 1]->data = node->data;
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
i += n_fuse - 1;
continue;
}
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;