mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	metal : optimize dequant q6_K kernel (#11892)
This commit is contained in:
		@@ -373,24 +373,33 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
 | 
			
		||||
template <typename type4x4>
 | 
			
		||||
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
 | 
			
		||||
    const half d_all = xb->d;
 | 
			
		||||
    device const uint8_t * ql = (device const uint8_t *)xb->ql;
 | 
			
		||||
    device const uint8_t * qh = (device const uint8_t *)xb->qh;
 | 
			
		||||
    device const uint16_t * ql = (device const uint16_t *)xb->ql;
 | 
			
		||||
    device const uint16_t * qh = (device const uint16_t *)xb->qh;
 | 
			
		||||
    device const int8_t * scales = (device const int8_t *)xb->scales;
 | 
			
		||||
 | 
			
		||||
    ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
 | 
			
		||||
    qh = qh + 32*(il/8) + 16*(il&1);
 | 
			
		||||
    ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1);
 | 
			
		||||
    qh = qh + 16*(il/8) + 8*(il&1);
 | 
			
		||||
    float sc = scales[(il%2) + 2 * ((il/2))];
 | 
			
		||||
    il = (il/2) & 3;
 | 
			
		||||
 | 
			
		||||
    const uint16_t  kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
 | 
			
		||||
    const uint16_t  kmask2 = il>1 ? 0xF0              : 0x0F;
 | 
			
		||||
    const float       coef = il>1 ? 1.f/16.f          : 1.f;
 | 
			
		||||
    const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303);
 | 
			
		||||
    const uint32_t kmask2 = il>1 ? 0xF0F0F0F0                       : 0x0F0F0F0F;
 | 
			
		||||
    const float ml = d_all * sc * 32.f;
 | 
			
		||||
    const float dl = d_all * sc * coef;
 | 
			
		||||
    for (int i = 0; i < 16; ++i) {
 | 
			
		||||
        const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
 | 
			
		||||
                            : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
 | 
			
		||||
        reg[i/4][i%4] = dl * q - ml;
 | 
			
		||||
    const float dl0 = d_all * sc;
 | 
			
		||||
    const float dl1 = dl0 / 256.f;
 | 
			
		||||
    const float dl2 = dl0 / (256.f * 256.f);
 | 
			
		||||
    const float dl3 = dl0 / (256.f * 256.f * 256.f);
 | 
			
		||||
    const uint8_t shr_h = il>2 ? 2 : 0;
 | 
			
		||||
    const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4);
 | 
			
		||||
    const uint8_t shr_l = il>1 ? 4 : 0;
 | 
			
		||||
    for (int i = 0; i < 4; ++i) {
 | 
			
		||||
        const uint32_t  low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2;
 | 
			
		||||
        const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1;
 | 
			
		||||
        const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l);
 | 
			
		||||
        reg[i][0] = dl0 *  ((half)(q & 0xFF))       - ml;
 | 
			
		||||
        reg[i][1] = dl1 * ((float)(q & 0xFF00))     - ml;
 | 
			
		||||
        reg[i][2] = dl2 * ((float)(q & 0xFF0000))   - ml;
 | 
			
		||||
        reg[i][3] = dl3 * ((float)(q & 0xFF000000)) - ml;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user