mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal : fix build and some more comments (#10229)
This commit is contained in:
		| @@ -3041,6 +3041,8 @@ static void ggml_metal_encode_node( | |||||||
|  |  | ||||||
|                 bool use_vec_kernel = false; |                 bool use_vec_kernel = false; | ||||||
|  |  | ||||||
|  |                 // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) | ||||||
|  |                 //       for now avoiding mainly to keep the number of templates/kernels a bit lower | ||||||
|                 if (ne01 >= 4 || (ne00%128 != 0)) { |                 if (ne01 >= 4 || (ne00%128 != 0)) { | ||||||
|                     switch (src1->type) { |                     switch (src1->type) { | ||||||
|                         case GGML_TYPE_F16: |                         case GGML_TYPE_F16: | ||||||
|   | |||||||
| @@ -3356,7 +3356,7 @@ kernel void kernel_flash_attn_ext_vec( | |||||||
|     const short D4  = D/4; |     const short D4  = D/4; | ||||||
|     const short D16 = D/16; |     const short D16 = D/16; | ||||||
|     const short NW  = N_SIMDWIDTH; |     const short NW  = N_SIMDWIDTH; | ||||||
|     const short NL  = NW/4; |     const short NL  = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0 | ||||||
|     const short SH  = 2*C;  // shared memory per simdgroup |     const short SH  = 2*C;  // shared memory per simdgroup | ||||||
|  |  | ||||||
|     const short T = D + nsg*SH; // shared memory size per query in (half) |     const short T = D + nsg*SH; // shared memory size per query in (half) | ||||||
| @@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec( | |||||||
|  |  | ||||||
|             // Q*K^T |             // Q*K^T | ||||||
|             { |             { | ||||||
|                 // each simdgroup processes 1 query and 4 keys |                 // each simdgroup processes 1 query and 4 (NW/NL) keys | ||||||
|                 for (short cc = 0; cc < C/4; ++cc) { |                 for (short cc = 0; cc < C/4; ++cc) { | ||||||
|                     qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 }; |                     qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 }; | ||||||
|  |  | ||||||
| @@ -3646,7 +3646,7 @@ kernel void kernel_flash_attn_ext_vec( | |||||||
|     half,  half4,  half4x4, \ |     half,  half4,  half4x4, \ | ||||||
|                    half4x4 |                    half4x4 | ||||||
|  |  | ||||||
| typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t; | typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t; | ||||||
|  |  | ||||||
| template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,     1, dequantize_f16,  128>; | template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,     1, dequantize_f16,  128>; | ||||||
| #if defined(GGML_METAL_USE_BF16) | #if defined(GGML_METAL_USE_BF16) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov