mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	quantize: Handle user-defined quantization levels for additional tensors (#12511)
* Add llama_model_quantize_params parameters * Add new quantize parameters parsing and validation * Update usage * Add new parameters defaults * Add new quantization parameters logic * Add llama_model_quantize_params parameters * Add new quantize parameters parsing and validation * Update usage * Add new parameters defaults * Add new quantization parameters logic * Minor refactoring as per the contributors' coding guidelines * Update descriptions to match existing style * Add llama_model_quantize_params parameters * Add new quantize parameters parsing and validation * Update usage * Add new parameters defaults * Add new quantization parameters logic * Minor refactoring as per the contributors' guidelines * Implement general --tensor-type instead of tensor-specific command option * Fix implied type bug * Restore missing #includes * Add regex capability for tensor selection * Refactor function name and update ALLOWED_TENSOR_TYPE * Add missing #include * Handle edge case when tensor name is cls.output * Minor logging improvement
This commit is contained in:
		| @@ -9,6 +9,7 @@ | |||||||
| #include <fstream> | #include <fstream> | ||||||
| #include <cmath> | #include <cmath> | ||||||
| #include <cctype> | #include <cctype> | ||||||
|  | #include <algorithm> | ||||||
|  |  | ||||||
| struct quant_option { | struct quant_option { | ||||||
|     std::string name; |     std::string name; | ||||||
| @@ -16,7 +17,7 @@ struct quant_option { | |||||||
|     std::string desc; |     std::string desc; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| static const std::vector<struct quant_option> QUANT_OPTIONS = { | static const std::vector<quant_option> QUANT_OPTIONS = { | ||||||
|     { "Q4_0",     LLAMA_FTYPE_MOSTLY_Q4_0,     " 4.34G, +0.4685 ppl @ Llama-3-8B",  }, |     { "Q4_0",     LLAMA_FTYPE_MOSTLY_Q4_0,     " 4.34G, +0.4685 ppl @ Llama-3-8B",  }, | ||||||
|     { "Q4_1",     LLAMA_FTYPE_MOSTLY_Q4_1,     " 4.78G, +0.4511 ppl @ Llama-3-8B",  }, |     { "Q4_1",     LLAMA_FTYPE_MOSTLY_Q4_1,     " 4.78G, +0.4511 ppl @ Llama-3-8B",  }, | ||||||
|     { "Q5_0",     LLAMA_FTYPE_MOSTLY_Q5_0,     " 5.21G, +0.1316 ppl @ Llama-3-8B",  }, |     { "Q5_0",     LLAMA_FTYPE_MOSTLY_Q5_0,     " 5.21G, +0.1316 ppl @ Llama-3-8B",  }, | ||||||
| @@ -105,7 +106,8 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp | |||||||
| // | // | ||||||
| [[noreturn]] | [[noreturn]] | ||||||
| static void usage(const char * executable) { | static void usage(const char * executable) { | ||||||
|     printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n", executable); |     printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights] [--exclude-weights] [--output-tensor-type]\n", executable); | ||||||
|  |     printf("       [--token-embedding-type] [--tensor-type] [--keep-split] [--override-kv] model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); | ||||||
|     printf("  --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); |     printf("  --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); | ||||||
|     printf("  --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); |     printf("  --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); | ||||||
|     printf("  --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); |     printf("  --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); | ||||||
| @@ -114,6 +116,8 @@ static void usage(const char * executable) { | |||||||
|     printf("  --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n"); |     printf("  --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n"); | ||||||
|     printf("  --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n"); |     printf("  --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n"); | ||||||
|     printf("  --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); |     printf("  --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); | ||||||
|  |     printf("  --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n"); | ||||||
|  |     printf("      Advanced option to selectively quantize tensors. May be specified multiple times.\n"); | ||||||
|     printf("  --keep-split: will generate quantized model in the same shards as input\n"); |     printf("  --keep-split: will generate quantized model in the same shards as input\n"); | ||||||
|     printf("  --override-kv KEY=TYPE:VALUE\n"); |     printf("  --override-kv KEY=TYPE:VALUE\n"); | ||||||
|     printf("      Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); |     printf("      Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); | ||||||
| @@ -244,6 +248,107 @@ static ggml_type parse_ggml_type(const char * arg) { | |||||||
|     return GGML_TYPE_COUNT; |     return GGML_TYPE_COUNT; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | // Allowed tensors for arbitrary quantization with --tensor-type option | ||||||
|  | static const std::vector<std::string> ALLOWED_TENSOR_TYPE = { | ||||||
|  |     "attn_k", | ||||||
|  |     "attn_kv_a_mqa", | ||||||
|  |     "attn_kv_b", | ||||||
|  |     "attn_o", | ||||||
|  |     "attn_output", | ||||||
|  |     "attn_q", | ||||||
|  |     "attn_q_a", | ||||||
|  |     "attn_q_b", | ||||||
|  |     "attn_qkv", | ||||||
|  |     "attn_v", | ||||||
|  |     "channel_mix_key", | ||||||
|  |     "channel_mix_receptance", | ||||||
|  |     "channel_mix_value", | ||||||
|  |     "cls", | ||||||
|  |     "cls.output", | ||||||
|  |     "cross_attn_k", | ||||||
|  |     "cross_attn_o", | ||||||
|  |     "cross_attn_q", | ||||||
|  |     "cross_attn_v", | ||||||
|  |     "ffn_act", | ||||||
|  |     "ffn_down", | ||||||
|  |     "ffn_down_exps", | ||||||
|  |     "ffn_down_shexp", | ||||||
|  |     "ffn_gate", | ||||||
|  |     "ffn_gate_exps", | ||||||
|  |     "ffn_gate_shexp", | ||||||
|  |     "ffn_up", | ||||||
|  |     "ffn_up_exps", | ||||||
|  |     "ffn_up_shexp", | ||||||
|  |     "ssm_in", | ||||||
|  |     "ssm_out", | ||||||
|  |     "time_mix_gate", | ||||||
|  |     "time_mix_key", | ||||||
|  |     "time_mix_output", | ||||||
|  |     "time_mix_receptance", | ||||||
|  |     "time_mix_value", | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | // changes to this struct must be replicated in llama-quant.cpp | ||||||
|  | struct tensor_quantization { | ||||||
|  |     std::string name; | ||||||
|  |     ggml_type quant = GGML_TYPE_COUNT; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | static bool parse_tensor_type(const char * data, std::vector<tensor_quantization> & tensor_type) { | ||||||
|  |     const char * sep = strchr(data, '='); | ||||||
|  |     if (sep == nullptr) { | ||||||
|  |         printf("\n%s: malformed tensor type '%s'\n\n", __func__, data); | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const size_t tn_len = sep - data; | ||||||
|  |     if (tn_len == 0) { | ||||||
|  |         printf("\n%s: missing tensor name\n\n", __func__); | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (const size_t qt_len = strlen(sep); qt_len == 1) { | ||||||
|  |         printf("\n%s: missing quantization type\n\n", __func__); | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     std::string tn(data, tn_len); | ||||||
|  |     std::transform(tn.begin(), tn.end(), tn.begin(), tolower); | ||||||
|  |     sep++; | ||||||
|  |     const std::string qt(sep); | ||||||
|  |  | ||||||
|  |     bool found = false; | ||||||
|  |     for (const auto & allowed : ALLOWED_TENSOR_TYPE) { | ||||||
|  |         std::string tensor; | ||||||
|  |         tensor = tn.rfind('.') != std::string::npos ? tn.substr(tn.rfind('.') + 1) : tn; | ||||||
|  |         // handle special case of cls.output | ||||||
|  |         std::string cls_output = "cls.output"; | ||||||
|  |         if (tn.find(cls_output) != std::string::npos) { | ||||||
|  |             tensor = "cls.output"; | ||||||
|  |         } | ||||||
|  |         // check if an allowed tensor exists and it's at the end of the kv string | ||||||
|  |         if (tensor == allowed) { | ||||||
|  |             found = true; | ||||||
|  |             break; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     if (!found) { | ||||||
|  |         printf("\n%s: invalid tensor name '%s'\n\n", __func__, tn.c_str()); | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if (parse_ggml_type(qt.c_str()) == GGML_TYPE_COUNT) { | ||||||
|  |         printf("\n%s: invalid quantization type '%s'\n\n", __func__, qt.c_str()); | ||||||
|  |         return false; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     tensor_quantization tqz; | ||||||
|  |     tqz.name = tn; | ||||||
|  |     tqz.quant = parse_ggml_type(qt.c_str()); | ||||||
|  |     tensor_type.emplace_back(std::move(tqz)); | ||||||
|  |     return true; | ||||||
|  | } | ||||||
|  |  | ||||||
| int main(int argc, char ** argv) { | int main(int argc, char ** argv) { | ||||||
|     if (argc < 3) { |     if (argc < 3) { | ||||||
|         usage(argv[0]); |         usage(argv[0]); | ||||||
| @@ -255,6 +360,7 @@ int main(int argc, char ** argv) { | |||||||
|     std::string imatrix_file; |     std::string imatrix_file; | ||||||
|     std::vector<std::string> included_weights, excluded_weights; |     std::vector<std::string> included_weights, excluded_weights; | ||||||
|     std::vector<llama_model_kv_override> kv_overrides; |     std::vector<llama_model_kv_override> kv_overrides; | ||||||
|  |     std::vector<tensor_quantization> tensor_types; | ||||||
|  |  | ||||||
|     for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { |     for (; arg_idx < argc && strncmp(argv[arg_idx], "--", 2) == 0; arg_idx++) { | ||||||
|         if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { |         if (strcmp(argv[arg_idx], "--leave-output-tensor") == 0) { | ||||||
| @@ -277,6 +383,10 @@ int main(int argc, char ** argv) { | |||||||
|             } else { |             } else { | ||||||
|                 usage(argv[0]); |                 usage(argv[0]); | ||||||
|             } |             } | ||||||
|  |         } else if (strcmp(argv[arg_idx], "--tensor-type") == 0) { | ||||||
|  |             if (arg_idx == argc-1 || !parse_tensor_type(argv[++arg_idx], tensor_types)) { | ||||||
|  |                 usage(argv[0]); | ||||||
|  |             } | ||||||
|         } else if (strcmp(argv[arg_idx], "--override-kv") == 0) { |         } else if (strcmp(argv[arg_idx], "--override-kv") == 0) { | ||||||
|             if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) { |             if (arg_idx == argc-1 || !string_parse_kv_override(argv[++arg_idx], kv_overrides)) { | ||||||
|                 usage(argv[0]); |                 usage(argv[0]); | ||||||
| @@ -361,6 +471,9 @@ int main(int argc, char ** argv) { | |||||||
|         kv_overrides.back().key[0] = 0; |         kv_overrides.back().key[0] = 0; | ||||||
|         params.kv_overrides = &kv_overrides; |         params.kv_overrides = &kv_overrides; | ||||||
|     } |     } | ||||||
|  |     if (!tensor_types.empty()) { | ||||||
|  |         params.tensor_types = &tensor_types; | ||||||
|  |     } | ||||||
|  |  | ||||||
|     llama_backend_init(); |     llama_backend_init(); | ||||||
|  |  | ||||||
|   | |||||||
| @@ -367,17 +367,18 @@ extern "C" { | |||||||
|  |  | ||||||
|     // model quantization parameters |     // model quantization parameters | ||||||
|     typedef struct llama_model_quantize_params { |     typedef struct llama_model_quantize_params { | ||||||
|         int32_t nthread;                     // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() |         int32_t nthread;                      // number of threads to use for quantizing, if <=0 will use std::thread::hardware_concurrency() | ||||||
|         enum llama_ftype ftype;              // quantize to this llama_ftype |         enum llama_ftype ftype;               // quantize to this llama_ftype | ||||||
|         enum ggml_type output_tensor_type;   // output tensor type |         enum ggml_type output_tensor_type;    // output tensor type | ||||||
|         enum ggml_type token_embedding_type; // token embeddings tensor type |         enum ggml_type token_embedding_type;  // token embeddings tensor type | ||||||
|         bool allow_requantize;               // allow quantizing non-f32/f16 tensors |         bool allow_requantize;                // allow quantizing non-f32/f16 tensors | ||||||
|         bool quantize_output_tensor;         // quantize output.weight |         bool quantize_output_tensor;          // quantize output.weight | ||||||
|         bool only_copy;                      // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored |         bool only_copy;                       // only copy tensors - ftype, allow_requantize and quantize_output_tensor are ignored | ||||||
|         bool pure;                           // quantize all tensors to the default type |         bool pure;                            // quantize all tensors to the default type | ||||||
|         bool keep_split;                     // quantize to the same number of shards |         bool keep_split;                      // quantize to the same number of shards | ||||||
|         void * imatrix;                      // pointer to importance matrix data |         void * imatrix;                       // pointer to importance matrix data | ||||||
|         void * kv_overrides;                 // pointer to vector containing overrides |         void * kv_overrides;                  // pointer to vector containing overrides | ||||||
|  |         void * tensor_types;                  // pointer to vector containing tensor types | ||||||
|     } llama_model_quantize_params; |     } llama_model_quantize_params; | ||||||
|  |  | ||||||
|     typedef struct llama_logit_bias { |     typedef struct llama_logit_bias { | ||||||
|   | |||||||
| @@ -10,6 +10,7 @@ | |||||||
| #include <cinttypes> | #include <cinttypes> | ||||||
| #include <fstream> | #include <fstream> | ||||||
| #include <mutex> | #include <mutex> | ||||||
|  | #include <regex> | ||||||
| #include <thread> | #include <thread> | ||||||
| #include <unordered_map> | #include <unordered_map> | ||||||
|  |  | ||||||
| @@ -47,8 +48,14 @@ struct quantize_state_impl { | |||||||
|         {} |         {} | ||||||
| }; | }; | ||||||
|  |  | ||||||
|  | // changes to this struct must be replicated in quantize.cpp | ||||||
|  | struct tensor_quantization { | ||||||
|  |     std::string name; | ||||||
|  |     ggml_type quant = GGML_TYPE_COUNT; | ||||||
|  | }; | ||||||
|  |  | ||||||
| static void llama_tensor_dequantize_impl( | static void llama_tensor_dequantize_impl( | ||||||
|     struct ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers, |     ggml_tensor * tensor, std::vector<no_init<float>> & output, std::vector<std::thread> & workers, | ||||||
|     const size_t nelements, const int nthread |     const size_t nelements, const int nthread | ||||||
| ) { | ) { | ||||||
|     if (output.size() < nelements) { |     if (output.size() < nelements) { | ||||||
| @@ -536,7 +543,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: | |||||||
|     model.load_hparams(ml); |     model.load_hparams(ml); | ||||||
|     model.load_stats  (ml); |     model.load_stats  (ml); | ||||||
|  |  | ||||||
|     struct quantize_state_impl qs(model, params); |     quantize_state_impl qs(model, params); | ||||||
|  |  | ||||||
|     if (params->only_copy) { |     if (params->only_copy) { | ||||||
|         ftype = ml.ftype; |         ftype = ml.ftype; | ||||||
| @@ -661,7 +668,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: | |||||||
|     // populate the original tensors so we get an initial meta data |     // populate the original tensors so we get an initial meta data | ||||||
|     for (const auto * it : tensors) { |     for (const auto * it : tensors) { | ||||||
|         uint16_t i_split = params->keep_split ? it->idx : 0; |         uint16_t i_split = params->keep_split ? it->idx : 0; | ||||||
|         struct ggml_tensor * tensor = it->tensor; |         ggml_tensor * tensor = it->tensor; | ||||||
|         if (!ctx_outs[i_split]) { |         if (!ctx_outs[i_split]) { | ||||||
|             ctx_outs[i_split].reset(gguf_init_empty()); |             ctx_outs[i_split].reset(gguf_init_empty()); | ||||||
|         } |         } | ||||||
| @@ -710,7 +717,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: | |||||||
|     new_ofstream(0); |     new_ofstream(0); | ||||||
|     for (const auto * it : tensors) { |     for (const auto * it : tensors) { | ||||||
|         const auto & weight = *it; |         const auto & weight = *it; | ||||||
|         struct ggml_tensor * tensor = weight.tensor; |         ggml_tensor * tensor = weight.tensor; | ||||||
|         if (weight.idx != cur_split && params->keep_split) { |         if (weight.idx != cur_split && params->keep_split) { | ||||||
|             close_ofstream(); |             close_ofstream(); | ||||||
|             new_ofstream(weight.idx); |             new_ofstream(weight.idx); | ||||||
| @@ -776,7 +783,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: | |||||||
|         // do not quantize relative position bias (T5) |         // do not quantize relative position bias (T5) | ||||||
|         quantize &= name.find("attn_rel_b.weight") == std::string::npos; |         quantize &= name.find("attn_rel_b.weight") == std::string::npos; | ||||||
|  |  | ||||||
|         enum ggml_type new_type; |         ggml_type new_type; | ||||||
|         void * new_data; |         void * new_data; | ||||||
|         size_t new_size; |         size_t new_size; | ||||||
|  |  | ||||||
| @@ -786,6 +793,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: | |||||||
|             // get more optimal quantization type based on the tensor shape, layer, etc. |             // get more optimal quantization type based on the tensor shape, layer, etc. | ||||||
|             if (!params->pure && ggml_is_quantized(default_type)) { |             if (!params->pure && ggml_is_quantized(default_type)) { | ||||||
|                 new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); |                 new_type = llama_tensor_get_type(qs, new_type, tensor, ftype); | ||||||
|  |                 // unless the user specifies a type | ||||||
|  |                 if (params->tensor_types) { | ||||||
|  |                     const std::vector<tensor_quantization> & tensor_types = *static_cast<const std::vector<tensor_quantization> *>(params->tensor_types); | ||||||
|  |                     for (const auto & [tname, qtype] : tensor_types) { | ||||||
|  |                         if (std::regex pattern(tname); std::regex_search(tensor->name, pattern)) { | ||||||
|  |                             if (qtype != new_type) { | ||||||
|  |                                 LLAMA_LOG_DEBUG("(overriding %s -> %s), ", ggml_type_name(new_type), ggml_type_name(qtype)); | ||||||
|  |                             } | ||||||
|  |                             new_type = qtype; | ||||||
|  |                             break; | ||||||
|  |                         } | ||||||
|  |                     } | ||||||
|  |                 } | ||||||
|             } |             } | ||||||
|             if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { |             if (params->token_embedding_type < GGML_TYPE_COUNT && strcmp(tensor->name, "token_embd.weight") == 0) { | ||||||
|                 new_type = params->token_embedding_type; |                 new_type = params->token_embedding_type; | ||||||
| @@ -910,8 +930,8 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: | |||||||
| // interface implementation | // interface implementation | ||||||
| // | // | ||||||
|  |  | ||||||
| struct llama_model_quantize_params llama_model_quantize_default_params() { | llama_model_quantize_params llama_model_quantize_default_params() { | ||||||
|     struct llama_model_quantize_params result = { |     llama_model_quantize_params result = { | ||||||
|         /*.nthread                     =*/ 0, |         /*.nthread                     =*/ 0, | ||||||
|         /*.ftype                       =*/ LLAMA_FTYPE_MOSTLY_Q5_1, |         /*.ftype                       =*/ LLAMA_FTYPE_MOSTLY_Q5_1, | ||||||
|         /*.output_tensor_type          =*/ GGML_TYPE_COUNT, |         /*.output_tensor_type          =*/ GGML_TYPE_COUNT, | ||||||
| @@ -923,6 +943,7 @@ struct llama_model_quantize_params llama_model_quantize_default_params() { | |||||||
|         /*.keep_split                  =*/ false, |         /*.keep_split                  =*/ false, | ||||||
|         /*.imatrix                     =*/ nullptr, |         /*.imatrix                     =*/ nullptr, | ||||||
|         /*.kv_overrides                =*/ nullptr, |         /*.kv_overrides                =*/ nullptr, | ||||||
|  |         /*.tensor_type                 =*/ nullptr, | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     return result; |     return result; | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Ed Addario
					Ed Addario