mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	wip
This commit is contained in:
		@@ -2756,11 +2756,24 @@ kernel void kernel_leaky_relu_f32(
 | 
			
		||||
 | 
			
		||||
// ref: https://arxiv.org/pdf/2307.08691.pdf
 | 
			
		||||
template<
 | 
			
		||||
    typename block_q,
 | 
			
		||||
    short nl,
 | 
			
		||||
    void (*dequantize_func)(device const block_q *, short, thread half4x4 &),
 | 
			
		||||
    typename q_t,
 | 
			
		||||
    typename q4_t,
 | 
			
		||||
    typename q8x8_t,
 | 
			
		||||
    typename k_t,
 | 
			
		||||
    typename k4x4_t,
 | 
			
		||||
    typename k8x8_t,
 | 
			
		||||
    typename v_t,
 | 
			
		||||
    typename v4x4_t,
 | 
			
		||||
    typename v8x8_t,
 | 
			
		||||
    typename s_t,    // attention accumulation types
 | 
			
		||||
    typename s8x8_t,
 | 
			
		||||
    typename o_t,
 | 
			
		||||
    typename o8x8_t,
 | 
			
		||||
    typename block_q,
 | 
			
		||||
    short nl_k,
 | 
			
		||||
    void (*deq_k)(device const block_q *, short, thread k4x4_t &),
 | 
			
		||||
    short nl_v,
 | 
			
		||||
    void (*deq_v)(device const block_q *, short, thread v4x4_t &),
 | 
			
		||||
    short D,         // head size
 | 
			
		||||
    short Q  = 8,    // queries per threadgroup
 | 
			
		||||
    short KV = 8,    // key/value processed per each simdgroup
 | 
			
		||||
@@ -2819,15 +2832,19 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
    const short TS = T/SF;          // shared memory size per query in (s_t)
 | 
			
		||||
    const short T4 = T/4;           // shared memory size per query in (half4)
 | 
			
		||||
 | 
			
		||||
    threadgroup half  * sq  = (threadgroup half  *) (shared +               0*D); // holds the query data
 | 
			
		||||
    threadgroup half4 * sq4 = (threadgroup half4 *) (shared +               0*D); // same as above but in half4
 | 
			
		||||
    threadgroup s_t   * ss  = (threadgroup s_t   *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
 | 
			
		||||
    threadgroup q_t  * sq  = (threadgroup q_t  *) (shared +               0*D); // holds the query data
 | 
			
		||||
    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared +               0*D); // same as above but in q4_t
 | 
			
		||||
    threadgroup o_t  * so  = (threadgroup o_t  *) (shared +               0*D); // reuse query data for accumulation
 | 
			
		||||
    threadgroup s_t  * ss  = (threadgroup s_t  *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
 | 
			
		||||
 | 
			
		||||
    threadgroup half    * skv  = (threadgroup half    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory
 | 
			
		||||
    threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4
 | 
			
		||||
    threadgroup k_t    * sk    = (threadgroup k_t    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
 | 
			
		||||
    threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
 | 
			
		||||
 | 
			
		||||
    threadgroup v_t    * sv    = (threadgroup v_t    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
 | 
			
		||||
    threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
 | 
			
		||||
 | 
			
		||||
    // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
 | 
			
		||||
    simdgroup_half8x8 lo[D8];
 | 
			
		||||
    o8x8_t lo[D8];
 | 
			
		||||
 | 
			
		||||
    // load heads from Q to shared memory
 | 
			
		||||
    for (short j = sgitg; j < Q; j += nsg) {
 | 
			
		||||
@@ -2835,7 +2852,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
 | 
			
		||||
        for (short i = tiisg; i < D4; i += NW) {
 | 
			
		||||
            if (iq1 + j < ne01) {
 | 
			
		||||
                sq4[j*T4 + i] = (half4) q4[i];
 | 
			
		||||
                sq4[j*T4 + i] = (q4_t) q4[i];
 | 
			
		||||
            } else {
 | 
			
		||||
                sq4[j*T4 + i] = 0.0h;
 | 
			
		||||
            }
 | 
			
		||||
@@ -2844,7 +2861,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
 | 
			
		||||
    // zero out lo
 | 
			
		||||
    for (short i = 0; i < D8; ++i) {
 | 
			
		||||
        lo[i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
 | 
			
		||||
        lo[i] = make_filled_simdgroup_matrix<o_t, 8>((o_t) 0.0f);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // zero out shared memory SH
 | 
			
		||||
@@ -2883,7 +2900,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
        const short iv3 = iq3/rv3;
 | 
			
		||||
 | 
			
		||||
        // load the queries from shared memory into local memory
 | 
			
		||||
        simdgroup_half8x8 mq[D8];
 | 
			
		||||
        q8x8_t mq[D8];
 | 
			
		||||
 | 
			
		||||
        for (short i = 0; i < D8; ++i) {
 | 
			
		||||
            simdgroup_load(mq[i], sq + i*8, T);
 | 
			
		||||
@@ -2915,16 +2932,16 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
            // Q*K^T
 | 
			
		||||
            {
 | 
			
		||||
                for (short cc = 0; cc < C/8; ++cc) {
 | 
			
		||||
                    s8x8_t mqk = make_filled_simdgroup_matrix<s_t, 8>(0.0f);
 | 
			
		||||
                    s8x8_t mqk = make_filled_simdgroup_matrix<s_t, 8>((s_t) 0.0f);
 | 
			
		||||
 | 
			
		||||
                    // this is compile-time check, so it does not have runtime overhead
 | 
			
		||||
                    if (is_same<block_q, half4x4>::value) {
 | 
			
		||||
                    if (is_same<block_q, k4x4_t>::value) {
 | 
			
		||||
                        // we can read directly from global memory
 | 
			
		||||
                        device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
 | 
			
		||||
                        device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
 | 
			
		||||
 | 
			
		||||
                        for (short i = 0; i < D8; ++i) {
 | 
			
		||||
                            simdgroup_half8x8 mk;
 | 
			
		||||
                            simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
 | 
			
		||||
                            k8x8_t mk;
 | 
			
		||||
                            simdgroup_load(mk, pk + i*8, nb11/sizeof(k_t), 0, true); // transpose // TODO: use ne10
 | 
			
		||||
 | 
			
		||||
                            simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
 | 
			
		||||
                        }
 | 
			
		||||
@@ -2934,38 +2951,38 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
 | 
			
		||||
                            if (D16%4 == 0) {
 | 
			
		||||
                                // the head is evenly divisible by 4*16 = 64, so no need for bound checks
 | 
			
		||||
                                half4x4 tmp;
 | 
			
		||||
                                dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
 | 
			
		||||
                                skv4[4*ty + tx] = tmp;
 | 
			
		||||
                                k4x4_t tmp;
 | 
			
		||||
                                deq_k(pk4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
 | 
			
		||||
                                sk4x4[4*ty + tx] = tmp;
 | 
			
		||||
 | 
			
		||||
                                simdgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
#pragma unroll
 | 
			
		||||
                                for (short k = 0; k < 4; ++k) {
 | 
			
		||||
                                    simdgroup_half8x8 mk;
 | 
			
		||||
                                    k8x8_t mk;
 | 
			
		||||
 | 
			
		||||
                                    simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
 | 
			
		||||
                                    simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
 | 
			
		||||
                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
 | 
			
		||||
 | 
			
		||||
                                    simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
 | 
			
		||||
                                    simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
 | 
			
		||||
                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
 | 
			
		||||
                                }
 | 
			
		||||
                            } else {
 | 
			
		||||
                                if (ii + tx < D16) {
 | 
			
		||||
                                    half4x4 tmp;
 | 
			
		||||
                                    dequantize_func(pk4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
 | 
			
		||||
                                    skv4[4*ty + tx] = tmp;
 | 
			
		||||
                                    k4x4_t tmp;
 | 
			
		||||
                                    deq_k(pk4 + (ii + tx)/nl_k, (ii + tx)%nl_k, tmp);
 | 
			
		||||
                                    sk4x4[4*ty + tx] = tmp;
 | 
			
		||||
                                }
 | 
			
		||||
 | 
			
		||||
                                simdgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
 | 
			
		||||
                                    simdgroup_half8x8 mk;
 | 
			
		||||
                                    k8x8_t mk;
 | 
			
		||||
 | 
			
		||||
                                    simdgroup_load(mk, skv + 16*k + 0*8, 4*16, 0, true); // transpose
 | 
			
		||||
                                    simdgroup_load(mk, sk + 16*k + 0*8, 4*16, 0, true); // transpose
 | 
			
		||||
                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 0], mk, mqk);
 | 
			
		||||
 | 
			
		||||
                                    simdgroup_load(mk, skv + 16*k + 1*8, 4*16, 0, true); // transpose
 | 
			
		||||
                                    simdgroup_load(mk, sk + 16*k + 1*8, 4*16, 0, true); // transpose
 | 
			
		||||
                                    simdgroup_multiply_accumulate(mqk, mq[2*(ii + k) + 1], mk, mqk);
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
@@ -2995,7 +3012,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
 | 
			
		||||
                    if (mask != q) {
 | 
			
		||||
                        // mqk = mqk + mask*slope
 | 
			
		||||
                        s += slope*mp[ic + j*nb31/sizeof(half) + tiisg];
 | 
			
		||||
                        s += slope*mp[ic + j*nb31/sizeof(half) + tiisg]; // TODO: use ne30
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    smax = simd_max(max(smax, s));
 | 
			
		||||
@@ -3037,13 +3054,13 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
                    s8x8_t ms;
 | 
			
		||||
                    simdgroup_load(ms, ss + 8*cc, TS, 0, false);
 | 
			
		||||
 | 
			
		||||
                    if (is_same<block_q, half4x4>::value) {
 | 
			
		||||
                    if (is_same<block_q, v4x4_t>::value) {
 | 
			
		||||
                        // we can read directly from global memory
 | 
			
		||||
                        device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
 | 
			
		||||
                        device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
 | 
			
		||||
#pragma unroll
 | 
			
		||||
                        for (short i = 0; i < D8; ++i) {
 | 
			
		||||
                            simdgroup_half8x8 mv;
 | 
			
		||||
                            simdgroup_load(mv, pv + i*8, nb21/sizeof(half), 0, false);
 | 
			
		||||
                            v8x8_t mv;
 | 
			
		||||
                            simdgroup_load(mv, pv + i*8, nb21/sizeof(v_t), 0, false); // TODO: use ne20
 | 
			
		||||
 | 
			
		||||
                            simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
 | 
			
		||||
                        }
 | 
			
		||||
@@ -3053,38 +3070,38 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
 | 
			
		||||
                            if (D16%4 == 0) {
 | 
			
		||||
                                // no need for bound checks
 | 
			
		||||
                                half4x4 tmp;
 | 
			
		||||
                                dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
 | 
			
		||||
                                skv4[4*ty + tx] = tmp;
 | 
			
		||||
                                v4x4_t tmp;
 | 
			
		||||
                                deq_v(pv4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
 | 
			
		||||
                                sv4x4[4*ty + tx] = tmp;
 | 
			
		||||
 | 
			
		||||
                                simdgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
#pragma unroll
 | 
			
		||||
                                for (short k = 0; k < 4; ++k) {
 | 
			
		||||
                                    simdgroup_half8x8 mv;
 | 
			
		||||
                                    v8x8_t mv;
 | 
			
		||||
 | 
			
		||||
                                    simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
 | 
			
		||||
                                    simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
 | 
			
		||||
                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
 | 
			
		||||
 | 
			
		||||
                                    simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
 | 
			
		||||
                                    simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
 | 
			
		||||
                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
 | 
			
		||||
                                }
 | 
			
		||||
                            } else {
 | 
			
		||||
                                if (ii + tx < D16) {
 | 
			
		||||
                                    half4x4 tmp;
 | 
			
		||||
                                    dequantize_func(pv4 + (ii + tx)/nl, (ii + tx)%nl, tmp);
 | 
			
		||||
                                    skv4[4*ty + tx] = tmp;
 | 
			
		||||
                                    v4x4_t tmp;
 | 
			
		||||
                                    deq_v(pv4 + (ii + tx)/nl_v, (ii + tx)%nl_v, tmp);
 | 
			
		||||
                                    sv4x4[4*ty + tx] = tmp;
 | 
			
		||||
                                }
 | 
			
		||||
 | 
			
		||||
                                simdgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
                                for (short k = 0; k < 4 && ii + k < D16; ++k) {
 | 
			
		||||
                                    simdgroup_half8x8 mv;
 | 
			
		||||
                                    v8x8_t mv;
 | 
			
		||||
 | 
			
		||||
                                    simdgroup_load(mv, skv + 16*k + 0*8, 4*16, 0, false);
 | 
			
		||||
                                    simdgroup_load(mv, sv + 16*k + 0*8, 4*16, 0, false);
 | 
			
		||||
                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 0], ms, mv, lo[2*(ii + k) + 0]);
 | 
			
		||||
 | 
			
		||||
                                    simdgroup_load(mv, skv + 16*k + 1*8, 4*16, 0, false);
 | 
			
		||||
                                    simdgroup_load(mv, sv + 16*k + 1*8, 4*16, 0, false);
 | 
			
		||||
                                    simdgroup_multiply_accumulate(lo[2*(ii + k) + 1], ms, mv, lo[2*(ii + k) + 1]);
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
@@ -3113,7 +3130,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
        // each simdgroup stores its output to shared memory, reusing sq
 | 
			
		||||
        if (sgitg == sg) {
 | 
			
		||||
            for (short i = 0; i < D8; ++i) {
 | 
			
		||||
                simdgroup_store(lo[i], sq + i*8, T, 0, false);
 | 
			
		||||
                simdgroup_store(lo[i], so + i*8, T, 0, false);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
@@ -3146,7 +3163,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
 | 
			
		||||
            // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
 | 
			
		||||
            {
 | 
			
		||||
                simdgroup_half8x8 t;
 | 
			
		||||
                o8x8_t t;
 | 
			
		||||
 | 
			
		||||
                s8x8_t ms0;
 | 
			
		||||
                s8x8_t ms1;
 | 
			
		||||
@@ -3155,7 +3172,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
                simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false);
 | 
			
		||||
 | 
			
		||||
                for (short i = 0; i < D8; ++i) {
 | 
			
		||||
                    simdgroup_load    (t, sq + i*8, T, 0, false);
 | 
			
		||||
                    simdgroup_load    (t, so + i*8, T, 0, false);
 | 
			
		||||
                    simdgroup_multiply(t, ms1, t);
 | 
			
		||||
 | 
			
		||||
                    simdgroup_multiply_accumulate(lo[i], ms0, lo[i], t);
 | 
			
		||||
@@ -3167,7 +3184,7 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
    // store result to shared memory (reuse sq)
 | 
			
		||||
    if (sgitg == 0) {
 | 
			
		||||
        for (short i = 0; i < D8; ++i) {
 | 
			
		||||
            simdgroup_store(lo[i], sq + i*8, T, 0, false);
 | 
			
		||||
            simdgroup_store(lo[i], so + i*8, T, 0, false);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -3187,55 +3204,75 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
 | 
			
		||||
#if defined(GGML_METAL_FORCE_FATTN_PREC_F16)
 | 
			
		||||
#define S_T    half
 | 
			
		||||
#define S4_T   half4
 | 
			
		||||
#define S4x4_T half4x4
 | 
			
		||||
#define S8x8_T simdgroup_half8x8
 | 
			
		||||
 | 
			
		||||
#define FA_TYPES \
 | 
			
		||||
    half, half4,   simdgroup_half8x8, \
 | 
			
		||||
    half, half4x4, simdgroup_half8x8, \
 | 
			
		||||
    half, half4x4, simdgroup_half8x8, \
 | 
			
		||||
    half, simdgroup_half8x8,          \
 | 
			
		||||
    half, simdgroup_half8x8
 | 
			
		||||
#else
 | 
			
		||||
#define S_T    float
 | 
			
		||||
#define S4_T   float4
 | 
			
		||||
#define S4x4_T float4x4
 | 
			
		||||
#define S8x8_T simdgroup_float8x8
 | 
			
		||||
 | 
			
		||||
#define FA_TYPES \
 | 
			
		||||
    half,  half4,   simdgroup_half8x8, \
 | 
			
		||||
    half,  half4x4, simdgroup_half8x8, \
 | 
			
		||||
    half,  half4x4, simdgroup_half8x8, \
 | 
			
		||||
    float, simdgroup_float8x8,         \
 | 
			
		||||
    half,  simdgroup_half8x8
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 64>) flash_attn_ext_t;
 | 
			
		||||
typedef decltype(kernel_flash_attn_ext<FA_TYPES, half4x4, 1, dequantize_f16, 1, dequantize_f16, 64>) flash_attn_ext_t;
 | 
			
		||||
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4,     1, dequantize_f16,  S_T, S8x8_T, 64>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4,     1, dequantize_f16,  S_T, S8x8_T, 80>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4,     1, dequantize_f16,  S_T, S8x8_T, 96>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4,     1, dequantize_f16,  S_T, S8x8_T, 112>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4,     1, dequantize_f16,  S_T, S8x8_T, 128>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4,     1, dequantize_f16,  S_T, S8x8_T, 256>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_f16_h64" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  1, dequantize_f16,  64>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_f16_h80" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  1, dequantize_f16,  80>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_f16_h96" )]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  1, dequantize_f16,  96>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_f16_h112")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  1, dequantize_f16,  112>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_f16_h128")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  1, dequantize_f16,  128>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_f16_h256")]]  kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, half4x4,    1, dequantize_f16,  1, dequantize_f16,  256>;
 | 
			
		||||
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 64>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 80>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 96>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 112>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 128>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 256>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 64>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 80>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 96>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 112>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 128>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_0, 2, dequantize_q4_0, 2, dequantize_q4_0, 256>;
 | 
			
		||||
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 64>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 80>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 96>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 112>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 128>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 256>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 64>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 80>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 96>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 112>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 128>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q4_1, 2, dequantize_q4_1, 2, dequantize_q4_1, 256>;
 | 
			
		||||
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 64>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 80>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 96>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 112>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 128>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 256>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 64>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 80>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 96>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 112>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 128>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_0, 2, dequantize_q5_0, 2, dequantize_q5_0, 256>;
 | 
			
		||||
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 64>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 80>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 96>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 112>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 128>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 256>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 64>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 80>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 96>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 112>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 128>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q5_1, 2, dequantize_q5_1, 2, dequantize_q5_1, 256>;
 | 
			
		||||
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 64>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 80>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 96>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 112>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 128>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 256>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 64>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 80>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 96>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 112>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 128>;
 | 
			
		||||
template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<FA_TYPES, block_q8_0, 2, dequantize_q8_0, 2, dequantize_q8_0, 256>;
 | 
			
		||||
 | 
			
		||||
#undef FA_TYPES
 | 
			
		||||
 | 
			
		||||
// NOTE: can use half instead of float precision for some extra perf
 | 
			
		||||
//       however, by default use F32 since the op should be mostly memory bandwidth bound
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user