CUDA: fuse adds, fuse add with rms norm (#15631)

* CUDA: fused add with rms_norm_mul

* Non-broadcast fuse works

* Add fused adds

* format

* Remove n_fuse from template params

* Address review comments

* Move template inside binbcast
This commit is contained in:
Aman Gupta
2025-08-29 11:35:58 +08:00
committed by GitHub
parent e8d99dd0b6
commit 009b709d6e
5 changed files with 501 additions and 190 deletions

View File

@@ -1,5 +1,6 @@
#include "binbcast.cuh" #include "binbcast.cuh"
#include <cstdint> #include <cstdint>
#include <utility>
static __device__ __forceinline__ float op_repeat(const float a, const float b) { static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b; return b;
@@ -22,13 +23,16 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
return a / b; return a / b;
} }
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3, const int ne0, const int ne1, const int ne2, const int ne3,
int ne10, int ne11, int ne12, int ne13, const int ne10, const int ne11, const int ne12, const int ne13,
/*int s0, */ int s1, int s2, int s3, /*int s0, */ const int s1, const int s2, const int s3,
/*int s00,*/ int s01, int s02, int s03, /*int s00,*/ const int s01, const int s02, const int s03,
/*int s10,*/ int s11, int s12, int s13) { /*int s10,*/ const int s11, const int s12, const int s13,
src1_ptrs... src1s) {
const int i0s = blockDim.x*blockIdx.x + threadIdx.x; const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y); const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3; const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
@@ -46,24 +50,27 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1; const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0; const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
const src1_t * src1_row = src1 + i_src1;
dst_t * dst_row = dst + i_dst; dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) { for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
const int i10 = i0 % ne10; const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
dst_row[i0] = (dst_t) result;
} }
} }
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t> template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3, const int ne0, const int ne1, const int ne2,const int ne3,
int ne10, int ne11, int ne12, int ne13, const int ne10, const int ne11, const int ne12, const int ne13,
/*int s0, */ int s1, int s2, int s3, /*int s0, */ const int s1, const int s2, const int s3,
/*int s00,*/ int s01, int s02, int s03, /*int s00,*/ const int s01, const int s02, const int s03,
/*int s10,*/ int s11, int s12, int s13) { /*int s10,*/ const int s11, const int s12, const int s13,
src1_ptrs ... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
const int i3 = i/(ne2*ne1*ne0); const int i3 = i/(ne2*ne1*ne0);
@@ -83,50 +90,21 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i3*s3 + i2*s2 + i1*s1; const size_t i_dst = i3*s3 + i2*s2 + i1*s1;
const src0_t * src0_row = src0 + i_src0; const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
const src1_t * src1_row = src1 + i_src1;
dst_t * dst_row = dst + i_dst; dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10; const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
dst_row[i0] = (dst_t) result;
} }
template <typename T> template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I>
static __global__ void k_repeat_back( static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
const int64_t tid2 = tid23 % ne2;
const int64_t tid3 = tid23 / ne2;
if (tid0 >= ne0) {
return;
}
T sum = 0;
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
}
}
}
}
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
}
template<float (*bin_op)(const float, const float)>
struct bin_bcast_cuda {
template<typename src0_t, typename src1_t, typename dst_t>
void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
cudaStream_t stream) { cudaStream_t stream, std::index_sequence<I...>) {
GGML_TENSOR_BINARY_OP_LOCALS GGML_TENSOR_BINARY_OP_LOCALS
int nr0 = ne10 / ne0; int nr0 = ne10 / ne0;
@@ -136,7 +114,6 @@ struct bin_bcast_cuda {
int nr[4] = { nr0, nr1, nr2, nr3 }; int nr[4] = { nr0, nr1, nr2, nr3 };
// collapse dimensions until first broadcast dimension
int64_t cne[] = { ne0, ne1, ne2, ne3 }; int64_t cne[] = { ne0, ne1, ne2, ne3 };
int64_t cne0[] = { ne00, ne01, ne02, ne03 }; int64_t cne0[] = { ne00, ne01, ne02, ne03 };
int64_t cne1[] = { ne10, ne11, ne12, ne13 }; int64_t cne1[] = { ne10, ne11, ne12, ne13 };
@@ -248,33 +225,71 @@ struct bin_bcast_cuda {
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x); block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U); block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
dim3 block_nums( dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x,
(hne0 + block_dims.x - 1) / block_dims.x,
(ne1 + block_dims.y - 1) / block_dims.y, (ne1 + block_dims.y - 1) / block_dims.y,
(ne2*ne3 + block_dims.z - 1) / block_dims.z (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
);
if (block_nums.z > 65535) { if (block_nums.z > 65535) {
// this is the maximum number of blocks in z dimension, fallback to 1D grid kernel
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>( k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
src0_dd, src1_dd, dst_dd, <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3, ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3, /* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03, /* s00,*/ s01, s02, s03,
/* s10, */ s11, s12, s13); /* s10,*/ s11, s12,s13,
(const src1_t *) dst->src[I + 1]->data...);
} else { } else {
k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>( k_bin_bcast<bin_op, src0_t, src1_t, dst_t>
src0_dd, src1_dd, dst_dd, <<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3, ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3, /* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03, /* s00,*/ s01, s02, s03,
/* s10, */ s11, s12, s13); /* s10,*/ s11, s12,s13,
(const src1_t *) dst->src[I + 1]->data...);
} }
} }
} }
template <typename T>
static __global__ void k_repeat_back(
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
const int64_t tid2 = tid23 % ne2;
const int64_t tid3 = tid23 / ne2;
if (tid0 >= ne0) {
return;
}
T sum = 0;
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
}
}
}
}
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
}
template <float (*bin_op)(const float, const float), int n_fuse = 1>
struct bin_bcast_cuda {
template<typename src0_t, typename src1_t, typename dst_t>
void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
cudaStream_t stream) {
launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>(
src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{});
}
}; };
template <typename T> template <typename T>
@@ -331,6 +346,68 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
} }
template <float (*op)(const float, const float), int n_fuse>
static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
cudaStream_t stream = ctx.stream();
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst,
(const float *) src0->data, (const float *) src1->data, (float *) dst->data,
stream, std::make_index_sequence<n_fuse>{});
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst,
(const half *) src0->data, (const half *) src1->data, (half *) dst->data,
stream, std::make_index_sequence<n_fuse>{});
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst,
(const half *) src0->data, (const float *) src1->data, (half *) dst->data,
stream, std::make_index_sequence<n_fuse>{});
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst,
(const half *) src0->data, (const float *) src1->data, (float *) dst->data,
stream, std::make_index_sequence<n_fuse>{});
} else {
fprintf(stderr,
"%s: unsupported types for fusion: dst: %s, src0: %s, src1: %s\n",
__func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ABORT("fatal error");
}
}
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
switch (n_fuse) {
case 2:
ggml_cuda_op_fused_binbcast_impl<op_add, 2>(ctx, dst);
break;
case 3:
ggml_cuda_op_fused_binbcast_impl<op_add, 3>(ctx, dst);
break;
case 4:
ggml_cuda_op_fused_binbcast_impl<op_add, 4>(ctx, dst);
break;
case 5:
ggml_cuda_op_fused_binbcast_impl<op_add, 5>(ctx, dst);
break;
case 6:
ggml_cuda_op_fused_binbcast_impl<op_add, 6>(ctx, dst);
break;
case 7:
ggml_cuda_op_fused_binbcast_impl<op_add, 7>(ctx, dst);
break;
case 8:
ggml_cuda_op_fused_binbcast_impl<op_add, 8>(ctx, dst);
break;
default:
GGML_ASSERT(false && "Unsupported n_fuse value");
}
}
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];

View File

@@ -7,3 +7,5 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);

View File

@@ -2821,9 +2821,14 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false; return false;
} }
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1]; const ggml_tensor *mul = cgraph->nodes[node_idx+1];
const ggml_tensor *add = nullptr;
if (ops.size() == 3 && ops.begin()[2] == GGML_OP_ADD) {
add = cgraph->nodes[node_idx+1];
}
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
@@ -2835,6 +2840,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false; return false;
} }
if (add && (add->src[0]->type != GGML_TYPE_F32 ||
add->src[1]->type != GGML_TYPE_F32 ||
add->type != GGML_TYPE_F32) ) {
return false;
}
//if rms norm is the B operand, then we don't handle broadcast //if rms norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
return false; return false;
@@ -2845,6 +2856,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false; return false;
} }
if (add && (!ggml_is_contiguous(add->src[0]) || !ggml_is_contiguous_rows(add->src[1]))) {
return false;
}
return true; return true;
} }
@@ -2891,6 +2906,45 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) { if (!disable_fusion) {
if (node->op == GGML_OP_ADD) {
int n_fuse = 0;
ggml_op ops[8];
std::fill(ops, ops + 8, GGML_OP_ADD);
for (; n_fuse <= 6; ++n_fuse){
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
break;
}
if (cgraph->nodes[i + n_fuse] != cgraph->nodes[i + n_fuse + 1]->src[0]) {
break;
}
if (!ggml_are_same_layout(cgraph->nodes[i + n_fuse]->src[1], cgraph->nodes[i + n_fuse + 1]->src[1])) {
break;
}
}
n_fuse++;
if (n_fuse > 1) {
for (int j = 0; j < n_fuse - 1; ++j) {
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
}
cgraph->nodes[i + n_fuse - 1]->data = node->data;
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
i += n_fuse - 1;
continue;
}
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]);
i += 2;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) { if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++; i++;

View File

@@ -104,12 +104,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
} }
} }
template <int block_size, bool do_multiply = false> template <int block_size, bool do_multiply = false, bool do_add = false>
static __global__ void rms_norm_f32( static __global__ void rms_norm_f32(const float * x, float * dst,
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, const int ncols,
const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0, const int64_t stride_row,
const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0, const int64_t stride_channel,
const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) { const int64_t stride_sample,
const float eps,
const float * mul = nullptr,
const int64_t mul_stride_row = 0,
const int64_t mul_stride_channel = 0,
const int64_t mul_stride_sample = 0,
const int mul_ncols = 0,
const int mul_nrows = 0,
const int mul_nchannels = 0,
const int mul_nsamples = 0,
const float * add = nullptr,
const int64_t add_stride_row = 0,
const int64_t add_stride_channel = 0,
const int64_t add_stride_sample = 0,
const int add_ncols = 0,
const int add_nrows = 0,
const int add_nchannels = 0,
const int add_nsamples = 0) {
const int nrows = gridDim.x; const int nrows = gridDim.x;
const int nchannels = gridDim.y; const int nchannels = gridDim.y;
@@ -128,6 +145,13 @@ static __global__ void rms_norm_f32(
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row; mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
} }
if constexpr (do_add) {
const int add_row = row % add_nrows;
const int add_channel = channel % add_nchannels;
const int add_sample = sample % add_nsamples;
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
}
float tmp = 0.0f; // partial sum for thread in warp float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
@@ -154,9 +178,16 @@ static __global__ void rms_norm_f32(
const float scale = rsqrtf(mean + eps); const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
if constexpr (do_multiply) { if constexpr (do_multiply && do_add) {
const int mul_col = col % mul_ncols;
const int add_col = col % add_ncols;
dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
} else if constexpr (do_multiply) {
const int mul_col = col % mul_ncols; const int mul_col = col % mul_ncols;
dst[col] = scale * x[col] * mul[mul_col]; dst[col] = scale * x[col] * mul[mul_col];
} else if constexpr (do_add) {
const int add_col = col % add_ncols;
dst[col] += add[add_col];
} else { } else {
dst[col] = scale * x[col]; dst[col] = scale * x[col];
} }
@@ -331,23 +362,70 @@ static void rms_norm_f32_cuda(
} }
} }
static void rms_norm_mul_f32_cuda( static void rms_norm_mul_f32_cuda(const float * x,
const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, const float * mul,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float * add,
const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample, float * dst,
const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples, const int ncols,
const float eps, cudaStream_t stream) { const int nrows,
const int nchannels,
const int nsamples,
const int64_t stride_row,
const int64_t stride_channel,
const int64_t stride_sample,
const int64_t mul_stride_row,
const int64_t mul_stride_channel,
const int64_t mul_stride_sample,
const int mul_ncols,
const int mul_nrows,
const int mul_nchannels,
const int mul_nsamples,
const int64_t add_stride_row,
const int64_t add_stride_channel,
const int64_t add_stride_sample,
const int add_ncols,
const int add_nrows,
const int add_nchannels,
const int add_nsamples,
const float eps,
cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples); const dim3 blocks_num(nrows, nchannels, nsamples);
if (mul == nullptr) { if (mul == nullptr) {
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return; return;
} }
if (add == nullptr) {
if (ncols < 1024) { if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
} else { } else {
const dim3 block_dims(1024, 1, 1); const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
}
} else {
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
add, add_stride_row, add_stride_channel, add_stride_sample,
add_ncols, add_nrows, add_nchannels, add_nsamples);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
add, add_stride_row, add_stride_channel, add_stride_sample,
add_ncols, add_nrows, add_nchannels, add_nsamples);
}
} }
} }
@@ -491,7 +569,102 @@ void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor *
const int mul_nchannels = mul_src->ne[2]; const int mul_nchannels = mul_src->ne[2];
const int mul_nsamples = mul_src->ne[3]; const int mul_nsamples = mul_src->ne[3];
rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream); rms_norm_mul_f32_cuda(src0_d, mul_d, nullptr, dst_d,
ne00, ne01, ne02, ne03,
/*s00*/ s01, s02, s03,
/*mul_s00*/ mul_s01, mul_s02, mul_s03,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
/*add_s00*/ 0, 0, 0,
0, 0, 0, 0,
eps, stream);
}
void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
ggml_tensor * dst,
ggml_tensor * mul_tensor,
ggml_tensor * add_tensor) {
const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0];
float eps = 0.0f;
memcpy(&eps, dst->op_params, sizeof(float));
const float * src0_d = (const float *) rms_norm_src->data;
const float * mul_d = nullptr;
const ggml_tensor * mul_src = nullptr;
if (mul_tensor->src[0] == dst) {
mul_d = (float *) mul_tensor->src[1]->data;
mul_src = mul_tensor->src[1];
} else if (mul_tensor->src[1] == dst) {
mul_d = (float *) mul_tensor->src[0]->data;
mul_src = mul_tensor->src[0];
} else {
GGML_ASSERT(false);
}
const float * add_d = nullptr;
const ggml_tensor * add_src = nullptr;
if (add_tensor->src[0] == mul_tensor) {
add_d = (float *) add_tensor->src[1]->data;
add_src = add_tensor->src[1];
} else if (add_tensor->src[1] == mul_tensor) {
add_d = (float *) add_tensor->src[0]->data;
add_src = add_tensor->src[0];
} else {
GGML_ASSERT(false);
}
float * dst_d = (float *) add_tensor->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32);
GGML_ASSERT(add_tensor->type == GGML_TYPE_F32);
GGML_ASSERT(eps >= 0.0f);
const int64_t ne00 = rms_norm_src->ne[0];
const int64_t ne01 = rms_norm_src->ne[1];
const int64_t ne02 = rms_norm_src->ne[2];
const int64_t ne03 = rms_norm_src->ne[3];
const size_t ts0 = ggml_type_size(rms_norm_src->type);
GGML_ASSERT(rms_norm_src->nb[0] == ts0);
const int64_t s01 = rms_norm_src->nb[1] / ts0;
const int64_t s02 = rms_norm_src->nb[2] / ts0;
const int64_t s03 = rms_norm_src->nb[3] / ts0;
const size_t ts_mul = ggml_type_size(mul_src->type);
GGML_ASSERT(mul_src->nb[0] == ts_mul);
const int64_t mul_s01 = mul_src->nb[1] / ts_mul;
const int64_t mul_s02 = mul_src->nb[2] / ts_mul;
const int64_t mul_s03 = mul_src->nb[3] / ts_mul;
const int mul_ncols = mul_src->ne[0];
const int mul_nrows = mul_src->ne[1];
const int mul_nchannels = mul_src->ne[2];
const int mul_nsamples = mul_src->ne[3];
const size_t ts_add = ggml_type_size(add_src->type);
GGML_ASSERT(add_src->nb[0] == ts_add);
const int64_t add_s01 = add_src->nb[1] / ts_add;
const int64_t add_s02 = add_src->nb[2] / ts_add;
const int64_t add_s03 = add_src->nb[3] / ts_add;
const int add_ncols = add_src->ne[0];
const int add_nrows = add_src->ne[1];
const int add_nchannels = add_src->ne[2];
const int add_nsamples = add_src->ne[3];
rms_norm_mul_f32_cuda(src0_d, mul_d,add_d,dst_d,
ne00,ne01, ne02, ne03,
/*s00*/ s01, s02, s03,
/*mul_s00*/ mul_s01, mul_s02, mul_s03,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
/*add_s00*/ add_s01, add_s02, add_s03,
add_ncols, add_nrows, add_nchannels, add_nsamples,
eps, stream);
} }
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@@ -8,6 +8,11 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor); void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor);
void ggml_cuda_op_rms_norm_fused_add(ggml_backend_cuda_context & ctx,
ggml_tensor * dst,
ggml_tensor * mul_tensor,
ggml_tensor * add_tensor);
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);