mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
metal : simplify soft max kernel
ggml-ci
This commit is contained in:
@@ -203,9 +203,9 @@ kernel void kernel_soft_max(
|
|||||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float lmax = (tpitg < ne00) ? (psrc0[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY;
|
float lmax = -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
|
||||||
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -284,9 +284,9 @@ kernel void kernel_soft_max_4(
|
|||||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float4 lmax4 = tpitg < ne00/4 ? (psrc4[tpitg]*scale + (pmask ? pmask[tpitg] : 0.0f)) : -INFINITY;
|
float4 lmax4 = -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user