vulkan : support ggml_mean (#15393)

* vulkan : support ggml_mean

* vulkan : support sum, sum_rows and mean with non-contiguous tensors

* vulkan : fix subbuffer size not accounting for misalign offset

* tests : add backend-op tests for non-contiguous sum_rows

* cuda : require contiguous src for SUM_ROWS, MEAN support
* sycl : require contiguous src for SUM, SUM_ROWS, ARGSORT support

* require ggml_contiguous_rows in supports_op and expect nb00=1 in the shader
This commit is contained in:
Acly
2025-08-23 08:35:21 +02:00
committed by GitHub
parent 330c3d2d21
commit 0a9b43e507
5 changed files with 135 additions and 18 deletions

View File

@@ -1,9 +1,9 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
@@ -11,16 +11,49 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (push_constant) uniform parameter
{
uint n_cols;
uint ne01, ne02;
uint nb01, nb02, nb03;
uint nb11, nb12, nb13;
float weight;
uint misalign_offsets;
uint ne0_12mp, ne0_12L;
uint ne0_1mp, ne0_1L;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
// see init_fastdiv_values in ggml-vulkan.cpp
uint fastdiv(uint n, uint mp, uint L) {
uint msbs, lsbs;
// msbs = mulhi(n, mp)
umulExtended(n, mp, msbs, lsbs);
return (msbs + n) >> L;
}
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint col = gl_LocalInvocationID.x;
const float weight = p.weight;
tmp[col] = FLOAT_TYPE(0.0f);
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
const uint i03_offset = i03 * p.ne01*p.ne02;
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
const uint i01 = row - i03_offset - i02*p.ne01;
for (uint i = col; i < p.KX; i += BLOCK_SIZE) {
tmp[col] += FLOAT_TYPE(data_a[row*p.KX + i]);
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
tmp[col] = FLOAT_TYPE(0.0);
for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) {
tmp[col] += FLOAT_TYPE(data_a[src_idx + i]);
}
barrier();
@@ -32,6 +65,6 @@ void main() {
}
if (col == 0) {
data_d[row] = D_TYPE(tmp[0]);
data_d[dst_idx] = D_TYPE(tmp[0] * weight);
}
}