vulkan: Handle src1 batch dimension in non-contiguous mat-vec-mul shader (#13191)

* vulkan: Handle src1 batch dimension in non-contiguous mat-vec-mul shader
This commit is contained in:
Jeff Bolz
2025-05-01 13:19:31 -05:00
committed by GitHub
parent b0ecbd434b
commit fc727bcdd5
2 changed files with 11 additions and 5 deletions

View File

@@ -21,7 +21,9 @@ layout (push_constant) uniform parameter
uint nrows_x;
uint row_stride_x;
uint channel_stride_x;
uint channel_stride_y;
uint channel_x_divisor;
uint ne12;
uint b_offset;
uint d_offset;
} p;
@@ -33,6 +35,7 @@ void main() {
const uint row_x = gl_GlobalInvocationID.y;
const uint channel = gl_GlobalInvocationID.z;
const uint channel_x = channel / p.channel_x_divisor;
const uint channel_y = channel % p.ne12;
const uint nrows_y = p.ncols_x;
const uint nrows_dst = p.nrows_x;
@@ -56,7 +59,7 @@ void main() {
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel*nrows_y + row_y;
const uint iy = channel_y*p.channel_stride_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@@ -72,7 +75,7 @@ void main() {
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel*nrows_y + row_y;
const uint iy = channel_y*p.channel_stride_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
@@ -89,7 +92,7 @@ void main() {
const uint row_y = col_x;
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = channel*nrows_y + row_y;
const uint iy = channel_y*p.channel_stride_y + row_y;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);