cuda : implement bf16 cpy ops and enable bf16 cont (#14763)

* implement bf16 cpy ops and enable bf16 cont

* deduplicate copy functions

* deduplicate checks
This commit is contained in:
Sigbjørn Skjæret
2025-07-22 12:33:10 +02:00
committed by GitHub
parent 8e6f8bc875
commit e28c0b80c2
4 changed files with 49 additions and 124 deletions

View File

@@ -4,24 +4,8 @@
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
template<typename src_t, typename dst_t>
__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
GGML_UNUSED(src_f);
GGML_UNUSED(dst_f);
}
template<>
__device__ __forceinline__ void set_rows_1<float, half>(const float * src_f, half * dst_h) {
convert_f32_f16(src_f, dst_h);
}
template<>
__device__ __forceinline__ void set_rows_1<float, nv_bfloat16>(const float * src_f, nv_bfloat16 * dst_b) {
convert_f32_bf16(src_f, dst_b);
}
template<>
__device__ __forceinline__ void set_rows_1<float, float>(const float * src_f, float * dst_f) {
convert_f32_f32(src_f, dst_f);
__device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) {
convert_flt(src_f, dst_f);
}
// Generic quantized set_rows kernel template