mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	k_quants tuning for Falcon-7b (#2816)
* Make ggml-cuda.cu build with QK_K = 64 Using LLAMA_CUDA_FORCE_DMMV = ON and -nommq it runs and produces a meaningful result. * k_quants tuning for Falcon-7b --------- Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
		
							
								
								
									
										25
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										25
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -306,11 +306,11 @@ typedef struct { | |||||||
| #define QI4_K (QK_K / (4*QR4_K)) | #define QI4_K (QK_K / (4*QR4_K)) | ||||||
| #ifdef GGML_QKK_64 | #ifdef GGML_QKK_64 | ||||||
| typedef struct { | typedef struct { | ||||||
|     half    d[2];              // super-block scales/mins |     half    dm[2];             // super-block scales/mins | ||||||
|     uint8_t scales[2];         // 4-bit block scales/mins |     uint8_t scales[2];         // 4-bit block scales/mins | ||||||
|     uint8_t qs[QK_K/2];        // 4--bit quants |     uint8_t qs[QK_K/2];        // 4--bit quants | ||||||
| } block_q4_K; | } block_q4_K; | ||||||
| static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + QK_K/2 + 2, "wrong q4_K block size/padding"); | static_assert(sizeof(block_q4_K) == sizeof(half2) + QK_K/2 + 2, "wrong q4_K block size/padding"); | ||||||
| #else | #else | ||||||
| typedef struct { | typedef struct { | ||||||
|     half2 dm;                  // super-block scale for quantized scales/mins |     half2 dm;                  // super-block scale for quantized scales/mins | ||||||
| @@ -737,8 +737,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float | |||||||
|     const int tid = threadIdx.x; |     const int tid = threadIdx.x; | ||||||
|     const uint8_t * q = x[i].qs; |     const uint8_t * q = x[i].qs; | ||||||
|     float * y = yy + i*QK_K; |     float * y = yy + i*QK_K; | ||||||
|     const float d = (float)x[i].d[0]; |     const float d = (float)x[i].dm[0]; | ||||||
|     const float m = (float)x[i].d[1]; |     const float m = (float)x[i].dm[1]; | ||||||
|     y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); |     y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4); | ||||||
|     y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >>  4) - m * (x[i].scales[1] >> 4); |     y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >>  4) - m * (x[i].scales[1] >> 4); | ||||||
| #endif | #endif | ||||||
| @@ -1155,8 +1155,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, | |||||||
|         const uint16_t * a = (const uint16_t *)x[i].scales; |         const uint16_t * a = (const uint16_t *)x[i].scales; | ||||||
|         aux16[0] = a[0] & 0x0f0f; |         aux16[0] = a[0] & 0x0f0f; | ||||||
|         aux16[1] = (a[0] >> 4) & 0x0f0f; |         aux16[1] = (a[0] >> 4) & 0x0f0f; | ||||||
|         const float d = (float)x[i].d[0]; |         const float d = (float)x[i].dm[0]; | ||||||
|         const float m = (float)x[i].d[1]; |         const float m = (float)x[i].dm[1]; | ||||||
|         float sum = 0.f; |         float sum = 0.f; | ||||||
|         for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { |         for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) { | ||||||
|             sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) |             sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2]) | ||||||
| @@ -2845,8 +2845,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( | |||||||
|     aux16[0] = a[0] & 0x0f0f; |     aux16[0] = a[0] & 0x0f0f; | ||||||
|     aux16[1] = (a[0] >> 4) & 0x0f0f; |     aux16[1] = (a[0] >> 4) & 0x0f0f; | ||||||
|  |  | ||||||
|     const float dall = bq4_K->d[0]; |     const float dall = bq4_K->dm[0]; | ||||||
|     const float dmin = bq4_K->d[1]; |     const float dmin = bq4_K->dm[1]; | ||||||
|  |  | ||||||
|     const float d8_1 = __low2float(bq8_1[0].ds); |     const float d8_1 = __low2float(bq8_1[0].ds); | ||||||
|     const float d8_2 = __low2float(bq8_1[1].ds); |     const float d8_2 = __low2float(bq8_1[1].ds); | ||||||
| @@ -2929,7 +2929,11 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|  |  | ||||||
|         const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; |         const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; | ||||||
|  |  | ||||||
|  | #if QK_K == 256 | ||||||
|         x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; |         x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; | ||||||
|  | #else | ||||||
|  |         x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = {bxi->dm[0], bxi->dm[1]}; | ||||||
|  | #endif | ||||||
|     } |     } | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
| @@ -3119,7 +3123,9 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin | |||||||
|  |  | ||||||
|         const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; |         const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; | ||||||
|  |  | ||||||
|  | #if QK_K == 256 | ||||||
|         x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; |         x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; | ||||||
|  | #endif | ||||||
|     } |     } | ||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
| @@ -4709,6 +4715,8 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( | |||||||
|     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, |     const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, | ||||||
|     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { |     const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { | ||||||
|  |  | ||||||
|  | #if QK_K == 256 | ||||||
|  |  | ||||||
|     int id; |     int id; | ||||||
|     CUDA_CHECK(cudaGetDevice(&id)); |     CUDA_CHECK(cudaGetDevice(&id)); | ||||||
|     const int compute_capability = g_compute_capabilities[id]; |     const int compute_capability = g_compute_capabilities[id]; | ||||||
| @@ -4740,6 +4748,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( | |||||||
|         mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>> |         mul_mat_q3_K<need_check><<<block_nums, block_dims, 0, stream>>> | ||||||
|             (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); |             (vx, vy, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); | ||||||
|     } |     } | ||||||
|  | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_mul_mat_q4_K_q8_1_cuda( | static void ggml_mul_mat_q4_K_q8_1_cuda( | ||||||
|   | |||||||
							
								
								
									
										37
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										37
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -4776,7 +4776,10 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s | |||||||
|  |  | ||||||
|             if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { |             if (name == tn(LLM_TENSOR_OUTPUT, "weight")) { | ||||||
|                 int nx = tensor->ne[0]; |                 int nx = tensor->ne[0]; | ||||||
|                 if (nx % QK_K == 0) { |                 if (model.arch == LLM_ARCH_FALCON || nx % QK_K != 0) { | ||||||
|  |                     new_type = GGML_TYPE_Q8_0; | ||||||
|  |                 } | ||||||
|  |                 else if (new_type != GGML_TYPE_Q8_0) { | ||||||
|                     new_type = GGML_TYPE_Q6_K; |                     new_type = GGML_TYPE_Q6_K; | ||||||
|                 } |                 } | ||||||
|             } else if (name.find("attn_v.weight") != std::string::npos) { |             } else if (name.find("attn_v.weight") != std::string::npos) { | ||||||
| @@ -4800,17 +4803,39 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s | |||||||
|             } else if (name.find("ffn_down.weight") != std::string::npos) { |             } else if (name.find("ffn_down.weight") != std::string::npos) { | ||||||
|                 if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; |                 if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; | ||||||
|                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { |                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) { | ||||||
|                     new_type = i_feed_forward_w2 < 2 ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; |                     new_type = i_feed_forward_w2 < 2 ? GGML_TYPE_Q5_K | ||||||
|  |                              : model.arch != LLM_ARCH_FALCON || use_more_bits(i_feed_forward_w2, n_feed_forward_w2) ? GGML_TYPE_Q4_K | ||||||
|  |                              : GGML_TYPE_Q3_K; | ||||||
|  |                 } | ||||||
|  |                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) { | ||||||
|  |                     new_type = model.arch == LLM_ARCH_FALCON ? GGML_TYPE_Q4_K : GGML_TYPE_Q5_K; | ||||||
|  |                 } | ||||||
|  |                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) { | ||||||
|  |                     if (model.arch == LLM_ARCH_FALCON) { | ||||||
|  |                         new_type = i_feed_forward_w2 < 2 ? GGML_TYPE_Q6_K : | ||||||
|  |                                    use_more_bits(i_feed_forward_w2, n_feed_forward_w2) ? GGML_TYPE_Q5_K : GGML_TYPE_Q4_K; | ||||||
|  |                     } else { | ||||||
|  |                         if (use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|  |                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M && use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; | ||||||
|  |                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && model.arch != LLM_ARCH_FALCON && i_feed_forward_w2 < 4) { | ||||||
|  |                     new_type = GGML_TYPE_Q5_K; | ||||||
|                 } |                 } | ||||||
|                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; |  | ||||||
|                 else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) && |  | ||||||
|                          use_more_bits(i_feed_forward_w2, n_feed_forward_w2)) new_type = GGML_TYPE_Q6_K; |  | ||||||
|                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && i_feed_forward_w2 < 4) new_type = GGML_TYPE_Q5_K; |  | ||||||
|                 ++i_feed_forward_w2; |                 ++i_feed_forward_w2; | ||||||
|             } else if (name.find("attn_output.weight") != std::string::npos) { |             } else if (name.find("attn_output.weight") != std::string::npos) { | ||||||
|  |                 if (model.arch != LLM_ARCH_FALCON) { | ||||||
|                     if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K  ) new_type = GGML_TYPE_Q3_K; |                     if      (ftype == LLAMA_FTYPE_MOSTLY_Q2_K  ) new_type = GGML_TYPE_Q3_K; | ||||||
|                     else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) new_type = GGML_TYPE_Q4_K; |                     else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) new_type = GGML_TYPE_Q4_K; | ||||||
|                     else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; |                     else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; | ||||||
|  |                 } else { | ||||||
|  |                     if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; | ||||||
|  |                 } | ||||||
|  |             } | ||||||
|  |             else if (name.find("attn_qkv.weight") != std::string::npos) { | ||||||
|  |                 if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q4_K; | ||||||
|  |                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_M) new_type = GGML_TYPE_Q5_K; | ||||||
|  |                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q5_K_M) new_type = GGML_TYPE_Q6_K; | ||||||
|             } |             } | ||||||
|             else if (name.find("ffn_gate.weight") != std::string::npos || name.find("ffn_up.weight") != std::string::npos) { |             else if (name.find("ffn_gate.weight") != std::string::npos || name.find("ffn_up.weight") != std::string::npos) { | ||||||
|                 if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; |                 if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q3_K; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Kawrakow
					Kawrakow