vulkan: optimize mul_mat for small values of N (#10991)

Make the mul_mat_vec shaders support N>1 (as a spec constant, NUM_COLS) where
the batch_strides are overloaded to hold the row strides. Put the loads from the
B matrix in the innermost loop because it should cache better.

Share some code for reducing the result values to memory in mul_mat_vec_base.
This commit is contained in:
Jeff Bolz
2024-12-30 11:27:11 -06:00
committed by GitHub
parent c250ecb315
commit 716bd6dec3
9 changed files with 288 additions and 349 deletions

View File

@@ -83,3 +83,36 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
batch_idx * p.batch_stride_d;
#endif
}
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;
layout (constant_id = 2) const uint NUM_COLS = 1;
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
// sum up partial sums and write back result
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[j][n][tid] = temp[j][n];
}
}
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[j][n][tid] += tmpsh[j][n][tid + s];
}
}
}
barrier();
}
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
}
}
}
}