CUDA: add BF16 support (#11093)

* CUDA: add BF16 support
This commit is contained in:
Johannes Gäßler
2025-01-06 02:33:52 +01:00
committed by GitHub
parent b56f079e28
commit 46e3556e01
6 changed files with 87 additions and 39 deletions

View File

@@ -3,6 +3,7 @@
#include <musa_runtime.h>
#include <musa.h>
#include <mublas.h>
#include <musa_bf16.h>
#include <musa_fp16.h>
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
@@ -132,3 +133,5 @@
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamEndCapture musaStreamEndCapture
typedef mt_bfloat16 nv_bfloat16;