mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	metal : use constexpr in FA kernels + fix typedef (#12659)
* metal : use constexpr in FA kernels ggml-ci * cont ggml-ci * cont : fix typedef ggml-ci
This commit is contained in:
		@@ -3128,14 +3128,15 @@ kernel void kernel_flash_attn_ext(
 | 
				
			|||||||
    const int iq2 = tgpig[1];
 | 
					    const int iq2 = tgpig[1];
 | 
				
			||||||
    const int iq1 = tgpig[0]*Q;
 | 
					    const int iq1 = tgpig[0]*Q;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const short DK4  = DK/4;
 | 
					    constexpr short DK4  = DK/4;
 | 
				
			||||||
    const short DK8  = DK/8;
 | 
					    constexpr short DK8  = DK/8;
 | 
				
			||||||
    const short DK16 = DK/16;
 | 
					    constexpr short DK16 = DK/16;
 | 
				
			||||||
    const short DV4  = DV/4;
 | 
					    constexpr short DV4  = DV/4;
 | 
				
			||||||
    const short DV8  = DV/8;
 | 
					    constexpr short DV8  = DV/8;
 | 
				
			||||||
    const short DV16 = DV/16;
 | 
					    constexpr short DV16 = DV/16;
 | 
				
			||||||
    const short NW  = N_SIMDWIDTH;
 | 
					
 | 
				
			||||||
    const short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
 | 
					    constexpr short NW  = N_SIMDWIDTH;
 | 
				
			||||||
 | 
					    constexpr short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const short TS = nsg*SH;   // shared memory size per query in (s_t == float)
 | 
					    const short TS = nsg*SH;   // shared memory size per query in (s_t == float)
 | 
				
			||||||
    const short T  = DK + 2*TS; // shared memory size per query in (half)
 | 
					    const short T  = DK + 2*TS; // shared memory size per query in (half)
 | 
				
			||||||
@@ -3641,11 +3642,11 @@ kernel void kernel_flash_attn_ext_vec(
 | 
				
			|||||||
    const int iq2 = tgpig[1];
 | 
					    const int iq2 = tgpig[1];
 | 
				
			||||||
    const int iq1 = tgpig[0];
 | 
					    const int iq1 = tgpig[0];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const short DK4 = DK/4;
 | 
					    constexpr short DK4 = DK/4;
 | 
				
			||||||
    const short DV4 = DV/4;
 | 
					    constexpr short DV4 = DV/4;
 | 
				
			||||||
    const short NW  = N_SIMDWIDTH;
 | 
					    constexpr short NW  = N_SIMDWIDTH;
 | 
				
			||||||
    const short NL  = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
 | 
					    constexpr short NL  = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
 | 
				
			||||||
    const short SH  = 2*C;   // shared memory per simdgroup
 | 
					    constexpr short SH  = 2*C;   // shared memory per simdgroup
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const short T = DK + nsg*SH; // shared memory size per query in (half)
 | 
					    const short T = DK + nsg*SH; // shared memory size per query in (half)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -3956,7 +3957,7 @@ kernel void kernel_flash_attn_ext_vec(
 | 
				
			|||||||
    half,  half4, \
 | 
					    half,  half4, \
 | 
				
			||||||
           half4
 | 
					           half4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 128>) flash_attn_ext_vec_t;
 | 
					typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  128, 128, 4>;
 | 
					template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  128, 128, 4>;
 | 
				
			||||||
#if defined(GGML_METAL_USE_BF16)
 | 
					#if defined(GGML_METAL_USE_BF16)
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user