Files
llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp
Ruben Ortlam 02c1813517 Vulkan: Add Integer Dot Product mul_mat_vec shader for legacy quants (#14903)
* vulkan: Add Integer Dot Product mul_mat_vec shader for legacy quants

* vulkan: use subgroup operations for quantize_q8_1 shader

* vulkan: add q8_1_x4 type with 128-bit alignment, use in mul_mat_vecq shader

* vulkan: use q8_1_x4 blocks in mul_mmq shader

* vulkan: do 8 calculations per invocation instead of 32 in mul_mat_vecq, similar to mul_mat_vec

* vulkan: tune mul_mat_vecq performance for Intel

* vulkan: fix quantizing issue when tensor is not divisible by 128

* vulkan: adapt integer dot mmv to mmv small m optimization (#15355)

* vulkan: allow all subgroup modes for mmv and mmvq

* vulkan: use prealloc intermediate reuse for mmvq path

* vulkan: tune mmvq for Intel, AMD GCN and Nvidia RTX 3090

* vulkan: adapt mmv quantize_y path to conditional sync logic

* vulkan: disable q8_0 mmvq on Nvidia

* vulkan: enable q8_0 on Nvidia pre-turing

* fix prealloc sync condition

* fix llvmpipe subgroup 8 issue
2025-09-01 16:19:07 +02:00

141 lines
4.4 KiB
Plaintext

#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_integer_dot_product : require
#define MMQ
#define B_TYPE block_q8_1_x4
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#define K_PER_ITER 8
#include "mul_mmq_funcs.comp"
uint a_offset, b_offset, d_offset;
int32_t cache_b_qs[2];
vec2 cache_b_ds;
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
// Preload data_b block
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
const uint b_qs_idx = tid % 4;
const uint b_block_idx_outer = b_block_idx / 4;
const uint b_block_idx_inner = b_block_idx % 4;
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
#if QUANT_R == 2
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
#else
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
#endif
uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
ibi += p.ncols;
int32_t q_sum = 0;
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
q_sum += dotPacked4x8EXT(data_a_qs.x,
cache_b_qs[0]);
q_sum += dotPacked4x8EXT(data_a_qs.y,
cache_b_qs[1]);
#else
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[0]);
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[1]);
#endif
#if QUANT_AUXF == 1
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
#else
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
#endif
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K;
b_offset /= QUANT_K_Q8_1;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
temp[j][n] = FLOAT_TYPE(0.0f);
}
}
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
num_iters++;
}
int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);
uint i = 0;
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
}
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
}
while (i < num_iters) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}