mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
cont : simplify
This commit is contained in:
@@ -2007,11 +2007,20 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||||||
|
|
||||||
GGML_ASSERT(ne01 < 65536);
|
GGML_ASSERT(ne01 < 65536);
|
||||||
|
|
||||||
|
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
|
||||||
|
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
|
||||||
|
ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]);
|
||||||
|
ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0;
|
||||||
|
ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0;
|
||||||
|
|
||||||
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
|
||||||
|
|
||||||
ggml_metal_buffer_id bid_pad = bid_dst;
|
ggml_metal_buffer_id bid_pad = bid_dst;
|
||||||
bid_pad.offs += ggml_nbytes(op);
|
bid_pad.offs += ggml_nbytes(op);
|
||||||
|
|
||||||
|
ggml_metal_buffer_id bid_tmp = bid_pad;
|
||||||
|
bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_pad(op);
|
||||||
|
|
||||||
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
if (!ggml_metal_op_flash_attn_ext_use_vec(op)) {
|
||||||
// half8x8 kernel
|
// half8x8 kernel
|
||||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
@@ -2048,14 +2057,10 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
|
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 2);
|
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
||||||
if (op->src[3]) {
|
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 3);
|
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
||||||
} else {
|
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
|
|
||||||
}
|
|
||||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
|
||||||
|
|
||||||
assert(ne12 == ne22);
|
assert(ne12 == ne22);
|
||||||
assert(ne13 == ne23);
|
assert(ne13 == ne23);
|
||||||
@@ -2137,21 +2142,13 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
|
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
||||||
if (op->src[3]) {
|
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
|
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
||||||
} else {
|
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
|
ggml_metal_encoder_set_buffer (enc, bid_dst, 7);
|
||||||
}
|
|
||||||
if (op->src[4]) {
|
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
|
||||||
} else {
|
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
|
||||||
}
|
|
||||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 6);
|
|
||||||
ggml_metal_encoder_set_buffer (enc, bid_dst, 7);
|
|
||||||
|
|
||||||
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
|
||||||
|
|
||||||
@@ -2194,14 +2191,10 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
ggml_metal_encoder_set_pipeline(enc, pipeline0);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
|
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 2);
|
ggml_metal_encoder_set_buffer (enc, bid_src2, 2);
|
||||||
if (op->src[3]) {
|
ggml_metal_encoder_set_buffer (enc, bid_src3, 3);
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 3);
|
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
||||||
} else {
|
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 3);
|
|
||||||
}
|
|
||||||
ggml_metal_encoder_set_buffer (enc, bid_pad, 4);
|
|
||||||
|
|
||||||
assert(ne12 == ne22);
|
assert(ne12 == ne22);
|
||||||
assert(ne13 == ne23);
|
assert(ne13 == ne23);
|
||||||
@@ -2300,19 +2293,11 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||||||
|
|
||||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
|
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
|
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
|
||||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3);
|
ggml_metal_encoder_set_buffer (enc, bid_src2, 3);
|
||||||
if (op->src[3]) {
|
ggml_metal_encoder_set_buffer (enc, bid_src3, 4);
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4);
|
ggml_metal_encoder_set_buffer (enc, bid_src4, 5);
|
||||||
} else {
|
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4);
|
|
||||||
}
|
|
||||||
if (op->src[4]) {
|
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5);
|
|
||||||
} else {
|
|
||||||
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5);
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t smem = FATTN_SMEM(nsg);
|
const size_t smem = FATTN_SMEM(nsg);
|
||||||
|
|
||||||
@@ -2320,6 +2305,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||||||
GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
|
GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size);
|
||||||
|
|
||||||
if (nwg == 1) {
|
if (nwg == 1) {
|
||||||
|
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0);
|
||||||
|
|
||||||
// using 1 workgroup -> write the result directly into dst
|
// using 1 workgroup -> write the result directly into dst
|
||||||
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
||||||
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
|
ggml_metal_encoder_set_buffer(enc, bid_dst, 7);
|
||||||
@@ -2329,13 +2316,12 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) {
|
|||||||
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1);
|
||||||
} else {
|
} else {
|
||||||
// sanity checks
|
// sanity checks
|
||||||
|
assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0);
|
||||||
|
|
||||||
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
||||||
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
|
GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31));
|
||||||
|
|
||||||
// write the results from each workgroup into a temp buffer
|
// write the results from each workgroup into a temp buffer
|
||||||
ggml_metal_buffer_id bid_tmp = bid_dst;
|
|
||||||
bid_tmp.offs += ggml_nbytes(op) + ggml_metal_op_flash_attn_ext_extra_pad(op);
|
|
||||||
|
|
||||||
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
ggml_metal_encoder_set_buffer(enc, bid_pad, 6);
|
||||||
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
ggml_metal_encoder_set_buffer(enc, bid_tmp, 7);
|
||||||
|
|
||||||
|
|||||||
@@ -4591,7 +4591,6 @@ void kernel_flash_attn_ext_impl(
|
|||||||
|
|
||||||
// mask storage in shared mem
|
// mask storage in shared mem
|
||||||
threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
|
threadgroup half2 * sm2 = (threadgroup half2 *) (shmem_f16 + Q*T + 2*C);
|
||||||
threadgroup half * sm = (threadgroup half *) (sm2);
|
|
||||||
|
|
||||||
// per-query mask pointers
|
// per-query mask pointers
|
||||||
device const half2 * pm2[NQ];
|
device const half2 * pm2[NQ];
|
||||||
@@ -4676,6 +4675,8 @@ void kernel_flash_attn_ext_impl(
|
|||||||
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
|
v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C;
|
||||||
|
|
||||||
if (!FC_flash_attn_ext_has_mask) {
|
if (!FC_flash_attn_ext_has_mask) {
|
||||||
|
threadgroup half * sm = (threadgroup half *) (sm2);
|
||||||
|
|
||||||
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
FOR_UNROLL (short jj = 0; jj < NQ; ++jj) {
|
||||||
const short j = jj*NSG + sgitg;
|
const short j = jj*NSG + sgitg;
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user