mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
metal : optimize FA vec for large sequences and BS <= 8 (#15566)
* metal : optmize FA vec for large heads and sequences * metal : adjust small-batch mul mv kernels ggml-ci * batched-bench : fix total speed computation ggml-ci * cont : add comments ggml-ci
This commit is contained in:
@@ -249,6 +249,7 @@ typedef struct {
|
||||
uint64_t nb33;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
int32_t ne3;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
@@ -257,6 +258,11 @@ typedef struct {
|
||||
float logit_softcap;
|
||||
} ggml_metal_kargs_flash_attn_ext;
|
||||
|
||||
typedef struct {
|
||||
int32_t nrows;
|
||||
int32_t ne20;
|
||||
} ggml_metal_kargs_flash_attn_ext_reduce;
|
||||
|
||||
typedef struct {
|
||||
int32_t ne00;
|
||||
int32_t ne02;
|
||||
|
||||
@@ -291,6 +291,10 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
||||
@@ -575,6 +579,7 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE,
|
||||
GGML_METAL_KERNEL_TYPE_SET_I32,
|
||||
GGML_METAL_KERNEL_TYPE_SET_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
@@ -1324,6 +1329,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_MXFP4_F32, mul_mv_mxfp4_f32, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2, mul_mv_ext_f32_f32_r1_2, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3, mul_mv_ext_f32_f32_r1_3, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4, mul_mv_ext_f32_f32_r1_4, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5, mul_mv_ext_f32_f32_r1_5, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
||||
@@ -1609,6 +1618,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_HK576_HV512, flash_attn_ext_vec_q5_0_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_HK576_HV512, flash_attn_ext_vec_q5_1_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_HK576_HV512, flash_attn_ext_vec_q8_0_hk576_hv512, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE, flash_attn_ext_reduce, has_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32, set_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32, set_i32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||
@@ -3385,15 +3395,16 @@ static int ggml_metal_encode_node(
|
||||
|
||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||
// to the matrix-vector kernel
|
||||
const int ne11_mm_min = 4;
|
||||
const int ne11_mm_min = 8;
|
||||
|
||||
// first try to use small-batch mat-mv kernels
|
||||
// these should be efficient for BS [2, ~8]
|
||||
if (src1t == GGML_TYPE_F32 && (ne00%256 == 0) &&
|
||||
if (src1t == GGML_TYPE_F32 && (ne00%128 == 0) &&
|
||||
(
|
||||
(
|
||||
(
|
||||
src0t == GGML_TYPE_F16 || // TODO: helper function
|
||||
src0t == GGML_TYPE_F32 || // TODO: helper function
|
||||
src0t == GGML_TYPE_F16 ||
|
||||
src0t == GGML_TYPE_Q4_0 ||
|
||||
src0t == GGML_TYPE_Q4_1 ||
|
||||
src0t == GGML_TYPE_Q5_0 ||
|
||||
@@ -3421,7 +3432,17 @@ static int ggml_metal_encode_node(
|
||||
// values and there can be some tail effects when nsg is high. need to confirm this
|
||||
//
|
||||
const int nsg = 2; // num simdgroups per threadgroup
|
||||
const int nxpsg = ne11 < 3 ? 16 : 8; // num threads along row per simdgroup
|
||||
|
||||
// num threads along row per simdgroup
|
||||
int nxpsg = 0;
|
||||
if (ne00 % 256 == 0 && ne11 < 3) {
|
||||
nxpsg = 16;
|
||||
} else if (ne00 % 128 == 0) {
|
||||
nxpsg = 8;
|
||||
} else {
|
||||
nxpsg = 4;
|
||||
}
|
||||
|
||||
const int nypsg = 32/nxpsg; // num threads along col per simdgroup (i.e. a simdgroup processes that many src0 rows at a time)
|
||||
const int r0ptg = nypsg*nsg; // num src0 rows per threadgroup
|
||||
int r1ptg = 4; // num src1 rows per threadgroup
|
||||
@@ -3444,6 +3465,14 @@ static int ggml_metal_encode_node(
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
switch (r1ptg) {
|
||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_2].pipeline; break;
|
||||
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_3].pipeline; break;
|
||||
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_4].pipeline; break;
|
||||
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F32_F32_R1_5].pipeline; break;
|
||||
default: GGML_ABORT("not implemented");
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
switch (r1ptg) {
|
||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
|
||||
@@ -3598,7 +3627,7 @@ static int ggml_metal_encode_node(
|
||||
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32 ].pipeline; break;
|
||||
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
|
||||
case GGML_TYPE_MXFP4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_MXFP4_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q2_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q3_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_K: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32 ].pipeline; break;
|
||||
@@ -5482,6 +5511,7 @@ static int ggml_metal_encode_node(
|
||||
/*.nb33 =*/ nb33,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.ne3 =*/ ne3,
|
||||
/*.scale =*/ scale,
|
||||
/*.max_bias =*/ max_bias,
|
||||
/*.m0 =*/ m0,
|
||||
@@ -5505,7 +5535,6 @@ static int ggml_metal_encode_node(
|
||||
} else {
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:5];
|
||||
}
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
|
||||
if (!use_vec_kernel) {
|
||||
// half8x8 kernel
|
||||
@@ -5531,7 +5560,7 @@ static int ggml_metal_encode_node(
|
||||
|
||||
while (true) {
|
||||
const size_t smem = FATTN_SMEM(nsgmax);
|
||||
if (smem > device.maxThreadgroupMemoryLength) {
|
||||
if (smem > device.maxThreadgroupMemoryLength/2) {
|
||||
break;
|
||||
}
|
||||
nsgmax *= 2;
|
||||
@@ -5543,15 +5572,18 @@ static int ggml_metal_encode_node(
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
|
||||
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
#undef FATTN_SMEM
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
#undef FATTN_SMEM
|
||||
} else {
|
||||
// half4x4 kernel
|
||||
const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !!
|
||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||
const int64_t nkpsg = 1*ncpsg; // TODO: make adjustable
|
||||
|
||||
GGML_ASSERT(nqptg <= 32);
|
||||
GGML_ASSERT(nqptg % 1 == 0);
|
||||
@@ -5561,15 +5593,17 @@ static int ggml_metal_encode_node(
|
||||
// for each query, we load it as f16 in shared memory (ne00)
|
||||
// and store the soft_max values and the mask
|
||||
//
|
||||
// ne00*(nsg)
|
||||
// ne20*(nsg)
|
||||
// each simdgroup has a full f32 head vector in shared mem to accumulate results
|
||||
//
|
||||
#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)) + 2*ne20*(nsg))*(sizeof(float)/2), 16))
|
||||
//#define FATTN_SMEM(nsg) (GGML_PAD((nqptg*(GGML_PAD(ne00, 128) + 4*ncpsg*(nsg)))*(sizeof(float)/2), 16))
|
||||
|
||||
int64_t nsgmax = 2;
|
||||
while (true) {
|
||||
const size_t smem = FATTN_SMEM(nsgmax);
|
||||
if (smem > device.maxThreadgroupMemoryLength) {
|
||||
// avoid using more than half of the threadgroup memory - can cause slow downs especially for large head sizes
|
||||
if (smem > device.maxThreadgroupMemoryLength/2) {
|
||||
break;
|
||||
}
|
||||
nsgmax *= 2;
|
||||
@@ -5577,7 +5611,7 @@ static int ggml_metal_encode_node(
|
||||
nsgmax /= 2;
|
||||
|
||||
// simdgroups per threadgroup (a.k.a. warps)
|
||||
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
||||
const int64_t nsgt = MAX(2, MIN(nsgmax, MIN((ne11 + nkpsg - 1)/(nkpsg), (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)));
|
||||
|
||||
int64_t nsg = 1;
|
||||
while (nsg <= nsgt) {
|
||||
@@ -5585,13 +5619,74 @@ static int ggml_metal_encode_node(
|
||||
}
|
||||
nsg /= 2;
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
// workgroups
|
||||
// each workgroup handles nsg*nkpsg cache values
|
||||
uint16_t nwg = 1;
|
||||
if (4*nsg*nkpsg >= ne11) {
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
//printf("smem: %zu, max: %zu, nsg = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg);
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
|
||||
// using 1 workgroup -> write the result directly into dst
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
|
||||
[encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
} else {
|
||||
nwg = 32;
|
||||
nsg = MIN(4, nsg);
|
||||
|
||||
const size_t smem = FATTN_SMEM(nsg);
|
||||
|
||||
//printf("smem: %zu, max: %zu, nsg = %d, nsgmax = %d\n", smem, device.maxThreadgroupMemoryLength, (int) nsg, (int) nsgmax);
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
|
||||
// sanity checks
|
||||
GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3);
|
||||
GGML_ASSERT(ne1*ne2*ne3 <= (1u << 31));
|
||||
|
||||
const int32_t nrows = ne1*ne2*ne3;
|
||||
|
||||
// temp buffer for writing the results from each workgroup
|
||||
// - ne20: the size of the head vector
|
||||
// - + 2: the S and M values for each intermediate result
|
||||
const size_t s_tmp = ggml_type_size(GGML_TYPE_F32)*(nrows*nwg*(ne20 + 2));
|
||||
id<MTLBuffer> h_tmp = ggml_metal_mem_pool_alloc(mem_pool, s_tmp);
|
||||
if (!h_tmp) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_tmp);
|
||||
return 0;
|
||||
}
|
||||
|
||||
//printf("ne01 = %d, ne02 = %d, ne03 = %d, ne20 = %d\n", ne01, ne02, ne03, ne20);
|
||||
//printf("needed memory: %.3f MiB\n", (float) (ne01*ne02*ne03*ne20*sizeof(float))/1024.0f/1024.0f);
|
||||
|
||||
[encoder setBuffer:h_tmp offset:0 atIndex:6];
|
||||
[encoder setBytes:&nwg length:sizeof(uint16_t) atIndex:7];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
|
||||
// reduce the results from the workgroups
|
||||
{
|
||||
ggml_metal_kargs_flash_attn_ext_reduce args0 = {
|
||||
nrows,
|
||||
ne20,
|
||||
};
|
||||
|
||||
id<MTLComputePipelineState> pipeline0 = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_REDUCE].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline0];
|
||||
[encoder setBytes:&args0 length:sizeof(args0) atIndex:0];
|
||||
[encoder setBuffer:h_tmp offset:0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
//printf("ne1 = %d, ne2 = %d, ne3 = %d, ne20 = %d\n", ne1, ne2, ne3, ne20);
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(32*32, 1, 1)];
|
||||
}
|
||||
}
|
||||
#undef FATTN_SMEM
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
|
||||
@@ -68,6 +68,11 @@ void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg)
|
||||
reg = (type4x4)(*src);
|
||||
}
|
||||
|
||||
template <typename type4>
|
||||
void dequantize_f32_t4(device const float4 * src, short il, thread type4 & reg) {
|
||||
reg = (type4)(*src);
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
||||
reg = (type4x4)(*src);
|
||||
@@ -3015,7 +3020,6 @@ void kernel_mul_mv_ext_q4_f32_impl(
|
||||
#pragma unroll(r1ptg)
|
||||
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||
sumf[ir1] += dot(lx[ch], y4[ir1][ch*nxpsg]);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3200,6 +3204,11 @@ kernel void kernel_mul_mv_ext_q4x4_f32_disp(
|
||||
typedef decltype(kernel_mul_mv_ext_q4_f32_disp <2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q4_f32_t;
|
||||
typedef decltype(kernel_mul_mv_ext_q4x4_f32_disp<2, block_q4_K, 256, dequantize_q4_K>) mul_mv_ext_q4x4_f32_t;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, float4, 4, dequantize_f32_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, float4, 4, dequantize_f32_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, float4, 4, dequantize_f32_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_f32_f32_r1_5")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<5, float4, 4, dequantize_f32_t4>;
|
||||
|
||||
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<2, half4, 4, dequantize_f16_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<3, half4, 4, dequantize_f16_t4>;
|
||||
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q4_f32_t kernel_mul_mv_ext_q4_f32_disp<4, half4, 4, dequantize_f16_t4>;
|
||||
@@ -4786,14 +4795,16 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
device const char * mask,
|
||||
device const char * sinks,
|
||||
device char * dst,
|
||||
constant uint16_t & nwg,
|
||||
threadgroup half * shmem_f16 [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 ntg[[threads_per_threadgroup]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const short nsg = ntg.y; // number of simdgroups
|
||||
const short iwg = tgpig[2]%nwg;
|
||||
|
||||
const int iq3 = tgpig[2];
|
||||
const int iq3 = tgpig[2]/nwg;
|
||||
const int iq2 = tgpig[1];
|
||||
const int iq1 = tgpig[0];
|
||||
|
||||
@@ -4872,7 +4883,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
|
||||
// loop over the KV cache
|
||||
// each simdgroup handles blocks of Q rows and C columns
|
||||
for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
|
||||
for (int ic0 = (int) iwg*C*nsg; ic0 < args.ne11; ic0 += (int) nwg*C*nsg) {
|
||||
const int ic = ic0 + C*sgitg;
|
||||
if (ic >= args.ne11) {
|
||||
break;
|
||||
@@ -5002,7 +5013,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
}
|
||||
}
|
||||
|
||||
if (sinks != q && sgitg == 0) {
|
||||
if (sinks != q && sgitg == 0 && iwg == 0) {
|
||||
const float m = M;
|
||||
const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2;
|
||||
|
||||
@@ -5111,14 +5122,25 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
device float4 * dst4 = (device float4 *) dst;
|
||||
|
||||
// final rescale with 1/S and store to global memory
|
||||
if (sgitg == 0) {
|
||||
const float S = ss[0];
|
||||
const int64_t nrows = args.ne3*args.ne2*args.ne1;
|
||||
const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1;
|
||||
|
||||
device float4 * dst4 = (device float4 *) dst;
|
||||
device float * dst1 = (device float *) dst + nrows*DV*nwg; // the S and M are stored after the results
|
||||
|
||||
const float S = nwg == 1 ? 1.0f/ss[0] : 1.0f;
|
||||
|
||||
// interleave the workgroup data
|
||||
for (short i = tiisg; i < DV4; i += NW) {
|
||||
dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*DV4 + i] = (float4) sr4[i]/S;
|
||||
dst4[rid*DV4*nwg + nwg*i + iwg] = (float4) sr4[i]*S;
|
||||
}
|
||||
|
||||
// store S and M
|
||||
if (nwg > 1 && tiisg == 0) {
|
||||
dst1[rid*(2*nwg) + 2*iwg + 0] = ss[0];
|
||||
dst1[rid*(2*nwg) + 2*iwg + 1] = ss[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5218,6 +5240,41 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_hk576_hv512")]] kernel flas
|
||||
|
||||
#undef FA_TYPES
|
||||
|
||||
kernel void kernel_flash_attn_ext_reduce(
|
||||
constant ggml_metal_kargs_flash_attn_ext_reduce & args,
|
||||
device const char * htmp,
|
||||
device char * dst,
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
ushort tiisg[[thread_index_in_simdgroup]],
|
||||
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
const uint64_t rid = tgpig;
|
||||
|
||||
const short nwg = 32;
|
||||
const short iwg = tiisg;
|
||||
const short DV = args.ne20;
|
||||
const short DV4 = DV/4;
|
||||
|
||||
device const float4 * htmp4 = (device const float4 *) htmp + rid*DV4*nwg;
|
||||
device const float * ss = (device const float *) htmp + (uint64_t)args.nrows*DV*nwg;
|
||||
device float4 * dst4 = (device float4 *) dst + rid*DV4;
|
||||
|
||||
float S = ss[rid*(2*nwg) + 2*iwg + 0];
|
||||
float M = ss[rid*(2*nwg) + 2*iwg + 1];
|
||||
|
||||
const float m = simd_max(M);
|
||||
const float ms = exp(M - m);
|
||||
|
||||
S = 1.0f/simd_sum(S*ms);
|
||||
|
||||
for (int i = sgitg; i < DV4; i += nwg) {
|
||||
const float4 v = simd_sum(htmp4[i*nwg + iwg]*ms);
|
||||
|
||||
if (iwg == 0) {
|
||||
dst4[i] = v*S;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_set(
|
||||
constant ggml_metal_kargs_set & args,
|
||||
|
||||
Reference in New Issue
Block a user