mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	metal : Q3_K speedup (#2995)
* Slightly faster Q3_K and Q5_K on metal * Another Q3_K speedup on metal Combined with previous commit, we are now +9.6% for TG. PP is not affected as this happens via the matrix multiplication templates. * Slowly progressing on Q3_K on metal We are now 13% faster than master * nother small improvement for Q3_K on metal --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
		
							
								
								
									
										135
									
								
								ggml-metal.metal
									
									
									
									
									
								
							
							
						
						
									
										135
									
								
								ggml-metal.metal
									
									
									
									
									
								
							@@ -1123,31 +1123,40 @@ kernel void kernel_mul_mat_q3_K_f32(
 | 
			
		||||
    device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
 | 
			
		||||
    device const float     * yy = (device const float      *) src1 + r1*ne10 + r2*ne00*ne1;
 | 
			
		||||
 | 
			
		||||
    float yl[16];
 | 
			
		||||
    float yl[32];
 | 
			
		||||
 | 
			
		||||
    const uint16_t kmask1 = 0x0303;
 | 
			
		||||
    const uint16_t kmask1 = 0x3030;
 | 
			
		||||
    const uint16_t kmask2 = 0x0f0f;
 | 
			
		||||
 | 
			
		||||
    const int tid = tiisg/2;
 | 
			
		||||
    const int ix  = tiisg%2;
 | 
			
		||||
    const int ip  = tid/8;          // 0 or 1
 | 
			
		||||
    const int il  = tid/2 - 4*ip;   // 0...3
 | 
			
		||||
    const int tid = tiisg/4;
 | 
			
		||||
    const int ix  = tiisg%4;
 | 
			
		||||
    const int ip  = tid/4;          // 0 or 1
 | 
			
		||||
    const int il  = 2*((tid%4)/2);  // 0 or 2
 | 
			
		||||
    const int ir  = tid%2;
 | 
			
		||||
    const int n   = 8;
 | 
			
		||||
    const int l0  = n*ir;
 | 
			
		||||
 | 
			
		||||
    const uint16_t m1 = 1 << (4*ip + il);
 | 
			
		||||
    const uint16_t m2 = m1 << 8;
 | 
			
		||||
    // One would think that the Metal compiler would figure out that ip and il can only have
 | 
			
		||||
    // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
 | 
			
		||||
    // with these two tales.
 | 
			
		||||
    //
 | 
			
		||||
    // Possible masks for the high bit
 | 
			
		||||
    const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200},  // ip = 0, il = 0
 | 
			
		||||
                           {0x0004, 0x0400, 0x0008, 0x0800},  // ip = 0, il = 2
 | 
			
		||||
                           {0x0010, 0x1000, 0x0020, 0x2000},  // ip = 1, il = 0
 | 
			
		||||
                           {0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
 | 
			
		||||
 | 
			
		||||
    // Possible masks for the low 2 bits
 | 
			
		||||
    const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
 | 
			
		||||
 | 
			
		||||
    const ushort4 hm = mm[2*ip + il/2];
 | 
			
		||||
 | 
			
		||||
    const int shift = 2*il;
 | 
			
		||||
    const uint16_t qm1 = 0x0003 << shift;
 | 
			
		||||
    const uint16_t qm2 = 0x0300 << shift;
 | 
			
		||||
    const int32_t v1 = 4 << shift;
 | 
			
		||||
    const int32_t v2 = 1024 << shift;
 | 
			
		||||
    const float    v1 = il == 0 ? 4.f : 64.f;
 | 
			
		||||
    const float    v2 = 4.f * v1;
 | 
			
		||||
 | 
			
		||||
    const uint16_t s_shift1 = 4*ip;
 | 
			
		||||
    const uint16_t s_shift2 = s_shift1 + 2*(il/2);
 | 
			
		||||
    const int ik = 4 + (il%2);
 | 
			
		||||
    const uint16_t s_shift2 = s_shift1 + il;
 | 
			
		||||
 | 
			
		||||
    const int q_offset = 32*ip + l0;
 | 
			
		||||
    const int y_offset = 128*ip + 32*il + l0;
 | 
			
		||||
@@ -1156,12 +1165,19 @@ kernel void kernel_mul_mat_q3_K_f32(
 | 
			
		||||
 | 
			
		||||
    device const float * y1 = yy + ix*QK_K + y_offset;
 | 
			
		||||
 | 
			
		||||
    float sumf1[2] = {0.f}, sumf2[2] = {0.f};
 | 
			
		||||
    for (int i = ix; i < nb; i += 2) {
 | 
			
		||||
    uint32_t scales32, aux32;
 | 
			
		||||
    thread uint16_t * scales16 = (thread uint16_t *)&scales32;
 | 
			
		||||
    thread const int8_t * scales = (thread const int8_t *)&scales32;
 | 
			
		||||
 | 
			
		||||
    float sumf1[2] = {0.f};
 | 
			
		||||
    float sumf2[2] = {0.f};
 | 
			
		||||
    for (int i = ix; i < nb; i += 4) {
 | 
			
		||||
 | 
			
		||||
        for (int l = 0; l < 8; ++l) {
 | 
			
		||||
            yl[l+0] = y1[l+ 0];
 | 
			
		||||
            yl[l+8] = y1[l+16];
 | 
			
		||||
            yl[l+ 0] = y1[l+ 0];
 | 
			
		||||
            yl[l+ 8] = y1[l+16];
 | 
			
		||||
            yl[l+16] = y1[l+32];
 | 
			
		||||
            yl[l+24] = y1[l+48];
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
 | 
			
		||||
@@ -1172,27 +1188,43 @@ kernel void kernel_mul_mat_q3_K_f32(
 | 
			
		||||
        for (int row = 0; row < 2; ++row) {
 | 
			
		||||
 | 
			
		||||
            const float d_all = (float)dh[0];
 | 
			
		||||
            const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
 | 
			
		||||
 | 
			
		||||
            float s1 = 0, s2 = 0;
 | 
			
		||||
            for (int l = 0; l < n; l += 2) {
 | 
			
		||||
                const uint16_t qs = q[l/2];
 | 
			
		||||
                s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
 | 
			
		||||
                s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
 | 
			
		||||
            }
 | 
			
		||||
            float d = d_all * (s1 + 1.f/256.f * s2);
 | 
			
		||||
            sumf1[row] += d * scales[0];
 | 
			
		||||
            sumf2[row] += d;
 | 
			
		||||
            scales16[0] = a[4];
 | 
			
		||||
            scales16[1] = a[5];
 | 
			
		||||
            aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
 | 
			
		||||
            scales16[0] = a[il+0];
 | 
			
		||||
            scales16[1] = a[il+1];
 | 
			
		||||
            scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
 | 
			
		||||
 | 
			
		||||
            s1 = s2 = 0;
 | 
			
		||||
            float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
 | 
			
		||||
            for (int l = 0; l < n; l += 2) {
 | 
			
		||||
                const uint16_t qs = q[l/2+8];
 | 
			
		||||
                s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
 | 
			
		||||
                s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
 | 
			
		||||
                const int32_t qs = q[l/2];
 | 
			
		||||
                s1 += yl[l+0] * (qs & qm[il/2][0]);
 | 
			
		||||
                s2 += yl[l+1] * (qs & qm[il/2][1]);
 | 
			
		||||
                s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
 | 
			
		||||
                s4 += yl[l+16] * (qs & qm[il/2][2]);
 | 
			
		||||
                s5 += yl[l+17] * (qs & qm[il/2][3]);
 | 
			
		||||
                s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
 | 
			
		||||
            }
 | 
			
		||||
            d = d_all * (s1 + 1.f/256.f * s2);
 | 
			
		||||
            sumf1[row] += d * scales[1];
 | 
			
		||||
            sumf2[row] += d;
 | 
			
		||||
            float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
 | 
			
		||||
            float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
 | 
			
		||||
            sumf1[row] += d1 * (scales[0] - 32);
 | 
			
		||||
            sumf2[row] += d2 * (scales[2] - 32);
 | 
			
		||||
 | 
			
		||||
            s1 = s2 = s3 = s4 = s5 = s6 = 0;
 | 
			
		||||
            for (int l = 0; l < n; l += 2) {
 | 
			
		||||
                const int32_t qs = q[l/2+8];
 | 
			
		||||
                s1 += yl[l+8] * (qs & qm[il/2][0]);
 | 
			
		||||
                s2 += yl[l+9] * (qs & qm[il/2][1]);
 | 
			
		||||
                s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
 | 
			
		||||
                s4 += yl[l+24] * (qs & qm[il/2][2]);
 | 
			
		||||
                s5 += yl[l+25] * (qs & qm[il/2][3]);
 | 
			
		||||
                s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
 | 
			
		||||
            }
 | 
			
		||||
            d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
 | 
			
		||||
            d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
 | 
			
		||||
            sumf1[row] += d1 * (scales[1] - 32);
 | 
			
		||||
            sumf2[row] += d2 * (scales[3] - 32);
 | 
			
		||||
 | 
			
		||||
            q  += step;
 | 
			
		||||
            h  += step;
 | 
			
		||||
@@ -1201,17 +1233,20 @@ kernel void kernel_mul_mat_q3_K_f32(
 | 
			
		||||
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        y1 += 2 * QK_K;
 | 
			
		||||
        y1 += 4 * QK_K;
 | 
			
		||||
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int row = 0; row < 2; ++row) {
 | 
			
		||||
        const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
 | 
			
		||||
        const float tot = simd_sum(sumf);
 | 
			
		||||
        if (tiisg == 0) {
 | 
			
		||||
            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
 | 
			
		||||
        const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
 | 
			
		||||
        sumf1[row] = simd_sum(sumf);
 | 
			
		||||
    }
 | 
			
		||||
    if (tiisg == 0) {
 | 
			
		||||
        for (int row = 0; row < 2; ++row) {
 | 
			
		||||
            dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
#else
 | 
			
		||||
kernel void kernel_mul_mat_q3_K_f32(
 | 
			
		||||
@@ -1564,17 +1599,25 @@ kernel void kernel_mul_mat_q5_K_f32(
 | 
			
		||||
            sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
 | 
			
		||||
            sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
 | 
			
		||||
 | 
			
		||||
            float4 acc = {0.f, 0.f, 0.f, 0.f};
 | 
			
		||||
            float4 acc1 = {0.f};
 | 
			
		||||
            float4 acc2 = {0.f};
 | 
			
		||||
            for (int l = 0; l < n; ++l) {
 | 
			
		||||
                uint8_t h = qh[l];
 | 
			
		||||
                acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0));
 | 
			
		||||
                acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0));
 | 
			
		||||
                acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0));
 | 
			
		||||
                acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0));
 | 
			
		||||
                acc1[0] += yl[l+0] * (q1[l] & 0x0F);
 | 
			
		||||
                acc1[1] += yl[l+8] * (q1[l] & 0xF0);
 | 
			
		||||
                acc1[2] += yh[l+0] * (q2[l] & 0x0F);
 | 
			
		||||
                acc1[3] += yh[l+8] * (q2[l] & 0xF0);
 | 
			
		||||
                acc2[0] += h & hm1 ? yl[l+0] : 0.f;
 | 
			
		||||
                acc2[1] += h & hm2 ? yl[l+8] : 0.f;
 | 
			
		||||
                acc2[2] += h & hm3 ? yh[l+0] : 0.f;
 | 
			
		||||
                acc2[3] += h & hm4 ? yh[l+8] : 0.f;
 | 
			
		||||
            }
 | 
			
		||||
            const float dall = dh[0];
 | 
			
		||||
            const float dmin = dh[1];
 | 
			
		||||
            sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) -
 | 
			
		||||
            sumf[row] += dall * (sc8[0] * (acc1[0] +  16.f*acc2[0]) +
 | 
			
		||||
                                 sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
 | 
			
		||||
                                 sc8[4] * (acc1[2] +  16.f*acc2[2]) +
 | 
			
		||||
                                 sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
 | 
			
		||||
                         dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
 | 
			
		||||
 | 
			
		||||
            q1 += step;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user