mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	float -> half regs
This commit is contained in:
		@@ -2898,8 +2898,8 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
    threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
    {
 | 
			
		||||
        float S[Q] = { [0 ... Q-1] = 0.0f };
 | 
			
		||||
        float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
 | 
			
		||||
        half S[Q] = { [0 ... Q-1] = 0.0f };
 | 
			
		||||
        half M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 };
 | 
			
		||||
 | 
			
		||||
        // thread indices inside the simdgroup
 | 
			
		||||
        // TODO: see if we can utilize quad-group functions for better performance
 | 
			
		||||
@@ -2934,14 +2934,14 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
 | 
			
		||||
        const bool has_mask = mask != q;
 | 
			
		||||
 | 
			
		||||
        float slope = 1.0f;
 | 
			
		||||
        half slope = 1.0f;
 | 
			
		||||
 | 
			
		||||
        // ALiBi
 | 
			
		||||
        if (max_bias > 0.0f) {
 | 
			
		||||
            const uint32_t h = iq2;
 | 
			
		||||
            const short h = iq2;
 | 
			
		||||
 | 
			
		||||
            const float base = h < n_head_log2 ? m0 : m1;
 | 
			
		||||
            const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 | 
			
		||||
            const half  base = h < n_head_log2 ? m0 : m1;
 | 
			
		||||
            const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 | 
			
		||||
 | 
			
		||||
            slope = pow(base, exph);
 | 
			
		||||
        }
 | 
			
		||||
@@ -3047,10 +3047,10 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
            // online softmax
 | 
			
		||||
            {
 | 
			
		||||
                for (short j = 0; j < Q; ++j) {
 | 
			
		||||
                    const float m = M[j];
 | 
			
		||||
                    const half m = M[j];
 | 
			
		||||
 | 
			
		||||
                    // scale and apply the logitcap / mask
 | 
			
		||||
                    float s = ss[j*TS + tiisg]*scale;
 | 
			
		||||
                    half s = ss[j*TS + tiisg]*scale;
 | 
			
		||||
 | 
			
		||||
                    if (logit_softcap != 0.0f) {
 | 
			
		||||
                        s = logit_softcap*precise::tanh(s);
 | 
			
		||||
@@ -3061,8 +3061,8 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
 | 
			
		||||
                    M[j] = simd_max(max(M[j], s));
 | 
			
		||||
 | 
			
		||||
                    const float ms = exp(m - M[j]);
 | 
			
		||||
                    const float vs = exp(s - M[j]);
 | 
			
		||||
                    const half ms = exp(m - M[j]);
 | 
			
		||||
                    const half vs = exp(s - M[j]);
 | 
			
		||||
 | 
			
		||||
                    S[j] = S[j]*ms + simd_sum(vs);
 | 
			
		||||
 | 
			
		||||
@@ -3163,8 +3163,8 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
 | 
			
		||||
    // reduce the warps sequentially
 | 
			
		||||
    for (short sg = 1; sg < nsg; ++sg) {
 | 
			
		||||
        float S = { 0.0f };
 | 
			
		||||
        float M = { -FLT_MAX/2 };
 | 
			
		||||
        half S = { 0.0f };
 | 
			
		||||
        half M = { -__FLT16_MAX__/2 };
 | 
			
		||||
 | 
			
		||||
        threadgroup_barrier(mem_flags::mem_threadgroup);
 | 
			
		||||
 | 
			
		||||
@@ -3180,16 +3180,16 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
        // the first simdgroup accumulates the results from the other simdgroups
 | 
			
		||||
        if (sgitg == 0) {
 | 
			
		||||
            for (short j = 0; j < Q; ++j) {
 | 
			
		||||
                const float S0 = ss[j*TS +         0];
 | 
			
		||||
                const float S1 = ss[j*TS + sg*SH + 0];
 | 
			
		||||
                const half S0 = ss[j*TS +         0];
 | 
			
		||||
                const half S1 = ss[j*TS + sg*SH + 0];
 | 
			
		||||
 | 
			
		||||
                const float M0 = ss[j*TS +         1];
 | 
			
		||||
                const float M1 = ss[j*TS + sg*SH + 1];
 | 
			
		||||
                const half M0 = ss[j*TS +         1];
 | 
			
		||||
                const half M1 = ss[j*TS + sg*SH + 1];
 | 
			
		||||
 | 
			
		||||
                M = max(M0, M1);
 | 
			
		||||
 | 
			
		||||
                const float ms0 = exp(M0 - M);
 | 
			
		||||
                const float ms1 = exp(M1 - M);
 | 
			
		||||
                const half ms0 = exp(M0 - M);
 | 
			
		||||
                const half ms1 = exp(M1 - M);
 | 
			
		||||
 | 
			
		||||
                S = S0*ms0 + S1*ms1;
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user