HIP: use v_dot2_f32_f16 instruction for FA (#15884)

This commit is contained in:
Johannes Gäßler
2025-09-09 14:04:43 +02:00
committed by GitHub
parent ed54e32558
commit 17bc5a815f
2 changed files with 26 additions and 6 deletions

View File

@@ -545,6 +545,31 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
#endif // defined(GGML_USE_HIP)
}
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) {
acc += v*u;
}
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) {
acc += v.x*u.x;
acc += v.y*u.y;
}
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
#if defined(GGML_USE_HIP) && defined(GCN)
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
#else
#ifdef FAST_FP16_AVAILABLE
const float2 tmp = __half22float2(v*u);
acc += tmp.x + tmp.y;
#else
const float2 tmpv = __half22float2(v);
const float2 tmpu = __half22float2(u);
acc += tmpv.x * tmpu.x;
acc += tmpv.y * tmpu.y;
#endif // FAST_FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(GCN)
}
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#if CUDART_VERSION >= 12080
const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);

View File

@@ -304,12 +304,7 @@ static __global__ void flash_attn_tile(
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
#ifdef FAST_FP16_AVAILABLE
const float2 tmp = __half22float2(K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]);
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += tmp.x + tmp.y;
#else
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps];
#endif // FAST_FP16_AVAILABLE
ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
}
}
}