diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 48d6747934..a6a528b74a 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -69,6 +69,7 @@ static const char * const LLM_KV_QUANTIZE_IMATRIX_FILE = "quantize.imatrix static const char * const LLM_KV_QUANTIZE_IMATRIX_DATASET = "quantize.imatrix.dataset"; static const char * const LLM_KV_QUANTIZE_IMATRIX_N_ENTRIES = "quantize.imatrix.entries_count"; static const char * const LLM_KV_QUANTIZE_IMATRIX_N_CHUNKS = "quantize.imatrix.chunks_count"; +static const char * const LLM_KV_QUANTIZE_IMATRIX_PRIOR_W = "quantize.imatrix.prior_weight"; // TODO: share with imatrix.cpp static const char * const LLM_KV_IMATRIX_DATASETS = "imatrix.datasets"; @@ -214,7 +215,7 @@ static int load_legacy_imatrix(const std::string & imatrix_file, std::vector & imatrix_datasets, std::unordered_map> & imatrix_data, float prior_weight) { +static int load_imatrix(const std::string & imatrix_file, std::vector & imatrix_datasets, std::unordered_map> & imatrix_data, float & prior_weight) { struct ggml_context * ctx = nullptr; struct gguf_init_params meta_gguf_params = { @@ -224,6 +225,7 @@ static int load_imatrix(const std::string & imatrix_file, std::vector & included_weights, const std::vector & excluded_weights, std::unordered_map> & imatrix_data, - float prior_weight) { + float & prior_weight) { int m_last_call = -1; if (!imatrix_file.empty()) { m_last_call = load_imatrix(imatrix_file, imatrix_dataset, imatrix_data, prior_weight); @@ -574,6 +576,14 @@ int main(int argc, char ** argv) { kvo.val_i64 = m_last_call; kv_overrides.emplace_back(std::move(kvo)); } + + { + llama_model_kv_override kvo; + std::strcpy(kvo.key, LLM_KV_QUANTIZE_IMATRIX_PRIOR_W); + kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT; + kvo.val_f64 = prior_weight; + kv_overrides.emplace_back(std::move(kvo)); + } } if (!kv_overrides.empty()) { kv_overrides.emplace_back();