CUDA: refactor and deduplicate vector FA kernels (#16208)

* CUDA: refactor and deduplicate vector FA kernels
This commit is contained in:
Johannes Gäßler
2025-09-27 18:45:07 +02:00
committed by GitHub
parent 0499b29c6f
commit 75a3a6c2cd
129 changed files with 1396 additions and 1989 deletions

View File

@@ -586,17 +586,42 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
}
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {
#ifdef FAST_FP16_AVAILABLE
acc += v*u;
#else
const float2 tmpv = __half22float2(v);
const float2 tmpu = __half22float2(u);
float2 tmpacc = __half22float2(acc);
tmpacc.x += tmpv.x * tmpu.x;
tmpacc.y += tmpv.y * tmpu.y;
acc = make_half2(tmpacc.x, tmpacc.y);
#endif // FAST_FP16_AVAILABLE
}
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
template <int nbytes>
template <int nbytes, int alignment = 0>
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
if constexpr (nbytes == 4) {
*(int *) dst = *(const int *) src;
} else if constexpr (nbytes == 8) {
*(int2 *) dst = *(const int2 *) src;
} else if constexpr (nbytes == 16) {
*(int4 *) dst = *(const int4 *) src;
} else {
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
if constexpr (alignment != 0) {
static_assert(nbytes % alignment == 0, "bad alignment");
}
constexpr int nb_per_cpy = alignment == 0 ? nbytes : alignment;
#pragma unroll
for (int i = 0; i < nbytes/nb_per_cpy; ++i) {
if constexpr (nb_per_cpy == 1) {
((char *) dst)[i] = ((const char *) src)[i];
} else if constexpr (nb_per_cpy == 2) {
((short *) dst)[i] = ((const short *) src)[i];
} else if constexpr (nb_per_cpy == 4) {
((int *) dst)[i] = ((const int *) src)[i];
} else if constexpr (nb_per_cpy == 8) {
((int2 *) dst)[i] = ((const int2 *) src)[i];
} else if constexpr (nb_per_cpy == 16) {
((int4 *) dst)[i] = ((const int4 *) src)[i];
} else {
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
}
}
}