mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-12 10:47:01 +00:00
vulkan: optimize rms_norm, and allow the work to spread across multiple SMs (#15281)
* vulkan: optimize rms_norm, and allow the work to spread across multiple SMs There are really two parts to this change: (1) Some optimizations similar to what we have in soft_max, to unroll with different numbers of iterations. (2) A fusion optimization where we detect add followed by rms_norm, and make the add shader atomically accumulate the values^2 into memory. Then the rms_norm shader can just load that sum. This allows the rms_norm to be parallelized across multiple workgroups, it just becomes a simple per-element multiply. The fusion optimization is currently only applied when the rms_norm is on a single vector. This previously always ran on a single SM. It could apply more broadly, but when there are other dimensions the work can already spread across SMs, and there would be some complexity to tracking multiple atomic sums. * Change add+rms_norm optimization to write out an array of partial sums rather than using atomic add, to make it deterministic. The rms_norm shader fetches a subgroup's worth in parallel and uses subgroupAdd to add them up. * complete rebase against fused adds - multi_add shader can also compute partial sums * fix validation errors * disable add_rms_fusion for Intel due to possible driver bug * resolve against #15489, sync after clearing partial sums
This commit is contained in:
@@ -1,20 +1,34 @@
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#if ADD_RMS
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#include "types.comp"
|
||||
#include "generic_binary_head.comp"
|
||||
|
||||
const uint num_threads = 256;
|
||||
|
||||
layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];};
|
||||
|
||||
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#if ADD_RMS
|
||||
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
|
||||
shared FLOAT_TYPE sumsh[num_threads];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
uint idx = get_idx();
|
||||
uint orig_idx = idx;
|
||||
|
||||
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
||||
const uint num_iter = 2;
|
||||
|
||||
FLOAT_TYPE sum_sq = 0;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
if (idx >= p.ne) {
|
||||
continue;
|
||||
@@ -22,8 +36,34 @@ void main() {
|
||||
uint i00, i01, i02, i03;
|
||||
get_indices(idx, i00, i01, i02, i03);
|
||||
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]);
|
||||
sum_sq += sum*sum;
|
||||
|
||||
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
|
||||
#if ADD_RMS
|
||||
if (p.param3 != 0) {
|
||||
// reduce the sum within each subgroup, then across subgroups
|
||||
const uint NumSubgroups = num_threads / gl_SubgroupSize;
|
||||
sum_sq = subgroupAdd(sum_sq);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
|
||||
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
|
||||
sum_sq += sumsh[gl_SubgroupID + s];
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
|
||||
partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -3,6 +3,10 @@
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_nonuniform_qualifier : enable
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#if ADD_RMS
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
#endif
|
||||
|
||||
#include "rte.comp"
|
||||
#include "types.comp"
|
||||
@@ -14,12 +18,16 @@ layout (push_constant) uniform parameter2
|
||||
uint ne20; uint ne21; uint ne22; uint ne23;
|
||||
|
||||
// strides for srcs+dst
|
||||
uint nb[8][4];
|
||||
uint nb[12][4];
|
||||
|
||||
uint rms_partials;
|
||||
} p;
|
||||
|
||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
|
||||
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
|
||||
|
||||
layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
|
||||
|
||||
layout(constant_id = 0) const uint num_srcs = 2;
|
||||
|
||||
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
|
||||
@@ -42,14 +50,22 @@ const uint num_threads = 256;
|
||||
|
||||
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
#if ADD_RMS
|
||||
// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant
|
||||
shared FLOAT_TYPE sumsh[num_threads];
|
||||
#endif
|
||||
|
||||
void main() {
|
||||
uint idx = get_idx();
|
||||
uint orig_idx = idx;
|
||||
|
||||
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
|
||||
|
||||
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
||||
const uint num_iter = 2;
|
||||
|
||||
FLOAT_TYPE sum_sq = 0;
|
||||
|
||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||
if (idx >= ne) {
|
||||
continue;
|
||||
@@ -61,8 +77,32 @@ void main() {
|
||||
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
|
||||
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
|
||||
}
|
||||
sum_sq += sum*sum;
|
||||
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
|
||||
|
||||
idx += num_threads;
|
||||
}
|
||||
|
||||
#if ADD_RMS
|
||||
if (p.rms_partials != 0) {
|
||||
// reduce the sum within each subgroup, then across subgroups
|
||||
const uint NumSubgroups = num_threads / gl_SubgroupSize;
|
||||
sum_sq = subgroupAdd(sum_sq);
|
||||
if (gl_SubgroupInvocationID == 0) {
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
[[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) {
|
||||
if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) {
|
||||
sum_sq += sumsh[gl_SubgroupID + s];
|
||||
sumsh[gl_SubgroupID] = sum_sq;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
|
||||
if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
|
||||
partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -10,9 +10,9 @@ layout (constant_id = 1) const bool do_multiply = false;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
shared FLOAT_TYPE sum[BLOCK_SIZE];
|
||||
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
void rms_norm(uint num_iters) {
|
||||
const uint ncols = p.ne00;
|
||||
const uint nrows = gl_NumWorkGroups.x;
|
||||
const uint nchannels = gl_NumWorkGroups.y;
|
||||
@@ -30,38 +30,76 @@ void main() {
|
||||
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||
|
||||
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_offset + col]);
|
||||
sum[tid] += xi * xi;
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
FLOAT_TYPE xi = FLOAT_TYPE(0);
|
||||
if (col < ncols) {
|
||||
xi = FLOAT_TYPE(data_a[a_offset + col]);
|
||||
}
|
||||
sum += xi * xi;
|
||||
}
|
||||
|
||||
sumsh[tid] = sum;
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||
if (tid < s) {
|
||||
sum[tid] += sum[tid + s];
|
||||
sum += sumsh[tid + s];
|
||||
sumsh[tid] = sum;
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
sum = sumsh[0];
|
||||
|
||||
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
||||
|
||||
if (do_multiply) {
|
||||
if (ncols > p.ne10) {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||
[[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) {
|
||||
if (col >= ncols) {
|
||||
continue;
|
||||
}
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void main() {
|
||||
// instantiate the rms_norm function for several different
|
||||
// dimensions, to allow loop unrolling
|
||||
uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||
if (num_blocks > 32) {
|
||||
rms_norm(num_blocks);
|
||||
} else if (num_blocks > 16) {
|
||||
rms_norm(32);
|
||||
} else if (num_blocks > 8) {
|
||||
rms_norm(16);
|
||||
} else if (num_blocks > 4) {
|
||||
rms_norm(8);
|
||||
} else if (num_blocks == 4) {
|
||||
rms_norm(4);
|
||||
} else if (num_blocks == 3) {
|
||||
rms_norm(3);
|
||||
} else if (num_blocks == 2) {
|
||||
rms_norm(2);
|
||||
} else if (num_blocks == 1) {
|
||||
rms_norm(1);
|
||||
}
|
||||
}
|
||||
|
||||
65
ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp
Normal file
65
ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp
Normal file
@@ -0,0 +1,65 @@
|
||||
#version 450
|
||||
|
||||
#include "generic_binary_head.comp"
|
||||
#include "types.comp"
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_KHR_shader_subgroup_arithmetic : enable
|
||||
#extension GL_KHR_shader_subgroup_basic : enable
|
||||
|
||||
#define BLOCK_SIZE 128
|
||||
|
||||
layout (constant_id = 1) const bool do_multiply = false;
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];};
|
||||
|
||||
shared FLOAT_TYPE sumsh[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const uint ncols = p.ne00;
|
||||
const uint nrows = gl_NumWorkGroups.x;
|
||||
const uint nchannels = gl_NumWorkGroups.y;
|
||||
|
||||
const uint row = 0;
|
||||
const uint channel = gl_WorkGroupID.y;
|
||||
const uint samp = gl_WorkGroupID.z;
|
||||
// The work is split across multiple workgroups in the x dimension. Each invocation
|
||||
// processes one element
|
||||
const uint tid = gl_GlobalInvocationID.x;
|
||||
|
||||
const uint stride_row = p.nb01;
|
||||
const uint stride_channel = p.nb02;
|
||||
const uint stride_sample = p.nb03;
|
||||
|
||||
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
|
||||
uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
|
||||
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
|
||||
|
||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp
|
||||
|
||||
uint32_t num_partials = p.param3;
|
||||
for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) {
|
||||
sum += partial_sums[i];
|
||||
}
|
||||
sum = subgroupAdd(sum);
|
||||
|
||||
uint col = tid;
|
||||
if (col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols);
|
||||
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
||||
|
||||
if (do_multiply) {
|
||||
if (ncols > p.ne10) {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
|
||||
} else {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
||||
}
|
||||
} else {
|
||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||
}
|
||||
}
|
||||
@@ -503,6 +503,7 @@ void process_shaders() {
|
||||
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
@@ -538,13 +539,15 @@ void process_shaders() {
|
||||
s += std::string(dst_f16 ? "_f16" : "_f32");
|
||||
return s;
|
||||
};
|
||||
for (std::string op : {"add", "sub", "mul", "div"}) {
|
||||
for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) {
|
||||
for (auto src0_f16 : {false, true}) {
|
||||
for (auto src1_f16 : {false, true}) {
|
||||
for (auto dst_f16 : {false, true}) {
|
||||
for (auto rte : {false, true}) {
|
||||
auto source = op == "add_rms" ? std::string("add") : op;
|
||||
auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
|
||||
string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
|
||||
auto add_rms = op == "add_rms" ? "1" : "0";
|
||||
string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}});
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -687,7 +690,8 @@ void process_shaders() {
|
||||
|
||||
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
||||
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
|
||||
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}});
|
||||
string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}});
|
||||
|
||||
for (auto &c : compiles) {
|
||||
c.wait();
|
||||
@@ -745,7 +749,7 @@ void write_output_files() {
|
||||
}
|
||||
|
||||
std::string suffixes[2] = {"_f32", "_f16"};
|
||||
for (const char *op : {"add", "sub", "mul", "div"}) {
|
||||
for (const char *op : {"add", "sub", "mul", "div", "add_rms"}) {
|
||||
fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
|
||||
fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
|
||||
std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
|
||||
|
||||
Reference in New Issue
Block a user