mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
vulkan: Handle FA with all -inf mask values (#16447)
This commit is contained in:
@@ -345,7 +345,7 @@ void main() {
|
||||
|
||||
float Lfrcp[Br];
|
||||
[[unroll]] for (uint32_t r = 0; r < Br; ++r) {
|
||||
Lfrcp[r] = 1.0 / Lf[r];
|
||||
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
|
||||
@@ -380,7 +380,7 @@ void main() {
|
||||
|
||||
float Lfrcp[rows_per_thread];
|
||||
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
|
||||
Lfrcp[r] = 1.0 / Lf[r];
|
||||
Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
|
||||
}
|
||||
|
||||
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
|
||||
|
||||
@@ -121,7 +121,11 @@ void main() {
|
||||
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
||||
|
||||
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
||||
#if defined(ACC_TYPE_MAX)
|
||||
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-ACC_TYPE_MAX / ACC_TYPE(2));
|
||||
#else
|
||||
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
|
||||
#endif
|
||||
|
||||
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
|
||||
|
||||
@@ -294,7 +298,7 @@ void main() {
|
||||
|
||||
[[unroll]]
|
||||
for (int k = 0; k < Ldiag.length(); ++k) {
|
||||
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
|
||||
Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
|
||||
}
|
||||
|
||||
O = Ldiag*O;
|
||||
|
||||
@@ -91,7 +91,7 @@ void main() {
|
||||
L = L*ms + vs;
|
||||
}
|
||||
|
||||
L = 1.0 / L;
|
||||
L = (L == 0.0) ? 0.0 : 1.0 / L;
|
||||
|
||||
// D dimension is split across workgroups in the y dimension
|
||||
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;
|
||||
|
||||
Reference in New Issue
Block a user