mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	mat vec double buffer (#12188)
This commit is contained in:
		@@ -5,23 +5,24 @@
 | 
			
		||||
 | 
			
		||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 | 
			
		||||
 | 
			
		||||
shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16];
 | 
			
		||||
shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16];
 | 
			
		||||
shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16];
 | 
			
		||||
shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16];
 | 
			
		||||
 | 
			
		||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
 | 
			
		||||
uint csel = 0;
 | 
			
		||||
 | 
			
		||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
 | 
			
		||||
    const uint y_idx = i * QUANT_K + y_offset;
 | 
			
		||||
 | 
			
		||||
    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
 | 
			
		||||
        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
 | 
			
		||||
        csel ^= 1;
 | 
			
		||||
 | 
			
		||||
        barrier();
 | 
			
		||||
        if (!all_threads) { // when we don't have enough blocks to use all threads
 | 
			
		||||
            if (i < num_blocks_per_row) {
 | 
			
		||||
                const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
 | 
			
		||||
                sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
 | 
			
		||||
                sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
 | 
			
		||||
                sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
 | 
			
		||||
                sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
 | 
			
		||||
            }
 | 
			
		||||
            barrier();
 | 
			
		||||
 | 
			
		||||
@@ -29,8 +30,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
 | 
			
		||||
                continue;
 | 
			
		||||
        } else {
 | 
			
		||||
            const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
 | 
			
		||||
            sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
 | 
			
		||||
            sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
 | 
			
		||||
            sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
 | 
			
		||||
            sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
 | 
			
		||||
            barrier();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -57,22 +58,22 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
 | 
			
		||||
            FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
 | 
			
		||||
            FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
 | 
			
		||||
            [[unroll]] for (int l = 0; l < 2; ++l) {
 | 
			
		||||
                sum1 = fma(FLOAT_TYPE(b0[l]),   sccache1[ix][    8*v_im] * qs_u32_0[l  ],
 | 
			
		||||
                       fma(FLOAT_TYPE(b16[l]),  sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2],
 | 
			
		||||
                       fma(FLOAT_TYPE(b32[l]),  sccache1[ix][2 + 8*v_im] * qs_u32_2[l  ],
 | 
			
		||||
                       fma(FLOAT_TYPE(b48[l]),  sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2],
 | 
			
		||||
                       fma(FLOAT_TYPE(b64[l]),  sccache1[ix][4 + 8*v_im] * qs_u32_4[l  ],
 | 
			
		||||
                       fma(FLOAT_TYPE(b80[l]),  sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2],
 | 
			
		||||
                       fma(FLOAT_TYPE(b96[l]),  sccache1[ix][6 + 8*v_im] * qs_u32_6[l  ],
 | 
			
		||||
                       fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
 | 
			
		||||
                sum2 = fma(FLOAT_TYPE(b0[l]),   sccache2[ix][    8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b16[l]),  sccache2[ix][1 + 8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b32[l]),  sccache2[ix][2 + 8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b48[l]),  sccache2[ix][3 + 8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b64[l]),  sccache2[ix][4 + 8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b80[l]),  sccache2[ix][5 + 8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b96[l]),  sccache2[ix][6 + 8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2))))))));
 | 
			
		||||
                sum1 = fma(FLOAT_TYPE(b0[l]),   sccache1[csel][ix][    8*v_im] * qs_u32_0[l  ],
 | 
			
		||||
                       fma(FLOAT_TYPE(b16[l]),  sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2],
 | 
			
		||||
                       fma(FLOAT_TYPE(b32[l]),  sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l  ],
 | 
			
		||||
                       fma(FLOAT_TYPE(b48[l]),  sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2],
 | 
			
		||||
                       fma(FLOAT_TYPE(b64[l]),  sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l  ],
 | 
			
		||||
                       fma(FLOAT_TYPE(b80[l]),  sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2],
 | 
			
		||||
                       fma(FLOAT_TYPE(b96[l]),  sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l  ],
 | 
			
		||||
                       fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
 | 
			
		||||
                sum2 = fma(FLOAT_TYPE(b0[l]),   sccache2[csel][ix][    8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b16[l]),  sccache2[csel][ix][1 + 8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b32[l]),  sccache2[csel][ix][2 + 8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b48[l]),  sccache2[csel][ix][3 + 8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b64[l]),  sccache2[csel][ix][4 + 8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b80[l]),  sccache2[csel][ix][5 + 8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b96[l]),  sccache2[csel][ix][6 + 8*v_im],
 | 
			
		||||
                       fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
 | 
			
		||||
            }
 | 
			
		||||
            temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
 | 
			
		||||
        }
 | 
			
		||||
 
 | 
			
		||||
@@ -5,20 +5,21 @@
 | 
			
		||||
 | 
			
		||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 | 
			
		||||
 | 
			
		||||
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8];
 | 
			
		||||
shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8];
 | 
			
		||||
 | 
			
		||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
 | 
			
		||||
uint csel = 0;
 | 
			
		||||
 | 
			
		||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
 | 
			
		||||
    const uint y_idx = i * QUANT_K + y_offset;
 | 
			
		||||
 | 
			
		||||
    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
 | 
			
		||||
        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
 | 
			
		||||
        csel ^= 1;
 | 
			
		||||
 | 
			
		||||
        if (!all_threads) { // when we don't have enough blocks to use all threads
 | 
			
		||||
            barrier();
 | 
			
		||||
            if (i < num_blocks_per_row)
 | 
			
		||||
                sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
 | 
			
		||||
                sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
 | 
			
		||||
            barrier();
 | 
			
		||||
 | 
			
		||||
            if (i >= num_blocks_per_row)
 | 
			
		||||
@@ -40,8 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
 | 
			
		||||
        const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
 | 
			
		||||
 | 
			
		||||
        if (all_threads) {
 | 
			
		||||
            barrier();
 | 
			
		||||
            sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
 | 
			
		||||
            sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
 | 
			
		||||
            barrier();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -59,14 +59,14 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
 | 
			
		||||
 | 
			
		||||
            FLOAT_TYPE sum = FLOAT_TYPE(0.0);
 | 
			
		||||
            [[unroll]] for (int l = 0; l < 2; ++l) {
 | 
			
		||||
                sum = fma(FLOAT_TYPE(  b0[l]) * sccache[ix][v_im][0], qs_u32_0[l  ] - hmk_0[l  ],
 | 
			
		||||
                      fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
 | 
			
		||||
                      fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l  ] - hmk_1[l  ],
 | 
			
		||||
                      fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
 | 
			
		||||
                      fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l  ] - hmk_2[l  ],
 | 
			
		||||
                      fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
 | 
			
		||||
                      fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l  ] - hmk_3[l  ],
 | 
			
		||||
                      fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
 | 
			
		||||
                sum = fma(FLOAT_TYPE(  b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l  ] - hmk_0[l  ],
 | 
			
		||||
                      fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
 | 
			
		||||
                      fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l  ] - hmk_1[l  ],
 | 
			
		||||
                      fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
 | 
			
		||||
                      fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l  ] - hmk_2[l  ],
 | 
			
		||||
                      fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
 | 
			
		||||
                      fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l  ] - hmk_3[l  ],
 | 
			
		||||
                      fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
 | 
			
		||||
            }
 | 
			
		||||
            temp[j][n] = fma(d, sum, temp[j][n]);
 | 
			
		||||
        }
 | 
			
		||||
 
 | 
			
		||||
@@ -6,20 +6,21 @@
 | 
			
		||||
 | 
			
		||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 | 
			
		||||
 | 
			
		||||
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
 | 
			
		||||
shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16];
 | 
			
		||||
 | 
			
		||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
 | 
			
		||||
uint csel = 0;
 | 
			
		||||
 | 
			
		||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
 | 
			
		||||
    const uint y_idx = i * QUANT_K + y_offset;
 | 
			
		||||
 | 
			
		||||
    [[unroll]] for (uint n = 0; n < num_rows; ++n) {
 | 
			
		||||
        const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
 | 
			
		||||
        csel ^= 1;
 | 
			
		||||
 | 
			
		||||
        if (!all_threads) { // when we don't have enough blocks to use all threads
 | 
			
		||||
            barrier();
 | 
			
		||||
            if (i < num_blocks_per_row)
 | 
			
		||||
                sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
 | 
			
		||||
                sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
 | 
			
		||||
            barrier();
 | 
			
		||||
 | 
			
		||||
            if (i >= num_blocks_per_row)
 | 
			
		||||
@@ -51,8 +52,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
 | 
			
		||||
        const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
 | 
			
		||||
 | 
			
		||||
        if (all_threads) {
 | 
			
		||||
            barrier();
 | 
			
		||||
            sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
 | 
			
		||||
            sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
 | 
			
		||||
            barrier();
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -71,7 +71,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
 | 
			
		||||
                sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
 | 
			
		||||
                sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
 | 
			
		||||
            }
 | 
			
		||||
            temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
 | 
			
		||||
            temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user