mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-20 12:07:33 +00:00
ggml : support bcast ggml_soft_max_ext, ggml_flash_attn_ext (#14435)
ggml-ci
This commit is contained in:
@@ -229,7 +229,9 @@ typedef struct {
|
||||
uint64_t nb21;
|
||||
uint64_t nb22;
|
||||
uint64_t nb23;
|
||||
int32_t ne32;
|
||||
uint64_t nb31;
|
||||
uint64_t nb32;
|
||||
int32_t ne1;
|
||||
int32_t ne2;
|
||||
float scale;
|
||||
@@ -461,9 +463,21 @@ typedef struct {
|
||||
} ggml_metal_kargs_sum_rows;
|
||||
|
||||
typedef struct {
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int32_t ne00;
|
||||
int32_t ne01;
|
||||
int32_t ne02;
|
||||
uint64_t nb01;
|
||||
uint64_t nb02;
|
||||
uint64_t nb03;
|
||||
int32_t ne11;
|
||||
int32_t ne12;
|
||||
int32_t ne13;
|
||||
uint64_t nb11;
|
||||
uint64_t nb12;
|
||||
uint64_t nb13;
|
||||
uint64_t nb1;
|
||||
uint64_t nb2;
|
||||
uint64_t nb3;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
|
||||
@@ -1725,7 +1725,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_MEAN:
|
||||
case GGML_OP_SOFT_MAX:
|
||||
case GGML_OP_GROUP_NORM:
|
||||
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
|
||||
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
|
||||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_L2_NORM:
|
||||
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
|
||||
@@ -2644,10 +2644,7 @@ static bool ggml_metal_encode_node(
|
||||
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(scale));
|
||||
memcpy(&max_bias, ((const int32_t *) dst->op_params) + 1, sizeof(max_bias));
|
||||
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
|
||||
const uint32_t n_head = nrows_x/nrows_y;
|
||||
const uint32_t n_head = src0->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
@@ -2707,6 +2704,18 @@ static bool ggml_metal_encode_node(
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.ne11 =*/ ne11,
|
||||
/*.ne12 =*/ ne12,
|
||||
/*.ne13 =*/ ne13,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.nb1 =*/ nb1,
|
||||
/*.nb2 =*/ nb2,
|
||||
/*.nb3 =*/ nb3,
|
||||
/*.scale =*/ scale,
|
||||
/*.max_bias =*/ max_bias,
|
||||
/*.m0 =*/ m0,
|
||||
@@ -2726,7 +2735,7 @@ static bool ggml_metal_encode_node(
|
||||
|
||||
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
} break;
|
||||
case GGML_OP_DIAG_MASK_INF:
|
||||
{
|
||||
@@ -4979,7 +4988,9 @@ static bool ggml_metal_encode_node(
|
||||
/*.nb21 =*/ nb21,
|
||||
/*.nb22 =*/ nb22,
|
||||
/*.nb23 =*/ nb23,
|
||||
/*.ne32 =*/ ne32,
|
||||
/*.nb31 =*/ nb31,
|
||||
/*.nb32 =*/ nb32,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.ne2 =*/ ne2,
|
||||
/*.scale =*/ scale,
|
||||
|
||||
@@ -1320,24 +1320,28 @@ kernel void kernel_soft_max(
|
||||
device char * dst,
|
||||
constant ggml_metal_kargs_soft_max & args,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
|
||||
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
|
||||
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
||||
uint3 tptg[[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
const int32_t i01 = tgpig.x;
|
||||
|
||||
device const float * psrc0 = (device const float *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00 : nullptr;
|
||||
device float * pdst = (device float *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00);
|
||||
const int32_t i13 = i03%args.ne13;
|
||||
const int32_t i12 = i02%args.ne12;
|
||||
const int32_t i11 = i01;
|
||||
|
||||
device const float * psrc0 = (device const float *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
||||
device float * pdst = (device float *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
// ALiBi
|
||||
if (args.max_bias > 0.0f) {
|
||||
const int64_t h = i02;
|
||||
const int32_t h = i02;
|
||||
|
||||
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||
@@ -1348,13 +1352,13 @@ kernel void kernel_soft_max(
|
||||
// parallel max
|
||||
float lmax = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||
lmax = MAX(lmax, psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f));
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
float max_val = simd_max(lmax);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (tptg.x > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
@@ -1373,7 +1377,7 @@ kernel void kernel_soft_max(
|
||||
|
||||
// parallel sum
|
||||
float lsum = 0.0f;
|
||||
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||
const float exp_psrc0 = exp((psrc0[i00]*args.scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
|
||||
lsum += exp_psrc0;
|
||||
pdst[i00] = exp_psrc0;
|
||||
@@ -1385,7 +1389,7 @@ kernel void kernel_soft_max(
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (tptg.x > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
@@ -1404,7 +1408,7 @@ kernel void kernel_soft_max(
|
||||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < args.ne00; i00 += ntg) {
|
||||
for (int i00 = tpitg.x; i00 < args.ne00; i00 += tptg.x) {
|
||||
pdst[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
@@ -1416,23 +1420,27 @@ kernel void kernel_soft_max_4(
|
||||
device char * dst,
|
||||
constant ggml_metal_kargs_soft_max & args,
|
||||
threadgroup float * buf [[threadgroup(0)]],
|
||||
uint tgpig[[threadgroup_position_in_grid]],
|
||||
uint tpitg[[thread_position_in_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint ntg[[threads_per_threadgroup]]) {
|
||||
const int64_t i03 = (tgpig) / (args.ne02*args.ne01);
|
||||
const int64_t i02 = (tgpig - i03*args.ne02*args.ne01) / args.ne01;
|
||||
const int64_t i01 = (tgpig - i03*args.ne02*args.ne01 - i02*args.ne01);
|
||||
uint3 tptg[[threads_per_threadgroup]]) {
|
||||
const int32_t i03 = tgpig.z;
|
||||
const int32_t i02 = tgpig.y;
|
||||
const int32_t i01 = tgpig.x;
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *) src0 + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
||||
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*args.ne00/4 : nullptr;
|
||||
device float4 * pdst4 = (device float4 *) dst + (i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00)/4;
|
||||
const int32_t i13 = i03%args.ne13;
|
||||
const int32_t i12 = i02%args.ne12;
|
||||
const int32_t i11 = i01;
|
||||
|
||||
device const float4 * psrc4 = (device const float4 *) (src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
||||
device const T * pmask = src1 != src0 ? (device const T * ) (src1 + i11*args.nb11 + i12*args.nb12 + i13*args.nb13) : nullptr;
|
||||
device float4 * pdst4 = (device float4 *) (dst + i01*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
if (args.max_bias > 0.0f) {
|
||||
const int64_t h = i02;
|
||||
const int32_t h = i02;
|
||||
|
||||
const float base = h < args.n_head_log2 ? args.m0 : args.m1;
|
||||
const int exp = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
|
||||
@@ -1443,14 +1451,14 @@ kernel void kernel_soft_max_4(
|
||||
// parallel max
|
||||
float4 lmax4 = -INFINITY;
|
||||
|
||||
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
||||
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||
lmax4 = fmax(lmax4, psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
|
||||
}
|
||||
|
||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||
|
||||
float max_val = simd_max(lmax);
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (tptg.x > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = -INFINITY;
|
||||
}
|
||||
@@ -1469,7 +1477,7 @@ kernel void kernel_soft_max_4(
|
||||
|
||||
// parallel sum
|
||||
float4 lsum4 = 0.0f;
|
||||
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
||||
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||
const float4 exp_psrc4 = exp((psrc4[i00]*args.scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
|
||||
lsum4 += exp_psrc4;
|
||||
pdst4[i00] = exp_psrc4;
|
||||
@@ -1483,7 +1491,7 @@ kernel void kernel_soft_max_4(
|
||||
|
||||
float sum = simd_sum(lsum);
|
||||
|
||||
if (ntg > N_SIMDWIDTH) {
|
||||
if (tptg.x > N_SIMDWIDTH) {
|
||||
if (sgitg == 0) {
|
||||
buf[tiisg] = 0.0f;
|
||||
}
|
||||
@@ -1502,7 +1510,7 @@ kernel void kernel_soft_max_4(
|
||||
|
||||
const float inv_sum = 1.0f/sum;
|
||||
|
||||
for (int i00 = tpitg; i00 < args.ne00/4; i00 += ntg) {
|
||||
for (int i00 = tpitg.x; i00 < args.ne00/4; i00 += tptg.x) {
|
||||
pdst4[i00] *= inv_sum;
|
||||
}
|
||||
}
|
||||
@@ -3776,7 +3784,7 @@ kernel void kernel_flash_attn_ext(
|
||||
// load the mask in shared memory
|
||||
#pragma unroll(Q)
|
||||
for (short j = 0; j < Q; ++j) {
|
||||
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
|
||||
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
|
||||
|
||||
const float m = pm[ic + tiisg];
|
||||
|
||||
@@ -4262,7 +4270,7 @@ kernel void kernel_flash_attn_ext_vec(
|
||||
const bool has_mask = mask != q;
|
||||
|
||||
// pointer to the mask
|
||||
device const half * pm = (device const half *) (mask + iq1*args.nb31);
|
||||
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
|
||||
|
||||
float slope = 1.0f;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user