vulkan: Handle FA with all -inf mask values (#16447)

This commit is contained in:
Jeff Bolz
2025-10-20 22:16:08 -05:00
committed by GitHub
parent 6de8ed7519
commit fb349848f3
4 changed files with 8 additions and 4 deletions

View File

@@ -345,7 +345,7 @@ void main() {
float Lfrcp[Br]; float Lfrcp[Br];
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Lfrcp[r] = 1.0 / Lf[r]; Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
} }
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {

View File

@@ -380,7 +380,7 @@ void main() {
float Lfrcp[rows_per_thread]; float Lfrcp[rows_per_thread];
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Lfrcp[r] = 1.0 / Lf[r]; Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]);
} }
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {

View File

@@ -121,7 +121,11 @@ void main() {
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0); L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
#if defined(ACC_TYPE_MAX)
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(-ACC_TYPE_MAX / ACC_TYPE(2));
#else
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2); M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
#endif
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0); coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
@@ -294,7 +298,7 @@ void main() {
[[unroll]] [[unroll]]
for (int k = 0; k < Ldiag.length(); ++k) { for (int k = 0; k < Ldiag.length(); ++k) {
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; Ldiag[k] = (Ldiag[k] == 0.0) ? ACC_TYPE(0.0) : (ACC_TYPE(1.0) / Ldiag[k]);
} }
O = Ldiag*O; O = Ldiag*O;

View File

@@ -91,7 +91,7 @@ void main() {
L = L*ms + vs; L = L*ms + vs;
} }
L = 1.0 / L; L = (L == 0.0) ? 0.0 : 1.0 / L;
// D dimension is split across workgroups in the y dimension // D dimension is split across workgroups in the y dimension
uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE; uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE;