mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	HIP: Enable Matrix cores for MMQ Kernels, Enable stream-K for CDNA 3 (#14624)
This commit adds support for MFMA instructions to MMQ. CDNA1/GFX908 CDNA2/GFX90a and CDNA3/GFX942 are supported by the MFMA-enabled code path added by this commit. The code path and stream-k is only enabled on CDNA3 for now as it fails to outperform blas in all cases on the other devices. Blas is currently only consistently outperformed on CDNA3 due to issues in the amd-provided blas libraries. This commit also improves the awareness of MMQ towards different warp sizes and as a side effect improves the performance of all quant formats besides q4_0 and q4_1, which regress slightly, on GCN gpus.
This commit is contained in:
		@@ -56,7 +56,7 @@
 | 
			
		||||
#define GGML_CUDA_CC_GCN4       (GGML_CUDA_CC_OFFSET_AMD + 0x803)  // Tonga, Fiji, Polaris, minimum for fast fp16
 | 
			
		||||
#define GGML_CUDA_CC_VEGA       (GGML_CUDA_CC_OFFSET_AMD + 0x900)  // Vega56/64, minimum for fp16 dual issue
 | 
			
		||||
#define GGML_CUDA_CC_VEGA20     (GGML_CUDA_CC_OFFSET_AMD + 0x906)  // MI50/Radeon VII, minimum for dp4a
 | 
			
		||||
#define GGML_CUDA_CC_CDNA       (GGML_CUDA_CC_OFFSET_AMD + 0x908)  // MI100, minimum for MFMA, acc registers
 | 
			
		||||
#define GGML_CUDA_CC_CDNA1      (GGML_CUDA_CC_OFFSET_AMD + 0x908)  // MI100, minimum for MFMA, acc registers
 | 
			
		||||
#define GGML_CUDA_CC_CDNA2      (GGML_CUDA_CC_OFFSET_AMD + 0x910)  // MI210, minimum acc register renameing
 | 
			
		||||
#define GGML_CUDA_CC_CDNA3      (GGML_CUDA_CC_OFFSET_AMD + 0x942)  // MI300
 | 
			
		||||
 | 
			
		||||
@@ -72,8 +72,9 @@
 | 
			
		||||
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
 | 
			
		||||
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
 | 
			
		||||
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
 | 
			
		||||
#define GGML_CUDA_CC_IS_GCN(cc)   (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
 | 
			
		||||
#define GGML_CUDA_CC_IS_CDNA(cc)  (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
 | 
			
		||||
#define GGML_CUDA_CC_IS_GCN(cc)   (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
 | 
			
		||||
#define GGML_CUDA_CC_IS_CDNA(cc)  (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
 | 
			
		||||
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)
 | 
			
		||||
 | 
			
		||||
// Moore Threads
 | 
			
		||||
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
 | 
			
		||||
@@ -226,6 +227,10 @@ typedef float2 dfloat2;
 | 
			
		||||
#define FP16_MMA_AVAILABLE
 | 
			
		||||
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))
 | 
			
		||||
 | 
			
		||||
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
 | 
			
		||||
#define AMD_MFMA_AVAILABLE
 | 
			
		||||
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
 | 
			
		||||
 | 
			
		||||
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 | 
			
		||||
#define NEW_MMA_AVAILABLE
 | 
			
		||||
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 | 
			
		||||
@@ -288,6 +293,11 @@ static bool fp32_mma_hardware_available(const int cc) {
 | 
			
		||||
    return GGML_CUDA_CC_IS_CDNA(cc);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later.
 | 
			
		||||
static bool amd_mfma_available(const int cc) {
 | 
			
		||||
    return cc >= GGML_CUDA_CC_OFFSET_AMD && GGML_CUDA_CC_IS_CDNA3(cc);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
 | 
			
		||||
static bool new_mma_available(const int cc) {
 | 
			
		||||
    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user