mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	vulkan: Optimize some mat-vec mul quant shaders (#10296)
Compute two result elements per workgroup (for Q{4,5}_{0,1}). This reuses
the B loads across the rows and also reuses some addressing calculations.
This required manually partially unrolling the loop, since the compiler
is less willing to unroll outer loops.
Add bounds-checking on the last iteration of the loop. I think this was at
least partly broken before.
Optimize the Q4_K shader to vectorize most loads and reduce the number of
bit twiddling instructions.
			
			
This commit is contained in:
		@@ -3,54 +3,107 @@
 | 
			
		||||
#ifdef FLOAT16
 | 
			
		||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
 | 
			
		||||
#endif
 | 
			
		||||
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
 | 
			
		||||
 | 
			
		||||
#extension GL_EXT_null_initializer : enable
 | 
			
		||||
 | 
			
		||||
#include "mul_mat_vec_base.comp"
 | 
			
		||||
 | 
			
		||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 | 
			
		||||
 | 
			
		||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
 | 
			
		||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
 | 
			
		||||
 | 
			
		||||
shared FLOAT_TYPE tmp[BLOCK_SIZE];
 | 
			
		||||
uint a_offset, b_offset, d_offset, y_offset;
 | 
			
		||||
 | 
			
		||||
void main() {
 | 
			
		||||
    const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
 | 
			
		||||
    const uint tid = gl_LocalInvocationID.x;
 | 
			
		||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
 | 
			
		||||
 | 
			
		||||
    // There are not enough cols to use all threads
 | 
			
		||||
    if (tid >= p.ncols) {
 | 
			
		||||
        return;
 | 
			
		||||
void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
 | 
			
		||||
{
 | 
			
		||||
    const uint col = i*BLOCK_SIZE + 2*tid;
 | 
			
		||||
    const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
 | 
			
		||||
    const uint iybs = col - col%QUANT_K; // y block start index
 | 
			
		||||
 | 
			
		||||
    // Check if the second of the pair of elements is OOB, and don't fetch B or
 | 
			
		||||
    // accumulate it. We still fetch a pair of elements for A, which is fine for
 | 
			
		||||
    // quantized formats since they'll be within the same block. We should
 | 
			
		||||
    // probably skip fetching the second element for F16/F32, but as of now we
 | 
			
		||||
    // still do.
 | 
			
		||||
    const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
 | 
			
		||||
 | 
			
		||||
    FLOAT_TYPE b0 = 0, b1 = 0;
 | 
			
		||||
    b0 = FLOAT_TYPE(data_b[b_offset + iybs + iqs]);
 | 
			
		||||
    if (!OOB) {
 | 
			
		||||
        b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
 | 
			
		||||
    }
 | 
			
		||||
    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
 | 
			
		||||
        const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index
 | 
			
		||||
 | 
			
		||||
    const uint block_size = min(p.ncols, BLOCK_SIZE);
 | 
			
		||||
 | 
			
		||||
    uint a_offset, b_offset, d_offset;
 | 
			
		||||
    get_offsets(a_offset, b_offset, d_offset);
 | 
			
		||||
 | 
			
		||||
    const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
 | 
			
		||||
 | 
			
		||||
    tmp[tid] = FLOAT_TYPE(0.0f);
 | 
			
		||||
 | 
			
		||||
    [[unroll]] for (uint i = 0; i < p.ncols/block_size; i += 2) {
 | 
			
		||||
        const uint col = i*block_size + 2*tid;
 | 
			
		||||
        const uint ib = (row*p.ncols + col)/QUANT_K; // block index
 | 
			
		||||
        const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
 | 
			
		||||
        const uint iybs = col - col%QUANT_K; // y block start index
 | 
			
		||||
 | 
			
		||||
        vec2 v = dequantize(ib, iqs, a_offset / QUANT_K);
 | 
			
		||||
        const vec2 v = dequantize(ib, iqs, a_offset);
 | 
			
		||||
 | 
			
		||||
        // matrix multiplication
 | 
			
		||||
        tmp[tid] = fma(FLOAT_TYPE(v.x), FLOAT_TYPE(data_b[b_offset + iybs + iqs]), fma(FLOAT_TYPE(v.y), FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]), tmp[tid]));
 | 
			
		||||
        temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
 | 
			
		||||
        if (!OOB) {
 | 
			
		||||
            temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
 | 
			
		||||
    const uint tid = gl_LocalInvocationID.x;
 | 
			
		||||
 | 
			
		||||
    get_offsets(a_offset, b_offset, d_offset);
 | 
			
		||||
    a_offset /= QUANT_K;
 | 
			
		||||
 | 
			
		||||
    y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
 | 
			
		||||
 | 
			
		||||
    FLOAT_TYPE temp[NUM_ROWS] = {};
 | 
			
		||||
 | 
			
		||||
    const int unroll_count = 8;
 | 
			
		||||
 | 
			
		||||
    const uint num_iters = (p.ncols >= 2*tid) ? ((p.ncols - 2*tid + BLOCK_SIZE - 1) / BLOCK_SIZE) : 0;
 | 
			
		||||
    const uint unrolled_iters = num_iters & ~(2*unroll_count - 1);
 | 
			
		||||
 | 
			
		||||
    uint i = 0;
 | 
			
		||||
    while (i < unrolled_iters) {
 | 
			
		||||
        // Manually partially unroll the loop
 | 
			
		||||
        [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
 | 
			
		||||
            iter(temp, first_row, num_rows, tid, i, false);
 | 
			
		||||
            i += 2;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    while (i < num_iters) {
 | 
			
		||||
        iter(temp, first_row, num_rows, tid, i, true);
 | 
			
		||||
        i += 2;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // sum up partial sums and write back result
 | 
			
		||||
    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
 | 
			
		||||
        tmpsh[n][tid] = temp[n];
 | 
			
		||||
    }
 | 
			
		||||
    barrier();
 | 
			
		||||
    [[unroll]] for (uint s = block_size/2; s > 0; s >>= 1) {
 | 
			
		||||
    [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
 | 
			
		||||
        if (tid < s) {
 | 
			
		||||
            tmp[tid] += tmp[tid + s];
 | 
			
		||||
            [[unroll]] for (uint n = 0; n < num_rows; ++n) {
 | 
			
		||||
                tmpsh[n][tid] += tmpsh[n][tid + s];
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        barrier();
 | 
			
		||||
    }
 | 
			
		||||
    if (tid == 0) {
 | 
			
		||||
        data_d[d_offset + row] = D_TYPE(tmp[0]);
 | 
			
		||||
        [[unroll]] for (uint n = 0; n < num_rows; ++n) {
 | 
			
		||||
            data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void main() {
 | 
			
		||||
    const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
 | 
			
		||||
 | 
			
		||||
    // do NUM_ROWS at a time, unless there aren't enough remaining rows
 | 
			
		||||
    if (first_row + NUM_ROWS <= p.stride_d) {
 | 
			
		||||
        compute_outputs(first_row, NUM_ROWS);
 | 
			
		||||
    } else {
 | 
			
		||||
        compute_outputs(first_row, p.stride_d - first_row);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user