vulkan: fuse adds (#15252)

* vulkan: fuse adds

Fuse adds that have the same shape, which are common in MoE models.
It will currently fuse up to 6 adds, because we assume no more than
8 descriptors per dispatch. But this could be changed.

* check runtimeDescriptorArray feature

* disable multi_add for Intel due to likely driver bug
This commit is contained in:
Jeff Bolz
2025-08-16 11:48:22 -05:00
committed by GitHub
parent de2192794f
commit 1fe00296f5
6 changed files with 301 additions and 25 deletions

View File

@@ -2,6 +2,7 @@
#extension GL_EXT_control_flow_attributes : require
#include "rte.comp"
#include "utils.comp"
layout (push_constant) uniform parameter
{
@@ -28,25 +29,9 @@ uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
uint get_doffset() { return p.misalign_offsets & 0xFF; }
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
uint fastmod(uint a, uint b) {
if ((b & (b-1)) == 0) {
return a & (b-1);
}
return a % b;
}
uint fastdiv(uint a, uint b) {
return (a < b) ? 0 : (a / b);
}
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) {
i03 = fastdiv(idx, (p.ne02*p.ne01*p.ne00));
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
i02 = fastdiv((idx - i03_offset), (p.ne01*p.ne00));
const uint i02_offset = i02*p.ne01*p.ne00;
i01 = (idx - i03_offset - i02_offset) / p.ne00;
i00 = idx - i03_offset - i02_offset - i01*p.ne00;
get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03);
}
uint src0_idx(uint i00, uint i01, uint i02, uint i03) {

View File

@@ -0,0 +1,68 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_nonuniform_qualifier : enable
#extension GL_EXT_control_flow_attributes : require
#include "rte.comp"
#include "types.comp"
#include "utils.comp"
layout (push_constant) uniform parameter2
{
// shape for dst
uint ne20; uint ne21; uint ne22; uint ne23;
// strides for srcs+dst
uint nb[8][4];
} p;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
layout(constant_id = 0) const uint num_srcs = 2;
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
}
uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
uint nb20 = p.nb[num_srcs][0];
uint nb21 = p.nb[num_srcs][1];
uint nb22 = p.nb[num_srcs][2];
uint nb23 = p.nb[num_srcs][3];
return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20;
}
uint get_idx() {
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
}
const uint num_threads = 256;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() {
uint idx = get_idx();
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23);
FLOAT_TYPE sum = FLOAT_TYPE(0);
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
}
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
idx += num_threads;
}
}

View File

@@ -0,0 +1,25 @@
#ifndef UTILS_COMP
#define UTILS_COMP
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
uint fastmod(uint a, uint b) {
if ((b & (b-1)) == 0) {
return a & (b-1);
}
return a % b;
}
uint fastdiv(uint a, uint b) {
return (a < b) ? 0 : (a / b);
}
void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03, uint ne00, uint ne01, uint ne02, uint ne03) {
i03 = fastdiv(idx, (ne02*ne01*ne00));
const uint i03_offset = i03 * ne02*ne01*ne00;
i02 = fastdiv((idx - i03_offset), (ne01*ne00));
const uint i02_offset = i02*ne01*ne00;
i01 = (idx - i03_offset - i02_offset) / ne00;
i00 = idx - i03_offset - i02_offset - i01*ne00;
}
#endif // UTILS_COMP

View File

@@ -677,6 +677,8 @@ void process_shaders() {
string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
for (auto &c : compiles) {
c.wait();
}