diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 17a1a8c30a..b6ccc8b9da 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2007,11 +2007,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne01 < 65536); + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]); + ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0; + ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0; + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); ggml_metal_buffer_id bid_pad = bid_dst; bid_pad.offs += ggml_nbytes(op); + ggml_metal_buffer_id bid_tmp = bid_pad; + bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad(op); + if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { // half8x8 kernel const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! @@ -2048,14 +2057,10 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 2); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 3); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3); - } - ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); assert(ne12 == ne22); assert(ne13 == ne23); @@ -2137,21 +2142,13 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); - } - if (op->src[4]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); - } - ggml_metal_encoder_set_buffer (enc, bid_pad, 6); - ggml_metal_encoder_set_buffer (enc, bid_dst, 7); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); + ggml_metal_encoder_set_buffer (enc, bid_pad, 6); + ggml_metal_encoder_set_buffer (enc, bid_dst, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); @@ -2194,14 +2191,10 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_pipeline(enc, pipeline0); ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 2); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 3); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3); - } - ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); assert(ne12 == ne22); assert(ne13 == ne23); @@ -2300,19 +2293,11 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); - } - if (op->src[4]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); - } + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); const size_t smem = FATTN_SMEM(nsg); @@ -2320,6 +2305,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); if (nwg == 1) { + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0); + // using 1 workgroup -> write the result directly into dst ggml_metal_encoder_set_buffer(enc, bid_pad, 6); ggml_metal_encoder_set_buffer(enc, bid_dst, 7); @@ -2329,13 +2316,12 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); } else { // sanity checks + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0); + GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31)); // write the results from each workgroup into a temp buffer - ggml_metal_buffer_id bid_tmp = bid_dst; - bid_tmp.offs += ggml_nbytes(op) + ggml_metal_op_flash_attn_ext_extra_pad(op); - ggml_metal_encoder_set_buffer(enc, bid_pad, 6); ggml_metal_encoder_set_buffer(enc, bid_tmp, 7); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0deb53d298..d636f7d776 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4591,7 +4591,6 @@ void kernel_flash_attn_ext_impl( // mask storage in shared mem threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C); - threadgroup half * sm = (threadgroup half *) (sm2); // per-query mask pointers device const half2 * pm2[NQ]; @@ -4676,6 +4675,8 @@ void kernel_flash_attn_ext_impl( v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; if (!FC_flash_attn_ext_has_mask) { + threadgroup half * sm = (threadgroup half *) (sm2); + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { const short j = jj*NSG + sgitg;