diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index a6a528b74a..6846850bdd 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -291,8 +291,15 @@ static int load_imatrix(const std::string & imatrix_file, std::vectordata)[j]; if (count > 0.0f) { + float sumw = 0.0f; for (int64_t i = 0; i < ne0; ++i) { - e[j*ne0 + i] = (((const float *) sums->data)[j*ne0 + i] + prior_weight) / (count + prior_weight); + sumw += ((const float *) sums->data)[j*ne0 + i]; + } + // the neutral prior is equal weights, and it should reduce the variance by weighted-averaging with the mean + const float prior_value = sumw / ne0; + + for (int64_t i = 0; i < ne0; ++i) { + e[j*ne0 + i] = (((const float *) sums->data)[j*ne0 + i] + prior_value * prior_weight) / (count + prior_weight); } } else { // Partial imatrix data, this tensor never got any input during calibration