mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	vulkan: optimize coopmat2 dequant functions (#10855)
Change the code to do 16b loads when possible and extract the appropriate component late, so the code is effectively decoding a pair of elements and then selecting one. This can allow more commoning to happen in the compiler when neighboring elements are loaded.
This commit is contained in:
		| @@ -10,9 +10,10 @@ float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2 | ||||
|     const float16_t d = bl.block.d; | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint shift = (idx & 0x10) >> 2; | ||||
|     uint32_t qs = unpack8(uint32_t(bl.block.qs[(idx & 0xE) >> 1]))[idx & 1]; | ||||
|     uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]); | ||||
|     qs >>= shift; | ||||
|     qs &= 0xF; | ||||
|     qs &= 0x0F0F; | ||||
|     qs = unpack8(qs)[idx & 1]; | ||||
|     float16_t ret = (float16_t(qs) - float16_t(8)) * d; | ||||
|     return ret; | ||||
| } | ||||
| @@ -152,15 +153,17 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4 | ||||
|    block_q4_K block; | ||||
| }; | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 { | ||||
|    block_q4_K_packed16 block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
|  | ||||
|     const uint n = iqs / 64;                   // 0,1,2,3 | ||||
|     const uint b = (iqs % 64) / 32;            // 0,1 | ||||
|     const uint b = (idx & 0x20) >> 5;            // 0,1 | ||||
|     const uint is = (idx & 0xE0) >> 5;         // 0..7 | ||||
|     const uint qsi = n * 32 + (iqs % 32);      // 0..127 | ||||
|  | ||||
|     const f16vec2 loadd = bl.block.d; | ||||
|  | ||||
| @@ -184,9 +187,11 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2 | ||||
|     const float16_t d = loadd.x * float16_t(sc); | ||||
|     const float16_t m = loadd.y * float16_t(mbyte); | ||||
|  | ||||
|     uint32_t dmask = 0xF << (b * 4); | ||||
|     uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); | ||||
|     qs = (qs >> (b * 4)) & 0x0F0F; | ||||
|     qs = unpack8(qs)[idx & 1]; | ||||
|  | ||||
|     float16_t ret = d * float16_t((bl.block.qs[qsi    ] & dmask) >> (b * 4)) - m; | ||||
|     float16_t ret = d * float16_t(qs) - m; | ||||
|  | ||||
|     return ret; | ||||
| } | ||||
| @@ -195,18 +200,19 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5 | ||||
|    block_q5_K block; | ||||
| }; | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 { | ||||
|    block_q5_K_packed16 block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
|  | ||||
|     const uint n = iqs / 64;                   // 0,1,2,3 | ||||
|     const uint b = (iqs % 64) / 32;            // 0,1 | ||||
|     const uint b = (idx & 0x20) >> 5;          // 0,1 | ||||
|     const uint is = (idx & 0xE0) >> 5;         // 0..7 | ||||
|     const uint qsi = n * 32 + (iqs % 32);      // 0..127 | ||||
|     const uint qhi = (iqs % 32);               // 0..31 | ||||
|  | ||||
|     const uint8_t hm = uint8_t(1 << (iqs / 32)); | ||||
|     const uint32_t hm = 0x0101 << is; | ||||
|  | ||||
|     const f16vec2 loadd = bl.block.d; | ||||
|  | ||||
| @@ -230,9 +236,15 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2 | ||||
|     const float16_t d = loadd.x * float16_t(sc); | ||||
|     const float16_t m = loadd.y * float16_t(mbyte); | ||||
|  | ||||
|     uint32_t dmask = 0xF << (b * 4); | ||||
|     uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); | ||||
|     qh = qh & hm; | ||||
|     qh = unpack8(qh)[idx & 1]; | ||||
|  | ||||
|     float16_t ret = d * (float16_t((bl.block.qs[qsi    ] & dmask) >> (b * 4)) + float16_t((bl.block.qh[qhi    ] & hm) != 0 ? 16 : 0)) - m; | ||||
|     uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); | ||||
|     qs = (qs >> (b * 4)) & 0x0F0F; | ||||
|     qs = unpack8(qs)[idx & 1]; | ||||
|  | ||||
|     float16_t ret = d * (float16_t(qs) + (qh != 0 ? float16_t(16) : float16_t(0))) - m; | ||||
|  | ||||
|     return ret; | ||||
| } | ||||
| @@ -241,22 +253,30 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_ | ||||
|    block_q6_K block; | ||||
| }; | ||||
|  | ||||
| layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 { | ||||
|    block_q6_K_packed16 block; | ||||
| }; | ||||
|  | ||||
| float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) | ||||
| { | ||||
|     decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl); | ||||
|     const uint idx = coordInBlock[1]; | ||||
|     const uint iqs = idx; | ||||
|  | ||||
|     const uint n = iqs / 128;                   // 0,1 | ||||
|     const uint b = (iqs % 128) / 64;            // 0,1 | ||||
|     const uint is_b = (iqs % 32) / 16;          // 0,1 | ||||
|     const uint qhshift = ((iqs % 128) / 32) * 2;// 0,2,4,6 | ||||
|     const uint is = 8 * n + qhshift + is_b;     // 0..15 | ||||
|     const uint qsi = n * 64 + (iqs % 64);       // 0..127 | ||||
|     const uint qhi = n * 32 + (iqs % 32);       // 0..63 | ||||
|     const uint b = (idx & 0x40) >> 6;           // 0,1 | ||||
|     const uint qhshift = (idx & 0x60) >> 4;    // 0,2,4,6 | ||||
|     const uint is = (idx & 0xF0) >> 4;          // 0..15 | ||||
|  | ||||
|     const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); | ||||
|  | ||||
|     float16_t ret = dscale * float16_t(int8_t(((bl.block.ql[qsi    ] >> (b * 4)) & 0xF) | (((bl.block.qh[qhi    ] >> qhshift) & 3) << 4)) - 32); | ||||
|     uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]); | ||||
|     ql = (ql >> (b * 4)) & 0x0F0F; | ||||
|  | ||||
|     uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); | ||||
|     qh = ((qh >> qhshift) & 0x0303) << 4; | ||||
|  | ||||
|     int q = unpack8(ql | qh)[idx & 1]; | ||||
|  | ||||
|     float16_t ret = dscale * float16_t(q - 32); | ||||
|  | ||||
|     return ret; | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz