vulkan: Support mul_mat_id with f32 accumulators (#15337)

* vulkan: Add missing bounds checking to scalar/coopmat1 mul_mat_id

* vulkan: Support mul_mat_id with f32 accumulators, but they are not hooked up

- There's no explicit way to request f32 precision for mul_mat_id, but there
probably should be, and this gets the code in place for that.
- A couple fixes to check_results.
- Remove casts to fp16 in coopmat1 FA shader (found by inspection).
This commit is contained in:
Jeff Bolz
2025-08-16 04:18:31 -05:00
committed by GitHub
parent 2e2b22ba66
commit de2192794f
2 changed files with 79 additions and 89 deletions

View File

@@ -210,7 +210,7 @@ void main() {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
}
}
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
@@ -233,7 +233,7 @@ void main() {
vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]);
#endif
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] += float16_t(Pf[r]) * ACC_TYPEV4(Vf);
Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf);
}
}
}
@@ -288,7 +288,7 @@ void main() {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
Of[r][d] = float16_t(eMf[r]) * Of[r][d];
Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d];
tmpshv4[tid] = Of[r][d];
barrier();
@@ -357,7 +357,7 @@ void main() {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= float16_t(Lfrcp[r]);
Of[r][d] *= ACC_TYPE(Lfrcp[r]);
}
}