mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	cont
This commit is contained in:
		| @@ -12,13 +12,9 @@ __embed_ggml-common.h__ | ||||
| #define GGML_METAL_USE_METAL4 | ||||
|  | ||||
| #ifdef GGML_METAL_USE_METAL4 | ||||
| #include <metal_stdlib> | ||||
| #include <metal_tensor> | ||||
|  | ||||
| #include <MetalPerformancePrimitives/MetalPerformancePrimitives.h> | ||||
|  | ||||
| using namespace metal; | ||||
| using namespace mpp::tensor_ops; | ||||
| #endif | ||||
|  | ||||
| using namespace metal; | ||||
| @@ -1754,7 +1750,7 @@ kernel void kernel_op_sum_f32( | ||||
|  | ||||
|     float sumf = 0; | ||||
|  | ||||
|     for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) { | ||||
|     for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) { | ||||
|         sumf += src0[i0]; | ||||
|     } | ||||
|  | ||||
| @@ -5457,6 +5453,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_at | ||||
|  | ||||
| #undef FA_TYPES | ||||
| #undef FA_TYPES_BF | ||||
| #undef FA_TYPES_F32 | ||||
|  | ||||
| constant bool FC_flash_attn_ext_vec_has_mask  [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]]; | ||||
| constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; | ||||
| @@ -6078,6 +6075,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flas | ||||
| template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES,     block_q8_0, 8, dequantize_q8_0_t4, block_q8_0,  8, dequantize_q8_0_t4, 576, 512, 2>; | ||||
|  | ||||
| #undef FA_TYPES | ||||
| #undef FA_TYPES_F32 | ||||
|  | ||||
| constant int32_t FC_flash_attn_ext_vec_reduce_DV  [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]]; | ||||
| constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]]; | ||||
| @@ -8211,9 +8209,9 @@ kernel void kernel_mul_mm( | ||||
|     auto tA = tensor<threadgroup S0,    dextents<int32_t, 2>, tensor_inline>(sa, dextents<int32_t, 2>(NK,  NR0)); | ||||
|     auto tB = tensor<threadgroup S1,    dextents<int32_t, 2>, tensor_inline>(sb, dextents<int32_t, 2>(NR1, NK )); | ||||
|  | ||||
|     constexpr auto desc = matmul2d_descriptor(NR1, NR0, NK, false, true, false, matmul2d_descriptor::mode::multiply_accumulate); | ||||
|     constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); | ||||
|  | ||||
|     matmul2d<desc, execution_simdgroups<4>> mm; | ||||
|     mpp::tensor_ops::matmul2d<desc, execution_simdgroups<4>> mm; | ||||
|  | ||||
|     auto cT = mm.get_destination_cooperative_tensor<decltype(tA), decltype(tB), float>(); | ||||
| #endif | ||||
| @@ -8359,6 +8357,7 @@ kernel void kernel_mul_mm( | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if (FC_mul_mm_bc_inp) { | ||||
|             for (short i = 0; i < 8; ++i) { | ||||
|                 const short sx = (tiitg%NL1); | ||||
|                 const short sy = (tiitg/NL1)/8; | ||||
| @@ -8370,6 +8369,17 @@ kernel void kernel_mul_mm( | ||||
|  | ||||
|                 *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; | ||||
|             } | ||||
|         } else { | ||||
|             const short sx = (tiitg%NL1); | ||||
|             const short sy = (tiitg/NL1)/8; | ||||
|  | ||||
|             //const short lx = i; | ||||
|             const short ly = (tiitg/NL1)%8; | ||||
|             //const short lx = (tiitg/NL1)%8; | ||||
|             //const short ly = i; | ||||
|  | ||||
|             *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); | ||||
|         } | ||||
|  | ||||
|         il = (il + 2 < nl) ? il + 2 : il % 2; | ||||
|         x  = (il < 2) ? x + (2 + nl - 1)/nl : x; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov