CUDA: fix bug in rms_norm fusion (#15660)

* CUDA: fix bug in rms_norm fusion

* Fix bug for OP_REPEAT

* Fix index for add
This commit is contained in:
Aman Gupta
2025-08-29 21:30:06 +08:00
committed by GitHub
parent 60e5eee31f
commit 81017865ee
3 changed files with 51 additions and 23 deletions

View File

@@ -2827,7 +2827,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
const ggml_tensor *add = nullptr;
if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
add = cgraph->nodes[node_idx+1];
add = cgraph->nodes[node_idx+2];
}
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);