mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	CUDA: more info when no device code (#5088)
This commit is contained in:
		
							
								
								
									
										89
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										89
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							@@ -13,6 +13,10 @@
 | 
				
			|||||||
#include <map>
 | 
					#include <map>
 | 
				
			||||||
#include <array>
 | 
					#include <array>
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					// stringize macro for converting __CUDA_ARCH_LIST__ (list of integers) to string
 | 
				
			||||||
 | 
					#define STRINGIZE_IMPL(...) #__VA_ARGS__
 | 
				
			||||||
 | 
					#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#if defined(GGML_USE_HIPBLAS)
 | 
					#if defined(GGML_USE_HIPBLAS)
 | 
				
			||||||
#include <hip/hip_runtime.h>
 | 
					#include <hip/hip_runtime.h>
 | 
				
			||||||
#include <hipblas/hipblas.h>
 | 
					#include <hipblas/hipblas.h>
 | 
				
			||||||
@@ -584,13 +588,28 @@ static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0,
 | 
				
			|||||||
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 | 
					static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
[[noreturn]]
 | 
					[[noreturn]]
 | 
				
			||||||
static __device__ void bad_arch() {
 | 
					static __device__ void no_device_code(
 | 
				
			||||||
    printf("ERROR: ggml-cuda was compiled without support for the current GPU architecture.\n");
 | 
					    const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 | 
				
			||||||
 | 
					    printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
 | 
				
			||||||
 | 
					           file_name, line, function_name, arch);
 | 
				
			||||||
 | 
					    (void) arch_list;
 | 
				
			||||||
 | 
					#else
 | 
				
			||||||
 | 
					    printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
 | 
				
			||||||
 | 
					           file_name, line, function_name, arch, arch_list);
 | 
				
			||||||
 | 
					#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
 | 
				
			||||||
    __trap();
 | 
					    __trap();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    (void) bad_arch; // suppress unused function warning
 | 
					    (void) no_device_code; // suppress unused function warning
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#ifdef __CUDA_ARCH__
 | 
				
			||||||
 | 
					#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
 | 
				
			||||||
 | 
					#else
 | 
				
			||||||
 | 
					#define NO_DEVICE_CODE GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
 | 
				
			||||||
 | 
					#endif // __CUDA_ARCH__
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
 | 
					static __device__ __forceinline__ float warp_reduce_sum(float x) {
 | 
				
			||||||
#pragma unroll
 | 
					#pragma unroll
 | 
				
			||||||
    for (int mask = 16; mask > 0; mask >>= 1) {
 | 
					    for (int mask = 16; mask > 0; mask >>= 1) {
 | 
				
			||||||
@@ -617,7 +636,7 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
 | 
				
			|||||||
    return a;
 | 
					    return a;
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) a;
 | 
					    (void) a;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
 | 
					#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -638,7 +657,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
 | 
				
			|||||||
    return x;
 | 
					    return x;
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) x;
 | 
					    (void) x;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
 | 
					#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2421,7 +2440,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) vx; (void) y; (void) k;
 | 
					    (void) vx; (void) y; (void) k;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= CC_PASCAL
 | 
					#endif // __CUDA_ARCH__ >= CC_PASCAL
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2452,7 +2471,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp
 | 
				
			|||||||
    // second part effectively subtracts 8 from each quant value
 | 
					    // second part effectively subtracts 8 from each quant value
 | 
				
			||||||
    return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
 | 
					    return d4 * (sumi * ds8f.x - (8*vdr/QI4_0) * ds8f.y);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2489,7 +2508,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp
 | 
				
			|||||||
    // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
 | 
					    // scale second part of sum by QI8_1/(vdr * QR4_1) to compensate for multiple threads adding it
 | 
				
			||||||
    return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
 | 
					    return sumi * d4d8 + m4s8 / (QI8_1 / (vdr * QR4_1));
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2524,7 +2543,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp
 | 
				
			|||||||
    // second part effectively subtracts 16 from each quant value
 | 
					    // second part effectively subtracts 16 from each quant value
 | 
				
			||||||
    return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
 | 
					    return d5 * (sumi * ds8f.x - (16*vdr/QI5_0) * ds8f.y);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2569,7 +2588,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp
 | 
				
			|||||||
    return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
 | 
					    return sumi*d5d8 + m5s8 / (QI5_1 / vdr);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2590,7 +2609,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return d8_0*d8_1 * sumi;
 | 
					    return d8_0*d8_1 * sumi;
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2620,7 +2639,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
 | 
				
			|||||||
    // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
 | 
					    // scale second part of sum by QI8_1/ vdr to compensate for multiple threads adding it
 | 
				
			||||||
    return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
 | 
					    return sumi*d8d8 + m8s8 / (QI8_1 / vdr);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2655,7 +2674,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return dm2f.x*sumf_d - dm2f.y*sumf_m;
 | 
					    return dm2f.x*sumf_d - dm2f.y*sumf_m;
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2692,7 +2711,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
 | 
					    return d8 * (dm2f.x*sumi_d - dm2f.y*sumi_m);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2732,7 +2751,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return d3 * sumf;
 | 
					    return d3 * sumf;
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2757,7 +2776,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return d3*d8 * sumi;
 | 
					    return d3*d8 * sumi;
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2790,7 +2809,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
 | 
				
			|||||||
    return dm4f.x*sumf_d - dm4f.y*sumf_m;
 | 
					    return dm4f.x*sumf_d - dm4f.y*sumf_m;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2823,7 +2842,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
 | 
				
			|||||||
    return dm4f.x*sumf_d - dm4f.y*sumf_m;
 | 
					    return dm4f.x*sumf_d - dm4f.y*sumf_m;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2863,7 +2882,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
 | 
				
			|||||||
    return dm5f.x*sumf_d - dm5f.y*sumf_m;
 | 
					    return dm5f.x*sumf_d - dm5f.y*sumf_m;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2896,7 +2915,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
 | 
				
			|||||||
    return dm4f.x*sumf_d - dm4f.y*sumf_m;
 | 
					    return dm4f.x*sumf_d - dm4f.y*sumf_m;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2926,7 +2945,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    return d*sumf;
 | 
					    return d*sumf;
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -2957,7 +2976,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
 | 
				
			|||||||
    return d6 * sumf_d;
 | 
					    return d6 * sumf_d;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -3823,7 +3842,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
 | 
				
			|||||||
    return dall * sumf_d - dmin * sumf_m;
 | 
					    return dall * sumf_d - dmin * sumf_m;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
@@ -4006,7 +4025,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
 | 
				
			|||||||
    return d * sumf_d;
 | 
					    return d * sumf_d;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
					#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#endif
 | 
					#endif
 | 
				
			||||||
@@ -4501,7 +4520,7 @@ template <bool need_check> static __global__ void
 | 
				
			|||||||
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
					        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) vec_dot_q4_0_q8_1_mul_mat;
 | 
					    (void) vec_dot_q4_0_q8_1_mul_mat;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
					#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -4570,7 +4589,7 @@ template <bool need_check> static __global__ void
 | 
				
			|||||||
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
					        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) vec_dot_q4_1_q8_1_mul_mat;
 | 
					    (void) vec_dot_q4_1_q8_1_mul_mat;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
					#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -4637,7 +4656,7 @@ template <bool need_check> static __global__ void
 | 
				
			|||||||
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
					        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) vec_dot_q5_0_q8_1_mul_mat;
 | 
					    (void) vec_dot_q5_0_q8_1_mul_mat;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
					#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -4704,7 +4723,7 @@ mul_mat_q5_1(
 | 
				
			|||||||
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
					        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) vec_dot_q5_1_q8_1_mul_mat;
 | 
					    (void) vec_dot_q5_1_q8_1_mul_mat;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
					#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -4771,7 +4790,7 @@ template <bool need_check> static __global__ void
 | 
				
			|||||||
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
					        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) vec_dot_q8_0_q8_1_mul_mat;
 | 
					    (void) vec_dot_q8_0_q8_1_mul_mat;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
					#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -4838,7 +4857,7 @@ mul_mat_q2_K(
 | 
				
			|||||||
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
					        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) vec_dot_q2_K_q8_1_mul_mat;
 | 
					    (void) vec_dot_q2_K_q8_1_mul_mat;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
					#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -4907,7 +4926,7 @@ template <bool need_check> static __global__ void
 | 
				
			|||||||
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
					        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) vec_dot_q3_K_q8_1_mul_mat;
 | 
					    (void) vec_dot_q3_K_q8_1_mul_mat;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
					#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -4976,7 +4995,7 @@ template <bool need_check> static __global__ void
 | 
				
			|||||||
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
					        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) vec_dot_q4_K_q8_1_mul_mat;
 | 
					    (void) vec_dot_q4_K_q8_1_mul_mat;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
					#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -5043,7 +5062,7 @@ mul_mat_q5_K(
 | 
				
			|||||||
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
					        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) vec_dot_q5_K_q8_1_mul_mat;
 | 
					    (void) vec_dot_q5_K_q8_1_mul_mat;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
					#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -5112,7 +5131,7 @@ template <bool need_check> static __global__ void
 | 
				
			|||||||
        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
					        (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst);
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) vec_dot_q6_K_q8_1_mul_mat;
 | 
					    (void) vec_dot_q6_K_q8_1_mul_mat;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
					#endif // __CUDA_ARCH__ >= CC_VOLTA
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -5835,7 +5854,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
    (void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
 | 
					    (void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
 | 
				
			||||||
    bad_arch();
 | 
					    NO_DEVICE_CODE;
 | 
				
			||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
 | 
					#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user