mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : use F16 precision in FA kernel
This commit is contained in:
		| @@ -12,6 +12,9 @@ | |||||||
| #define MIN(a, b) ((a) < (b) ? (a) : (b)) | #define MIN(a, b) ((a) < (b) ? (a) : (b)) | ||||||
| #define MAX(a, b) ((a) > (b) ? (a) : (b)) | #define MAX(a, b) ((a) > (b) ? (a) : (b)) | ||||||
|  |  | ||||||
|  | // TODO: for now, always use F32 for flash attention to avoid compiling 2 sets of kernels | ||||||
|  | #define GGML_METAL_FORCE_FATTN_PREC_F32 | ||||||
|  |  | ||||||
| // max memory buffers that can be mapped to the device | // max memory buffers that can be mapped to the device | ||||||
| #define GGML_METAL_MAX_BUFFERS 64 | #define GGML_METAL_MAX_BUFFERS 64 | ||||||
|  |  | ||||||
| @@ -496,6 +499,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | |||||||
|                 // dictionary of preprocessor macros |                 // dictionary of preprocessor macros | ||||||
|                 NSMutableDictionary * prep = [NSMutableDictionary dictionary]; |                 NSMutableDictionary * prep = [NSMutableDictionary dictionary]; | ||||||
|  |  | ||||||
|  |                 // add GGML_METAL_FORCE_FATTN_PREC_F32 | ||||||
|  | #if defined(GGML_METAL_FORCE_FATTN_PREC_F32) | ||||||
|  |                 [prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F32"]; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|                 MTLCompileOptions * options = [MTLCompileOptions new]; |                 MTLCompileOptions * options = [MTLCompileOptions new]; | ||||||
|                 options.preprocessorMacros = prep; |                 options.preprocessorMacros = prep; | ||||||
|  |  | ||||||
| @@ -3216,11 +3224,19 @@ static void ggml_metal_encode_node( | |||||||
|                     GGML_ASSERT(nqptg  % 8  == 0); |                     GGML_ASSERT(nqptg  % 8  == 0); | ||||||
|                     GGML_ASSERT(ncpsg  % 32 == 0); |                     GGML_ASSERT(ncpsg  % 32 == 0); | ||||||
|  |  | ||||||
|  | #ifdef GGML_METAL_FORCE_FATTN_PREC_F32 | ||||||
|  |                     const enum ggml_prec prec = GGML_PREC_DEFAULT; | ||||||
|  | #else | ||||||
|  |                     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst); | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  |                     const int64_t nhalfs = prec == GGML_PREC_DEFAULT ? 1 : 2; | ||||||
|  |  | ||||||
|                     // 16*32*(nsg) |                     // 16*32*(nsg) | ||||||
|                     // the shared memory needed for the simdgroups to load the KV cache |                     // the shared memory needed for the simdgroups to load the KV cache | ||||||
|                     // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG |                     // each thread loads (dequantizes) 16 head elements, there are 32 threads in th SG | ||||||
|                     // |                     // | ||||||
| #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + 2*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) | #define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(ne00 + nhalfs*(ncpsg + nqptg)*(nsg)) + 16*32*(nsg))*(sizeof(float)/2), 16)) | ||||||
|  |  | ||||||
|                     int64_t nsgmax = 2; |                     int64_t nsgmax = 2; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -2805,13 +2805,13 @@ kernel void kernel_flash_attn_ext( | |||||||
|     const short NW  = N_SIMDWIDTH; |     const short NW  = N_SIMDWIDTH; | ||||||
|     const short SH  = (C + Q); // shared memory per simdgroup in (half) |     const short SH  = (C + Q); // shared memory per simdgroup in (half) | ||||||
|  |  | ||||||
|     const short T  = D + 2*nsg*SH; // shared memory size per query in (half) |     const short T  = D + nsg*SH; // shared memory size per query in (half) | ||||||
|     const short TF = T/2;        // shared memory size per query in (float) |     const short TF = T;          // shared memory size per query in (float) | ||||||
|     const short T4 = T/4;        // shared memory size per query in (half4) |     const short T4 = T/4;        // shared memory size per query in (half4) | ||||||
|  |  | ||||||
|     threadgroup half  * sq  = (threadgroup half  *) (shared +              0*D); // holds the query data |     threadgroup half  * sq  = (threadgroup half  *) (shared +            0*D); // holds the query data | ||||||
|     threadgroup half4 * sq4 = (threadgroup half4 *) (shared +              0*D); // same as above but in half4 |     threadgroup half4 * sq4 = (threadgroup half4 *) (shared +            0*D); // same as above but in half4 | ||||||
|     threadgroup float * ss  = (threadgroup float *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix |     threadgroup half  * ss  = (threadgroup half  *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix | ||||||
|  |  | ||||||
|     threadgroup half    * skv  = (threadgroup half    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory |     threadgroup half    * skv  = (threadgroup half    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory | ||||||
|     threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4 |     threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4 | ||||||
| @@ -2840,7 +2840,7 @@ kernel void kernel_flash_attn_ext( | |||||||
|     // zero out shared memory SH |     // zero out shared memory SH | ||||||
|     for (short j = 0; j < Q; ++j) { |     for (short j = 0; j < Q; ++j) { | ||||||
|         for (short i = tiisg; i < SH; i += NW) { |         for (short i = tiisg; i < SH; i += NW) { | ||||||
|             ss[j*TF + i] = 0.0f; |             ss[j*TF + i] = 0.0h; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -2905,7 +2905,7 @@ kernel void kernel_flash_attn_ext( | |||||||
|             // Q*K^T |             // Q*K^T | ||||||
|             { |             { | ||||||
|                 for (short cc = 0; cc < C/8; ++cc) { |                 for (short cc = 0; cc < C/8; ++cc) { | ||||||
|                     simdgroup_float8x8 mqk = make_filled_simdgroup_matrix<float, 8>(0.h); |                     simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, 8>(0.h); | ||||||
|  |  | ||||||
|                     // this is compile-time check, so it does not have runtime overhead |                     // this is compile-time check, so it does not have runtime overhead | ||||||
|                     if (is_same<block_q, half4x4>::value) { |                     if (is_same<block_q, half4x4>::value) { | ||||||
| @@ -2977,7 +2977,7 @@ kernel void kernel_flash_attn_ext( | |||||||
|                     const float m = M[j]; |                     const float m = M[j]; | ||||||
|  |  | ||||||
|                     // scale and apply the logitcap / mask |                     // scale and apply the logitcap / mask | ||||||
|                     float s = ss[j*TF + tiisg]*scale; |                     float s = ((float)(ss[j*TF + tiisg]))*scale; | ||||||
|  |  | ||||||
|                     if (logit_softcap != 0.0f) { |                     if (logit_softcap != 0.0f) { | ||||||
|                         s = logit_softcap*precise::tanh(s); |                         s = logit_softcap*precise::tanh(s); | ||||||
| @@ -3013,7 +3013,7 @@ kernel void kernel_flash_attn_ext( | |||||||
|  |  | ||||||
|             // O = diag(ms)*O |             // O = diag(ms)*O | ||||||
|             { |             { | ||||||
|                 simdgroup_float8x8 mm; |                 simdgroup_half8x8 mm; | ||||||
|                 simdgroup_load(mm, ss + C, TF, 0, false); |                 simdgroup_load(mm, ss + C, TF, 0, false); | ||||||
|  |  | ||||||
|                 for (short i = 0; i < D8; ++i) { |                 for (short i = 0; i < D8; ++i) { | ||||||
| @@ -3024,7 +3024,7 @@ kernel void kernel_flash_attn_ext( | |||||||
|             // O = O + (Q*K^T)*V |             // O = O + (Q*K^T)*V | ||||||
|             { |             { | ||||||
|                 for (short cc = 0; cc < C/8; ++cc) { |                 for (short cc = 0; cc < C/8; ++cc) { | ||||||
|                     simdgroup_float8x8 ms; |                     simdgroup_half8x8 ms; | ||||||
|                     simdgroup_load(ms, ss + 8*cc, TF, 0, false); |                     simdgroup_load(ms, ss + 8*cc, TF, 0, false); | ||||||
|  |  | ||||||
|                     if (is_same<block_q, half4x4>::value) { |                     if (is_same<block_q, half4x4>::value) { | ||||||
| @@ -3137,8 +3137,8 @@ kernel void kernel_flash_attn_ext( | |||||||
|             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 |             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 | ||||||
|             { |             { | ||||||
|                 simdgroup_half8x8 t; |                 simdgroup_half8x8 t; | ||||||
|                 simdgroup_float8x8 ms0; |                 simdgroup_half8x8 ms0; | ||||||
|                 simdgroup_float8x8 ms1; |                 simdgroup_half8x8 ms1; | ||||||
|  |  | ||||||
|                 simdgroup_load(ms0, ss + C,         TF, 0, false); |                 simdgroup_load(ms0, ss + C,         TF, 0, false); | ||||||
|                 simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); |                 simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); | ||||||
| @@ -3219,6 +3219,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_ | |||||||
| template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>; | template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>; | ||||||
|  |  | ||||||
| // NOTE: can use half instead of float precision for some extra perf | // NOTE: can use half instead of float precision for some extra perf | ||||||
|  | //       however, by default use F32 since the op should be mostly memory bandwidth bound | ||||||
| // D - head size, Q - queries per threadgroup, C - cache items per threadgroup | // D - head size, Q - queries per threadgroup, C - cache items per threadgroup | ||||||
| template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32> | template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &), short D, short Q = 1, short C = 32> | ||||||
| kernel void kernel_flash_attn_ext_vec( | kernel void kernel_flash_attn_ext_vec( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov