mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	all : be more strict about converting float to double (#458)
* Be more strict about converting float to double * Test equivalence of round, SILU implementations Test module is commented out in CMakeLists.txt because the tests may take a long time, depending on how much the compiler optimizes. * Fix softmax in perplexity.cpp * all : prefer float over double where appropriate * perplexity : add <cmath> --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		| @@ -124,17 +124,18 @@ if (LLAMA_ALL_WARNINGS) | |||||||
|             -Wall |             -Wall | ||||||
|             -Wextra |             -Wextra | ||||||
|             -Wpedantic |             -Wpedantic | ||||||
|             -Wshadow |  | ||||||
|             -Wcast-qual |             -Wcast-qual | ||||||
|  |             -Wdouble-promotion | ||||||
|  |             -Wshadow | ||||||
|             -Wstrict-prototypes |             -Wstrict-prototypes | ||||||
|             -Wpointer-arith |             -Wpointer-arith | ||||||
|             -Wno-unused-function |  | ||||||
|         ) |         ) | ||||||
|         set(cxx_flags |         set(cxx_flags | ||||||
|             -Wall |             -Wall | ||||||
|             -Wextra |             -Wextra | ||||||
|             -Wpedantic |             -Wpedantic | ||||||
|             -Wcast-qual |             -Wcast-qual | ||||||
|  |             -Wdouble-promotion | ||||||
|         ) |         ) | ||||||
|     else() |     else() | ||||||
|         # todo : msvc |         # todo : msvc | ||||||
|   | |||||||
							
								
								
									
										4
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										4
									
								
								Makefile
									
									
									
									
									
								
							| @@ -35,6 +35,10 @@ CFLAGS   = -I.              -O3 -DNDEBUG -std=c11   -fPIC | |||||||
| CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC | CXXFLAGS = -I. -I./examples -O3 -DNDEBUG -std=c++11 -fPIC | ||||||
| LDFLAGS  = | LDFLAGS  = | ||||||
|  |  | ||||||
|  | # warnings | ||||||
|  | CFLAGS   += -Wall -Wextra -Wpedantic -Wcast-qual -Wdouble-promotion -Wshadow -Wstrict-prototypes -Wpointer-arith -Wno-unused-function | ||||||
|  | CXXFLAGS += -Wall -Wextra -Wpedantic -Wcast-qual -Wno-unused-function | ||||||
|  |  | ||||||
| # OS specific | # OS specific | ||||||
| # TODO: support Windows | # TODO: support Windows | ||||||
| ifeq ($(UNAME_S),Linux) | ifeq ($(UNAME_S),Linux) | ||||||
|   | |||||||
| @@ -215,13 +215,13 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | |||||||
|     fprintf(stderr, "                        prompt file to start generation.\n"); |     fprintf(stderr, "                        prompt file to start generation.\n"); | ||||||
|     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); |     fprintf(stderr, "  -n N, --n_predict N   number of tokens to predict (default: %d, -1 = infinity)\n", params.n_predict); | ||||||
|     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k); |     fprintf(stderr, "  --top_k N             top-k sampling (default: %d)\n", params.top_k); | ||||||
|     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", params.top_p); |     fprintf(stderr, "  --top_p N             top-p sampling (default: %.1f)\n", (double)params.top_p); | ||||||
|     fprintf(stderr, "  --repeat_last_n N     last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); |     fprintf(stderr, "  --repeat_last_n N     last n tokens to consider for penalize (default: %d)\n", params.repeat_last_n); | ||||||
|     fprintf(stderr, "  --repeat_penalty N    penalize repeat sequence of tokens (default: %.1f)\n", params.repeat_penalty); |     fprintf(stderr, "  --repeat_penalty N    penalize repeat sequence of tokens (default: %.1f)\n", (double)params.repeat_penalty); | ||||||
|     fprintf(stderr, "  -c N, --ctx_size N    size of the prompt context (default: %d)\n", params.n_ctx); |     fprintf(stderr, "  -c N, --ctx_size N    size of the prompt context (default: %d)\n", params.n_ctx); | ||||||
|     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating\n"); |     fprintf(stderr, "  --ignore-eos          ignore end of stream token and continue generating\n"); | ||||||
|     fprintf(stderr, "  --memory_f32          use f32 instead of f16 for memory key+value\n"); |     fprintf(stderr, "  --memory_f32          use f32 instead of f16 for memory key+value\n"); | ||||||
|     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", params.temp); |     fprintf(stderr, "  --temp N              temperature (default: %.1f)\n", (double)params.temp); | ||||||
|     fprintf(stderr, "  --n_parts N           number of model parts (default: -1 = determine from dimensions)\n"); |     fprintf(stderr, "  --n_parts N           number of model parts (default: -1 = determine from dimensions)\n"); | ||||||
|     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch); |     fprintf(stderr, "  -b N, --batch_size N  batch size for prompt processing (default: %d)\n", params.n_batch); | ||||||
|     fprintf(stderr, "  --perplexity          compute perplexity over the prompt\n"); |     fprintf(stderr, "  --perplexity          compute perplexity over the prompt\n"); | ||||||
|   | |||||||
| @@ -209,7 +209,8 @@ int main(int argc, char ** argv) { | |||||||
|             fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); |             fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str()); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|     fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); |     fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n", | ||||||
|  |         params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty); | ||||||
|     fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); |     fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); | ||||||
|     fprintf(stderr, "\n\n"); |     fprintf(stderr, "\n\n"); | ||||||
|  |  | ||||||
| @@ -274,7 +275,7 @@ int main(int argc, char ** argv) { | |||||||
|  |  | ||||||
|         if ((int) embd_inp.size() <= n_consumed && !is_interacting) { |         if ((int) embd_inp.size() <= n_consumed && !is_interacting) { | ||||||
|             // out of user input, sample next token |             // out of user input, sample next token | ||||||
|             const float top_k          = params.top_k; |             const int32_t top_k          = params.top_k; | ||||||
|             const float   top_p          = params.top_p; |             const float   top_p          = params.top_p; | ||||||
|             const float   temp           = params.temp; |             const float   temp           = params.temp; | ||||||
|             const float   repeat_penalty = params.repeat_penalty; |             const float   repeat_penalty = params.repeat_penalty; | ||||||
|   | |||||||
| @@ -1,15 +1,17 @@ | |||||||
| #include "common.h" | #include "common.h" | ||||||
| #include "llama.h" | #include "llama.h" | ||||||
|  |  | ||||||
| std::vector<double> softmax(const std::vector<float>& logits) { | #include <cmath> | ||||||
|     std::vector<double> probs(logits.size()); |  | ||||||
|  | std::vector<float> softmax(const std::vector<float>& logits) { | ||||||
|  |     std::vector<float> probs(logits.size()); | ||||||
|     float max_logit = logits[0]; |     float max_logit = logits[0]; | ||||||
|     for (float v : logits) max_logit = std::max(max_logit, v); |     for (float v : logits) max_logit = std::max(max_logit, v); | ||||||
|     double sum_exp = 0.0; |     double sum_exp = 0.0; | ||||||
|     for (size_t i = 0; i < logits.size(); i++) { |     for (size_t i = 0; i < logits.size(); i++) { | ||||||
|         // Subtract the maximum logit value from the current logit value for numerical stability |         // Subtract the maximum logit value from the current logit value for numerical stability | ||||||
|         float logit = logits[i] - max_logit; |         const float logit = logits[i] - max_logit; | ||||||
|         double exp_logit = std::exp(logit); |         const float exp_logit = expf(logit); | ||||||
|         sum_exp += exp_logit; |         sum_exp += exp_logit; | ||||||
|         probs[i] = exp_logit; |         probs[i] = exp_logit; | ||||||
|     } |     } | ||||||
| @@ -24,14 +26,16 @@ void perplexity(llama_context * ctx, const gpt_params & params) { | |||||||
|     auto tokens = ::llama_tokenize(ctx, params.prompt, true); |     auto tokens = ::llama_tokenize(ctx, params.prompt, true); | ||||||
|  |  | ||||||
|     int count = 0; |     int count = 0; | ||||||
|     double nll = 0.0; |  | ||||||
|     int seq_count = tokens.size() / params.n_ctx; |     int seq_count = tokens.size() / params.n_ctx; | ||||||
|  |  | ||||||
|  |     double nll = 0.0; | ||||||
|  |  | ||||||
|     fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count); |     fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count); | ||||||
|  |  | ||||||
|     for (int i = 0; i < seq_count; ++i) { |     for (int i = 0; i < seq_count; ++i) { | ||||||
|         int start = i * params.n_ctx; |         int start = i * params.n_ctx; | ||||||
|         int end = start + params.n_ctx - 1; |         int end = start + params.n_ctx - 1; // TODO: this is not optimal, e.g. it makes the batch 511 instead of 512 | ||||||
|  |                                             //       it is better to always be power of 2 for better performance | ||||||
|         std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end); |         std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end); | ||||||
|         auto start_t = std::chrono::high_resolution_clock::now(); |         auto start_t = std::chrono::high_resolution_clock::now(); | ||||||
|         if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) { |         if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) { | ||||||
| @@ -40,7 +44,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) { | |||||||
|         } |         } | ||||||
|         auto end_t = std::chrono::high_resolution_clock::now(); |         auto end_t = std::chrono::high_resolution_clock::now(); | ||||||
|         if (i == 0) { |         if (i == 0) { | ||||||
|             double seconds = std::chrono::duration<double>(end_t - start_t).count(); |             const float seconds = std::chrono::duration<float>(end_t - start_t).count(); | ||||||
|             printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0)); |             printf("%.2f seconds per pass - ETA %.2f hours\n", seconds, (seconds * seq_count) / (60.0*60.0)); | ||||||
|         } |         } | ||||||
|         // We get the logits for all the tokens in the context window (params.n_ctx) |         // We get the logits for all the tokens in the context window (params.n_ctx) | ||||||
| @@ -63,7 +67,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) { | |||||||
|             std::vector<float> tok_logits( |             std::vector<float> tok_logits( | ||||||
|                 logits + j * n_vocab, |                 logits + j * n_vocab, | ||||||
|                 logits + (j + 1) * n_vocab); |                 logits + (j + 1) * n_vocab); | ||||||
|             double prob = softmax(tok_logits)[tokens[start + j + 1]]; |             const float prob = softmax(tok_logits)[tokens[start + j + 1]]; | ||||||
|             nll += -std::log(prob); |             nll += -std::log(prob); | ||||||
|             ++count; |             ++count; | ||||||
|         } |         } | ||||||
|   | |||||||
| @@ -50,8 +50,8 @@ int main(int argc, char ** argv) { | |||||||
|         const int64_t t_main_end_us = ggml_time_us(); |         const int64_t t_main_end_us = ggml_time_us(); | ||||||
|  |  | ||||||
|         printf("\n"); |         printf("\n"); | ||||||
|         printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0f); |         printf("%s: quantize time = %8.2f ms\n", __func__, t_quantize_us/1000.0); | ||||||
|         printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0f); |         printf("%s:    total time = %8.2f ms\n", __func__, (t_main_end_us - t_main_start_us)/1000.0); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     return 0; |     return 0; | ||||||
|   | |||||||
							
								
								
									
										138
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										138
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -150,10 +150,10 @@ typedef double ggml_float; | |||||||
| // | // | ||||||
| #include <arm_neon.h> | #include <arm_neon.h> | ||||||
|  |  | ||||||
| #define GGML_COMPUTE_FP16_TO_FP32(x) (x) | #define GGML_COMPUTE_FP16_TO_FP32(x) ((float) (x)) | ||||||
| #define GGML_COMPUTE_FP32_TO_FP16(x) (x) | #define GGML_COMPUTE_FP32_TO_FP16(x) (x) | ||||||
|  |  | ||||||
| #define GGML_FP16_TO_FP32(x) (x) | #define GGML_FP16_TO_FP32(x) ((float) (x)) | ||||||
| #define GGML_FP32_TO_FP16(x) (x) | #define GGML_FP32_TO_FP16(x) (x) | ||||||
|  |  | ||||||
| #else | #else | ||||||
| @@ -322,7 +322,7 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) { | |||||||
| // note: do not use these inside ggml.c | // note: do not use these inside ggml.c | ||||||
| // these are meant to be used via the ggml.h API | // these are meant to be used via the ggml.h API | ||||||
| float ggml_fp16_to_fp32(ggml_fp16_t x) { | float ggml_fp16_to_fp32(ggml_fp16_t x) { | ||||||
|     return GGML_FP16_TO_FP32(x); |     return (float) GGML_FP16_TO_FP32(x); | ||||||
| } | } | ||||||
|  |  | ||||||
| ggml_fp16_t ggml_fp32_to_fp16(float x) { | ggml_fp16_t ggml_fp32_to_fp16(float x) { | ||||||
| @@ -488,8 +488,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r | |||||||
|             const float v0 = x[i*QK + l + 0]*id; |             const float v0 = x[i*QK + l + 0]*id; | ||||||
|             const float v1 = x[i*QK + l + 1]*id; |             const float v1 = x[i*QK + l + 1]*id; | ||||||
|  |  | ||||||
|             const uint8_t vi0 = ((int8_t) (round(v0))) + 8; |             const uint8_t vi0 = (int8_t)roundf(v0) + 8; | ||||||
|             const uint8_t vi1 = ((int8_t) (round(v1))) + 8; |             const uint8_t vi1 = (int8_t)roundf(v1) + 8; | ||||||
|  |  | ||||||
|             assert(vi0 >= 0 && vi0 < 16); |             assert(vi0 >= 0 && vi0 < 16); | ||||||
|             assert(vi1 >= 0 && vi1 < 16); |             assert(vi1 >= 0 && vi1 < 16); | ||||||
| @@ -566,7 +566,7 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int | |||||||
|                 MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3))); |                 MAX(vgetq_lane_f32(amaxv[0], 2), vgetq_lane_f32(amaxv[0], 3))); | ||||||
|  |  | ||||||
|         const float d = amax / ((1 << 3) - 1); |         const float d = amax / ((1 << 3) - 1); | ||||||
|         const float id = d ? 1.0/d : 0.0; |         const float id = d ? 1.0f/d : 0.0f; | ||||||
|  |  | ||||||
|         y[i].d = d; |         y[i].d = d; | ||||||
|  |  | ||||||
| @@ -716,8 +716,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int | |||||||
|             const float v0 = (x[i*QK + l + 0] - min)*id; |             const float v0 = (x[i*QK + l + 0] - min)*id; | ||||||
|             const float v1 = (x[i*QK + l + 1] - min)*id; |             const float v1 = (x[i*QK + l + 1] - min)*id; | ||||||
|  |  | ||||||
|             const uint8_t vi0 = round(v0); |             const uint8_t vi0 = roundf(v0); | ||||||
|             const uint8_t vi1 = round(v1); |             const uint8_t vi1 = roundf(v1); | ||||||
|  |  | ||||||
|             assert(vi0 >= 0 && vi0 < 16); |             assert(vi0 >= 0 && vi0 < 16); | ||||||
|             assert(vi1 >= 0 && vi1 < 16); |             assert(vi1 >= 0 && vi1 < 16); | ||||||
| @@ -1001,7 +1001,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in | |||||||
|         }                                                         \ |         }                                                         \ | ||||||
|         const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ |         const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 (x[0])); \ | ||||||
|         const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ |         const float32x4_t t1 = vcvt_f32_f16(vget_high_f16(x[0])); \ | ||||||
|         res = vaddvq_f32(vaddq_f32(t0, t1));                      \ |         res = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \ | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     #define GGML_F16_VEC                GGML_F16x8 |     #define GGML_F16_VEC                GGML_F16x8 | ||||||
| @@ -1437,9 +1437,8 @@ inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, co | |||||||
| inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   } | inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   } | ||||||
|  |  | ||||||
| inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { | inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) { | ||||||
|     ggml_float sumf = 0.0; |  | ||||||
|  |  | ||||||
| #ifdef GGML_SIMD | #ifdef GGML_SIMD | ||||||
|  |     float sumf = 0.0f; | ||||||
|     const int np = (n & ~(GGML_F32_STEP - 1)); |     const int np = (n & ~(GGML_F32_STEP - 1)); | ||||||
|  |  | ||||||
|     GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; |     GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; | ||||||
| @@ -1465,8 +1464,9 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float | |||||||
|     } |     } | ||||||
| #else | #else | ||||||
|     // scalar |     // scalar | ||||||
|  |     ggml_float sumf = 0.0; | ||||||
|     for (int i = 0; i < n; ++i) { |     for (int i = 0; i < n; ++i) { | ||||||
|         sumf += x[i]*y[i]; |         sumf += (ggml_float)(x[i]*y[i]); | ||||||
|     } |     } | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| @@ -1529,11 +1529,11 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t | |||||||
|  |  | ||||||
|     // leftovers |     // leftovers | ||||||
|     for (int i = np; i < n; ++i) { |     for (int i = np; i < n; ++i) { | ||||||
|         sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); |         sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); | ||||||
|     } |     } | ||||||
| #else | #else | ||||||
|     for (int i = 0; i < n; ++i) { |     for (int i = 0; i < n; ++i) { | ||||||
|         sumf += GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i]); |         sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); | ||||||
|     } |     } | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| @@ -1549,7 +1549,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void | |||||||
|     const block_q4_0 * restrict x = vx; |     const block_q4_0 * restrict x = vx; | ||||||
|     const block_q4_0 * restrict y = vy; |     const block_q4_0 * restrict y = vy; | ||||||
|  |  | ||||||
|     float sumf = 0.0; |     ggml_float sumf = 0.0; | ||||||
|  |  | ||||||
| #if defined(__ARM_NEON) | #if defined(__ARM_NEON) | ||||||
|     float sum0 = 0.0f; |     float sum0 = 0.0f; | ||||||
| @@ -1644,7 +1644,7 @@ inline static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void | |||||||
| #endif | #endif | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     sumf = sum0 + sum1; |     sumf = (ggml_float)(sum0 + sum1); | ||||||
| #elif defined(__AVX512F__) | #elif defined(__AVX512F__) | ||||||
|     // Initialize accumulator with zeros |     // Initialize accumulator with zeros | ||||||
|     __m512 acc0 = _mm512_setzero_ps(); |     __m512 acc0 = _mm512_setzero_ps(); | ||||||
| @@ -1972,13 +1972,13 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * re | |||||||
|     // leftovers |     // leftovers | ||||||
|     for (int i = np; i < n; ++i) { |     for (int i = np; i < n; ++i) { | ||||||
|         for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { |         for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { | ||||||
|             sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); |             sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| #else | #else | ||||||
|     for (int i = 0; i < n; ++i) { |     for (int i = 0; i < n; ++i) { | ||||||
|         for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { |         for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { | ||||||
|             sumf[j] += GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i]); |             sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| #endif | #endif | ||||||
| @@ -2049,19 +2049,19 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { | |||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrt(*s);   } | inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s);   } | ||||||
| inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   } | inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   } | ||||||
| inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrt(x[i]); } | inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } | ||||||
| inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } | inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } | ||||||
| inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } | inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } | ||||||
| inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } | inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } | ||||||
| inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } | inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } | ||||||
|  |  | ||||||
| static const ggml_float GELU_COEF_A    = 0.044715; | static const float GELU_COEF_A    = 0.044715f; | ||||||
| static const ggml_float SQRT_2_OVER_PI = 0.79788456080286535587989211986876; | static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; | ||||||
|  |  | ||||||
| inline static float ggml_gelu_f32(float x) { | inline static float ggml_gelu_f32(float x) { | ||||||
|     return 0.5*x*(1.0 + tanh(SQRT_2_OVER_PI*x*(1.0 + GELU_COEF_A*x*x))); |     return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); | ||||||
| } | } | ||||||
|  |  | ||||||
| inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||||
| @@ -2090,7 +2090,7 @@ inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { | |||||||
|  |  | ||||||
| // Sigmoid Linear Unit (SiLU) function | // Sigmoid Linear Unit (SiLU) function | ||||||
| inline static float ggml_silu_f32(float x) { | inline static float ggml_silu_f32(float x) { | ||||||
|     return x/(1.0 + exp(-x)); |     return x/(1.0f + expf(-x)); | ||||||
| } | } | ||||||
|  |  | ||||||
| inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||||
| @@ -2121,7 +2121,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { | |||||||
| #ifndef GGML_USE_ACCELERATE | #ifndef GGML_USE_ACCELERATE | ||||||
|     ggml_float sum = 0.0; |     ggml_float sum = 0.0; | ||||||
|     for (int i = 0; i < n; ++i) { |     for (int i = 0; i < n; ++i) { | ||||||
|         sum += x[i]; |         sum += (ggml_float)x[i]; | ||||||
|     } |     } | ||||||
|     *s = sum; |     *s = sum; | ||||||
| #else | #else | ||||||
| @@ -2131,7 +2131,7 @@ inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { | |||||||
|  |  | ||||||
| inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { | inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { | ||||||
| #ifndef GGML_USE_ACCELERATE | #ifndef GGML_USE_ACCELERATE | ||||||
|     ggml_float max = -INFINITY; |     float max = -INFINITY; | ||||||
|     for (int i = 0; i < n; ++i) { |     for (int i = 0; i < n; ++i) { | ||||||
|         max = MAX(max, x[i]); |         max = MAX(max, x[i]); | ||||||
|     } |     } | ||||||
| @@ -2141,7 +2141,10 @@ inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { | |||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
| inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { ggml_vec_norm_f32(n, s, x); *s = 1./(*s); } | inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { | ||||||
|  |     ggml_vec_norm_f32(n, s, x); | ||||||
|  |     *s = 1.f/(*s); | ||||||
|  | } | ||||||
|  |  | ||||||
| // | // | ||||||
| // logging | // logging | ||||||
| @@ -2540,7 +2543,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) { | |||||||
|                 const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); |                 const float f = table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(ii); | ||||||
|                 table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); |                 table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f)); | ||||||
|                 table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); |                 table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f)); | ||||||
|                 table_exp_f16[i]  = GGML_FP32_TO_FP16(exp(f)); |                 table_exp_f16[i]  = GGML_FP32_TO_FP16(expf(f)); | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             const uint64_t t_end = ggml_time_us(); UNUSED(t_end); |             const uint64_t t_end = ggml_time_us(); UNUSED(t_end); | ||||||
| @@ -5583,7 +5586,7 @@ static void ggml_compute_forward_norm_f32( | |||||||
|     const size_t nb2 = dst->nb[2]; |     const size_t nb2 = dst->nb[2]; | ||||||
|     const size_t nb3 = dst->nb[3]; |     const size_t nb3 = dst->nb[3]; | ||||||
|  |  | ||||||
|     const ggml_float eps = 1e-5f; // TODO: make this a parameter |     const float eps = 1e-5f; // TODO: make this a parameter | ||||||
|  |  | ||||||
|     // TODO: optimize |     // TODO: optimize | ||||||
|     for (int i03 = 0; i03 < ne03; i03++) { |     for (int i03 = 0; i03 < ne03; i03++) { | ||||||
| @@ -5591,23 +5594,24 @@ static void ggml_compute_forward_norm_f32( | |||||||
|             for (int i01 = ith; i01 < ne01; i01 += nth) { |             for (int i01 = ith; i01 < ne01; i01 += nth) { | ||||||
|                 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); |                 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); | ||||||
|  |  | ||||||
|                 ggml_float mean = 0.0; |                 ggml_float sum = 0.0; | ||||||
|                 for (int i00 = 0; i00 < ne00; i00++) { |                 for (int i00 = 0; i00 < ne00; i00++) { | ||||||
|                     mean += x[i00]; |                     sum += (ggml_float)x[i00]; | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 mean /= ne00; |                 float mean = sum/ne00; | ||||||
|  |  | ||||||
|                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); |                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); | ||||||
|  |  | ||||||
|                 ggml_float sum2 = 0.0; |                 ggml_float sum2 = 0.0; | ||||||
|                 for (int i00 = 0; i00 < ne00; i00++) { |                 for (int i00 = 0; i00 < ne00; i00++) { | ||||||
|                     ggml_float v = x[i00] - mean; |                     float v = x[i00] - mean; | ||||||
|                     y[i00] = v; |                     y[i00] = v; | ||||||
|                     sum2 += v*v; |                     sum2 += (ggml_float)(v*v); | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 const float scale = 1.0/sqrt(sum2/ne00 + eps); |                 float variance = sum2/ne00; | ||||||
|  |                 const float scale = 1.0f/sqrtf(variance + eps); | ||||||
|  |  | ||||||
|                 ggml_vec_scale_f32(ne00, y, scale); |                 ggml_vec_scale_f32(ne00, y, scale); | ||||||
|             } |             } | ||||||
| @@ -5665,7 +5669,7 @@ static void ggml_compute_forward_rms_norm_f32( | |||||||
|     const size_t nb2 = dst->nb[2]; |     const size_t nb2 = dst->nb[2]; | ||||||
|     const size_t nb3 = dst->nb[3]; |     const size_t nb3 = dst->nb[3]; | ||||||
|  |  | ||||||
|     const ggml_float eps = 1e-6f; // TODO: make this a parameter |     const float eps = 1e-6f; // TODO: make this a parameter | ||||||
|  |  | ||||||
|     // TODO: optimize |     // TODO: optimize | ||||||
|     for (int i03 = 0; i03 < ne03; i03++) { |     for (int i03 = 0; i03 < ne03; i03++) { | ||||||
| @@ -5673,12 +5677,12 @@ static void ggml_compute_forward_rms_norm_f32( | |||||||
|             for (int i01 = ith; i01 < ne01; i01 += nth) { |             for (int i01 = ith; i01 < ne01; i01 += nth) { | ||||||
|                 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); |                 const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); | ||||||
|  |  | ||||||
|                 ggml_float mean = 0.0; |                 ggml_float sum = 0.0; | ||||||
|                 for (int i00 = 0; i00 < ne00; i00++) { |                 for (int i00 = 0; i00 < ne00; i00++) { | ||||||
|                     mean += x[i00] * x[i00]; |                     sum += (ggml_float)(x[i00] * x[i00]); | ||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 mean /= ne00; |                 float mean = sum/ne00; | ||||||
|  |  | ||||||
|                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); |                 float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); | ||||||
|  |  | ||||||
| @@ -5687,7 +5691,7 @@ static void ggml_compute_forward_rms_norm_f32( | |||||||
|                 //     y[i00] = x[i00]; |                 //     y[i00] = x[i00]; | ||||||
|                 // } |                 // } | ||||||
|  |  | ||||||
|                 const float scale = 1.0/sqrt(mean + eps); |                 const float scale = 1.0f/sqrtf(mean + eps); | ||||||
|  |  | ||||||
|                 ggml_vec_scale_f32(ne00, y, scale); |                 ggml_vec_scale_f32(ne00, y, scale); | ||||||
|             } |             } | ||||||
| @@ -6913,12 +6917,12 @@ static void ggml_compute_forward_soft_max_f32( | |||||||
|                 ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max); |                 ggml_fp16_t s = GGML_FP32_TO_FP16(p[i] - max); | ||||||
|                 memcpy(&scvt, &s, sizeof(scvt)); |                 memcpy(&scvt, &s, sizeof(scvt)); | ||||||
|                 const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); |                 const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt]); | ||||||
|                 sum += val; |                 sum += (ggml_float)val; | ||||||
|                 p[i] = val; |                 p[i] = val; | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         assert(sum > 0.0f); |         assert(sum > 0.0); | ||||||
|  |  | ||||||
|         sum = 1.0/sum; |         sum = 1.0/sum; | ||||||
|         ggml_vec_scale_f32(nc, p, sum); |         ggml_vec_scale_f32(nc, p, sum); | ||||||
| @@ -6994,16 +6998,16 @@ static void ggml_compute_forward_rope_f32( | |||||||
|             const int p = (mode == 0 ? n_past + i2 : i2); |             const int p = (mode == 0 ? n_past + i2 : i2); | ||||||
|             for (int i1 = 0; i1 < ne1; i1++) { |             for (int i1 = 0; i1 < ne1; i1++) { | ||||||
|                 for (int i0 = 0; i0 < n_dims; i0 += 2) { |                 for (int i0 = 0; i0 < n_dims; i0 += 2) { | ||||||
|                     const double theta = pow(10000.0, ((double)-i0)/n_dims); |                     const float theta = powf(10000.0, ((float)-i0)/n_dims); | ||||||
|  |  | ||||||
|                     const double cos_theta = cos(p*theta); |                     const float cos_theta = cosf(p*theta); | ||||||
|                     const double sin_theta = sin(p*theta); |                     const float sin_theta = sinf(p*theta); | ||||||
|  |  | ||||||
|                     const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); |                     const float * const src = (float *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); | ||||||
|                           float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); |                           float * dst_data  = (float *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); | ||||||
|  |  | ||||||
|                     double x0 = src[0]; |                     const float x0 = src[0]; | ||||||
|                     double x1 = src[1]; |                     const float x1 = src[1]; | ||||||
|  |  | ||||||
|                     dst_data[0] = x0*cos_theta - x1*sin_theta; |                     dst_data[0] = x0*cos_theta - x1*sin_theta; | ||||||
|                     dst_data[1] = x0*sin_theta + x1*cos_theta; |                     dst_data[1] = x0*sin_theta + x1*cos_theta; | ||||||
| @@ -7050,16 +7054,16 @@ static void ggml_compute_forward_rope_f16( | |||||||
|             const int p = (mode == 0 ? n_past + i2 : i2); |             const int p = (mode == 0 ? n_past + i2 : i2); | ||||||
|             for (int i1 = 0; i1 < ne1; i1++) { |             for (int i1 = 0; i1 < ne1; i1++) { | ||||||
|                 for (int i0 = 0; i0 < n_dims; i0 += 2) { |                 for (int i0 = 0; i0 < n_dims; i0 += 2) { | ||||||
|                     const double theta = pow(10000.0, ((double)-i0)/n_dims); |                     const float theta = powf(10000.0, ((float)-i0)/n_dims); | ||||||
|  |  | ||||||
|                     const double cos_theta = cos(p*theta); |                     const float cos_theta = cosf(p*theta); | ||||||
|                     const double sin_theta = sin(p*theta); |                     const float sin_theta = sinf(p*theta); | ||||||
|  |  | ||||||
|                     const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); |                     const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); | ||||||
|                           ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); |                           ggml_fp16_t * dst_data  = (ggml_fp16_t *)((char *)  dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); | ||||||
|  |  | ||||||
|                     double x0 = ggml_fp16_to_fp32(src[0]); |                     const float x0 = ggml_fp16_to_fp32(src[0]); | ||||||
|                     double x1 = ggml_fp16_to_fp32(src[1]); |                     const float x1 = ggml_fp16_to_fp32(src[1]); | ||||||
|  |  | ||||||
|                     dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta); |                     dst_data[0] = ggml_fp32_to_fp16(x0*cos_theta - x1*sin_theta); | ||||||
|                     dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta); |                     dst_data[1] = ggml_fp32_to_fp16(x0*sin_theta + x1*cos_theta); | ||||||
| @@ -7735,7 +7739,7 @@ static void ggml_compute_forward_flash_attn_f32( | |||||||
|     const int ir0 = dr*ith; |     const int ir0 = dr*ith; | ||||||
|     const int ir1 = MIN(ir0 + dr, nr); |     const int ir1 = MIN(ir0 + dr, nr); | ||||||
|  |  | ||||||
|     const float scale = 1.0/sqrt((double) D); |     const float scale = 1.0f/sqrtf(D); | ||||||
|  |  | ||||||
|     //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); |     //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); | ||||||
|  |  | ||||||
| @@ -7782,7 +7786,7 @@ static void ggml_compute_forward_flash_attn_f32( | |||||||
|             float max = -INFINITY; |             float max = -INFINITY; | ||||||
|             ggml_vec_max_f32(M, &max, S); |             ggml_vec_max_f32(M, &max, S); | ||||||
|  |  | ||||||
|             float sum = 0.0f; |             ggml_float sum = 0.0; | ||||||
|             { |             { | ||||||
| #ifdef GGML_SOFT_MAX_ACCELERATE | #ifdef GGML_SOFT_MAX_ACCELERATE | ||||||
|                 max = -max; |                 max = -max; | ||||||
| @@ -7803,7 +7807,7 @@ static void ggml_compute_forward_flash_attn_f32( | |||||||
|                             ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); |                             ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); | ||||||
|                             memcpy(&scvt[j], &s, sizeof(uint16_t)); |                             memcpy(&scvt[j], &s, sizeof(uint16_t)); | ||||||
|                             const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); |                             const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); | ||||||
|                             sump[j] += val; |                             sump[j] += (ggml_float)val; | ||||||
|                             SS[j] = val; |                             SS[j] = val; | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
| @@ -7815,7 +7819,7 @@ static void ggml_compute_forward_flash_attn_f32( | |||||||
| #endif | #endif | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             assert(sum > 0.0f); |             assert(sum > 0.0); | ||||||
|  |  | ||||||
|             sum = 1.0/sum; |             sum = 1.0/sum; | ||||||
|             ggml_vec_scale_f32(M, S, sum); |             ggml_vec_scale_f32(M, S, sum); | ||||||
| @@ -7944,7 +7948,7 @@ static void ggml_compute_forward_flash_attn_f16( | |||||||
|     const int ir0 = dr*ith; |     const int ir0 = dr*ith; | ||||||
|     const int ir1 = MIN(ir0 + dr, nr); |     const int ir1 = MIN(ir0 + dr, nr); | ||||||
|  |  | ||||||
|     const float scale = 1.0/sqrt((double) D); |     const float scale = 1.0f/sqrtf(D); | ||||||
|  |  | ||||||
|     //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); |     //printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale); | ||||||
|  |  | ||||||
| @@ -8008,7 +8012,7 @@ static void ggml_compute_forward_flash_attn_f16( | |||||||
|             float max = -INFINITY; |             float max = -INFINITY; | ||||||
|             ggml_vec_max_f32(M, &max, S); |             ggml_vec_max_f32(M, &max, S); | ||||||
|  |  | ||||||
|             float sum = 0.0f; |             ggml_float sum = 0.0; | ||||||
|             { |             { | ||||||
| #ifdef GGML_SOFT_MAX_ACCELERATE | #ifdef GGML_SOFT_MAX_ACCELERATE | ||||||
|                 max = -max; |                 max = -max; | ||||||
| @@ -8029,7 +8033,7 @@ static void ggml_compute_forward_flash_attn_f16( | |||||||
|                             ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); |                             ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); | ||||||
|                             memcpy(&scvt[j], &s, sizeof(uint16_t)); |                             memcpy(&scvt[j], &s, sizeof(uint16_t)); | ||||||
|                             const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); |                             const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); | ||||||
|                             sump[j] += val; |                             sump[j] += (ggml_float)val; | ||||||
|                             SS[j] = val; |                             SS[j] = val; | ||||||
|                         } |                         } | ||||||
|                     } |                     } | ||||||
| @@ -8041,7 +8045,7 @@ static void ggml_compute_forward_flash_attn_f16( | |||||||
| #endif | #endif | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             assert(sum > 0.0f); |             assert(sum > 0.0); | ||||||
|  |  | ||||||
|             sum = 1.0/sum; |             sum = 1.0/sum; | ||||||
|             ggml_vec_scale_f32(M, S, sum); |             ggml_vec_scale_f32(M, S, sum); | ||||||
| @@ -9566,7 +9570,7 @@ label=\"%d [%d, %d] | <x>%s", | |||||||
|             fprintf(fp, "  \"%p\" [ \ |             fprintf(fp, "  \"%p\" [ \ | ||||||
| style = filled; fillcolor = %s; shape = record; \ | style = filled; fillcolor = %s; shape = record; \ | ||||||
| label=\"<x>%.1e\"; ]\n", | label=\"<x>%.1e\"; ]\n", | ||||||
|                     (void *) node, color, ggml_get_f32_1d(node, 0)); |                     (void *) node, color, (double)ggml_get_f32_1d(node, 0)); | ||||||
|         } else { |         } else { | ||||||
|             fprintf(fp, "  \"%p\" [ \ |             fprintf(fp, "  \"%p\" [ \ | ||||||
| style = filled; fillcolor = %s; shape = record; \ | style = filled; fillcolor = %s; shape = record; \ | ||||||
| @@ -9804,7 +9808,7 @@ static enum ggml_opt_result ggml_opt_adam( | |||||||
|             if (params.past <= t) { |             if (params.past <= t) { | ||||||
|                 const float rate = (pf[t%params.past] - fx)/fx; |                 const float rate = (pf[t%params.past] - fx)/fx; | ||||||
|  |  | ||||||
|                 if (fabs(rate) < params.delta) { |                 if (fabsf(rate) < params.delta) { | ||||||
|                     return GGML_OPT_OK; |                     return GGML_OPT_OK; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| @@ -9883,7 +9887,7 @@ static enum ggml_opt_result linesearch_backtracking( | |||||||
|     const float dec = 0.5f; |     const float dec = 0.5f; | ||||||
|     const float inc = 2.1f; |     const float inc = 2.1f; | ||||||
|  |  | ||||||
|     if (*step <= 0.) { |     if (*step <= 0.f) { | ||||||
|         return GGML_LINESEARCH_INVALID_PARAMETERS; |         return GGML_LINESEARCH_INVALID_PARAMETERS; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -9971,7 +9975,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( | |||||||
|         struct ggml_cgraph * gb) { |         struct ggml_cgraph * gb) { | ||||||
|     if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || |     if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || | ||||||
|         params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { |         params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { | ||||||
|         if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1. <= params.lbfgs.wolfe) { |         if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { | ||||||
|             return GGML_OPT_INVALID_WOLFE; |             return GGML_OPT_INVALID_WOLFE; | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
| @@ -10092,8 +10096,8 @@ static enum ggml_opt_result ggml_opt_lbfgs( | |||||||
|  |  | ||||||
|         GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0)); |         GGML_PRINT_DEBUG("f = %10.6f\n", ggml_get_f32_1d(f, 0)); | ||||||
|  |  | ||||||
|         if (xnorm < 1.0) { |         if (xnorm < 1.0f) { | ||||||
|             xnorm = 1.0; |             xnorm = 1.0f; | ||||||
|         } |         } | ||||||
|         if (gnorm/xnorm <= params.lbfgs.eps) { |         if (gnorm/xnorm <= params.lbfgs.eps) { | ||||||
|             // converged |             // converged | ||||||
| @@ -10106,7 +10110,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( | |||||||
|             if (params.past <= k) { |             if (params.past <= k) { | ||||||
|                 const float rate = (pf[k%params.past] - fx)/fx; |                 const float rate = (pf[k%params.past] - fx)/fx; | ||||||
|  |  | ||||||
|                 if (fabs(rate) < params.delta) { |                 if (fabsf(rate) < params.delta) { | ||||||
|                     return GGML_OPT_OK; |                     return GGML_OPT_OK; | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|   | |||||||
							
								
								
									
										52
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -779,8 +779,8 @@ static bool llama_model_load( | |||||||
|  |  | ||||||
|                 // progress |                 // progress | ||||||
|                 if (progress_callback) { |                 if (progress_callback) { | ||||||
|                     double current_file_progress = double(size_t(fin.tellg()) - file_offset) / double(file_size - file_offset); |                     float current_file_progress = float(size_t(fin.tellg()) - file_offset) / float(file_size - file_offset); | ||||||
|                     double current_progress = (double(i) + current_file_progress) / double(n_parts); |                     float current_progress = (float(i) + current_file_progress) / float(n_parts); | ||||||
|                     progress_callback(current_progress, progress_callback_user_data); |                     progress_callback(current_progress, progress_callback_user_data); | ||||||
|                 } |                 } | ||||||
|                 if (model.n_loaded % 8 == 0) { |                 if (model.n_loaded % 8 == 0) { | ||||||
| @@ -922,7 +922,7 @@ static bool llama_eval_internal( | |||||||
|             struct ggml_tensor * KQ_scaled = |             struct ggml_tensor * KQ_scaled = | ||||||
|                 ggml_scale(ctx0, |                 ggml_scale(ctx0, | ||||||
|                         KQ, |                         KQ, | ||||||
|                         ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))); |                         ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head))); | ||||||
|  |  | ||||||
|             // KQ_masked = mask_past(KQ_scaled) |             // KQ_masked = mask_past(KQ_scaled) | ||||||
|             struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); |             struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); | ||||||
| @@ -1240,12 +1240,12 @@ static std::vector<llama_vocab::id> llama_tokenize(const llama_vocab & vocab, co | |||||||
| // sampling | // sampling | ||||||
| // | // | ||||||
|  |  | ||||||
| static void sample_top_k(std::vector<std::pair<double, llama_vocab::id>> & logits_id, int top_k) { | static void sample_top_k(std::vector<std::pair<float, llama_vocab::id>> & logits_id, int top_k) { | ||||||
|     // find the top k tokens |     // find the top k tokens | ||||||
|     std::partial_sort( |     std::partial_sort( | ||||||
|             logits_id.begin(), |             logits_id.begin(), | ||||||
|             logits_id.begin() + top_k, logits_id.end(), |             logits_id.begin() + top_k, logits_id.end(), | ||||||
|             [](const std::pair<double, llama_vocab::id> & a, const std::pair<double, llama_vocab::id> & b) { |             [](const std::pair<float, llama_vocab::id> & a, const std::pair<float, llama_vocab::id> & b) { | ||||||
|         return a.first > b.first; |         return a.first > b.first; | ||||||
|     }); |     }); | ||||||
|  |  | ||||||
| @@ -1256,9 +1256,9 @@ static llama_vocab::id llama_sample_top_p_top_k( | |||||||
|         llama_context & lctx, |         llama_context & lctx, | ||||||
|         const std::vector<llama_vocab::id> & last_n_tokens, |         const std::vector<llama_vocab::id> & last_n_tokens, | ||||||
|         int top_k, |         int top_k, | ||||||
|         double top_p, |         float top_p, | ||||||
|         double temp, |         float temp, | ||||||
|         double repeat_penalty) { |         float repeat_penalty) { | ||||||
|     auto & rng = lctx.rng; |     auto & rng = lctx.rng; | ||||||
|  |  | ||||||
|     const int n_logits = lctx.model.hparams.n_vocab; |     const int n_logits = lctx.model.hparams.n_vocab; | ||||||
| @@ -1266,17 +1266,17 @@ static llama_vocab::id llama_sample_top_p_top_k( | |||||||
|     const auto & logits = lctx.logits; |     const auto & logits = lctx.logits; | ||||||
|     const auto * plogits = logits.data() + logits.size() - n_logits; |     const auto * plogits = logits.data() + logits.size() - n_logits; | ||||||
|  |  | ||||||
|     std::vector<std::pair<double, llama_vocab::id>> logits_id; |     std::vector<std::pair<float, llama_vocab::id>> logits_id; | ||||||
|     logits_id.reserve(n_logits); |     logits_id.reserve(n_logits); | ||||||
|  |  | ||||||
|     { |     { | ||||||
|         const double scale = 1.0/temp; |         const float scale = 1.0f/temp; | ||||||
|         for (int i = 0; i < n_logits; ++i) { |         for (int i = 0; i < n_logits; ++i) { | ||||||
|             // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858) |             // repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858) | ||||||
|             // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main |             // credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main | ||||||
|             if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { |             if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) { | ||||||
|                 // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability |                 // if score < 0 then repetition penalty has to multiplied to reduce the previous token probability | ||||||
|                 if (plogits[i] < 0.0) { |                 if (plogits[i] < 0.0f) { | ||||||
|                     logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); |                     logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i)); | ||||||
|                 } else { |                 } else { | ||||||
|                     logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); |                     logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i)); | ||||||
| @@ -1289,18 +1289,18 @@ static llama_vocab::id llama_sample_top_p_top_k( | |||||||
|  |  | ||||||
|     sample_top_k(logits_id, top_k); |     sample_top_k(logits_id, top_k); | ||||||
|  |  | ||||||
|     double maxl = -std::numeric_limits<double>::infinity(); |     float maxl = -std::numeric_limits<float>::infinity(); | ||||||
|     for (const auto & kv : logits_id) { |     for (const auto & kv : logits_id) { | ||||||
|         maxl = std::max(maxl, kv.first); |         maxl = std::max(maxl, kv.first); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // compute probs for the top k tokens |     // compute probs for the top k tokens | ||||||
|     std::vector<double> probs; |     std::vector<float> probs; | ||||||
|     probs.reserve(logits_id.size()); |     probs.reserve(logits_id.size()); | ||||||
|  |  | ||||||
|     double sum = 0.0; |     double sum = 0.0; | ||||||
|     for (const auto & kv : logits_id) { |     for (const auto & kv : logits_id) { | ||||||
|         double p = exp(kv.first - maxl); |         const float p = expf(kv.first - maxl); | ||||||
|         probs.push_back(p); |         probs.push_back(p); | ||||||
|         sum += p; |         sum += p; | ||||||
|     } |     } | ||||||
| @@ -1310,8 +1310,8 @@ static llama_vocab::id llama_sample_top_p_top_k( | |||||||
|         p /= sum; |         p /= sum; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (top_p < 1.0f) { |     if (top_p < 1.0) { | ||||||
|         double cumsum = 0.0f; |         double cumsum = 0.0; | ||||||
|         for (int i = 0; i < (int) probs.size(); i++) { |         for (int i = 0; i < (int) probs.size(); i++) { | ||||||
|             cumsum += probs[i]; |             cumsum += probs[i]; | ||||||
|             if (cumsum >= top_p) { |             if (cumsum >= top_p) { | ||||||
| @@ -1590,7 +1590,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s | |||||||
|                 } |                 } | ||||||
|  |  | ||||||
|                 for (int i = 0; i < (int) hist_cur.size(); ++i) { |                 for (int i = 0; i < (int) hist_cur.size(); ++i) { | ||||||
|                     printf("%5.3f ", hist_cur[i] / (float)nelements); |                     printf("%5.3f ", hist_cur[i] / float(nelements)); | ||||||
|                 } |                 } | ||||||
|                 printf("\n"); |                 printf("\n"); | ||||||
|             } else { |             } else { | ||||||
| @@ -1613,7 +1613,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s | |||||||
|  |  | ||||||
|             printf("%s: hist: ", __func__); |             printf("%s: hist: ", __func__); | ||||||
|             for (int i = 0; i < (int) hist_all.size(); ++i) { |             for (int i = 0; i < (int) hist_all.size(); ++i) { | ||||||
|                 printf("%5.3f ", hist_all[i] / (float)sum_all); |                 printf("%5.3f ", hist_all[i] / float(sum_all)); | ||||||
|             } |             } | ||||||
|             printf("\n"); |             printf("\n"); | ||||||
|         } |         } | ||||||
| @@ -1795,9 +1795,9 @@ llama_token llama_sample_top_p_top_k( | |||||||
|       const llama_token * last_n_tokens_data, |       const llama_token * last_n_tokens_data, | ||||||
|                     int   last_n_tokens_size, |                     int   last_n_tokens_size, | ||||||
|                     int   top_k, |                     int   top_k, | ||||||
|                  double   top_p, |                   float   top_p, | ||||||
|                  double   temp, |                   float   temp, | ||||||
|                  double   repeat_penalty) { |                   float   repeat_penalty) { | ||||||
|     const int64_t t_start_sample_us = ggml_time_us(); |     const int64_t t_start_sample_us = ggml_time_us(); | ||||||
|  |  | ||||||
|     llama_token result = 0; |     llama_token result = 0; | ||||||
| @@ -1828,11 +1828,11 @@ void llama_print_timings(struct llama_context * ctx) { | |||||||
|     const int32_t n_p_eval = std::max(1, ctx->n_p_eval); |     const int32_t n_p_eval = std::max(1, ctx->n_p_eval); | ||||||
|  |  | ||||||
|     fprintf(stderr, "\n"); |     fprintf(stderr, "\n"); | ||||||
|     fprintf(stderr, "%s:        load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0f); |     fprintf(stderr, "%s:        load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0); | ||||||
|     fprintf(stderr, "%s:      sample time = %8.2f ms / %5d runs   (%8.2f ms per run)\n",   __func__, 1e-3f * ctx->t_sample_us, n_sample, 1e-3f * ctx->t_sample_us / n_sample); |     fprintf(stderr, "%s:      sample time = %8.2f ms / %5d runs   (%8.2f ms per run)\n",   __func__, 1e-3 * ctx->t_sample_us, n_sample, 1e-3 * ctx->t_sample_us / n_sample); | ||||||
|     fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3f * ctx->t_p_eval_us, n_p_eval, 1e-3f * ctx->t_p_eval_us / n_p_eval); |     fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_p_eval_us, n_p_eval, 1e-3 * ctx->t_p_eval_us / n_p_eval); | ||||||
|     fprintf(stderr, "%s:        eval time = %8.2f ms / %5d runs   (%8.2f ms per run)\n",   __func__, 1e-3f * ctx->t_eval_us,   n_eval,   1e-3f * ctx->t_eval_us   / n_eval); |     fprintf(stderr, "%s:        eval time = %8.2f ms / %5d runs   (%8.2f ms per run)\n",   __func__, 1e-3 * ctx->t_eval_us,   n_eval,   1e-3 * ctx->t_eval_us   / n_eval); | ||||||
|     fprintf(stderr, "%s:       total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0f); |     fprintf(stderr, "%s:       total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0); | ||||||
| } | } | ||||||
|  |  | ||||||
| void llama_reset_timings(struct llama_context * ctx) { | void llama_reset_timings(struct llama_context * ctx) { | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								llama.h
									
									
									
									
									
								
							| @@ -45,7 +45,7 @@ extern "C" { | |||||||
|  |  | ||||||
|     } llama_token_data; |     } llama_token_data; | ||||||
|  |  | ||||||
|     typedef void (*llama_progress_callback)(double progress, void *ctx); |     typedef void (*llama_progress_callback)(float progress, void *ctx); | ||||||
|  |  | ||||||
|     struct llama_context_params { |     struct llama_context_params { | ||||||
|         int n_ctx;   // text context |         int n_ctx;   // text context | ||||||
| @@ -134,9 +134,9 @@ extern "C" { | |||||||
|           const llama_token * last_n_tokens_data, |           const llama_token * last_n_tokens_data, | ||||||
|                         int   last_n_tokens_size, |                         int   last_n_tokens_size, | ||||||
|                         int   top_k, |                         int   top_k, | ||||||
|                      double   top_p, |                       float   top_p, | ||||||
|                      double   temp, |                       float   temp, | ||||||
|                      double   repeat_penalty); |                       float   repeat_penalty); | ||||||
|  |  | ||||||
|     // Performance information |     // Performance information | ||||||
|     LLAMA_API void llama_print_timings(struct llama_context * ctx); |     LLAMA_API void llama_print_timings(struct llama_context * ctx); | ||||||
|   | |||||||
| @@ -5,5 +5,6 @@ function(llama_add_test source) | |||||||
|     add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN}) |     add_test(NAME ${TEST_TARGET} COMMAND $<TARGET_FILE:${TEST_TARGET}> ${ARGN}) | ||||||
| endfunction() | endfunction() | ||||||
|  |  | ||||||
|  | # llama_add_test(test-double-float.c) # SLOW | ||||||
| llama_add_test(test-quantize.c) | llama_add_test(test-quantize.c) | ||||||
| llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin) | llama_add_test(test-tokenizer-0.cpp ${CMAKE_CURRENT_SOURCE_DIR}/../models/ggml-vocab.bin) | ||||||
|   | |||||||
							
								
								
									
										53
									
								
								tests/test-double-float.c
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								tests/test-double-float.c
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | |||||||
|  | // These tests may take a long time! | ||||||
|  | // They are to prove that conversion from double to float of various functions in ggml.c doesn't affect the result. | ||||||
|  | // This is done by checking all finite (non-NaN, non-infinite) floats. | ||||||
|  |  | ||||||
|  | #undef NDEBUG | ||||||
|  | #include <assert.h> | ||||||
|  | #include <immintrin.h> | ||||||
|  | #include <math.h> | ||||||
|  | #include <stdint.h> | ||||||
|  |  | ||||||
|  | #pragma GCC diagnostic push | ||||||
|  | #pragma GCC diagnostic ignored "-Wdouble-promotion" | ||||||
|  |  | ||||||
|  | // ggml.c::quantize_row_q4_0_reference | ||||||
|  | inline static uint8_t round_orig(float v0) { return ((int8_t) (round(v0))) + 8; } | ||||||
|  |  | ||||||
|  | // ggml.c::ggml_silu_f32 | ||||||
|  | inline static float silu_orig(float x) { | ||||||
|  |     return x/(1.0 + exp(-x)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #pragma GCC diagnostic pop | ||||||
|  |  | ||||||
|  | // ggml.c::quantize_row_q4_0_reference | ||||||
|  | inline static uint8_t round_float(float v0) { return (int8_t)roundf(v0) + 8; } | ||||||
|  |  | ||||||
|  | // ggml.c::ggml_silu_f32 | ||||||
|  | inline static float silu_float(float x) { | ||||||
|  |     return x/(1.0f + expf(-x)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | int main(void) { | ||||||
|  |     uint32_t x = UINT32_MAX; | ||||||
|  |     do { | ||||||
|  |         float f = *(float *)&x; | ||||||
|  |         assert(!isfinite(f) || (round_orig(f) == round_float(f))); | ||||||
|  |     } while (x--); | ||||||
|  |  | ||||||
|  | #ifdef __F16C__ | ||||||
|  |     // GELU and SILU implementations are used with a FP16 lookup table. | ||||||
|  |     // The original and float-only results are not equal for all inputs after converting to FP16. | ||||||
|  |     // GELU is an approximation anyway (tanh), not tested here. | ||||||
|  |     // For SILU, verify that the results are at least the closest floating point numbers, if the FP16 values don't match. | ||||||
|  |     for (x = 0; x <= UINT16_MAX; x++) { | ||||||
|  |         float f = _cvtsh_ss(x); | ||||||
|  |         const float so = silu_orig(f); | ||||||
|  |         const float sf = silu_float(f); | ||||||
|  |         assert(   (_cvtss_sh(so, 0) == _cvtss_sh(sf, 0)) | ||||||
|  |                || (nextafterf(so, sf) == sf) | ||||||
|  |                || (nextafterf(sf, so) == so)); | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user
	 Stephan Walter
					Stephan Walter