#version 450 #extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_nonuniform_qualifier : enable #extension GL_EXT_control_flow_attributes : require #include "rte.comp" #include "types.comp" #include "utils.comp" layout (push_constant) uniform parameter2 { // shape for dst uint ne20; uint ne21; uint ne22; uint ne23; // strides for srcs+dst uint nb[8][4]; } p; layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[]; layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[]; layout(constant_id = 0) const uint num_srcs = 2; uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) { return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0]; } uint dst_idx(uint i00, uint i01, uint i02, uint i03) { uint nb20 = p.nb[num_srcs][0]; uint nb21 = p.nb[num_srcs][1]; uint nb22 = p.nb[num_srcs][2]; uint nb23 = p.nb[num_srcs][3]; return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20; } uint get_idx() { return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; } const uint num_threads = 256; layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; void main() { uint idx = get_idx(); uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23; // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation const uint num_iter = 2; [[unroll]] for (uint i = 0; i < num_iter; ++i) { if (idx >= ne) { continue; } uint i00, i01, i02, i03; get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23); FLOAT_TYPE sum = FLOAT_TYPE(0); [[unroll]] for (uint s = 0; s < num_srcs; ++s) { sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]); } d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum); idx += num_threads; } }