CUDA: fix bug in rms_norm fusion (#15660)

* CUDA: fix bug in rms_norm fusion

* Fix bug for OP_REPEAT

* Fix index for add
This commit is contained in:
Aman Gupta
2025-08-29 21:30:06 +08:00
committed by GitHub
parent 60e5eee31f
commit 81017865ee
3 changed files with 51 additions and 23 deletions

View File

@@ -57,7 +57,11 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
const int i10 = i0 % ne10; const int i10 = i0 % ne10;
float result = src0_row ? (float) src0_row[i0] : 0.0f; float result = src0_row ? (float) src0_row[i0] : 0.0f;
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
}
dst_row[i0] = (dst_t) result; dst_row[i0] = (dst_t) result;
} }
@@ -96,7 +100,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
const int i10 = i0 % ne10; const int i10 = i0 % ne10;
float result = src0_row ? (float) src0_row[i0] : 0.0f; float result = src0_row ? (float) src0_row[i0] : 0.0f;
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
}
dst_row[i0] = (dst_t) result; dst_row[i0] = (dst_t) result;
} }
@@ -231,23 +239,43 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
if (block_nums.z > 65535) { if (block_nums.z > 65535) {
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t> if constexpr (sizeof...(I) > 0) {
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
ne0, ne1, ne2, ne3, <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
/* s0, */ s1, s2, s3, ne10, ne11, ne12, ne13,
/* s00,*/ s01, s02, s03, /* s0, */ s1, s2, s3,
/* s10,*/ s11, s12,s13, /* s00,*/ s01, s02, s03,
(const src1_t *) dst->src[I + 1]->data...); /* s10,*/ s11, s12,s13,
(const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13);
}
} else { } else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t> if constexpr (sizeof...(I) > 0) {
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd, k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
ne0, ne1, ne2, ne3, <<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
/* s0, */ s1, s2, s3, ne10, ne11, ne12, ne13,
/* s00,*/ s01, s02, s03, /* s0, */ s1, s2, s3,
/* s10,*/ s11, s12,s13, /* s00,*/ s01, s02, s03,
(const src1_t *) dst->src[I + 1]->data...); /* s10,*/ s11, s12,s13,
(const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13);
}
} }
} }
} }
@@ -327,7 +355,7 @@ static void ggml_cuda_op_bin_bcast(
} }
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream()); ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat, 0>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
} }
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@@ -2827,7 +2827,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
const ggml_tensor *add = nullptr; const ggml_tensor *add = nullptr;
if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) { if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
add = cgraph->nodes[node_idx+1]; add = cgraph->nodes[node_idx+2];
} }
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);

View File

@@ -127,6 +127,7 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
const int add_nrows = 0, const int add_nrows = 0,
const int add_nchannels = 0, const int add_nchannels = 0,
const int add_nsamples = 0) { const int add_nsamples = 0) {
const int nrows = gridDim.x; const int nrows = gridDim.x;
const int nchannels = gridDim.y; const int nchannels = gridDim.y;
@@ -135,6 +136,8 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
const int sample = blockIdx.z; const int sample = blockIdx.z;
const int tid = threadIdx.x; const int tid = threadIdx.x;
static_assert(!do_add || do_multiply, "fusing add is not supported without multiplying");
x += sample*stride_sample + channel*stride_channel + row*stride_row; x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols; dst += ((sample*nchannels + channel)*nrows + row)*ncols;
@@ -185,9 +188,6 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
} else if constexpr (do_multiply) { } else if constexpr (do_multiply) {
const int mul_col = col % mul_ncols; const int mul_col = col % mul_ncols;
dst[col] = scale * x[col] * mul[mul_col]; dst[col] = scale * x[col] * mul[mul_col];
} else if constexpr (do_add) {
const int add_col = col % add_ncols;
dst[col] += add[add_col];
} else { } else {
dst[col] = scale * x[col]; dst[col] = scale * x[col];
} }