mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	CUDA: fix MMQ nwarps for AMD with warp_size==32 (#15014)
This commit is contained in:
		@@ -251,25 +251,21 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)
 | 
			
		||||
#endif // AMD_MFMA_AVAILABLE
 | 
			
		||||
 | 
			
		||||
#if defined(GGML_USE_HIP)
 | 
			
		||||
static int mmq_get_nwarps_host(const int cc) {
 | 
			
		||||
    return amd_mfma_available(cc) ? 8 : 4;
 | 
			
		||||
static int mmq_get_nwarps_host(const int cc, const int warp_size) {
 | 
			
		||||
    return amd_mfma_available(cc) ? 8 : 256/warp_size;
 | 
			
		||||
}
 | 
			
		||||
#else
 | 
			
		||||
static int mmq_get_nwarps_host(const int /*cc*/) {
 | 
			
		||||
    return 8;
 | 
			
		||||
static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
 | 
			
		||||
    return 256/warp_size;
 | 
			
		||||
}
 | 
			
		||||
#endif // (GGML_USE_HIP)
 | 
			
		||||
 | 
			
		||||
static constexpr __device__ int mmq_get_nwarps_device() {
 | 
			
		||||
#if defined(GGML_USE_HIP)
 | 
			
		||||
#if defined(AMD_MFMA_AVAILABLE)
 | 
			
		||||
    return 8;
 | 
			
		||||
#else
 | 
			
		||||
    return 4;
 | 
			
		||||
    return 256/ggml_cuda_get_physical_warp_size();
 | 
			
		||||
#endif // AMD_MFMA_AVAILABLE
 | 
			
		||||
#else
 | 
			
		||||
    return 8;
 | 
			
		||||
#endif // defined(GGML_USE_HIP)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ------------------------------------------------------------
 | 
			
		||||
@@ -3472,7 +3468,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
 | 
			
		||||
    const int cc = ggml_cuda_info().devices[id].cc;
 | 
			
		||||
    const int nsm = ggml_cuda_info().devices[id].nsm;
 | 
			
		||||
    const int warp_size = ggml_cuda_info().devices[id].warp_size;
 | 
			
		||||
    const int nwarps = mmq_get_nwarps_host(cc);
 | 
			
		||||
    const int nwarps = mmq_get_nwarps_host(cc, warp_size);
 | 
			
		||||
    const int mmq_y = get_mmq_y_host(cc);
 | 
			
		||||
 | 
			
		||||
    const dim3 block_dims(warp_size, nwarps, 1);
 | 
			
		||||
@@ -3559,7 +3555,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
 | 
			
		||||
    const int    cc     = ggml_cuda_info().devices[id].cc;
 | 
			
		||||
    const size_t smpbo  = ggml_cuda_info().devices[id].smpbo;
 | 
			
		||||
    const int warp_size = ggml_cuda_info().devices[id].warp_size;
 | 
			
		||||
    const int nwarps    = mmq_get_nwarps_host(cc);
 | 
			
		||||
    const int nwarps    = mmq_get_nwarps_host(cc, warp_size);
 | 
			
		||||
 | 
			
		||||
    const int mmq_x_max = get_mmq_x_max_host(cc);
 | 
			
		||||
    const int mmq_y = get_mmq_y_host(cc);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user