diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 03cc2716a7..537a7bc4a8 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -12,13 +12,9 @@ __embed_ggml-common.h__ #define GGML_METAL_USE_METAL4 #ifdef GGML_METAL_USE_METAL4 -#include #include #include - -using namespace metal; -using namespace mpp::tensor_ops; #endif using namespace metal; @@ -1754,7 +1750,7 @@ kernel void kernel_op_sum_f32( float sumf = 0; - for (int64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) { + for (uint64_t i0 = tpitg.x; i0 < args.np; i0 += ntg.x) { sumf += src0[i0]; } @@ -5457,6 +5453,7 @@ template [[host_name("kernel_flash_attn_ext_q8_0_dk576_dv512")]] kernel flash_at #undef FA_TYPES #undef FA_TYPES_BF +#undef FA_TYPES_F32 constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_EXT_VEC + 0)]]; constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; @@ -6078,6 +6075,7 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk576_dv512")]] kernel flas template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk576_dv512")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #undef FA_TYPES +#undef FA_TYPES_F32 constant int32_t FC_flash_attn_ext_vec_reduce_DV [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 0)]]; constant int32_t FC_flash_attn_ext_vec_reduce_NWG [[function_constant(FC_FLASH_ATTN_EXT_VEC_REDUCE + 1)]]; @@ -8211,9 +8209,9 @@ kernel void kernel_mul_mm( auto tA = tensor, tensor_inline>(sa, dextents(NK, NR0)); auto tB = tensor, tensor_inline>(sb, dextents(NR1, NK )); - constexpr auto desc = matmul2d_descriptor(NR1, NR0, NK, false, true, false, matmul2d_descriptor::mode::multiply_accumulate); + constexpr auto desc = mpp::tensor_ops::matmul2d_descriptor(NR1, NR0, NK, false, true, false, mpp::tensor_ops::matmul2d_descriptor::mode::multiply_accumulate); - matmul2d> mm; + mpp::tensor_ops::matmul2d> mm; auto cT = mm.get_destination_cooperative_tensor(); #endif @@ -8359,16 +8357,28 @@ kernel void kernel_mul_mm( } } - for (short i = 0; i < 8; ++i) { + if (FC_mul_mm_bc_inp) { + for (short i = 0; i < 8; ++i) { + const short sx = (tiitg%NL1); + const short sy = (tiitg/NL1)/8; + + const short lx = i; + const short ly = (tiitg/NL1)%8; + //const short lx = (tiitg/NL1)%8; + //const short ly = i; + + *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; + } + } else { const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short lx = i; + //const short lx = i; const short ly = (tiitg/NL1)%8; //const short lx = (tiitg/NL1)%8; //const short ly = i; - *(sb + NK*(8*sy + ly) + 8*sx + lx) = loop_k + iy + i < args.ne00 ? (S1) *((device T1 *) y + i) : 0; + *(threadgroup S1_2x4 *)(sb + NK*(8*sy + ly) + 8*sx) = (S1_2x4)(*((device T1_2x4 *) y)); } il = (il + 2 < nl) ? il + 2 : il % 2;