mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-12 10:47:01 +00:00
* vulkan: fuse adds Fuse adds that have the same shape, which are common in MoE models. It will currently fuse up to 6 adds, because we assume no more than 8 descriptors per dispatch. But this could be changed. * check runtimeDescriptorArray feature * disable multi_add for Intel due to likely driver bug
69 lines
1.9 KiB
Plaintext
69 lines
1.9 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_shader_16bit_storage : require
|
|
#extension GL_EXT_nonuniform_qualifier : enable
|
|
#extension GL_EXT_control_flow_attributes : require
|
|
|
|
#include "rte.comp"
|
|
#include "types.comp"
|
|
#include "utils.comp"
|
|
|
|
layout (push_constant) uniform parameter2
|
|
{
|
|
// shape for dst
|
|
uint ne20; uint ne21; uint ne22; uint ne23;
|
|
|
|
// strides for srcs+dst
|
|
uint nb[8][4];
|
|
} p;
|
|
|
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
|
|
layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
|
|
|
|
layout(constant_id = 0) const uint num_srcs = 2;
|
|
|
|
uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
|
|
return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
|
|
}
|
|
|
|
uint dst_idx(uint i00, uint i01, uint i02, uint i03) {
|
|
uint nb20 = p.nb[num_srcs][0];
|
|
uint nb21 = p.nb[num_srcs][1];
|
|
uint nb22 = p.nb[num_srcs][2];
|
|
uint nb23 = p.nb[num_srcs][3];
|
|
return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20;
|
|
}
|
|
|
|
uint get_idx() {
|
|
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
|
}
|
|
|
|
const uint num_threads = 256;
|
|
|
|
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
void main() {
|
|
uint idx = get_idx();
|
|
|
|
uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23;
|
|
|
|
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
|
|
const uint num_iter = 2;
|
|
|
|
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
|
if (idx >= ne) {
|
|
continue;
|
|
}
|
|
uint i00, i01, i02, i03;
|
|
get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23);
|
|
|
|
FLOAT_TYPE sum = FLOAT_TYPE(0);
|
|
[[unroll]] for (uint s = 0; s < num_srcs; ++s) {
|
|
sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
|
|
}
|
|
d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
|
|
|
|
idx += num_threads;
|
|
}
|
|
}
|