mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	metal : add GGML_METAL_FORCE_FATTN_PREC_F16
ggml-ci
This commit is contained in:
		
							
								
								
									
										5
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										5
									
								
								Makefile
									
									
									
									
									
								
							| @@ -876,6 +876,11 @@ endif # GGML_HIPBLAS | ||||
|  | ||||
| ifdef GGML_METAL | ||||
| 	MK_CPPFLAGS += -DGGML_USE_METAL | ||||
|  | ||||
| ifdef GGML_METAL_FORCE_FATTN_PREC_F16 | ||||
| 	MK_CPPFLAGS += -DGGML_METAL_FORCE_FATTN_PREC_F16 | ||||
| endif # GGML_METAL_FORCE_FATTN_PREC_F16 | ||||
|  | ||||
| 	MK_LDFLAGS  += -framework Foundation -framework Metal -framework MetalKit | ||||
| 	OBJ_GGML	+= ggml/src/ggml-metal.o | ||||
| ifdef GGML_METAL_NDEBUG | ||||
|   | ||||
| @@ -153,6 +153,7 @@ option(GGML_VULKAN_VALIDATE                 "ggml: enable Vulkan validation" | ||||
| option(GGML_VULKAN_RUN_TESTS                "ggml: run Vulkan tests"                          OFF) | ||||
| option(GGML_KOMPUTE                         "ggml: use Kompute"                               OFF) | ||||
| option(GGML_METAL                           "ggml: use Metal"                                 ${GGML_METAL_DEFAULT}) | ||||
| option(GGML_METAL_FORCE_FATTN_PREC_F16      "ggml: force F16 accumulators for FA kernels"     OFF) | ||||
| option(GGML_METAL_NDEBUG                    "ggml: disable Metal debugging"                   OFF) | ||||
| option(GGML_METAL_SHADER_DEBUG              "ggml: compile Metal with -fno-fast-math"         OFF) | ||||
| option(GGML_METAL_EMBED_LIBRARY             "ggml: embed Metal library"                       ${GGML_METAL}) | ||||
|   | ||||
| @@ -58,6 +58,10 @@ if (GGML_METAL) | ||||
|         add_compile_definitions(GGML_METAL_NDEBUG) | ||||
|     endif() | ||||
|  | ||||
|     if (GGML_METAL_FORCE_FATTN_PREC_F16) | ||||
|         add_compile_definitions(GGML_METAL_FORCE_FATTN_PREC_F16) | ||||
|     endif() | ||||
|  | ||||
|     # copy ggml-common.h and ggml-metal.metal to bin directory | ||||
|     configure_file(ggml-common.h    ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h    COPYONLY) | ||||
|     configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) | ||||
|   | ||||
| @@ -12,9 +12,6 @@ | ||||
| #define MIN(a, b) ((a) < (b) ? (a) : (b)) | ||||
| #define MAX(a, b) ((a) > (b) ? (a) : (b)) | ||||
|  | ||||
| // TODO: for now, always use F32 for flash attention to avoid compiling 2 sets of kernels | ||||
| #define GGML_METAL_FORCE_FATTN_PREC_F32 | ||||
|  | ||||
| // max memory buffers that can be mapped to the device | ||||
| #define GGML_METAL_MAX_BUFFERS 64 | ||||
|  | ||||
| @@ -499,9 +496,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | ||||
|                 // dictionary of preprocessor macros | ||||
|                 NSMutableDictionary * prep = [NSMutableDictionary dictionary]; | ||||
|  | ||||
|                 // add GGML_METAL_FORCE_FATTN_PREC_F32 | ||||
| #if defined(GGML_METAL_FORCE_FATTN_PREC_F32) | ||||
|                 [prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F32"]; | ||||
| #if defined(GGML_METAL_FORCE_FATTN_PREC_F16) | ||||
|                 [prep setObject:@"1" forKey:@"GGML_METAL_FORCE_FATTN_PREC_F16"]; | ||||
| #endif | ||||
|  | ||||
|                 MTLCompileOptions * options = [MTLCompileOptions new]; | ||||
| @@ -554,6 +550,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de | ||||
|         } | ||||
|     } | ||||
|  | ||||
| #if defined(GGML_METAL_FORCE_FATTN_PREC_F16) | ||||
|     GGML_LOG_INFO("%s: GGML_METAL_FORCE_FATTN_PREC_F16  = yes\n", __func__); | ||||
| #else | ||||
|     GGML_LOG_INFO("%s: GGML_METAL_FORCE_FATTN_PREC_F16  = no\n",  __func__); | ||||
| #endif | ||||
|     GGML_LOG_INFO("%s: simdgroup reduction   = %s\n", __func__, ctx_dev->has_simdgroup_reduction     ? "true" : "false"); | ||||
|     GGML_LOG_INFO("%s: simdgroup matrix mul. = %s\n", __func__, ctx_dev->has_simdgroup_mm            ? "true" : "false"); | ||||
|     GGML_LOG_INFO("%s: bfloat                = %s\n", __func__, ctx_dev->has_bfloat                  ? "true" : "false"); | ||||
| @@ -3224,10 +3225,12 @@ static void ggml_metal_encode_node( | ||||
|                     GGML_ASSERT(nqptg  % 8  == 0); | ||||
|                     GGML_ASSERT(ncpsg  % 32 == 0); | ||||
|  | ||||
| #ifdef GGML_METAL_FORCE_FATTN_PREC_F32 | ||||
| #ifdef GGML_METAL_FORCE_FATTN_PREC_F16 | ||||
|                     const enum ggml_prec prec = GGML_PREC_DEFAULT; | ||||
| #else | ||||
|                     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst); | ||||
|                     // TODO: support both precisions | ||||
|                     const enum ggml_prec prec = GGML_PREC_F32; | ||||
|                     //const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(dst); | ||||
| #endif | ||||
|  | ||||
|                     const int64_t nhalfs = prec == GGML_PREC_DEFAULT ? 1 : 2; | ||||
|   | ||||
| @@ -2755,8 +2755,16 @@ kernel void kernel_leaky_relu_f32( | ||||
| } | ||||
|  | ||||
| // ref: https://arxiv.org/pdf/2307.08691.pdf | ||||
| // D - head size, Q - queries per threadgroup, KV - key/value processed per each simdgroup, C - cache items per threadgroup | ||||
| template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &), short D, short Q = 8, short KV = 8, short C = 32> | ||||
| template< | ||||
|     typename block_q, | ||||
|     short nl, | ||||
|     void (*dequantize_func)(device const block_q *, short, thread half4x4 &), | ||||
|     typename s_t,    // attention accumulation types | ||||
|     typename s8x8_t, | ||||
|     short D,         // head size | ||||
|     short Q  = 8,    // queries per threadgroup | ||||
|     short KV = 8,    // key/value processed per each simdgroup | ||||
|     short C  = 32>   // cache items per threadgroup | ||||
| kernel void kernel_flash_attn_ext( | ||||
|         device const  char * q, | ||||
|         device const  char * k, | ||||
| @@ -2805,13 +2813,15 @@ kernel void kernel_flash_attn_ext( | ||||
|     const short NW  = N_SIMDWIDTH; | ||||
|     const short SH  = (C + Q); // shared memory per simdgroup in (half) | ||||
|  | ||||
|     const short T  = D + nsg*SH; // shared memory size per query in (half) | ||||
|     const short TF = T;          // shared memory size per query in (float) | ||||
|     const short T4 = T/4;        // shared memory size per query in (half4) | ||||
|     const short SF = sizeof(s_t)/sizeof(half); | ||||
|  | ||||
|     threadgroup half  * sq  = (threadgroup half  *) (shared +            0*D); // holds the query data | ||||
|     threadgroup half4 * sq4 = (threadgroup half4 *) (shared +            0*D); // same as above but in half4 | ||||
|     threadgroup half  * ss  = (threadgroup half  *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix | ||||
|     const short T  = D + SF*nsg*SH; // shared memory size per query in (half) | ||||
|     const short TS = T/SF;          // shared memory size per query in (s_t) | ||||
|     const short T4 = T/4;           // shared memory size per query in (half4) | ||||
|  | ||||
|     threadgroup half  * sq  = (threadgroup half  *) (shared +               0*D); // holds the query data | ||||
|     threadgroup half4 * sq4 = (threadgroup half4 *) (shared +               0*D); // same as above but in half4 | ||||
|     threadgroup s_t   * ss  = (threadgroup s_t   *) (shared + SF*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix | ||||
|  | ||||
|     threadgroup half    * skv  = (threadgroup half    *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K and V in shared memory | ||||
|     threadgroup half4x4 * skv4 = (threadgroup half4x4 *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in half4x4 | ||||
| @@ -2840,7 +2850,7 @@ kernel void kernel_flash_attn_ext( | ||||
|     // zero out shared memory SH | ||||
|     for (short j = 0; j < Q; ++j) { | ||||
|         for (short i = tiisg; i < SH; i += NW) { | ||||
|             ss[j*TF + i] = 0.0h; | ||||
|             ss[j*TS + i] = 0.0f; | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -2905,7 +2915,7 @@ kernel void kernel_flash_attn_ext( | ||||
|             // Q*K^T | ||||
|             { | ||||
|                 for (short cc = 0; cc < C/8; ++cc) { | ||||
|                     simdgroup_half8x8 mqk = make_filled_simdgroup_matrix<half, 8>(0.h); | ||||
|                     s8x8_t mqk = make_filled_simdgroup_matrix<s_t, 8>(0.0f); | ||||
|  | ||||
|                     // this is compile-time check, so it does not have runtime overhead | ||||
|                     if (is_same<block_q, half4x4>::value) { | ||||
| @@ -2962,7 +2972,7 @@ kernel void kernel_flash_attn_ext( | ||||
|                         } | ||||
|                     } | ||||
|  | ||||
|                     simdgroup_store(mqk, ss + 8*cc, TF, 0, false); | ||||
|                     simdgroup_store(mqk, ss + 8*cc, TS, 0, false); | ||||
|                 } | ||||
|             } | ||||
|  | ||||
| @@ -2977,7 +2987,7 @@ kernel void kernel_flash_attn_ext( | ||||
|                     const float m = M[j]; | ||||
|  | ||||
|                     // scale and apply the logitcap / mask | ||||
|                     float s = ((float)(ss[j*TF + tiisg]))*scale; | ||||
|                     float s = ((float)(ss[j*TS + tiisg]))*scale; | ||||
|  | ||||
|                     if (logit_softcap != 0.0f) { | ||||
|                         s = logit_softcap*precise::tanh(s); | ||||
| @@ -2997,12 +3007,12 @@ kernel void kernel_flash_attn_ext( | ||||
|                     S[j] = S[j]*ms[j] + simd_sum(vs); | ||||
|  | ||||
|                     // the P matrix from the paper (Q rows, C columns) | ||||
|                     ss[j*TF + tiisg] = vs; | ||||
|                     ss[j*TS + tiisg] = vs; | ||||
|                 } | ||||
|  | ||||
|                 // create a QxQ diagonal matrix for rescaling the output | ||||
|                 if (tiisg < Q) { | ||||
|                     ss[tiisg*TF + C + tiisg] = ms[tiisg]; | ||||
|                     ss[tiisg*TS + C + tiisg] = ms[tiisg]; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
| @@ -3013,8 +3023,8 @@ kernel void kernel_flash_attn_ext( | ||||
|  | ||||
|             // O = diag(ms)*O | ||||
|             { | ||||
|                 simdgroup_half8x8 mm; | ||||
|                 simdgroup_load(mm, ss + C, TF, 0, false); | ||||
|                 s8x8_t mm; | ||||
|                 simdgroup_load(mm, ss + C, TS, 0, false); | ||||
|  | ||||
|                 for (short i = 0; i < D8; ++i) { | ||||
|                     simdgroup_multiply(lo[i], mm, lo[i]); | ||||
| @@ -3024,8 +3034,8 @@ kernel void kernel_flash_attn_ext( | ||||
|             // O = O + (Q*K^T)*V | ||||
|             { | ||||
|                 for (short cc = 0; cc < C/8; ++cc) { | ||||
|                     simdgroup_half8x8 ms; | ||||
|                     simdgroup_load(ms, ss + 8*cc, TF, 0, false); | ||||
|                     s8x8_t ms; | ||||
|                     simdgroup_load(ms, ss + 8*cc, TS, 0, false); | ||||
|  | ||||
|                     if (is_same<block_q, half4x4>::value) { | ||||
|                         // we can read directly from global memory | ||||
| @@ -3087,8 +3097,8 @@ kernel void kernel_flash_attn_ext( | ||||
|         // these are needed for reducing the results from the simdgroups (reuse the ss buffer) | ||||
|         for (short j = 0; j < Q; ++j) { | ||||
|             if (tiisg == 0) { | ||||
|                 ss[j*TF + 0] = S[j]; | ||||
|                 ss[j*TF + 1] = M[j]; | ||||
|                 ss[j*TS + 0] = S[j]; | ||||
|                 ss[j*TS + 1] = M[j]; | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| @@ -3112,11 +3122,11 @@ kernel void kernel_flash_attn_ext( | ||||
|         // the first simdgroup accumulates the results from the other simdgroups | ||||
|         if (sgitg == 0) { | ||||
|             for (short j = 0; j < Q; ++j) { | ||||
|                 const float S0 = ss[j*TF +         0]; | ||||
|                 const float S1 = ss[j*TF + sg*SH + 0]; | ||||
|                 const float S0 = ss[j*TS +         0]; | ||||
|                 const float S1 = ss[j*TS + sg*SH + 0]; | ||||
|  | ||||
|                 const float M0 = ss[j*TF +         1]; | ||||
|                 const float M1 = ss[j*TF + sg*SH + 1]; | ||||
|                 const float M0 = ss[j*TS +         1]; | ||||
|                 const float M1 = ss[j*TS + sg*SH + 1]; | ||||
|  | ||||
|                 M = max(M0, M1); | ||||
|  | ||||
| @@ -3126,22 +3136,23 @@ kernel void kernel_flash_attn_ext( | ||||
|                 S = S0*ms0 + S1*ms1; | ||||
|  | ||||
|                 if (tiisg == 0) { | ||||
|                     ss[j*TF + 0] = S; | ||||
|                     ss[j*TF + 1] = M; | ||||
|                     ss[j*TS + 0] = S; | ||||
|                     ss[j*TS + 1] = M; | ||||
|  | ||||
|                     ss[j*TF + C + j        ] = ms0; | ||||
|                     ss[j*TF + C + j + sg*SH] = ms1; | ||||
|                     ss[j*TS + C + j        ] = ms0; | ||||
|                     ss[j*TS + C + j + sg*SH] = ms1; | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 | ||||
|             { | ||||
|                 simdgroup_half8x8 t; | ||||
|                 simdgroup_half8x8 ms0; | ||||
|                 simdgroup_half8x8 ms1; | ||||
|  | ||||
|                 simdgroup_load(ms0, ss + C,         TF, 0, false); | ||||
|                 simdgroup_load(ms1, ss + C + sg*SH, TF, 0, false); | ||||
|                 s8x8_t ms0; | ||||
|                 s8x8_t ms1; | ||||
|  | ||||
|                 simdgroup_load(ms0, ss + C,         TS, 0, false); | ||||
|                 simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false); | ||||
|  | ||||
|                 for (short i = 0; i < D8; ++i) { | ||||
|                     simdgroup_load    (t, sq + i*8, T, 0, false); | ||||
| @@ -3165,7 +3176,7 @@ kernel void kernel_flash_attn_ext( | ||||
|     // final rescale with 1/S and store to global memory | ||||
|     if (sgitg == 0) { | ||||
|         for (short j = 0; j < Q && iq1 + j < ne01; ++j) { | ||||
|             const float S = ss[j*TF + 0]; | ||||
|             const float S = ss[j*TS + 0]; | ||||
|  | ||||
|             for (short i = tiisg; i < D4; i += NW) { | ||||
|                 dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; | ||||
| @@ -3174,49 +3185,57 @@ kernel void kernel_flash_attn_ext( | ||||
|     } | ||||
| } | ||||
|  | ||||
| typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>) flash_attn_ext_t; | ||||
| #if defined(GGML_METAL_FORCE_FATTN_PREC_F16) | ||||
| #define S_T    half | ||||
| #define S8x8_T simdgroup_half8x8 | ||||
| #else | ||||
| #define S_T    float | ||||
| #define S8x8_T simdgroup_float8x8 | ||||
| #endif | ||||
|  | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 64>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 80>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 96>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 112>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 128>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4, 1, dequantize_f16, 256>; | ||||
| typedef decltype(kernel_flash_attn_ext<half4x4, 1, dequantize_f16, S_T, S8x8_T, 64>) flash_attn_ext_t; | ||||
|  | ||||
| template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 64>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 80>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 96>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 112>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 128>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, 256>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4,     1, dequantize_f16,  S_T, S8x8_T, 64>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4,     1, dequantize_f16,  S_T, S8x8_T, 80>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4,     1, dequantize_f16,  S_T, S8x8_T, 96>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4,     1, dequantize_f16,  S_T, S8x8_T, 112>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4,     1, dequantize_f16,  S_T, S8x8_T, 128>; | ||||
| template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<half4x4,     1, dequantize_f16,  S_T, S8x8_T, 256>; | ||||
|  | ||||
| template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 64>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 80>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 96>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 112>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 128>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, 256>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 64>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 80>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 96>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 112>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 128>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_0, 2, dequantize_q4_0, S_T, S8x8_T, 256>; | ||||
|  | ||||
| template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 64>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 80>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 96>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 112>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 128>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, 256>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 64>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 80>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 96>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 112>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 128>; | ||||
| template [[host_name("kernel_flash_attn_ext_q4_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q4_1, 2, dequantize_q4_1, S_T, S8x8_T, 256>; | ||||
|  | ||||
| template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 64>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 80>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 96>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 112>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 128>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, 256>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 64>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 80>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 96>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 112>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 128>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_0, 2, dequantize_q5_0, S_T, S8x8_T, 256>; | ||||
|  | ||||
| template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 64>; | ||||
| template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 80>; | ||||
| template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 96>; | ||||
| template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 112>; | ||||
| template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 128>; | ||||
| template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, 256>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_1_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 64>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_1_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 80>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_1_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 96>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_1_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 112>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_1_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 128>; | ||||
| template [[host_name("kernel_flash_attn_ext_q5_1_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q5_1, 2, dequantize_q5_1, S_T, S8x8_T, 256>; | ||||
|  | ||||
| template [[host_name("kernel_flash_attn_ext_q8_0_h64" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 64>; | ||||
| template [[host_name("kernel_flash_attn_ext_q8_0_h80" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 80>; | ||||
| template [[host_name("kernel_flash_attn_ext_q8_0_h96" )]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 96>; | ||||
| template [[host_name("kernel_flash_attn_ext_q8_0_h112")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 112>; | ||||
| template [[host_name("kernel_flash_attn_ext_q8_0_h128")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 128>; | ||||
| template [[host_name("kernel_flash_attn_ext_q8_0_h256")]] kernel flash_attn_ext_t kernel_flash_attn_ext<block_q8_0, 2, dequantize_q8_0, S_T, S8x8_T, 256>; | ||||
|  | ||||
| // NOTE: can use half instead of float precision for some extra perf | ||||
| //       however, by default use F32 since the op should be mostly memory bandwidth bound | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov