mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	ggml-cuda.cu: Clean up warnings when compiling with clang
This commit is contained in:
		
							
								
								
									
										93
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										93
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -235,7 +235,7 @@ typedef float2 dfloat2; | |||||||
| #endif //GGML_CUDA_F16 | #endif //GGML_CUDA_F16 | ||||||
|  |  | ||||||
| static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { | static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const int & i32) { | ||||||
|     const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment |     const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment | ||||||
|  |  | ||||||
|     int x32 = 0; |     int x32 = 0; | ||||||
|     x32 |= x16[0] <<  0; |     x32 |= x16[0] <<  0; | ||||||
| @@ -245,7 +245,7 @@ static __device__ __forceinline__ int get_int_from_int8(const int8_t * x8, const | |||||||
| } | } | ||||||
|  |  | ||||||
| static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { | static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, const int & i32) { | ||||||
|     const uint16_t * x16 = (uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment |     const uint16_t * x16 = (const uint16_t *) (x8 + sizeof(int) * i32); // assume at least 2 byte alignment | ||||||
|  |  | ||||||
|     int x32 = 0; |     int x32 = 0; | ||||||
|     x32 |= x16[0] <<  0; |     x32 |= x16[0] <<  0; | ||||||
| @@ -255,11 +255,11 @@ static __device__ __forceinline__ int get_int_from_uint8(const uint8_t * x8, con | |||||||
| } | } | ||||||
|  |  | ||||||
| static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { | static __device__ __forceinline__ int get_int_from_int8_aligned(const int8_t * x8, const int & i32) { | ||||||
|     return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment |     return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment | ||||||
| } | } | ||||||
|  |  | ||||||
| static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { | static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * x8, const int & i32) { | ||||||
|     return *((int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment |     return *((const int *) (x8 + sizeof(int) * i32)); // assume at least 4 byte alignment | ||||||
| } | } | ||||||
|  |  | ||||||
| template<typename T> | template<typename T> | ||||||
| @@ -469,7 +469,7 @@ static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUA | |||||||
| #define MUL_MAT_SRC1_COL_STRIDE 128 | #define MUL_MAT_SRC1_COL_STRIDE 128 | ||||||
|  |  | ||||||
| #define MAX_STREAMS 8 | #define MAX_STREAMS 8 | ||||||
| static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { nullptr }; | static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_DEVICES][MAX_STREAMS] = { { nullptr } }; | ||||||
|  |  | ||||||
| struct ggml_tensor_extra_gpu { | struct ggml_tensor_extra_gpu { | ||||||
|     void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors |     void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors | ||||||
| @@ -2248,6 +2248,7 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1( | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y]; |     __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y]; | ||||||
|     __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; |     __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; | ||||||
| @@ -2259,7 +2260,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_0( | |||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_0( | ||||||
|     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
|     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { |     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|     GGML_CUDA_ASSUME(i_offset >= 0); |     GGML_CUDA_ASSUME(i_offset >= 0); | ||||||
|     GGML_CUDA_ASSUME(i_offset <  nwarps); |     GGML_CUDA_ASSUME(i_offset <  nwarps); | ||||||
|     GGML_CUDA_ASSUME(k >= 0); |     GGML_CUDA_ASSUME(k >= 0); | ||||||
| @@ -2268,7 +2269,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|     const int kbx  = k / QI4_0; |     const int kbx  = k / QI4_0; | ||||||
|     const int kqsx = k % QI4_0; |     const int kqsx = k % QI4_0; | ||||||
|  |  | ||||||
|     const block_q4_0 * bx0 = (block_q4_0 *) vx; |     const block_q4_0 * bx0 = (const block_q4_0 *) vx; | ||||||
|  |  | ||||||
|     float * x_dmf = (float *) x_dm; |     float * x_dmf = (float *) x_dm; | ||||||
|  |  | ||||||
| @@ -2306,9 +2307,10 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( | static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { |     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); |     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); | ||||||
|     const float * x_dmf = (float *) x_dm; |     const float * x_dmf = (const float *) x_dm; | ||||||
|  |  | ||||||
|     int u[2*VDR_Q4_0_Q8_1_MMQ]; |     int u[2*VDR_Q4_0_Q8_1_MMQ]; | ||||||
|  |  | ||||||
| @@ -2342,6 +2344,7 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1( | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     __shared__ int   tile_x_qs[mmq_y * (WARP_SIZE) +     + mmq_y]; |     __shared__ int   tile_x_qs[mmq_y * (WARP_SIZE) +     + mmq_y]; | ||||||
|     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; |     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; | ||||||
| @@ -2353,6 +2356,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_1( | |||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_1( | ||||||
|     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
|     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { |     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     GGML_CUDA_ASSUME(i_offset >= 0); |     GGML_CUDA_ASSUME(i_offset >= 0); | ||||||
|     GGML_CUDA_ASSUME(i_offset <  nwarps); |     GGML_CUDA_ASSUME(i_offset <  nwarps); | ||||||
| @@ -2362,7 +2366,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|     const int kbx  = k / QI4_1; |     const int kbx  = k / QI4_1; | ||||||
|     const int kqsx = k % QI4_1; |     const int kqsx = k % QI4_1; | ||||||
|  |  | ||||||
|     const block_q4_1 * bx0 = (block_q4_1 *) vx; |     const block_q4_1 * bx0 = (const block_q4_1 *) vx; | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { |     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { | ||||||
| @@ -2397,6 +2401,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( | static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { |     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); |     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); | ||||||
|  |  | ||||||
| @@ -2434,6 +2439,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1( | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     __shared__ int  tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y]; |     __shared__ int  tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y]; | ||||||
|     __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0]; |     __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0]; | ||||||
| @@ -2445,6 +2451,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_0( | |||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_0( | ||||||
|     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
|     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { |     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     GGML_CUDA_ASSUME(i_offset >= 0); |     GGML_CUDA_ASSUME(i_offset >= 0); | ||||||
|     GGML_CUDA_ASSUME(i_offset <  nwarps); |     GGML_CUDA_ASSUME(i_offset <  nwarps); | ||||||
| @@ -2454,7 +2461,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|     const int kbx  = k / QI5_0; |     const int kbx  = k / QI5_0; | ||||||
|     const int kqsx = k % QI5_0; |     const int kqsx = k % QI5_0; | ||||||
|  |  | ||||||
|     const block_q5_0 * bx0 = (block_q5_0 *) vx; |     const block_q5_0 * bx0 = (const block_q5_0 *) vx; | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { |     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { | ||||||
| @@ -2509,6 +2516,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( | static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { |     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); |     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); | ||||||
|     const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; |     const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; | ||||||
| @@ -2548,6 +2556,7 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1( | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y]; |     __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y]; | ||||||
|     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1]; |     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1]; | ||||||
| @@ -2559,6 +2568,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_1( | |||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_1( | ||||||
|     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
|     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { |     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     GGML_CUDA_ASSUME(i_offset >= 0); |     GGML_CUDA_ASSUME(i_offset >= 0); | ||||||
|     GGML_CUDA_ASSUME(i_offset < nwarps); |     GGML_CUDA_ASSUME(i_offset < nwarps); | ||||||
| @@ -2568,7 +2578,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|     const int kbx  = k / QI5_1; |     const int kbx  = k / QI5_1; | ||||||
|     const int kqsx = k % QI5_1; |     const int kqsx = k % QI5_1; | ||||||
|  |  | ||||||
|     const block_q5_1 * bx0 = (block_q5_1 *) vx; |     const block_q5_1 * bx0 = (const block_q5_1 *) vx; | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { |     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { | ||||||
| @@ -2620,6 +2630,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( | static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { |     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); |     const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); | ||||||
|     const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; |     const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; | ||||||
| @@ -2654,6 +2665,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y]; |     __shared__ int  tile_x_qs[mmq_y * (WARP_SIZE)       + mmq_y]; | ||||||
|     __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0]; |     __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0]; | ||||||
| @@ -2665,6 +2677,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0( | |||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q8_0( | ||||||
|     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
|     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { |     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     GGML_CUDA_ASSUME(i_offset >= 0); |     GGML_CUDA_ASSUME(i_offset >= 0); | ||||||
|     GGML_CUDA_ASSUME(i_offset <  nwarps); |     GGML_CUDA_ASSUME(i_offset <  nwarps); | ||||||
| @@ -2675,7 +2688,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|     const int kqsx = k % QI8_0; |     const int kqsx = k % QI8_0; | ||||||
|     float * x_dmf = (float *) x_dm; |     float * x_dmf = (float *) x_dm; | ||||||
|  |  | ||||||
|     const block_q8_0 * bx0 = (block_q8_0 *) vx; |     const block_q8_0 * bx0 = (const block_q8_0 *) vx; | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { |     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { | ||||||
| @@ -2710,6 +2723,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( | static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { |     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { | ||||||
|  |     (void)x_qh; (void)x_sc; | ||||||
|  |  | ||||||
|     const float * x_dmf = (const float *) x_dm; |     const float * x_dmf = (const float *) x_dm; | ||||||
|     const float * y_df  = (const float *) y_ds; |     const float * y_df  = (const float *) y_ds; | ||||||
| @@ -2743,6 +2757,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1( | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | ||||||
|  |     (void)x_qh; | ||||||
|  |  | ||||||
|     __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y]; |     __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y]; | ||||||
|     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K]; |     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K]; | ||||||
| @@ -2756,6 +2771,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q2_K( | |||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q2_K( | ||||||
|     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
|     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { |     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { | ||||||
|  |     (void)x_qh; | ||||||
|  |  | ||||||
|     GGML_CUDA_ASSUME(i_offset >= 0); |     GGML_CUDA_ASSUME(i_offset >= 0); | ||||||
|     GGML_CUDA_ASSUME(i_offset <  nwarps); |     GGML_CUDA_ASSUME(i_offset <  nwarps); | ||||||
| @@ -2765,7 +2781,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|     const int kbx  = k / QI2_K; |     const int kbx  = k / QI2_K; | ||||||
|     const int kqsx = k % QI2_K; |     const int kqsx = k % QI2_K; | ||||||
|  |  | ||||||
|     const block_q2_K * bx0 = (block_q2_K *) vx; |     const block_q2_K * bx0 = (const block_q2_K *) vx; | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { |     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { | ||||||
| @@ -2813,6 +2829,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( | static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { |     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { | ||||||
|  |     (void)x_qh; | ||||||
|  |  | ||||||
|     const int kbx = k / QI2_K; |     const int kbx = k / QI2_K; | ||||||
|     const int ky  = (k % QI2_K) * QR2_K; |     const int ky  = (k % QI2_K) * QR2_K; | ||||||
| @@ -2886,7 +2903,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|     const int kbx  = k / QI3_K; |     const int kbx  = k / QI3_K; | ||||||
|     const int kqsx = k % QI3_K; |     const int kqsx = k % QI3_K; | ||||||
|  |  | ||||||
|     const block_q3_K * bx0 = (block_q3_K *) vx; |     const block_q3_K * bx0 = (const block_q3_K *) vx; | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { |     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { | ||||||
| @@ -2967,7 +2984,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( | |||||||
|     const float * x_dmf = (const float *) x_dm; |     const float * x_dmf = (const float *) x_dm; | ||||||
|     const float * y_df  = (const float *) y_ds; |     const float * y_df  = (const float *) y_ds; | ||||||
|  |  | ||||||
|     const int8_t * scales = ((int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; |     const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; | ||||||
|  |  | ||||||
|     int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; |     int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; | ||||||
|  |  | ||||||
| @@ -3082,6 +3099,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | ||||||
|  |     (void)x_qh; | ||||||
|  |  | ||||||
|     __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y]; |     __shared__ int   tile_x_ql[mmq_y * (WARP_SIZE)       + mmq_y]; | ||||||
|     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K]; |     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K]; | ||||||
| @@ -3095,6 +3113,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q4_K( | |||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q4_K( | ||||||
|     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
|     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { |     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { | ||||||
|  |     (void)x_qh; | ||||||
|  |  | ||||||
|     GGML_CUDA_ASSUME(i_offset >= 0); |     GGML_CUDA_ASSUME(i_offset >= 0); | ||||||
|     GGML_CUDA_ASSUME(i_offset <  nwarps); |     GGML_CUDA_ASSUME(i_offset <  nwarps); | ||||||
| @@ -3104,7 +3123,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|     const int kbx  = k / QI4_K; // == 0 if QK_K == 256 |     const int kbx  = k / QI4_K; // == 0 if QK_K == 256 | ||||||
|     const int kqsx = k % QI4_K; // == k if QK_K == 256 |     const int kqsx = k % QI4_K; // == k if QK_K == 256 | ||||||
|  |  | ||||||
|     const block_q4_K * bx0 = (block_q4_K *) vx; |     const block_q4_K * bx0 = (const block_q4_K *) vx; | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { |     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { | ||||||
| @@ -3149,7 +3168,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|  |  | ||||||
|         const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); |         const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); | ||||||
|  |  | ||||||
|         const int * scales = (int *) bxi->scales; |         const int * scales = (const int *) bxi->scales; | ||||||
|  |  | ||||||
|         const int ksc = k % (WARP_SIZE/8); |         const int ksc = k % (WARP_SIZE/8); | ||||||
|  |  | ||||||
| @@ -3164,6 +3183,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( | static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { |     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { | ||||||
|  |     (void)x_qh; | ||||||
|  |  | ||||||
|     const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); |     const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); | ||||||
|  |  | ||||||
| @@ -3263,6 +3283,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | ||||||
|  |     (void)x_qh; | ||||||
|  |  | ||||||
|     __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y]; |     __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y]; | ||||||
|     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K]; |     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K]; | ||||||
| @@ -3276,6 +3297,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q5_K( | |||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q5_K( | ||||||
|     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
|     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { |     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { | ||||||
|  |     (void)x_qh; | ||||||
|  |  | ||||||
|     GGML_CUDA_ASSUME(i_offset >= 0); |     GGML_CUDA_ASSUME(i_offset >= 0); | ||||||
|     GGML_CUDA_ASSUME(i_offset <  nwarps); |     GGML_CUDA_ASSUME(i_offset <  nwarps); | ||||||
| @@ -3285,7 +3307,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|     const int kbx  = k / QI5_K; // == 0 if QK_K == 256 |     const int kbx  = k / QI5_K; // == 0 if QK_K == 256 | ||||||
|     const int kqsx = k % QI5_K; // == k if QK_K == 256 |     const int kqsx = k % QI5_K; // == k if QK_K == 256 | ||||||
|  |  | ||||||
|     const block_q5_K * bx0 = (block_q5_K *) vx; |     const block_q5_K * bx0 = (const block_q5_K *) vx; | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { |     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { | ||||||
| @@ -3341,7 +3363,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|  |  | ||||||
|         const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); |         const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); | ||||||
|  |  | ||||||
|         const int * scales = (int *) bxi->scales; |         const int * scales = (const int *) bxi->scales; | ||||||
|  |  | ||||||
|         const int ksc = k % (WARP_SIZE/8); |         const int ksc = k % (WARP_SIZE/8); | ||||||
|  |  | ||||||
| @@ -3356,6 +3378,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( | static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { |     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { | ||||||
|  |     (void)x_qh; | ||||||
|  |  | ||||||
|     const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); |     const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); | ||||||
|  |  | ||||||
| @@ -3392,6 +3415,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( | |||||||
| } | } | ||||||
|  |  | ||||||
| template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { | ||||||
|  |     (void)x_qh; | ||||||
|  |  | ||||||
|     __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y]; |     __shared__ int   tile_x_ql[mmq_y * (2*WARP_SIZE)     + mmq_y]; | ||||||
|     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K]; |     __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K]; | ||||||
| @@ -3405,6 +3429,7 @@ template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q6_K( | |||||||
| template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K( | template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_q6_K( | ||||||
|     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, |     const void * __restrict__ vx, int * __restrict__ x_ql, half2 * __restrict__ x_dm, int * __restrict__ x_qh, | ||||||
|     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { |     int * __restrict__ x_sc, const int & i_offset, const int & i_max, const int & k, const int & blocks_per_row) { | ||||||
|  |     (void)x_qh; | ||||||
|  |  | ||||||
|     GGML_CUDA_ASSUME(i_offset >= 0); |     GGML_CUDA_ASSUME(i_offset >= 0); | ||||||
|     GGML_CUDA_ASSUME(i_offset <  nwarps); |     GGML_CUDA_ASSUME(i_offset <  nwarps); | ||||||
| @@ -3414,7 +3439,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|     const int kbx  = k / QI6_K; // == 0 if QK_K == 256 |     const int kbx  = k / QI6_K; // == 0 if QK_K == 256 | ||||||
|     const int kqsx = k % QI6_K; // == k if QK_K == 256 |     const int kqsx = k % QI6_K; // == k if QK_K == 256 | ||||||
|  |  | ||||||
|     const block_q6_K * bx0 = (block_q6_K *) vx; |     const block_q6_K * bx0 = (const block_q6_K *) vx; | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { |     for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { | ||||||
| @@ -3476,6 +3501,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
| static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( | static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( | ||||||
|     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, |     const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, | ||||||
|     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { |     const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { | ||||||
|  |     (void)x_qh; | ||||||
|  |  | ||||||
|     const float * x_dmf = (const float *) x_dm; |     const float * x_dmf = (const float *) x_dm; | ||||||
|     const float * y_df  = (const float *) y_ds; |     const float * y_df  = (const float *) y_ds; | ||||||
| @@ -3518,7 +3544,7 @@ static __device__ __forceinline__ void mul_mat_q( | |||||||
|     __shared__ int    tile_y_qs[mmq_x * WARP_SIZE]; |     __shared__ int    tile_y_qs[mmq_x * WARP_SIZE]; | ||||||
|     __shared__ half2  tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; |     __shared__ half2  tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; | ||||||
|  |  | ||||||
|     float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {0.0f}; |     float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}}; | ||||||
|  |  | ||||||
|     for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { |     for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { | ||||||
|  |  | ||||||
| @@ -6023,18 +6049,18 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( | |||||||
|     const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; |     const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3; | ||||||
|     if (nb0 == ts && nb1 == ts*ne0/bs) { |     if (nb0 == ts && nb1 == ts*ne0/bs) { | ||||||
|         return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream); |         return cudaMemcpyAsync(dst_ptr, x, i1_diff*nb1, kind, stream); | ||||||
|     } else if (nb0 == ts) { |     } | ||||||
|  |     if (nb0 == ts) { | ||||||
|         return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream); |         return cudaMemcpy2DAsync(dst_ptr, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, kind, stream); | ||||||
|     } else { |     } | ||||||
|     for (int64_t i1 = 0; i1 < i1_diff; i1++) { |     for (int64_t i1 = 0; i1 < i1_diff; i1++) { | ||||||
|         const void * rx = (const void *) ((const char *) x + i1*nb1); |         const void * rx = (const void *) ((const char *) x + i1*nb1); | ||||||
|         void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); |         void * rd = (void *) (dst_ptr + i1*ts*ne0/bs); | ||||||
|         // pretend the row is a matrix with cols=1 |         // pretend the row is a matrix with cols=1 | ||||||
|         cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream); |         cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, kind, stream); | ||||||
|             if (r != cudaSuccess) return r; |         if (r != cudaSuccess) { return r; } | ||||||
|     } |     } | ||||||
|     return cudaSuccess; |     return cudaSuccess; | ||||||
|     } |  | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_cuda_op_repeat( | static void ggml_cuda_op_repeat( | ||||||
| @@ -6989,7 +7015,7 @@ static void ggml_cuda_op_mul_mat( | |||||||
|     const int64_t ne01 = src0->ne[1]; |     const int64_t ne01 = src0->ne[1]; | ||||||
|     const int64_t ne02 = src0->ne[2]; |     const int64_t ne02 = src0->ne[2]; | ||||||
|     const int64_t ne03 = src0->ne[3]; |     const int64_t ne03 = src0->ne[3]; | ||||||
|     const int64_t nrows0 = ggml_nrows(src0); |     // const int64_t nrows0 = ggml_nrows(src0); | ||||||
|  |  | ||||||
|     const int64_t ne10 = src1->ne[0]; |     const int64_t ne10 = src1->ne[0]; | ||||||
|     const int64_t ne11 = src1->ne[1]; |     const int64_t ne11 = src1->ne[1]; | ||||||
| @@ -7090,7 +7116,7 @@ static void ggml_cuda_op_mul_mat( | |||||||
|         if (src0_on_device && src0_is_contiguous) { |         if (src0_on_device && src0_is_contiguous) { | ||||||
|             src0_dd[id] = (char *) src0_extra->data_device[id]; |             src0_dd[id] = (char *) src0_extra->data_device[id]; | ||||||
|         } else { |         } else { | ||||||
|             const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); |             // const size_t size_src0_ddq = split ? (row_high[id]-row_low[id])*ne00 * src0_ts/src0_bs : ggml_nbytes(src0); | ||||||
|             src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]); |             src0_dd[id] = (char *) ggml_cuda_pool_malloc(ggml_nbytes(src0), &src0_as[id]); | ||||||
|         } |         } | ||||||
|  |  | ||||||
| @@ -7323,7 +7349,7 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src | |||||||
| } | } | ||||||
|  |  | ||||||
| bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { | bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { | ||||||
|     if (!g_cublas_loaded) return false; |     if (!g_cublas_loaded) { return false; } | ||||||
|  |  | ||||||
|     const int64_t ne10 = src1->ne[0]; |     const int64_t ne10 = src1->ne[0]; | ||||||
|  |  | ||||||
| @@ -7401,7 +7427,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor | |||||||
|     ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); |     ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); | ||||||
| } | } | ||||||
|  |  | ||||||
| __global__ void k_compute_batched_ptrs( | __global__ static void k_compute_batched_ptrs( | ||||||
|         const half * src0_as_f16, const half * src1_as_f16, half * dst_f16, |         const half * src0_as_f16, const half * src1_as_f16, half * dst_f16, | ||||||
|         const void ** ptrs_src, void ** ptrs_dst, |         const void ** ptrs_src, void ** ptrs_dst, | ||||||
|         int ne12, int ne13, |         int ne12, int ne13, | ||||||
| @@ -8017,7 +8043,7 @@ void ggml_cuda_free_scratch() { | |||||||
| } | } | ||||||
|  |  | ||||||
| bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { | bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { | ||||||
|     if (!g_cublas_loaded) return false; |     if (!g_cublas_loaded) { return false; } | ||||||
|  |  | ||||||
|     ggml_cuda_func_t func; |     ggml_cuda_func_t func; | ||||||
|     const bool any_on_device = tensor->backend == GGML_BACKEND_GPU |     const bool any_on_device = tensor->backend == GGML_BACKEND_GPU | ||||||
| @@ -8316,14 +8342,14 @@ static ggml_backend_graph_plan_t ggml_backend_cuda_graph_plan_create(ggml_backen | |||||||
|     UNUSED(cgraph); |     UNUSED(cgraph); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { | [[noreturn]] static void ggml_backend_cuda_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { | ||||||
|     GGML_ASSERT(!"not implemented"); |     GGML_ASSERT(!"not implemented"); | ||||||
|  |  | ||||||
|     UNUSED(backend); |     UNUSED(backend); | ||||||
|     UNUSED(plan); |     UNUSED(plan); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { | [[noreturn]] static void ggml_backend_cuda_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { | ||||||
|     GGML_ASSERT(!"not implemented"); |     GGML_ASSERT(!"not implemented"); | ||||||
|  |  | ||||||
|     UNUSED(backend); |     UNUSED(backend); | ||||||
| @@ -8339,8 +8365,9 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph | |||||||
|     for (int i = 0; i < cgraph->n_nodes; i++) { |     for (int i = 0; i < cgraph->n_nodes; i++) { | ||||||
|         ggml_tensor * node = cgraph->nodes[i]; |         ggml_tensor * node = cgraph->nodes[i]; | ||||||
|  |  | ||||||
|         if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE) |         if (node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE) { | ||||||
|             continue; |             continue; | ||||||
|  |         } | ||||||
|         assert(node->backend == GGML_BACKEND_GPU); |         assert(node->backend == GGML_BACKEND_GPU); | ||||||
|         for (int j = 0; j < GGML_MAX_SRC; j++) { |         for (int j = 0; j < GGML_MAX_SRC; j++) { | ||||||
|             if (node->src[j] != nullptr) { |             if (node->src[j] != nullptr) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 KerfuffleV2
					KerfuffleV2