mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	clean-up
This commit is contained in:
		@@ -3251,6 +3251,9 @@ static void ggml_metal_encode_node(
 | 
			
		||||
                    GGML_ASSERT(nqptg  % 8  == 0);
 | 
			
		||||
                    GGML_ASSERT(ncpsg  % 32 == 0);
 | 
			
		||||
 | 
			
		||||
                    // 2*(2*ncpsg + nqptg)*(nsg)
 | 
			
		||||
                    // ncpsg soft_max values + ncpsg mask values + a diagonal scaling matrix (in float)
 | 
			
		||||
                    //
 | 
			
		||||
                    // 16*32*(nsg)
 | 
			
		||||
                    // the shared memory needed for the simdgroups to load the KV cache
 | 
			
		||||
                    // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG
 | 
			
		||||
 
 | 
			
		||||
@@ -2846,14 +2846,14 @@ kernel void kernel_flash_attn_ext(
 | 
			
		||||
    const short NW  = N_SIMDWIDTH;
 | 
			
		||||
    const short SH  = (2*C + Q); // shared memory per simdgroup in (half)
 | 
			
		||||
 | 
			
		||||
    const short TS = nsg*SH;   // shared memory size per query in (s_t)
 | 
			
		||||
    const short TS = nsg*SH;   // shared memory size per query in (s_t == float)
 | 
			
		||||
    const short T  = D + 2*TS; // shared memory size per query in (half)
 | 
			
		||||
 | 
			
		||||
    threadgroup q_t  * sq  = (threadgroup q_t  *) (shared +              0*D); // holds the query data
 | 
			
		||||
    threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared +              0*D); // same as above but in q4_t
 | 
			
		||||
    threadgroup o_t  * so  = (threadgroup o_t  *) (shared +              0*D); // reuse query data for accumulation
 | 
			
		||||
    threadgroup o4_t * so4 = (threadgroup o4_t *) (shared +              0*D); // reuse query data for accumulation
 | 
			
		||||
    threadgroup s_t  * ss  = (threadgroup s_t  *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention and diagonal matrix
 | 
			
		||||
    threadgroup o4_t * so4 = (threadgroup o4_t *) (shared +              0*D); // same as above but in o4_t
 | 
			
		||||
    threadgroup s_t  * ss  = (threadgroup s_t  *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
 | 
			
		||||
 | 
			
		||||
    threadgroup k_t    * sk    = (threadgroup k_t    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
 | 
			
		||||
    threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user