mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	cpu: move all the operators into a separate c++ file (except mul_mat) (ggml/1167)
* cpu: refactor SIMD mappings and vectorized op functions into separate files * Fix warning for ggml_float to float * Fix warnings * cpu: move all the operations (except mul_mat) to a separate c++ file * fix whitespace * Update ggml/src/ggml-cpu/vec.h Co-authored-by: Diego Devesa <slarengh@gmail.com> * Fix PR comments - use GGML_UNUSED, use cassert in ops.cpp * Reverse the order of import for ops.h and vec.h, to match what was present in ggml-cpu.c previously --------- Co-authored-by: Diego Devesa <slarengh@gmail.com>
This commit is contained in:
		| @@ -28,6 +28,11 @@ function(ggml_add_cpu_backend_variant_impl tag_name) | ||||
|         ggml-cpu/binary-ops.cpp | ||||
|         ggml-cpu/unary-ops.h | ||||
|         ggml-cpu/unary-ops.cpp | ||||
|         ggml-cpu/simd-mappings.h | ||||
|         ggml-cpu/vec.h | ||||
|         ggml-cpu/vec.cpp | ||||
|         ggml-cpu/ops.h | ||||
|         ggml-cpu/ops.cpp | ||||
|         ) | ||||
|  | ||||
|     target_compile_features(${GGML_CPU_NAME} PRIVATE c_std_11 cxx_std_17) | ||||
|   | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										8719
									
								
								ggml/src/ggml-cpu/ops.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										8719
									
								
								ggml/src/ggml-cpu/ops.cpp
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										128
									
								
								ggml/src/ggml-cpu/ops.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								ggml/src/ggml-cpu/ops.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,128 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include "ggml.h" | ||||
|  | ||||
| // | ||||
| // cache line | ||||
| // | ||||
|  | ||||
| #if defined(__cpp_lib_hardware_interference_size) | ||||
| #define CACHE_LINE_SIZE std::hardware_destructive_interference_size | ||||
| #else | ||||
| #if defined(__POWER9_VECTOR__) | ||||
| #define CACHE_LINE_SIZE 128 | ||||
| #elif defined(__VXE__) || defined(__VXE2__) | ||||
| #define CACHE_LINE_SIZE 256 | ||||
| #else | ||||
| #define CACHE_LINE_SIZE 64 | ||||
| #endif | ||||
| #endif | ||||
|  | ||||
| static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float); | ||||
|  | ||||
| #ifdef __cplusplus | ||||
| extern "C" { | ||||
| #endif | ||||
|  | ||||
| void ggml_compute_forward_dup(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_add(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_add1(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_acc(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_sum(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_sum_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_mean(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_argmax(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_count_equal(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_repeat(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_repeat_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_concat(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_silu_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_rms_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_rms_norm_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_group_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_l2_norm(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_out_prod(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_scale(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_set(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_cpy(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_cont(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_reshape(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_view(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_permute(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_soft_max(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_soft_max_ext_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_rope(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_rope_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_pool_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_pool_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_pool_2d_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_upscale(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_pad(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_pad_reflect_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_flash_attn_ext( | ||||
|     const struct ggml_compute_params * params, | ||||
|     const struct ggml_tensor * q, | ||||
|     const struct ggml_tensor * k, | ||||
|     const struct ggml_tensor * v, | ||||
|     const struct ggml_tensor * mask, | ||||
|     struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_flash_attn_back( | ||||
|         const struct ggml_compute_params * params, | ||||
|         const bool masked, | ||||
|         struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_ssm_conv(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_rwkv_wkv7(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_gla(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_map_unary( | ||||
|     const struct ggml_compute_params * params, | ||||
|     struct ggml_tensor * dst, | ||||
|     const ggml_unary_op_f32_t fun); | ||||
| void ggml_compute_forward_map_binary( | ||||
|     const struct ggml_compute_params * params, | ||||
|     struct ggml_tensor * dst, | ||||
|     const ggml_binary_op_f32_t fun); | ||||
| void ggml_compute_forward_map_custom1_f32( | ||||
|     const struct ggml_compute_params * params, | ||||
|     struct ggml_tensor * dst, | ||||
|     const ggml_custom1_op_f32_t fun); | ||||
| void ggml_compute_forward_map_custom2_f32( | ||||
|     const struct ggml_compute_params * params, | ||||
|     struct ggml_tensor * dst, | ||||
|     const ggml_custom2_op_f32_t fun); | ||||
| void ggml_compute_forward_map_custom3_f32( | ||||
|     const struct ggml_compute_params * params, | ||||
|     struct ggml_tensor * dst, | ||||
|     const ggml_custom3_op_f32_t fun); | ||||
| void ggml_compute_forward_map_custom1(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_map_custom2(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_map_custom3(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
| void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst); | ||||
|  | ||||
| #ifdef __cplusplus | ||||
| } | ||||
| #endif | ||||
							
								
								
									
										884
									
								
								ggml/src/ggml-cpu/simd-mappings.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										884
									
								
								ggml/src/ggml-cpu/simd-mappings.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,884 @@ | ||||
| #pragma once | ||||
|  | ||||
| #include "ggml-cpu-impl.h" | ||||
|  | ||||
| // | ||||
| // simd mappings | ||||
| // | ||||
|  | ||||
| // we define a common set of C macros which map to specific intrinsics based on the current architecture | ||||
| // we then implement the fundamental computation operations below using only these macros | ||||
| // adding support for new architectures requires to define the corresponding SIMD macros | ||||
| // | ||||
| // GGML_F32_STEP / GGML_F16_STEP | ||||
| //   number of elements to process in a single step | ||||
| // | ||||
| // GGML_F32_EPR / GGML_F16_EPR | ||||
| //   number of elements to fit in a single register | ||||
| // | ||||
|  | ||||
| #if defined(__ARM_NEON) && defined(__ARM_FEATURE_FMA) | ||||
|  | ||||
| #define GGML_SIMD | ||||
|  | ||||
| // F32 NEON | ||||
|  | ||||
| #define GGML_F32_STEP 16 | ||||
| #define GGML_F32_EPR  4 | ||||
|  | ||||
| #define GGML_F32x4              float32x4_t | ||||
| #define GGML_F32x4_ZERO         vdupq_n_f32(0.0f) | ||||
| #define GGML_F32x4_SET1(x)      vdupq_n_f32(x) | ||||
| #define GGML_F32x4_LOAD         vld1q_f32 | ||||
| #define GGML_F32x4_STORE        vst1q_f32 | ||||
| #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) | ||||
| #define GGML_F32x4_ADD          vaddq_f32 | ||||
| #define GGML_F32x4_MUL          vmulq_f32 | ||||
| #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) | ||||
| #define GGML_F32x4_REDUCE(res, x)                       \ | ||||
| {                                                       \ | ||||
|     int offset = GGML_F32_ARR >> 1;                     \ | ||||
|     for (int i = 0; i < offset; ++i) {                  \ | ||||
|         (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \ | ||||
|     }                                                   \ | ||||
|     offset >>= 1;                                       \ | ||||
|     for (int i = 0; i < offset; ++i) {                  \ | ||||
|         (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \ | ||||
|     }                                                   \ | ||||
|     offset >>= 1;                                       \ | ||||
|     for (int i = 0; i < offset; ++i) {                  \ | ||||
|         (x)[i] = vaddq_f32((x)[i], (x)[offset+i]);      \ | ||||
|     }                                                   \ | ||||
|     (res) = (ggml_float) GGML_F32x4_REDUCE_ONE((x)[0]); \ | ||||
| } | ||||
|  | ||||
| #define GGML_F32_VEC        GGML_F32x4 | ||||
| #define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO | ||||
| #define GGML_F32_VEC_SET1   GGML_F32x4_SET1 | ||||
| #define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD | ||||
| #define GGML_F32_VEC_STORE  GGML_F32x4_STORE | ||||
| #define GGML_F32_VEC_FMA    GGML_F32x4_FMA | ||||
| #define GGML_F32_VEC_ADD    GGML_F32x4_ADD | ||||
| #define GGML_F32_VEC_MUL    GGML_F32x4_MUL | ||||
| #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE | ||||
|  | ||||
| // F16 NEON | ||||
|  | ||||
| #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) | ||||
|     #define GGML_F16_STEP 32 | ||||
|     #define GGML_F16_EPR  8 | ||||
|  | ||||
|     #define GGML_F16x8              float16x8_t | ||||
|     #define GGML_F16x8_ZERO         vdupq_n_f16(0.0f) | ||||
|     #define GGML_F16x8_SET1(x)      vdupq_n_f16(x) | ||||
|     #define GGML_F16x8_LOAD(x)      vld1q_f16((const ggml_fp16_internal_t *)(x)) | ||||
|     #define GGML_F16x8_STORE        vst1q_f16 | ||||
|     #define GGML_F16x8_FMA(a, b, c) vfmaq_f16(a, b, c) | ||||
|     #define GGML_F16x8_ADD          vaddq_f16 | ||||
|     #define GGML_F16x8_MUL          vmulq_f16 | ||||
|     #define GGML_F16x8_REDUCE(res, x)                               \ | ||||
|     do {                                                            \ | ||||
|         int offset = GGML_F16_ARR >> 1;                             \ | ||||
|         for (int i = 0; i < offset; ++i) {                          \ | ||||
|             (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \ | ||||
|         }                                                           \ | ||||
|         offset >>= 1;                                               \ | ||||
|         for (int i = 0; i < offset; ++i) {                          \ | ||||
|             (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \ | ||||
|         }                                                           \ | ||||
|         offset >>= 1;                                               \ | ||||
|         for (int i = 0; i < offset; ++i) {                          \ | ||||
|             (x)[i] = vaddq_f16((x)[i], (x)[offset+i]);              \ | ||||
|         }                                                           \ | ||||
|         const float32x4_t t0 = vcvt_f32_f16(vget_low_f16 ((x)[0])); \ | ||||
|         const float32x4_t t1 = vcvt_f32_f16(vget_high_f16((x)[0])); \ | ||||
|         (res) = (ggml_float) vaddvq_f32(vaddq_f32(t0, t1));         \ | ||||
|     } while (0) | ||||
|  | ||||
|     #define GGML_F16_VEC                GGML_F16x8 | ||||
|     #define GGML_F16_VEC_ZERO           GGML_F16x8_ZERO | ||||
|     #define GGML_F16_VEC_SET1           GGML_F16x8_SET1 | ||||
|     #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x8_LOAD(p) | ||||
|     #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((ggml_fp16_internal_t *)(p), (r)[i]) | ||||
|     #define GGML_F16_VEC_FMA            GGML_F16x8_FMA | ||||
|     #define GGML_F16_VEC_ADD            GGML_F16x8_ADD | ||||
|     #define GGML_F16_VEC_MUL            GGML_F16x8_MUL | ||||
|     #define GGML_F16_VEC_REDUCE         GGML_F16x8_REDUCE | ||||
| #else | ||||
|     // if FP16 vector arithmetic is not supported, we use FP32 instead | ||||
|     // and take advantage of the vcvt_ functions to convert to/from FP16 | ||||
|  | ||||
|     #define GGML_F16_STEP 16 | ||||
|     #define GGML_F16_EPR  4 | ||||
|  | ||||
|     #define GGML_F32Cx4              float32x4_t | ||||
|     #define GGML_F32Cx4_ZERO         vdupq_n_f32(0.0f) | ||||
|     #define GGML_F32Cx4_SET1(x)      vdupq_n_f32(x) | ||||
|     #define GGML_F32Cx4_LOAD(x)      vcvt_f32_f16(vld1_f16((const ggml_fp16_internal_t *)(x))) | ||||
|     #define GGML_F32Cx4_STORE(x, y)  vst1_f16(x, vcvt_f16_f32(y)) | ||||
|     #define GGML_F32Cx4_FMA(a, b, c) vfmaq_f32(a, b, c) | ||||
|     #define GGML_F32Cx4_ADD          vaddq_f32 | ||||
|     #define GGML_F32Cx4_MUL          vmulq_f32 | ||||
|     #define GGML_F32Cx4_REDUCE       GGML_F32x4_REDUCE | ||||
|  | ||||
|     #define GGML_F16_VEC                GGML_F32Cx4 | ||||
|     #define GGML_F16_VEC_ZERO           GGML_F32Cx4_ZERO | ||||
|     #define GGML_F16_VEC_SET1           GGML_F32Cx4_SET1 | ||||
|     #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx4_LOAD(p) | ||||
|     #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((ggml_fp16_internal_t *)(p), r[i]) | ||||
|     #define GGML_F16_VEC_FMA            GGML_F32Cx4_FMA | ||||
|     #define GGML_F16_VEC_ADD            GGML_F32Cx4_ADD | ||||
|     #define GGML_F16_VEC_MUL            GGML_F32Cx4_MUL | ||||
|     #define GGML_F16_VEC_REDUCE         GGML_F32Cx4_REDUCE | ||||
| #endif | ||||
|  | ||||
| #elif defined(__AVX512F__) | ||||
|  | ||||
| #define GGML_SIMD | ||||
|  | ||||
| // F32 AVX512 | ||||
|  | ||||
| #define GGML_F32_STEP 64 | ||||
| #define GGML_F32_EPR  16 | ||||
|  | ||||
| #define GGML_F32x16         __m512 | ||||
| #define GGML_F32x16_ZERO    _mm512_setzero_ps() | ||||
| #define GGML_F32x16_SET1(x) _mm512_set1_ps(x) | ||||
| #define GGML_F32x16_LOAD    _mm512_loadu_ps | ||||
| #define GGML_F32x16_STORE   _mm512_storeu_ps | ||||
| // _mm512_fmadd_ps is defined in AVX512F so no guard is required | ||||
| #define GGML_F32x16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) | ||||
| #define GGML_F32x16_ADD     _mm512_add_ps | ||||
| #define GGML_F32x16_MUL     _mm512_mul_ps | ||||
| #define GGML_F32x16_REDUCE(res, x)                                    \ | ||||
| do {                                                                  \ | ||||
|     int offset = GGML_F32_ARR >> 1;                                   \ | ||||
|     for (int i = 0; i < offset; ++i) {                                \ | ||||
|         x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \ | ||||
|     }                                                                 \ | ||||
|     offset >>= 1;                                                     \ | ||||
|     for (int i = 0; i < offset; ++i) {                                \ | ||||
|         x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \ | ||||
|     }                                                                 \ | ||||
|     offset >>= 1;                                                     \ | ||||
|     for (int i = 0; i < offset; ++i) {                                \ | ||||
|         x[i] = _mm512_add_ps(x[i], x[offset+i]);                      \ | ||||
|     }                                                                 \ | ||||
|     res = (ggml_float) _mm512_reduce_add_ps(x[0]);                    \ | ||||
| } while (0) | ||||
|  | ||||
| // TODO: is this optimal ? | ||||
|  | ||||
| #define GGML_F32_VEC        GGML_F32x16 | ||||
| #define GGML_F32_VEC_ZERO   GGML_F32x16_ZERO | ||||
| #define GGML_F32_VEC_SET1   GGML_F32x16_SET1 | ||||
| #define GGML_F32_VEC_LOAD   GGML_F32x16_LOAD | ||||
| #define GGML_F32_VEC_STORE  GGML_F32x16_STORE | ||||
| #define GGML_F32_VEC_FMA    GGML_F32x16_FMA | ||||
| #define GGML_F32_VEC_ADD    GGML_F32x16_ADD | ||||
| #define GGML_F32_VEC_MUL    GGML_F32x16_MUL | ||||
| #define GGML_F32_VEC_REDUCE GGML_F32x16_REDUCE | ||||
|  | ||||
| // F16 AVX512 | ||||
|  | ||||
| // F16 AVX | ||||
|  | ||||
| #define GGML_F16_STEP 64 | ||||
| #define GGML_F16_EPR  16 | ||||
|  | ||||
| // AVX512 has FP16 extension (AVX512_FP16) but I don't have it on my machine so I use FP32 instead | ||||
|  | ||||
| #define GGML_F32Cx16             __m512 | ||||
| #define GGML_F32Cx16_ZERO        _mm512_setzero_ps() | ||||
| #define GGML_F32Cx16_SET1(x)     _mm512_set1_ps(x) | ||||
|  | ||||
| // unlike  _mm256_cvt intrinsics that require F16C, _mm512_cvt is defined in AVX512F | ||||
| // so F16C guard isn't required | ||||
| #define GGML_F32Cx16_LOAD(x)     _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) | ||||
| #define GGML_F32Cx16_STORE(x, y) _mm256_storeu_si256((__m256i *)(x), _mm512_cvtps_ph(y, 0)) | ||||
|  | ||||
| #define GGML_F32Cx16_FMA(a, b, c) _mm512_fmadd_ps(b, c, a) | ||||
| #define GGML_F32Cx16_ADD         _mm512_add_ps | ||||
| #define GGML_F32Cx16_MUL         _mm512_mul_ps | ||||
| #define GGML_F32Cx16_REDUCE(res, x)                               \ | ||||
| do {                                                              \ | ||||
|     int offset = GGML_F32_ARR >> 1;                               \ | ||||
|     for (int i = 0; i < offset; ++i) {                            \ | ||||
|         x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \ | ||||
|     }                                                             \ | ||||
|     offset >>= 1;                                                 \ | ||||
|     for (int i = 0; i < offset; ++i) {                            \ | ||||
|         x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \ | ||||
|     }                                                             \ | ||||
|     offset >>= 1;                                                 \ | ||||
|     for (int i = 0; i < offset; ++i) {                            \ | ||||
|         x[i] = _mm512_add_ps(x[i], x[offset+i]);                  \ | ||||
|     }                                                             \ | ||||
|     res = (ggml_float) _mm512_reduce_add_ps(x[0]);                \ | ||||
| } while (0) | ||||
|  | ||||
| #define GGML_F16_VEC                GGML_F32Cx16 | ||||
| #define GGML_F16_VEC_ZERO           GGML_F32Cx16_ZERO | ||||
| #define GGML_F16_VEC_SET1           GGML_F32Cx16_SET1 | ||||
| #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx16_LOAD(p) | ||||
| #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx16_STORE(p, r[i]) | ||||
| #define GGML_F16_VEC_FMA            GGML_F32Cx16_FMA | ||||
| #define GGML_F16_VEC_ADD            GGML_F32Cx16_ADD | ||||
| #define GGML_F16_VEC_MUL            GGML_F32Cx16_MUL | ||||
|  | ||||
| #define GGML_F16_VEC_REDUCE         GGML_F32Cx16_REDUCE | ||||
| #elif defined(__AVX__) | ||||
|  | ||||
| #define GGML_SIMD | ||||
|  | ||||
| // F32 AVX | ||||
|  | ||||
| #define GGML_F32_STEP 32 | ||||
| #define GGML_F32_EPR  8 | ||||
|  | ||||
| #define GGML_F32x8         __m256 | ||||
| #define GGML_F32x8_ZERO    _mm256_setzero_ps() | ||||
| #define GGML_F32x8_SET1(x) _mm256_set1_ps(x) | ||||
| #define GGML_F32x8_LOAD    _mm256_loadu_ps | ||||
| #define GGML_F32x8_STORE   _mm256_storeu_ps | ||||
| #if defined(__FMA__) | ||||
|     #define GGML_F32x8_FMA(a, b, c) _mm256_fmadd_ps(b, c, a) | ||||
| #else | ||||
|     #define GGML_F32x8_FMA(a, b, c) _mm256_add_ps(_mm256_mul_ps(b, c), a) | ||||
| #endif | ||||
| #define GGML_F32x8_ADD     _mm256_add_ps | ||||
| #define GGML_F32x8_MUL     _mm256_mul_ps | ||||
| #define GGML_F32x8_REDUCE(res, x)                                 \ | ||||
| do {                                                              \ | ||||
|     int offset = GGML_F32_ARR >> 1;                               \ | ||||
|     for (int i = 0; i < offset; ++i) {                            \ | ||||
|         x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \ | ||||
|     }                                                             \ | ||||
|     offset >>= 1;                                                 \ | ||||
|     for (int i = 0; i < offset; ++i) {                            \ | ||||
|         x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \ | ||||
|     }                                                             \ | ||||
|     offset >>= 1;                                                 \ | ||||
|     for (int i = 0; i < offset; ++i) {                            \ | ||||
|         x[i] = _mm256_add_ps(x[i], x[offset+i]);                  \ | ||||
|     }                                                             \ | ||||
|     const __m128 t0 = _mm_add_ps(_mm256_castps256_ps128(x[0]),    \ | ||||
|                                  _mm256_extractf128_ps(x[0], 1)); \ | ||||
|     const __m128 t1 = _mm_hadd_ps(t0, t0);                        \ | ||||
|     res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t1, t1));        \ | ||||
| } while (0) | ||||
| // TODO: is this optimal ? | ||||
|  | ||||
| #define GGML_F32_VEC        GGML_F32x8 | ||||
| #define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO | ||||
| #define GGML_F32_VEC_SET1   GGML_F32x8_SET1 | ||||
| #define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD | ||||
| #define GGML_F32_VEC_STORE  GGML_F32x8_STORE | ||||
| #define GGML_F32_VEC_FMA    GGML_F32x8_FMA | ||||
| #define GGML_F32_VEC_ADD    GGML_F32x8_ADD | ||||
| #define GGML_F32_VEC_MUL    GGML_F32x8_MUL | ||||
| #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE | ||||
|  | ||||
| // F16 AVX | ||||
|  | ||||
| #define GGML_F16_STEP 32 | ||||
| #define GGML_F16_EPR  8 | ||||
|  | ||||
| // F16 arithmetic is not supported by AVX, so we use F32 instead | ||||
|  | ||||
| #define GGML_F32Cx8             __m256 | ||||
| #define GGML_F32Cx8_ZERO        _mm256_setzero_ps() | ||||
| #define GGML_F32Cx8_SET1(x)     _mm256_set1_ps(x) | ||||
|  | ||||
| #if defined(__F16C__) | ||||
| // the  _mm256_cvt intrinsics require F16C | ||||
| #define GGML_F32Cx8_LOAD(x)     _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) | ||||
| #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) | ||||
| #else | ||||
| static inline __m256 __avx_f32cx8_load(const ggml_fp16_t * x) { | ||||
|     float tmp[8]; | ||||
|  | ||||
|     for (int i = 0; i < 8; i++) { | ||||
|         tmp[i] = GGML_FP16_TO_FP32(x[i]); | ||||
|     } | ||||
|  | ||||
|     return _mm256_loadu_ps(tmp); | ||||
| } | ||||
| static inline void __avx_f32cx8_store(ggml_fp16_t *x, __m256 y) { | ||||
|     float arr[8]; | ||||
|  | ||||
|     _mm256_storeu_ps(arr, y); | ||||
|  | ||||
|     for (int i = 0; i < 8; i++) | ||||
|         x[i] = GGML_FP32_TO_FP16(arr[i]); | ||||
| } | ||||
| #define GGML_F32Cx8_LOAD(x)     __avx_f32cx8_load(x) | ||||
| #define GGML_F32Cx8_STORE(x, y) __avx_f32cx8_store(x, y) | ||||
| #endif | ||||
|  | ||||
| #define GGML_F32Cx8_FMA         GGML_F32x8_FMA | ||||
| #define GGML_F32Cx8_ADD         _mm256_add_ps | ||||
| #define GGML_F32Cx8_MUL         _mm256_mul_ps | ||||
| #define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE | ||||
|  | ||||
| #define GGML_F16_VEC                GGML_F32Cx8 | ||||
| #define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO | ||||
| #define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1 | ||||
| #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p) | ||||
| #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i]) | ||||
| #define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA | ||||
| #define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD | ||||
| #define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL | ||||
| #define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE | ||||
|  | ||||
| #elif defined(__POWER9_VECTOR__) | ||||
|  | ||||
| #define GGML_SIMD | ||||
|  | ||||
| // F32 POWER9 | ||||
|  | ||||
| #define GGML_F32_STEP 32 | ||||
| #define GGML_F32_EPR  4 | ||||
|  | ||||
| #define GGML_F32x4              vector float | ||||
| #define GGML_F32x4_ZERO         0.0f | ||||
| #define GGML_F32x4_SET1         vec_splats | ||||
| #define GGML_F32x4_LOAD(p)      vec_xl(0, p) | ||||
| #define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p) | ||||
| #define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) | ||||
| #define GGML_F32x4_ADD          vec_add | ||||
| #define GGML_F32x4_MUL          vec_mul | ||||
| #define GGML_F32x4_REDUCE(res, x)              \ | ||||
| {                                              \ | ||||
|     int offset = GGML_F32_ARR >> 1;            \ | ||||
|     for (int i = 0; i < offset; ++i) {         \ | ||||
|         x[i] = vec_add(x[i], x[offset+i]);     \ | ||||
|     }                                          \ | ||||
|     offset >>= 1;                              \ | ||||
|     for (int i = 0; i < offset; ++i) {         \ | ||||
|         x[i] = vec_add(x[i], x[offset+i]);     \ | ||||
|     }                                          \ | ||||
|     offset >>= 1;                              \ | ||||
|     for (int i = 0; i < offset; ++i) {         \ | ||||
|         x[i] = vec_add(x[i], x[offset+i]);     \ | ||||
|     }                                          \ | ||||
|     res = vec_extract(x[0], 0) +               \ | ||||
|           vec_extract(x[0], 1) +               \ | ||||
|           vec_extract(x[0], 2) +               \ | ||||
|           vec_extract(x[0], 3);                \ | ||||
| } | ||||
|  | ||||
| #define GGML_F32_VEC        GGML_F32x4 | ||||
| #define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO | ||||
| #define GGML_F32_VEC_SET1   GGML_F32x4_SET1 | ||||
| #define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD | ||||
| #define GGML_F32_VEC_STORE  GGML_F32x4_STORE | ||||
| #define GGML_F32_VEC_FMA    GGML_F32x4_FMA | ||||
| #define GGML_F32_VEC_ADD    GGML_F32x4_ADD | ||||
| #define GGML_F32_VEC_MUL    GGML_F32x4_MUL | ||||
| #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE | ||||
|  | ||||
| // F16 POWER9 | ||||
| #define GGML_F16_STEP       GGML_F32_STEP | ||||
| #define GGML_F16_EPR        GGML_F32_EPR | ||||
| #define GGML_F16_VEC        GGML_F32x4 | ||||
| #define GGML_F16_VEC_ZERO   GGML_F32x4_ZERO | ||||
| #define GGML_F16_VEC_SET1   GGML_F32x4_SET1 | ||||
| #define GGML_F16_VEC_FMA    GGML_F32x4_FMA | ||||
| #define GGML_F16_VEC_ADD    GGML_F32x4_ADD | ||||
| #define GGML_F16_VEC_MUL    GGML_F32x4_MUL | ||||
| #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE | ||||
| // Use vec_xl, not vec_ld, in case the load address is not aligned. | ||||
| #define GGML_F16_VEC_LOAD(p, i) (i & 0x1) ?                   \ | ||||
|   vec_extract_fp32_from_shorth(vec_xl(0, p - GGML_F16_EPR)) : \ | ||||
|   vec_extract_fp32_from_shortl(vec_xl(0, p)) | ||||
| #define GGML_ENDIAN_BYTE(i) ((unsigned char *)&(uint16_t){1})[i] | ||||
| #define GGML_F16_VEC_STORE(p, r, i)                             \ | ||||
|   if (i & 0x1)                                                  \ | ||||
|     vec_xst(vec_pack_to_short_fp32(r[i - GGML_ENDIAN_BYTE(1)],  \ | ||||
|                                    r[i - GGML_ENDIAN_BYTE(0)]), \ | ||||
|             0, p - GGML_F16_EPR) | ||||
|  | ||||
| #elif defined(__wasm_simd128__) | ||||
|  | ||||
| #define GGML_SIMD | ||||
|  | ||||
| // F32 WASM | ||||
|  | ||||
| #define GGML_F32_STEP 16 | ||||
| #define GGML_F32_EPR  4 | ||||
|  | ||||
| #define GGML_F32x4              v128_t | ||||
| #define GGML_F32x4_ZERO         wasm_f32x4_splat(0.0f) | ||||
| #define GGML_F32x4_SET1(x)      wasm_f32x4_splat(x) | ||||
| #define GGML_F32x4_LOAD         wasm_v128_load | ||||
| #define GGML_F32x4_STORE        wasm_v128_store | ||||
| #define GGML_F32x4_FMA(a, b, c) wasm_f32x4_add(wasm_f32x4_mul(b, c), a) | ||||
| #define GGML_F32x4_ADD          wasm_f32x4_add | ||||
| #define GGML_F32x4_MUL          wasm_f32x4_mul | ||||
| #define GGML_F32x4_REDUCE(res, x)                  \ | ||||
| {                                                  \ | ||||
|     int offset = GGML_F32_ARR >> 1;                \ | ||||
|     for (int i = 0; i < offset; ++i) {             \ | ||||
|         x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \ | ||||
|     }                                              \ | ||||
|     offset >>= 1;                                  \ | ||||
|     for (int i = 0; i < offset; ++i) {             \ | ||||
|         x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \ | ||||
|     }                                              \ | ||||
|     offset >>= 1;                                  \ | ||||
|     for (int i = 0; i < offset; ++i) {             \ | ||||
|         x[i] = wasm_f32x4_add(x[i], x[offset+i]);  \ | ||||
|     }                                              \ | ||||
|     res = wasm_f32x4_extract_lane(x[0], 0) +       \ | ||||
|           wasm_f32x4_extract_lane(x[0], 1) +       \ | ||||
|           wasm_f32x4_extract_lane(x[0], 2) +       \ | ||||
|           wasm_f32x4_extract_lane(x[0], 3);        \ | ||||
| } | ||||
|  | ||||
| #define GGML_F32_VEC        GGML_F32x4 | ||||
| #define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO | ||||
| #define GGML_F32_VEC_SET1   GGML_F32x4_SET1 | ||||
| #define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD | ||||
| #define GGML_F32_VEC_STORE  GGML_F32x4_STORE | ||||
| #define GGML_F32_VEC_FMA    GGML_F32x4_FMA | ||||
| #define GGML_F32_VEC_ADD    GGML_F32x4_ADD | ||||
| #define GGML_F32_VEC_MUL    GGML_F32x4_MUL | ||||
| #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE | ||||
|  | ||||
| // F16 WASM | ||||
|  | ||||
| #define GGML_F16_STEP 16 | ||||
| #define GGML_F16_EPR  4 | ||||
|  | ||||
| inline static v128_t __wasm_f16x4_load(const ggml_fp16_t * p) { | ||||
|     float tmp[4]; | ||||
|  | ||||
|     tmp[0] = GGML_FP16_TO_FP32(p[0]); | ||||
|     tmp[1] = GGML_FP16_TO_FP32(p[1]); | ||||
|     tmp[2] = GGML_FP16_TO_FP32(p[2]); | ||||
|     tmp[3] = GGML_FP16_TO_FP32(p[3]); | ||||
|  | ||||
|     return wasm_v128_load(tmp); | ||||
| } | ||||
|  | ||||
| inline static void __wasm_f16x4_store(ggml_fp16_t * p, v128_t x) { | ||||
|     float tmp[4]; | ||||
|  | ||||
|     wasm_v128_store(tmp, x); | ||||
|  | ||||
|     p[0] = GGML_FP32_TO_FP16(tmp[0]); | ||||
|     p[1] = GGML_FP32_TO_FP16(tmp[1]); | ||||
|     p[2] = GGML_FP32_TO_FP16(tmp[2]); | ||||
|     p[3] = GGML_FP32_TO_FP16(tmp[3]); | ||||
| } | ||||
|  | ||||
| #define GGML_F16x4             v128_t | ||||
| #define GGML_F16x4_ZERO        wasm_f32x4_splat(0.0f) | ||||
| #define GGML_F16x4_SET1(x)     wasm_f32x4_splat(x) | ||||
| #define GGML_F16x4_LOAD(x)     __wasm_f16x4_load(x) | ||||
| #define GGML_F16x4_STORE(x, y) __wasm_f16x4_store(x, y) | ||||
| #define GGML_F16x4_FMA         GGML_F32x4_FMA | ||||
| #define GGML_F16x4_ADD         wasm_f32x4_add | ||||
| #define GGML_F16x4_MUL         wasm_f32x4_mul | ||||
| #define GGML_F16x4_REDUCE(res, x)                           \ | ||||
| {                                                           \ | ||||
|     int offset = GGML_F16_ARR >> 1;                         \ | ||||
|     for (int i = 0; i < offset; ++i) {                      \ | ||||
|         x[i] = wasm_f32x4_add(x[i], x[offset+i]);           \ | ||||
|     }                                                       \ | ||||
|     offset >>= 1;                                           \ | ||||
|     for (int i = 0; i < offset; ++i) {                      \ | ||||
|         x[i] = wasm_f32x4_add(x[i], x[offset+i]);           \ | ||||
|     }                                                       \ | ||||
|     offset >>= 1;                                           \ | ||||
|     for (int i = 0; i < offset; ++i) {                      \ | ||||
|         x[i] = wasm_f32x4_add(x[i], x[offset+i]);           \ | ||||
|     }                                                       \ | ||||
|     res = (ggml_float) (wasm_f32x4_extract_lane(x[0], 0) +  \ | ||||
|           wasm_f32x4_extract_lane(x[0], 1) +                \ | ||||
|           wasm_f32x4_extract_lane(x[0], 2) +                \ | ||||
|           wasm_f32x4_extract_lane(x[0], 3));                \ | ||||
| } | ||||
|  | ||||
| #define GGML_F16_VEC                GGML_F16x4 | ||||
| #define GGML_F16_VEC_ZERO           GGML_F16x4_ZERO | ||||
| #define GGML_F16_VEC_SET1           GGML_F16x4_SET1 | ||||
| #define GGML_F16_VEC_LOAD(p, i)     GGML_F16x4_LOAD(p) | ||||
| #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x4_STORE(p, r[i]) | ||||
| #define GGML_F16_VEC_FMA            GGML_F16x4_FMA | ||||
| #define GGML_F16_VEC_ADD            GGML_F16x4_ADD | ||||
| #define GGML_F16_VEC_MUL            GGML_F16x4_MUL | ||||
| #define GGML_F16_VEC_REDUCE         GGML_F16x4_REDUCE | ||||
|  | ||||
| #elif defined(__SSE3__) | ||||
|  | ||||
| #define GGML_SIMD | ||||
|  | ||||
| // F32 SSE | ||||
|  | ||||
| #define GGML_F32_STEP 32 | ||||
| #define GGML_F32_EPR  4 | ||||
|  | ||||
| #define GGML_F32x4         __m128 | ||||
| #define GGML_F32x4_ZERO    _mm_setzero_ps() | ||||
| #define GGML_F32x4_SET1(x) _mm_set1_ps(x) | ||||
| #define GGML_F32x4_LOAD    _mm_loadu_ps | ||||
| #define GGML_F32x4_STORE   _mm_storeu_ps | ||||
| #if defined(__FMA__) | ||||
|     // TODO: Does this work? | ||||
|     #define GGML_F32x4_FMA(a, b, c) _mm_fmadd_ps(b, c, a) | ||||
| #else | ||||
|     #define GGML_F32x4_FMA(a, b, c) _mm_add_ps(_mm_mul_ps(b, c), a) | ||||
| #endif | ||||
| #define GGML_F32x4_ADD     _mm_add_ps | ||||
| #define GGML_F32x4_MUL     _mm_mul_ps | ||||
| #define GGML_F32x4_REDUCE(res, x)                                 \ | ||||
| {                                                                 \ | ||||
|     int offset = GGML_F32_ARR >> 1;                               \ | ||||
|     for (int i = 0; i < offset; ++i) {                            \ | ||||
|         x[i] = _mm_add_ps(x[i], x[offset+i]);                     \ | ||||
|     }                                                             \ | ||||
|     offset >>= 1;                                                 \ | ||||
|     for (int i = 0; i < offset; ++i) {                            \ | ||||
|         x[i] = _mm_add_ps(x[i], x[offset+i]);                     \ | ||||
|     }                                                             \ | ||||
|     offset >>= 1;                                                 \ | ||||
|     for (int i = 0; i < offset; ++i) {                            \ | ||||
|         x[i] = _mm_add_ps(x[i], x[offset+i]);                     \ | ||||
|     }                                                             \ | ||||
|     const __m128 t0 = _mm_hadd_ps(x[0], x[0]);                    \ | ||||
|     res = (ggml_float) _mm_cvtss_f32(_mm_hadd_ps(t0, t0));        \ | ||||
| } | ||||
| // TODO: is this optimal ? | ||||
|  | ||||
| #define GGML_F32_VEC        GGML_F32x4 | ||||
| #define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO | ||||
| #define GGML_F32_VEC_SET1   GGML_F32x4_SET1 | ||||
| #define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD | ||||
| #define GGML_F32_VEC_STORE  GGML_F32x4_STORE | ||||
| #define GGML_F32_VEC_FMA    GGML_F32x4_FMA | ||||
| #define GGML_F32_VEC_ADD    GGML_F32x4_ADD | ||||
| #define GGML_F32_VEC_MUL    GGML_F32x4_MUL | ||||
| #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE | ||||
|  | ||||
| // F16 SSE | ||||
|  | ||||
| #define GGML_F16_STEP 32 | ||||
| #define GGML_F16_EPR  4 | ||||
|  | ||||
| static inline __m128 __sse_f16x4_load(const ggml_fp16_t * x) { | ||||
|     float tmp[4]; | ||||
|  | ||||
|     tmp[0] = GGML_FP16_TO_FP32(x[0]); | ||||
|     tmp[1] = GGML_FP16_TO_FP32(x[1]); | ||||
|     tmp[2] = GGML_FP16_TO_FP32(x[2]); | ||||
|     tmp[3] = GGML_FP16_TO_FP32(x[3]); | ||||
|  | ||||
|     return _mm_loadu_ps(tmp); | ||||
| } | ||||
|  | ||||
| static inline void __sse_f16x4_store(ggml_fp16_t * x, __m128 y) { | ||||
|     float arr[4]; | ||||
|  | ||||
|     _mm_storeu_ps(arr, y); | ||||
|  | ||||
|     x[0] = GGML_FP32_TO_FP16(arr[0]); | ||||
|     x[1] = GGML_FP32_TO_FP16(arr[1]); | ||||
|     x[2] = GGML_FP32_TO_FP16(arr[2]); | ||||
|     x[3] = GGML_FP32_TO_FP16(arr[3]); | ||||
| } | ||||
|  | ||||
| #define GGML_F32Cx4             __m128 | ||||
| #define GGML_F32Cx4_ZERO        _mm_setzero_ps() | ||||
| #define GGML_F32Cx4_SET1(x)     _mm_set1_ps(x) | ||||
| #define GGML_F32Cx4_LOAD(x)     __sse_f16x4_load(x) | ||||
| #define GGML_F32Cx4_STORE(x, y) __sse_f16x4_store(x, y) | ||||
| #define GGML_F32Cx4_FMA         GGML_F32x4_FMA | ||||
| #define GGML_F32Cx4_ADD         _mm_add_ps | ||||
| #define GGML_F32Cx4_MUL         _mm_mul_ps | ||||
| #define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE | ||||
|  | ||||
| #define GGML_F16_VEC                 GGML_F32Cx4 | ||||
| #define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO | ||||
| #define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1 | ||||
| #define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p) | ||||
| #define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i]) | ||||
| #define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA | ||||
| #define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD | ||||
| #define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL | ||||
| #define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE | ||||
|  | ||||
| #elif defined(__loongarch_asx) | ||||
|  | ||||
| #define GGML_SIMD | ||||
|  | ||||
| // F32 LASX | ||||
| #define GGML_F32_STEP 32 | ||||
| #define GGML_F32_EPR  8 | ||||
|  | ||||
| #define GGML_F32x8         __m256 | ||||
| #define GGML_F32x8_ZERO    (__m256)__lasx_xvldi(0) | ||||
| #define GGML_F32x8_SET1(x) (__m256)__lasx_xvreplfr2vr_s((x)) | ||||
| #define GGML_F32x8_LOAD(x) (__m256)__lasx_xvld((x), 0) | ||||
| #define GGML_F32x8_STORE(x,y)   __lasx_xvst((y), (x), 0) | ||||
| #define GGML_F32x8_FMA(a, b, c) __lasx_xvfmadd_s(b, c, a) | ||||
| #define GGML_F32x8_ADD     __lasx_xvfadd_s | ||||
| #define GGML_F32x8_MUL     __lasx_xvfmul_s | ||||
| #define GGML_F32x8_REDUCE(res, x)                                 \ | ||||
| do {                                                              \ | ||||
|     int offset = GGML_F32_ARR >> 1;                               \ | ||||
|     for (int i = 0; i < offset; ++i) {                            \ | ||||
|         x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \ | ||||
|     }                                                             \ | ||||
|     offset >>= 1;                                                 \ | ||||
|     for (int i = 0; i < offset; ++i) {                            \ | ||||
|         x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \ | ||||
|     }                                                             \ | ||||
|     offset >>= 1;                                                 \ | ||||
|     for (int i = 0; i < offset; ++i) {                            \ | ||||
|         x[i] = __lasx_xvfadd_s(x[i], x[offset+i]);                  \ | ||||
|     }                                                             \ | ||||
|     float *tmp_p = (float *)&x[0]; \ | ||||
|     res = tmp_p[0] + tmp_p[1] + tmp_p[2] + tmp_p[3] + tmp_p[4] + tmp_p[5] + tmp_p[6] + tmp_p[7];  \ | ||||
| } while (0) | ||||
| // TODO: is this optimal ? | ||||
|  | ||||
| #define GGML_F32_VEC        GGML_F32x8 | ||||
| #define GGML_F32_VEC_ZERO   GGML_F32x8_ZERO | ||||
| #define GGML_F32_VEC_SET1   GGML_F32x8_SET1 | ||||
| #define GGML_F32_VEC_LOAD   GGML_F32x8_LOAD | ||||
| #define GGML_F32_VEC_STORE  GGML_F32x8_STORE | ||||
| #define GGML_F32_VEC_FMA    GGML_F32x8_FMA | ||||
| #define GGML_F32_VEC_ADD    GGML_F32x8_ADD | ||||
| #define GGML_F32_VEC_MUL    GGML_F32x8_MUL | ||||
| #define GGML_F32_VEC_REDUCE GGML_F32x8_REDUCE | ||||
|  | ||||
| // F16 LASX | ||||
|  | ||||
| #define GGML_F16_STEP 32 | ||||
| #define GGML_F16_EPR  8 | ||||
|  | ||||
| // F16 arithmetic is not supported by LASX, so we use F32 instead | ||||
|  | ||||
| #define GGML_F32Cx8          __m256 | ||||
| #define GGML_F32Cx8_ZERO    (__m256)__lasx_xvldi(0) | ||||
| #define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x)) | ||||
|  | ||||
| static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) { | ||||
|     __m256i a; | ||||
|     memcpy(&a, x, sizeof(ggml_fp16_t) * 8); | ||||
|     a = __lasx_xvpermi_d(a, 0 | (1 << 4)); | ||||
|     return __lasx_xvfcvtl_s_h(a); | ||||
| } | ||||
|  | ||||
| static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) { | ||||
|     __m256i a = __lasx_xvfcvt_h_s(y, y); | ||||
|     a = __lasx_xvpermi_d(a, 0 | (2 << 2)); | ||||
|     memcpy(x, &a, sizeof(ggml_fp16_t) * 8); | ||||
| } | ||||
| #define GGML_F32Cx8_LOAD(x)     __lasx_f32cx8_load(x) | ||||
| #define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y) | ||||
|  | ||||
| #define GGML_F32Cx8_FMA         GGML_F32x8_FMA | ||||
| #define GGML_F32Cx8_ADD         __lasx_xvfadd_s | ||||
| #define GGML_F32Cx8_MUL         __lasx_xvfmul_s | ||||
| #define GGML_F32Cx8_REDUCE      GGML_F32x8_REDUCE | ||||
|  | ||||
| #define GGML_F16_VEC                GGML_F32Cx8 | ||||
| #define GGML_F16_VEC_ZERO           GGML_F32Cx8_ZERO | ||||
| #define GGML_F16_VEC_SET1           GGML_F32Cx8_SET1 | ||||
| #define GGML_F16_VEC_LOAD(p, i)     GGML_F32Cx8_LOAD(p) | ||||
| #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx8_STORE(p, r[i]) | ||||
| #define GGML_F16_VEC_FMA            GGML_F32Cx8_FMA | ||||
| #define GGML_F16_VEC_ADD            GGML_F32Cx8_ADD | ||||
| #define GGML_F16_VEC_MUL            GGML_F32Cx8_MUL | ||||
| #define GGML_F16_VEC_REDUCE         GGML_F32Cx8_REDUCE | ||||
|  | ||||
| #elif defined(__loongarch_sx) | ||||
|  | ||||
| #define GGML_SIMD | ||||
|  | ||||
| // F32 LSX | ||||
|  | ||||
| #define GGML_F32_STEP 32 | ||||
| #define GGML_F32_EPR  4 | ||||
|  | ||||
| #define GGML_F32x4         __m128 | ||||
| #define GGML_F32x4_ZERO    __lsx_vldi(0) | ||||
| #define GGML_F32x4_SET1(x) __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) | ||||
| #define GGML_F32x4_LOAD(x) __lsx_vld((x), 0) | ||||
| #define GGML_F32x4_STORE((x),(y))   __lsx_vst((y), (x), 0) | ||||
| #define GGML_F32x4_FMA(a, b, c) __lsx_vfmadd_s(b, c, a) | ||||
| #define GGML_F32x4_ADD     __lsx_vfadd_s | ||||
| #define GGML_F32x4_MUL     __lsx_vfmul_s | ||||
| #define GGML_F32x4_REDUCE(res, x)                                                     \ | ||||
| {                                                                                     \ | ||||
|     int offset = GGML_F32_ARR >> 1;                                                   \ | ||||
|     for (int i = 0; i < offset; ++i) {                                                \ | ||||
|         x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \ | ||||
|     }                                                                                 \ | ||||
|     offset >>= 1;                                                                     \ | ||||
|     for (int i = 0; i < offset; ++i) {                                                \ | ||||
|         x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \ | ||||
|     }                                                                                 \ | ||||
|     offset >>= 1;                                                                     \ | ||||
|     for (int i = 0; i < offset; ++i) {                                                \ | ||||
|         x[i] = __lsx_vfadd_s(x[i], x[offset + i]);                                    \ | ||||
|     }                                                                                 \ | ||||
|     __m128i tmp     = __lsx_vsrli_d((__m128i) x[0], 32);                              \ | ||||
|     tmp             = (__m128i) __lsx_vfadd_s((__m128) tmp, x[0]);                    \ | ||||
|     tmp             = __lsx_vpickev_w(__lsx_vldi(0), tmp);                            \ | ||||
|     const __m128 t0 = __lsx_vshuf4i_w(tmp, 0x88);                                     \ | ||||
|     tmp             = __lsx_vsrli_d((__m128i) t0, 32);                                \ | ||||
|     tmp             = (__m128i) __lsx_vfadd_s((__m128) tmp, t0);                      \ | ||||
|     tmp             = __lsx_vpickev_w(__lsx_vldi(0), tmp);                            \ | ||||
|     res             = (ggml_float) __lsx_vpickve2gr_w(__lsx_vshuf4i_w(tmp, 0x88), 0); \ | ||||
| } | ||||
|  | ||||
| #define GGML_F32_VEC        GGML_F32x4 | ||||
| #define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO | ||||
| #define GGML_F32_VEC_SET1   GGML_F32x4_SET1 | ||||
| #define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD | ||||
| #define GGML_F32_VEC_STORE  GGML_F32x4_STORE | ||||
| #define GGML_F32_VEC_FMA    GGML_F32x4_FMA | ||||
| #define GGML_F32_VEC_ADD    GGML_F32x4_ADD | ||||
| #define GGML_F32_VEC_MUL    GGML_F32x4_MUL | ||||
| #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE | ||||
|  | ||||
| // F16 LSX | ||||
|  | ||||
| #define GGML_F16_STEP 32 | ||||
| #define GGML_F16_EPR  4 | ||||
|  | ||||
| static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) { | ||||
|     float tmp[4]; | ||||
|  | ||||
|     tmp[0] = GGML_FP16_TO_FP32(x[0]); | ||||
|     tmp[1] = GGML_FP16_TO_FP32(x[1]); | ||||
|     tmp[2] = GGML_FP16_TO_FP32(x[2]); | ||||
|     tmp[3] = GGML_FP16_TO_FP32(x[3]); | ||||
|  | ||||
|     return __lsx_vld(tmp, 0); | ||||
| } | ||||
|  | ||||
| static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { | ||||
|     float arr[4]; | ||||
|  | ||||
|     __lsx_vst(y, arr, 0); | ||||
|  | ||||
|     x[0] = GGML_FP32_TO_FP16(arr[0]); | ||||
|     x[1] = GGML_FP32_TO_FP16(arr[1]); | ||||
|     x[2] = GGML_FP32_TO_FP16(arr[2]); | ||||
|     x[3] = GGML_FP32_TO_FP16(arr[3]); | ||||
| } | ||||
|  | ||||
| #define GGML_F32Cx4             __m128 | ||||
| #define GGML_F32Cx4_ZERO        __lsx_vldi(0) | ||||
| #define GGML_F32Cx4_SET1(x)     __lsx_vinsgr2vr_w(__lsx_vldi(0),(x), 0) | ||||
| #define GGML_F32Cx4_LOAD(x)     __lsx_f16x4_load(x) | ||||
| #define GGML_F32Cx4_STORE(x, y) __lsx_f16x4_store(x, y) | ||||
| #define GGML_F32Cx4_FMA         GGML_F32x4_FMA | ||||
| #define GGML_F32Cx4_ADD         __lsx_vfadd_s | ||||
| #define GGML_F32Cx4_MUL         __lsx_vfmul_s | ||||
| #define GGML_F32Cx4_REDUCE      GGML_F32x4_REDUCE | ||||
|  | ||||
| #define GGML_F16_VEC                 GGML_F32Cx4 | ||||
| #define GGML_F16_VEC_ZERO            GGML_F32Cx4_ZERO | ||||
| #define GGML_F16_VEC_SET1            GGML_F32Cx4_SET1 | ||||
| #define GGML_F16_VEC_LOAD(p, i)      GGML_F32Cx4_LOAD(p) | ||||
| #define GGML_F16_VEC_STORE(p, r, i)  GGML_F32Cx4_STORE(p, r[i]) | ||||
| #define GGML_F16_VEC_FMA             GGML_F32Cx4_FMA | ||||
| #define GGML_F16_VEC_ADD             GGML_F32Cx4_ADD | ||||
| #define GGML_F16_VEC_MUL             GGML_F32Cx4_MUL | ||||
| #define GGML_F16_VEC_REDUCE          GGML_F32Cx4_REDUCE | ||||
|  | ||||
| #elif defined(__VXE__) || defined(__VXE2__) | ||||
|  | ||||
| #define GGML_SIMD | ||||
|  | ||||
| // F32 s390x | ||||
|  | ||||
| #define GGML_F32_STEP 32 | ||||
| #define GGML_F32_EPR  4 | ||||
|  | ||||
| #define GGML_F32x4              __vector float | ||||
| #define GGML_F32x4_ZERO         vec_splats(0.0f) | ||||
| #define GGML_F32x4_SET1         vec_splats | ||||
| #define GGML_F32x4_LOAD(p)      vec_xl(0, p) | ||||
| #define GGML_F32x4_STORE(p, r)  vec_xst(r, 0, p) | ||||
| #define GGML_F32x4_FMA(a, b, c) vec_madd(b, c, a) | ||||
| #define GGML_F32x4_ADD          vec_add | ||||
| #define GGML_F32x4_MUL          vec_mul | ||||
| #define GGML_F32x4_REDUCE(res, x)                   \ | ||||
| {                                                   \ | ||||
|     int offset = GGML_F32_ARR >> 1;                 \ | ||||
|     for (int i = 0; i < offset; ++i) {              \ | ||||
|         x[i] = vec_add(x[i], x[offset + i]);        \ | ||||
|     }                                               \ | ||||
|     offset >>= 1;                                   \ | ||||
|     for (int i = 0; i < offset; ++i) {              \ | ||||
|         x[i] = vec_add(x[i], x[offset + i]);        \ | ||||
|     }                                               \ | ||||
|     offset >>= 1;                                   \ | ||||
|     for (int i = 0; i < offset; ++i) {              \ | ||||
|         x[i] = vec_add(x[i], x[offset + i]);        \ | ||||
|     }                                               \ | ||||
|     res = vec_extract(x[0], 0) +                    \ | ||||
|           vec_extract(x[0], 1) +                    \ | ||||
|           vec_extract(x[0], 2) +                    \ | ||||
|           vec_extract(x[0], 3);                     \ | ||||
| } | ||||
|  | ||||
| #define GGML_F32_VEC        GGML_F32x4 | ||||
| #define GGML_F32_VEC_ZERO   GGML_F32x4_ZERO | ||||
| #define GGML_F32_VEC_SET1   GGML_F32x4_SET1 | ||||
| #define GGML_F32_VEC_LOAD   GGML_F32x4_LOAD | ||||
| #define GGML_F32_VEC_STORE  GGML_F32x4_STORE | ||||
| #define GGML_F32_VEC_FMA    GGML_F32x4_FMA | ||||
| #define GGML_F32_VEC_ADD    GGML_F32x4_ADD | ||||
| #define GGML_F32_VEC_MUL    GGML_F32x4_MUL | ||||
| #define GGML_F32_VEC_REDUCE GGML_F32x4_REDUCE | ||||
|  | ||||
| // F16 s390x | ||||
| #define GGML_F16_STEP GGML_F32_STEP | ||||
| #define GGML_F16_EPR  GGML_F32_EPR | ||||
|  | ||||
| static inline __vector float __lzs_f16cx4_load(const ggml_fp16_t * x) { | ||||
|     float tmp[4]; | ||||
|  | ||||
|     for (int i = 0; i < 4; i++) { | ||||
|         tmp[i] = GGML_FP16_TO_FP32(x[i]); | ||||
|     } | ||||
|  | ||||
|     return vec_xl(0, tmp); | ||||
| } | ||||
|  | ||||
| static inline void __lzs_f16cx4_store(ggml_fp16_t * x, __vector float y) { | ||||
|     float arr[4]; | ||||
|  | ||||
|     vec_xst(y, 0, arr); | ||||
|  | ||||
|     for (int i = 0; i < 4; i++) { | ||||
|         x[i] = GGML_FP32_TO_FP16(arr[i]); | ||||
|     } | ||||
| } | ||||
|  | ||||
| #define GGML_F16_VEC                GGML_F32x4 | ||||
| #define GGML_F16_VEC_ZERO           GGML_F32x4_ZERO | ||||
| #define GGML_F16_VEC_SET1           GGML_F32x4_SET1 | ||||
| #define GGML_F16_VEC_LOAD(p, i)     __lzs_f16cx4_load(p) | ||||
| #define GGML_F16_VEC_STORE(p, r, i) __lzs_f16cx4_store(p, r[i]) | ||||
| #define GGML_F16_VEC_FMA            GGML_F32x4_FMA | ||||
| #define GGML_F16_VEC_ADD            GGML_F32x4_ADD | ||||
| #define GGML_F16_VEC_MUL            GGML_F32x4_MUL | ||||
| #define GGML_F16_VEC_REDUCE         GGML_F32x4_REDUCE | ||||
|  | ||||
| #endif | ||||
|  | ||||
| // GGML_F32_ARR / GGML_F16_ARR | ||||
| //   number of registers to use per step | ||||
| #ifdef GGML_SIMD | ||||
| #define GGML_F32_ARR (GGML_F32_STEP/GGML_F32_EPR) | ||||
| #define GGML_F16_ARR (GGML_F16_STEP/GGML_F16_EPR) | ||||
| #endif | ||||
							
								
								
									
										258
									
								
								ggml/src/ggml-cpu/vec.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										258
									
								
								ggml/src/ggml-cpu/vec.cpp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,258 @@ | ||||
| #include "vec.h" | ||||
|  | ||||
| #include <cassert> | ||||
|  | ||||
| #if defined(_MSC_VER) | ||||
| // disable "possible loss of data" to avoid hundreds of casts | ||||
| // we should just be careful :) | ||||
| #pragma warning(disable: 4244 4267) | ||||
| #endif | ||||
|  | ||||
| // precomputed gelu table for f16 (128 KB) | ||||
| ggml_fp16_t ggml_table_gelu_f16[1 << 16]; | ||||
|  | ||||
| // precomputed quick gelu table for f16 (128 KB) | ||||
| ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16]; | ||||
|  | ||||
| void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc) { | ||||
|    assert(nrc == 1); | ||||
|    GGML_UNUSED(nrc); | ||||
|    GGML_UNUSED(bx); | ||||
|    GGML_UNUSED(by); | ||||
|    GGML_UNUSED(bs); | ||||
|  | ||||
| #if defined(GGML_SIMD) | ||||
|     float sumf = 0.0f; | ||||
|     const int np = (n & ~(GGML_F32_STEP - 1)); | ||||
|  | ||||
|     GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; | ||||
|  | ||||
|     GGML_F32_VEC ax[GGML_F32_ARR]; | ||||
|     GGML_F32_VEC ay[GGML_F32_ARR]; | ||||
|  | ||||
|     for (int i = 0; i < np; i += GGML_F32_STEP) { | ||||
|         for (int j = 0; j < GGML_F32_ARR; j++) { | ||||
|             ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); | ||||
|             ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); | ||||
|  | ||||
|             sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], ay[j]); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // reduce sum0..sum3 to sum0 | ||||
|     GGML_F32_VEC_REDUCE(sumf, sum); | ||||
|  | ||||
|     // leftovers | ||||
|     for (int i = np; i < n; ++i) { | ||||
|         sumf += x[i]*y[i]; | ||||
|     } | ||||
| #else | ||||
|     // scalar | ||||
|     ggml_float sumf = 0.0; | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         sumf += (ggml_float)(x[i]*y[i]); | ||||
|     } | ||||
| #endif | ||||
|  | ||||
|     *s = sumf; | ||||
| } | ||||
|  | ||||
| void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc) { | ||||
|     assert(nrc == 1); | ||||
|     GGML_UNUSED(nrc); | ||||
|     GGML_UNUSED(bx); | ||||
|     GGML_UNUSED(by); | ||||
|     GGML_UNUSED(bs); | ||||
|     int i = 0; | ||||
|     ggml_float sumf = 0; | ||||
|  | ||||
| #if defined(__AVX512BF16__) | ||||
|     __m512 c1 = _mm512_setzero_ps(); | ||||
|     __m512 c2 = _mm512_setzero_ps(); | ||||
|     for (; i + 64 <= n; i += 64) { | ||||
|         c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512((x + i))), | ||||
|                              m512bh(_mm512_loadu_si512((y + i)))); | ||||
|         c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512((x + i + 32))), | ||||
|                              m512bh(_mm512_loadu_si512((y + i + 32)))); | ||||
|     } | ||||
|     sumf += (ggml_float)_mm512_reduce_add_ps(c1); | ||||
|     sumf += (ggml_float)_mm512_reduce_add_ps(c2); | ||||
|  | ||||
| #elif defined(__AVX512F__) | ||||
| #define LOAD(p) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(_mm256_loadu_si256((const __m256i *)(p))), 16)) | ||||
|     __m512 c1 = _mm512_setzero_ps(); | ||||
|     __m512 c2 = _mm512_setzero_ps(); | ||||
|     for (; i + 32 <= n; i += 32) { | ||||
|         c1 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i), LOAD(y + i)), c1); | ||||
|         c2 = _mm512_add_ps(_mm512_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c2); | ||||
|     } | ||||
|     sumf += (ggml_float)_mm512_reduce_add_ps(c1); | ||||
|     sumf += (ggml_float)_mm512_reduce_add_ps(c2); | ||||
|  | ||||
| #undef LOAD | ||||
| #elif defined(__AVX2__) || defined(__AVX__) | ||||
| #if defined(__AVX2__) | ||||
| #define LOAD(p) _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)) | ||||
| #else | ||||
| #define LOAD(p) _mm256_castsi256_ps(_mm256_insertf128_si256(_mm256_castsi128_si256(_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_loadu_si128((const __m128i *)(p))), 16)), (_mm_slli_epi32(_mm_cvtepu16_epi32(_mm_bsrli_si128(_mm_loadu_si128((const __m128i *)(p)), 8)), 16)), 1)) | ||||
| #endif | ||||
|     __m256 c1 = _mm256_setzero_ps(); | ||||
|     __m256 c2 = _mm256_setzero_ps(); | ||||
|     __m256 c3 = _mm256_setzero_ps(); | ||||
|     __m256 c4 = _mm256_setzero_ps(); | ||||
|     for (; i + 32 <= n; i += 32) { | ||||
|         c1 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i), LOAD(y + i)), c1); | ||||
|         c2 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 8), LOAD(y + i + 8)), c2); | ||||
|         c3 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 16), LOAD(y + i + 16)), c3); | ||||
|         c4 = _mm256_add_ps(_mm256_mul_ps(LOAD(x + i + 24), LOAD(y + i + 24)), c4); | ||||
|     } | ||||
|     __m128 g; | ||||
|     c1 = _mm256_add_ps(_mm256_add_ps(c1, c3), | ||||
|                        _mm256_add_ps(c2, c4)); | ||||
|     g = _mm_add_ps(_mm256_extractf128_ps(c1, 1), | ||||
|                    _mm256_castps256_ps128(c1)); | ||||
|     g = _mm_add_ps(g, _mm_movehl_ps(g, g)); | ||||
|     g = _mm_add_ss(g, _mm_movehdup_ps(g)); | ||||
|     sumf += (ggml_float)_mm_cvtss_f32(g); | ||||
|  | ||||
| #undef LOAD | ||||
| #endif | ||||
|  | ||||
|     for (; i < n; ++i) { | ||||
|         sumf += (ggml_float)(GGML_BF16_TO_FP32(x[i]) * | ||||
|                              GGML_BF16_TO_FP32(y[i])); | ||||
|     } | ||||
|     *s = sumf; | ||||
| } | ||||
|  | ||||
| void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc) { | ||||
|     assert(nrc == 1); | ||||
|     GGML_UNUSED(nrc); | ||||
|     GGML_UNUSED(bx); | ||||
|     GGML_UNUSED(by); | ||||
|     GGML_UNUSED(bs); | ||||
|  | ||||
|     ggml_float sumf = 0.0; | ||||
|  | ||||
| #if defined(GGML_SIMD) | ||||
|     const int np = (n & ~(GGML_F16_STEP - 1)); | ||||
|  | ||||
|     GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; | ||||
|  | ||||
|     GGML_F16_VEC ax[GGML_F16_ARR]; | ||||
|     GGML_F16_VEC ay[GGML_F16_ARR]; | ||||
|  | ||||
|     for (int i = 0; i < np; i += GGML_F16_STEP) { | ||||
|         for (int j = 0; j < GGML_F16_ARR; j++) { | ||||
|             ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); | ||||
|             ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); | ||||
|  | ||||
|             sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // reduce sum0..sum3 to sum0 | ||||
|     GGML_F16_VEC_REDUCE(sumf, sum); | ||||
|  | ||||
|     // leftovers | ||||
|     for (int i = np; i < n; ++i) { | ||||
|         sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); | ||||
|     } | ||||
| #else | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         sumf += (ggml_float)(GGML_FP16_TO_FP32(x[i])*GGML_FP16_TO_FP32(y[i])); | ||||
|     } | ||||
| #endif | ||||
|  | ||||
|     *s = sumf; | ||||
| } | ||||
|  | ||||
| void ggml_vec_silu_f32(const int n, float * y, const float * x) { | ||||
|     int i = 0; | ||||
| #if defined(__AVX512F__) && defined(__AVX512DQ__) | ||||
|     for (; i + 15 < n; i += 16) { | ||||
|         _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i))); | ||||
|     } | ||||
| #elif defined(__AVX2__) && defined(__FMA__) | ||||
|     for (; i + 7 < n; i += 8) { | ||||
|         _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i))); | ||||
|     } | ||||
| #elif defined(__SSE2__) | ||||
|     for (; i + 3 < n; i += 4) { | ||||
|         _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i))); | ||||
|     } | ||||
| #elif defined(__ARM_NEON) && defined(__aarch64__) | ||||
|     for (; i + 3 < n; i += 4) { | ||||
|         vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i))); | ||||
|     } | ||||
| #endif | ||||
|     for (; i < n; ++i) { | ||||
|         y[i] = ggml_silu_f32(x[i]); | ||||
|     } | ||||
| } | ||||
|  | ||||
| ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { | ||||
|     int i = 0; | ||||
|     ggml_float sum = 0; | ||||
| #if defined(__AVX512F__) && defined(__AVX512DQ__) | ||||
|     for (; i + 15 < n; i += 16) { | ||||
|         __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i), | ||||
|                                                _mm512_set1_ps(max))); | ||||
|         _mm512_storeu_ps(y + i, val); | ||||
|         sum += (ggml_float)_mm512_reduce_add_ps(val); | ||||
|     } | ||||
| #elif defined(__AVX2__) && defined(__FMA__) | ||||
|     for (; i + 7 < n; i += 8) { | ||||
|         __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i), | ||||
|                                                _mm256_set1_ps(max))); | ||||
|         _mm256_storeu_ps(y + i, val); | ||||
|         __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), | ||||
|                                  _mm256_castps256_ps128(val)); | ||||
|         val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); | ||||
|         val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); | ||||
|         sum += (ggml_float)_mm_cvtss_f32(val2); | ||||
|     } | ||||
| #elif defined(__SSE2__) | ||||
|     for (; i + 3 < n; i += 4) { | ||||
|         __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i), | ||||
|                                             _mm_set1_ps(max))); | ||||
|         _mm_storeu_ps(y + i, val); | ||||
| #if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) | ||||
|         val = _mm_add_ps(val, _mm_movehl_ps(val, val)); | ||||
|         val = _mm_add_ss(val, _mm_movehdup_ps(val)); | ||||
| #else | ||||
|         __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); | ||||
|         val = _mm_add_ps(val, tmp); | ||||
|         tmp = _mm_movehl_ps(tmp, val); | ||||
|         val = _mm_add_ss(val, tmp); | ||||
| #endif | ||||
|         sum += (ggml_float)_mm_cvtss_f32(val); | ||||
|     } | ||||
| #elif defined(__ARM_NEON) && defined(__aarch64__) | ||||
|     for (; i + 3 < n; i += 4) { | ||||
|         float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i), | ||||
|                                                 vdupq_n_f32(max))); | ||||
|         vst1q_f32(y + i, val); | ||||
|         sum += (ggml_float)vaddvq_f32(val); | ||||
|     } | ||||
| #endif | ||||
|     for (; i < n; ++i) { | ||||
|         float val = expf(x[i] - max); | ||||
|         sum += (ggml_float)val; | ||||
|         y[i] = val; | ||||
|     } | ||||
|     return sum; | ||||
| } | ||||
|  | ||||
| ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max) { | ||||
|     // log(soft_max) = log(soft_max_i / soft_max_sum) = log(soft_max_i) - log(soft_max_sum) = (logit_i - max) - log(soft_max_i) | ||||
|  | ||||
|     int i = 0; | ||||
|     ggml_float sum = 0; | ||||
|     for (; i < n; ++i) { | ||||
|         float val = x[i] - max; | ||||
|         y[i] = val; | ||||
|         sum += (ggml_float)expf(val); | ||||
|     } | ||||
|     return sum = (ggml_float)logf(sum); | ||||
| } | ||||
							
								
								
									
										802
									
								
								ggml/src/ggml-cpu/vec.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										802
									
								
								ggml/src/ggml-cpu/vec.h
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,802 @@ | ||||
| // Vectorized functions for fundamental operations | ||||
|  | ||||
| #pragma once | ||||
|  | ||||
| #include "ggml-impl.h" | ||||
| #include "simd-mappings.h" | ||||
| #include "ggml.h" | ||||
|  | ||||
| #if defined(GGML_USE_ACCELERATE) | ||||
| #include <Accelerate/Accelerate.h> | ||||
| #endif | ||||
|  | ||||
| // floating point type used to accumulate sums | ||||
| typedef double ggml_float; | ||||
|  | ||||
| #define GGML_GELU_FP16 | ||||
| #define GGML_GELU_QUICK_FP16 | ||||
|  | ||||
| #define GGML_SOFT_MAX_UNROLL 4 | ||||
| #define GGML_VEC_DOT_UNROLL  2 | ||||
| #define GGML_VEC_MAD_UNROLL  32 | ||||
|  | ||||
| #ifdef __cplusplus | ||||
| extern "C" { | ||||
| #endif | ||||
|  | ||||
| // | ||||
| // global data | ||||
| // | ||||
|  | ||||
| // precomputed gelu table for f16 (128 KB) | ||||
| extern ggml_fp16_t ggml_table_gelu_f16[1 << 16]; | ||||
|  | ||||
| // precomputed quick gelu table for f16 (128 KB) | ||||
| extern ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16]; | ||||
|  | ||||
| // | ||||
| // fundamental operations | ||||
| // | ||||
|  | ||||
| void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * GGML_RESTRICT x, size_t bx, const float * GGML_RESTRICT y, size_t by, int nrc); | ||||
| void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * GGML_RESTRICT x, size_t bx, ggml_bf16_t * GGML_RESTRICT y, size_t by, int nrc); | ||||
| void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc); | ||||
|  | ||||
| void ggml_vec_silu_f32(const int n, float * y, const float * x); | ||||
| ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max); | ||||
| ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max); | ||||
|  | ||||
| inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; } | ||||
| inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } | ||||
|  | ||||
| inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t   v) { for (int i = 0; i < n; ++i) x[i] = v;    } | ||||
| inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; } | ||||
|  | ||||
| inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const ggml_fp16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } | ||||
| inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; } | ||||
| inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; } | ||||
| inline static void ggml_vec_add_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) + GGML_FP16_TO_FP32(y[i])); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float   v) { for (int i = 0; i < n; ++i) z[i]  = x[i] + v;    } | ||||
| inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i] += x[i];        } | ||||
| inline static void ggml_vec_acc1_f32(const int n, float * y, const float   v)                  { for (int i = 0; i < n; ++i) y[i] += v;           } | ||||
| inline static void ggml_vec_sub_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] - y[i]; } | ||||
| inline static void ggml_vec_sub_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) - GGML_FP16_TO_FP32(y[i])); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_set_f32 (const int n, float * x, const float   v)                  { for (int i = 0; i < n; ++i) x[i]  = v;           } | ||||
| inline static void ggml_vec_cpy_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = x[i];        } | ||||
| inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)                  { for (int i = 0; i < n; ++i) y[i]  = -x[i];       } | ||||
| inline static void ggml_vec_neg_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(-GGML_FP16_TO_FP32(x[i])); | ||||
|     } | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]*y[i];   } | ||||
| inline static void ggml_vec_mul_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) * GGML_FP16_TO_FP32(y[i])); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i]/y[i];   } | ||||
| inline static void ggml_vec_div_f16 (const int n, ggml_fp16_t * z, const ggml_fp16_t * x, const ggml_fp16_t * y) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         z[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(x[i]) / GGML_FP16_TO_FP32(y[i])); | ||||
|     } | ||||
| } | ||||
|  | ||||
| // compute GGML_VEC_DOT_UNROLL dot products at once | ||||
| // xs - x row stride in bytes | ||||
| inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GGML_RESTRICT s, void * GGML_RESTRICT xv, ggml_fp16_t * GGML_RESTRICT y) { | ||||
|     ggml_float sumf[GGML_VEC_DOT_UNROLL] = { 0.0 }; | ||||
|  | ||||
|     ggml_fp16_t * GGML_RESTRICT x[GGML_VEC_DOT_UNROLL]; | ||||
|  | ||||
|     for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { | ||||
|         x[i] = (ggml_fp16_t *) ((char *) xv + i*xs); | ||||
|     } | ||||
|  | ||||
| #if defined(GGML_SIMD) | ||||
|     const int np = (n & ~(GGML_F16_STEP - 1)); | ||||
|  | ||||
|     GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; | ||||
|  | ||||
|     GGML_F16_VEC ax[GGML_F16_ARR]; | ||||
|     GGML_F16_VEC ay[GGML_F16_ARR]; | ||||
|  | ||||
|     for (int i = 0; i < np; i += GGML_F16_STEP) { | ||||
|         for (int j = 0; j < GGML_F16_ARR; j++) { | ||||
|             ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); | ||||
|  | ||||
|             for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { | ||||
|                 ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); | ||||
|  | ||||
|                 sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // reduce sum0..sum3 to sum0 | ||||
|     for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { | ||||
|         GGML_F16_VEC_REDUCE(sumf[k], sum[k]); | ||||
|     } | ||||
|  | ||||
|     // leftovers | ||||
|     for (int i = np; i < n; ++i) { | ||||
|         for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { | ||||
|             sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); | ||||
|         } | ||||
|     } | ||||
| #else | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { | ||||
|             sumf[j] += (ggml_float)(GGML_FP16_TO_FP32(x[j][i])*GGML_FP16_TO_FP32(y[i])); | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
|  | ||||
|     for (int i = 0; i < GGML_VEC_DOT_UNROLL; ++i) { | ||||
|         s[i] = (float)sumf[i]; | ||||
|     } | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const float * GGML_RESTRICT x, const float v) { | ||||
| #if defined(GGML_SIMD) | ||||
|     const int np = (n & ~(GGML_F32_STEP - 1)); | ||||
|  | ||||
|     GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); | ||||
|  | ||||
|     GGML_F32_VEC ax[GGML_F32_ARR]; | ||||
|     GGML_F32_VEC ay[GGML_F32_ARR]; | ||||
|  | ||||
|     for (int i = 0; i < np; i += GGML_F32_STEP) { | ||||
|         for (int j = 0; j < GGML_F32_ARR; j++) { | ||||
|             ax[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); | ||||
|             ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); | ||||
|             ay[j] = GGML_F32_VEC_FMA(ay[j], ax[j], vx); | ||||
|  | ||||
|             GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // leftovers | ||||
|     for (int i = np; i < n; ++i) { | ||||
|         y[i] += x[i]*v; | ||||
|     } | ||||
| #else | ||||
|     // scalar | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] += x[i]*v; | ||||
|     } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) { | ||||
| #if defined(GGML_SIMD) | ||||
|     const int np = (n & ~(GGML_F16_STEP - 1)); | ||||
|  | ||||
|     GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); | ||||
|  | ||||
|     GGML_F16_VEC ax[GGML_F16_ARR]; | ||||
|     GGML_F16_VEC ay[GGML_F16_ARR]; | ||||
|  | ||||
|     for (int i = 0; i < np; i += GGML_F16_STEP) { | ||||
|         for (int j = 0; j < GGML_F16_ARR; j++) { | ||||
|             ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); | ||||
|             ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); | ||||
|             ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); | ||||
|  | ||||
|             GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // leftovers | ||||
|     for (int i = np; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); | ||||
|     } | ||||
| #else | ||||
|     // scalar | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); | ||||
|     } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| // xs and vs are byte strides of x and v | ||||
| inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * GGML_RESTRICT y, const float * GGML_RESTRICT xv, const float * GGML_RESTRICT vv) { | ||||
|  | ||||
|     const float * GGML_RESTRICT x[GGML_VEC_MAD_UNROLL]; | ||||
|     const float * GGML_RESTRICT v[GGML_VEC_MAD_UNROLL]; | ||||
|  | ||||
|     for (int i = 0; i < GGML_VEC_MAD_UNROLL; ++i) { | ||||
|         x[i] = (const float *) ((const char *) xv + i*xs); | ||||
|         v[i] = (const float *) ((const char *) vv + i*vs); | ||||
|     } | ||||
|  | ||||
| #if defined(GGML_SIMD) | ||||
|     const int np = (n & ~(GGML_F32_STEP - 1)); | ||||
|  | ||||
|     GGML_F32_VEC vx[GGML_VEC_MAD_UNROLL]; | ||||
|  | ||||
|     for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { | ||||
|         vx[k] = GGML_F32_VEC_SET1(v[k][0]); | ||||
|     } | ||||
|  | ||||
|     GGML_F32_VEC ax[GGML_VEC_MAD_UNROLL][GGML_F32_ARR]; | ||||
|     GGML_F32_VEC ay[GGML_F32_ARR]; | ||||
|  | ||||
|     for (int i = 0; i < np; i += GGML_F32_STEP) { | ||||
|         for (int j = 0; j < GGML_F32_ARR; j++) { | ||||
|             ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); | ||||
|  | ||||
|             for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { | ||||
|                 ax[k][j] = GGML_F32_VEC_LOAD(x[k] + i + j*GGML_F32_EPR); | ||||
|                 ay[j] = GGML_F32_VEC_FMA(ay[j], ax[k][j], vx[k]); | ||||
|             } | ||||
|  | ||||
|             GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // leftovers | ||||
|     for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { | ||||
|         for (int i = np; i < n; ++i) { | ||||
|             y[i] += x[k][i]*v[k][0]; | ||||
|         } | ||||
|     } | ||||
| #else | ||||
|     // scalar | ||||
|     for (int k = 0; k < GGML_VEC_MAD_UNROLL; ++k) { | ||||
|         for (int i = 0; i < n; ++i) { | ||||
|             y[i] += x[k][i]*v[k][0]; | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| //inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { for (int i = 0; i < n; ++i) y[i] *= v;          } | ||||
| inline static void ggml_vec_scale_f32(const int n, float * y, const float   v) { | ||||
| #if defined(GGML_USE_ACCELERATE) | ||||
|     vDSP_vsmul(y, 1, &v, y, 1, n); | ||||
| #elif defined(GGML_SIMD) | ||||
|     const int np = (n & ~(GGML_F32_STEP - 1)); | ||||
|  | ||||
|     GGML_F32_VEC vx = GGML_F32_VEC_SET1(v); | ||||
|  | ||||
|     GGML_F32_VEC ay[GGML_F32_ARR]; | ||||
|  | ||||
|     for (int i = 0; i < np; i += GGML_F32_STEP) { | ||||
|         for (int j = 0; j < GGML_F32_ARR; j++) { | ||||
|             ay[j] = GGML_F32_VEC_LOAD(y + i + j*GGML_F32_EPR); | ||||
|             ay[j] = GGML_F32_VEC_MUL(ay[j], vx); | ||||
|  | ||||
|             GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // leftovers | ||||
|     for (int i = np; i < n; ++i) { | ||||
|         y[i] *= v; | ||||
|     } | ||||
| #else | ||||
|     // scalar | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] *= v; | ||||
|     } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { | ||||
| #if defined(GGML_SIMD) | ||||
|     const int np = (n & ~(GGML_F16_STEP - 1)); | ||||
|  | ||||
|     GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); | ||||
|  | ||||
|     GGML_F16_VEC ay[GGML_F16_ARR]; | ||||
|  | ||||
|     for (int i = 0; i < np; i += GGML_F16_STEP) { | ||||
|         for (int j = 0; j < GGML_F16_ARR; j++) { | ||||
|             ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); | ||||
|             ay[j] = GGML_F16_VEC_MUL(ay[j], vx); | ||||
|  | ||||
|             GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     // leftovers | ||||
|     for (int i = np; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); | ||||
|     } | ||||
| #else | ||||
|     // scalar | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); | ||||
|     } | ||||
| #endif | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s);   } | ||||
| inline static void ggml_vec_sqr_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i];   } | ||||
| inline static void ggml_vec_sqr_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         float v = GGML_FP16_TO_FP32(x[i]); | ||||
|         y[i] = GGML_FP32_TO_FP16(v*v); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } | ||||
| inline static void ggml_vec_sqrt_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(sqrtf(GGML_FP16_TO_FP32(x[i]))); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_log_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]);  } | ||||
| inline static void ggml_vec_log_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(logf(GGML_FP16_TO_FP32(x[i]))); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_sin_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sinf(x[i]);  } | ||||
| inline static void ggml_vec_sin_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(sinf(GGML_FP16_TO_FP32(x[i]))); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_cos_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = cosf(x[i]);  } | ||||
| inline static void ggml_vec_cos_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(cosf(GGML_FP16_TO_FP32(x[i]))); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_abs_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fabsf(x[i]); } | ||||
| inline static void ggml_vec_abs_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(fabsf(GGML_FP16_TO_FP32(x[i]))); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_sgn_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : ((x[i] < 0.f) ? -1.f : 0.f); } | ||||
| inline static void ggml_vec_sgn_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         float v = GGML_FP16_TO_FP32(x[i]); | ||||
|         y[i] = GGML_FP32_TO_FP16((v > 0.f) ? 1.f : ((v < 0.f) ? -1.f : 0.f)); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_step_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? 1.f : 0.f; } | ||||
| inline static void ggml_vec_step_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16((GGML_FP16_TO_FP32(x[i]) > 0.f) ? 1.f : 0.f); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]);  } | ||||
| inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(tanhf(GGML_FP16_TO_FP32(x[i]))); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_elu_f32  (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); } | ||||
| inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(expm1f(GGML_FP16_TO_FP32(x[i]))); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } | ||||
| inline static void ggml_vec_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         float v = GGML_FP16_TO_FP32(x[i]); | ||||
|         y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v : 0.f); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); } | ||||
| inline static void ggml_vec_leaky_relu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const float ns) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         float v = GGML_FP16_TO_FP32(x[i]); | ||||
|         y[i] = GGML_FP32_TO_FP16(((v > 0.f) ? v : 0.f) + ns * ((v < 0.0f) ? v : 0.f)); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_sigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = 1.f / (1.f + expf(-x[i])); } | ||||
| inline static void ggml_vec_sigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(1.f / (1.f + expf(-GGML_FP16_TO_FP32(x[i])))); | ||||
|     } | ||||
| } | ||||
| // TODO: optimize performance | ||||
| inline static void ggml_vec_hardswish_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } | ||||
| inline static void ggml_vec_hardswish_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         float v = GGML_FP16_TO_FP32(x[i]); | ||||
|         y[i] = GGML_FP32_TO_FP16(v * fminf(1.0f, fmaxf(0.0f, (v + 3.0f) / 6.0f))); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_hardsigmoid_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f)); } | ||||
| inline static void ggml_vec_hardsigmoid_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(fminf(1.0f, fmaxf(0.0f, (GGML_FP16_TO_FP32(x[i]) + 3.0f) / 6.0f))); | ||||
|     } | ||||
| } | ||||
| inline static void ggml_vec_exp_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = expf(x[i]); } | ||||
| inline static void ggml_vec_exp_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = GGML_FP32_TO_FP16(expf(GGML_FP16_TO_FP32(x[i]))); | ||||
|     } | ||||
| } | ||||
|  | ||||
| static const float GELU_COEF_A     = 0.044715f; | ||||
| static const float GELU_QUICK_COEF = -1.702f; | ||||
| static const float SQRT_2_OVER_PI  = 0.79788456080286535587989211986876f; | ||||
|  | ||||
| inline static float ggml_gelu_f32(float x) { | ||||
|     return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_gelu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     const uint16_t * i16 = (const uint16_t *) x; | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = ggml_table_gelu_f16[i16[i]]; | ||||
|     } | ||||
| } | ||||
|  | ||||
| #ifdef GGML_GELU_FP16 | ||||
| inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { | ||||
|     uint16_t t; | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         if (x[i] <= -10.0f) { | ||||
|             y[i] = 0.0f; | ||||
|         } else if (x[i] >= 10.0f) { | ||||
|             y[i] = x[i]; | ||||
|         } else { | ||||
|             ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); | ||||
|             memcpy(&t, &fp16, sizeof(uint16_t)); | ||||
|             y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| #else | ||||
| inline static void ggml_vec_gelu_f32(const int n, float * y, const float * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = ggml_gelu_f32(x[i]); | ||||
|     } | ||||
| } | ||||
| #endif | ||||
|  | ||||
| inline static float ggml_gelu_quick_f32(float x) { | ||||
|     return x*(1.0f/(1.0f+expf(GELU_QUICK_COEF*x))); | ||||
| } | ||||
|  | ||||
| //inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
| //    const uint16_t * i16 = (const uint16_t *) x; | ||||
| //    for (int i = 0; i < n; ++i) { | ||||
| //        y[i] = ggml_table_gelu_quick_f16[i16[i]]; | ||||
| //    } | ||||
| //} | ||||
|  | ||||
| #ifdef GGML_GELU_QUICK_FP16 | ||||
| inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { | ||||
|     uint16_t t; | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]); | ||||
|         memcpy(&t, &fp16, sizeof(uint16_t)); | ||||
|         y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_quick_f16[t]); | ||||
|     } | ||||
| } | ||||
| #else | ||||
| inline static void ggml_vec_gelu_quick_f32(const int n, float * y, const float * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = ggml_gelu_quick_f32(x[i]); | ||||
|     } | ||||
| } | ||||
| #endif | ||||
|  | ||||
| inline static void ggml_vec_gelu_quick_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         float v = GGML_FP16_TO_FP32(x[i]); | ||||
|         y[i] = GGML_FP32_TO_FP16(v*(1.0f/(1.0f+expf(GELU_QUICK_COEF*v)))); | ||||
|     } | ||||
| } | ||||
|  | ||||
| // Sigmoid Linear Unit (SiLU) function | ||||
| inline static float ggml_silu_f32(float x) { | ||||
|     return x/(1.0f + expf(-x)); | ||||
| } | ||||
| inline static ggml_fp16_t ggml_silu_f16(ggml_fp16_t x) { | ||||
|     float v = GGML_FP16_TO_FP32(x); | ||||
|     return GGML_FP32_TO_FP16(v/(1.0f + expf(-v))); | ||||
| } | ||||
|  | ||||
| #if __FINITE_MATH_ONLY__ | ||||
| #error "some routines in ggml.c require non-finite math arithmetics -- pass -fno-finite-math-only to the compiler to fix" | ||||
| #error "ref: https://github.com/ggml-org/llama.cpp/pull/7154#issuecomment-2143844461" | ||||
| #endif | ||||
|  | ||||
| #if defined(__ARM_NEON) && defined(__aarch64__) | ||||
|  | ||||
| // adapted from arm limited optimized routine | ||||
| // the maximum error is 1.45358 plus 0.5 ulps | ||||
| // numbers above 88.38 will flush to infinity | ||||
| // numbers beneath -103.97 will flush to zero | ||||
| inline static float32x4_t ggml_v_expf(float32x4_t x) { | ||||
|     const float32x4_t r = vdupq_n_f32(0x1.8p23f); | ||||
|     const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f)); | ||||
|     const float32x4_t n = vsubq_f32(z, r); | ||||
|     const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n, | ||||
|                                     vdupq_n_f32(0x1.7f7d1cp-20f)); | ||||
|     const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23); | ||||
|     const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1)))); | ||||
|     const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126)); | ||||
|     const float32x4_t u = vmulq_f32(b, b); | ||||
|     const float32x4_t j = vfmaq_f32( | ||||
|         vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b), | ||||
|         vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b), | ||||
|                   vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u); | ||||
|     if (!vpaddd_u64(vreinterpretq_u64_u32(c))) | ||||
|         return vfmaq_f32(k, j, k); | ||||
|     const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000)); | ||||
|     const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000))); | ||||
|     const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d)); | ||||
|     return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1), | ||||
|                      vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j))); | ||||
| } | ||||
|  | ||||
| // computes silu x/(1+exp(-x)) in single precision vector | ||||
| inline static float32x4_t ggml_v_silu(float32x4_t x) { | ||||
|     const float32x4_t one = vdupq_n_f32(1.0f); | ||||
|     const float32x4_t zero = vdupq_n_f32(0.0f); | ||||
|     const float32x4_t neg_x = vsubq_f32(zero, x); | ||||
|     const float32x4_t exp_neg_x = ggml_v_expf(neg_x); | ||||
|     const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x); | ||||
|     return vdivq_f32(x, one_plus_exp_neg_x); | ||||
| } | ||||
|  | ||||
| #elif defined(__AVX512F__) && defined(__AVX512DQ__) | ||||
|  | ||||
| // adapted from arm limited optimized routine | ||||
| // the maximum error is 1.45358 plus 0.5 ulps | ||||
| // numbers above 88.38 will flush to infinity | ||||
| // numbers beneath -103.97 will flush to zero | ||||
| inline static __m512 ggml_v_expf(__m512 x) { | ||||
|   const __m512 r = _mm512_set1_ps(0x1.8p23f); | ||||
|   const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r); | ||||
|   const __m512 n = _mm512_sub_ps(z, r); | ||||
|   const __m512 b = | ||||
|       _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f), | ||||
|                        _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x)); | ||||
|   const __mmask16 d = | ||||
|       _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ); | ||||
|   const __m512 u = _mm512_mul_ps(b, b); | ||||
|   const __m512 j = _mm512_fmadd_ps( | ||||
|       _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b, | ||||
|                                       _mm512_set1_ps(0x1.573e2ep-5f)), | ||||
|                       u, | ||||
|                       _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b, | ||||
|                                       _mm512_set1_ps(0x1.fffdb6p-2f))), | ||||
|       u, | ||||
|       _mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F))); | ||||
|   const __m512 res = _mm512_scalef_ps(j, n); | ||||
|   if (_mm512_kortestz(d, d)) | ||||
|     return res; | ||||
|   const __m512 zero = _mm512_setzero_ps(); | ||||
|   const __m512 alt = _mm512_mask_blend_ps( | ||||
|       _mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero); | ||||
|   return _mm512_mask_blend_ps(d, res, alt); | ||||
| } | ||||
|  | ||||
| // computes silu x/(1+exp(-x)) in single precision vector | ||||
| inline static __m512 ggml_v_silu(__m512 x) { | ||||
|     const __m512 one = _mm512_set1_ps(1); | ||||
|     const __m512 zero = _mm512_setzero_ps(); | ||||
|     const __m512 neg_x = _mm512_sub_ps(zero, x); | ||||
|     const __m512 exp_neg_x = ggml_v_expf(neg_x); | ||||
|     const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x); | ||||
|     return _mm512_div_ps(x, one_plus_exp_neg_x); | ||||
| } | ||||
|  | ||||
| #elif defined(__AVX2__) && defined(__FMA__) | ||||
|  | ||||
| // adapted from arm limited optimized routine | ||||
| // the maximum error is 1.45358 plus 0.5 ulps | ||||
| // numbers above 88.38 will flush to infinity | ||||
| // numbers beneath -103.97 will flush to zero | ||||
| inline static __m256 ggml_v_expf(__m256 x) { | ||||
|   const __m256 r = _mm256_set1_ps(0x1.8p23f); | ||||
|   const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r); | ||||
|   const __m256 n = _mm256_sub_ps(z, r); | ||||
|   const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f), | ||||
|                                     _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x)); | ||||
|   const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23); | ||||
|   const __m256 k = _mm256_castsi256_ps( | ||||
|       _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1)))); | ||||
|   const __m256i c = _mm256_castps_si256( | ||||
|       _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), | ||||
|                     _mm256_set1_ps(126), _CMP_GT_OQ)); | ||||
|   const __m256 u = _mm256_mul_ps(b, b); | ||||
|   const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b, | ||||
|                                                                    _mm256_set1_ps(0x1.573e2ep-5f)), u, | ||||
|                                                    _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b, | ||||
|                                                                    _mm256_set1_ps(0x1.fffdb6p-2f))), | ||||
|                                    u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b)); | ||||
|   if (!_mm256_movemask_ps(_mm256_castsi256_ps(c))) | ||||
|     return _mm256_fmadd_ps(j, k, k); | ||||
|   const __m256i g = _mm256_and_si256( | ||||
|       _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)), | ||||
|       _mm256_set1_epi32(0x82000000u)); | ||||
|   const __m256 s1 = | ||||
|       _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u))); | ||||
|   const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g)); | ||||
|   const __m256i d = _mm256_castps_si256( | ||||
|       _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n), | ||||
|                     _mm256_set1_ps(192), _CMP_GT_OQ)); | ||||
|   return _mm256_or_ps( | ||||
|       _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)), | ||||
|       _mm256_andnot_ps( | ||||
|           _mm256_castsi256_ps(d), | ||||
|           _mm256_or_ps( | ||||
|               _mm256_and_ps(_mm256_castsi256_ps(c), | ||||
|                             _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)), | ||||
|               _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k))))); | ||||
| } | ||||
|  | ||||
| // computes silu x/(1+exp(-x)) in single precision vector | ||||
| inline static __m256 ggml_v_silu(__m256 x) { | ||||
|     const __m256 one = _mm256_set1_ps(1); | ||||
|     const __m256 zero = _mm256_setzero_ps(); | ||||
|     const __m256 neg_x = _mm256_sub_ps(zero, x); | ||||
|     const __m256 exp_neg_x = ggml_v_expf(neg_x); | ||||
|     const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x); | ||||
|     return _mm256_div_ps(x, one_plus_exp_neg_x); | ||||
| } | ||||
|  | ||||
| #elif defined(__SSE2__) // __AVX2__ / __ARM_NEON | ||||
|  | ||||
| #if defined(__FMA__) | ||||
| #define MADD128(x, y, z) _mm_fmadd_ps(x, y, z) | ||||
| #define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z) | ||||
| #else | ||||
| #define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z) | ||||
| #define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y)) | ||||
| #endif | ||||
|  | ||||
| // adapted from arm limited optimized routine | ||||
| // the maximum error is 1.45358 plus 0.5 ulps | ||||
| // numbers above 88.38 will flush to infinity | ||||
| // numbers beneath -103.97 will flush to zero | ||||
| inline static __m128 ggml_v_expf(__m128 x) { | ||||
|     const __m128 r = _mm_set1_ps(0x1.8p23f); | ||||
|     const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r); | ||||
|     const __m128 n = _mm_sub_ps(z, r); | ||||
|     const __m128 b = | ||||
|         NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x)); | ||||
|     const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23); | ||||
|     const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1)))); | ||||
|     const __m128i c = | ||||
|         _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126))); | ||||
|     const __m128 u = _mm_mul_ps(b, b); | ||||
|     const __m128 j = | ||||
|         MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u, | ||||
|                         MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))), | ||||
|                 u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b)); | ||||
|     if (!_mm_movemask_epi8(c)) | ||||
|         return MADD128(j, k, k); | ||||
|     const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())), | ||||
|                                     _mm_set1_epi32(0x82000000u)); | ||||
|     const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u))); | ||||
|     const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g)); | ||||
|     const __m128i d = | ||||
|         _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192))); | ||||
|     return _mm_or_ps( | ||||
|         _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)), | ||||
|         _mm_andnot_ps(_mm_castsi128_ps(d), | ||||
|                       _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)), | ||||
|                                 _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k))))); | ||||
| } | ||||
|  | ||||
| // computes silu x/(1+exp(-x)) in single precision vector | ||||
| inline static __m128 ggml_v_silu(__m128 x) { | ||||
|     const __m128 one = _mm_set1_ps(1); | ||||
|     const __m128 zero = _mm_setzero_ps(); | ||||
|     const __m128 neg_x = _mm_sub_ps(zero, x); | ||||
|     const __m128 exp_neg_x = ggml_v_expf(neg_x); | ||||
|     const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x); | ||||
|     return _mm_div_ps(x, one_plus_exp_neg_x); | ||||
| } | ||||
|  | ||||
| #endif // __ARM_NEON / __AVX2__ / __SSE2__ | ||||
|  | ||||
| inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         y[i] = ggml_silu_f16(x[i]); | ||||
|     } | ||||
| } | ||||
|  | ||||
| inline static float ggml_silu_backward_f32(float x, float dy) { | ||||
|     const float s = 1.0f/(1.0f + expf(-x)); | ||||
|     return dy*s*(1.0f + x*(1.0f - s)); | ||||
| } | ||||
|  | ||||
| inline static ggml_fp16_t ggml_silu_backward_f16(ggml_fp16_t x, ggml_fp16_t dy) { | ||||
|     const float v = GGML_FP16_TO_FP32(x); | ||||
|     const float s = 1.0f/(1.0f + expf(-v)); | ||||
|     return GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(dy)*s*(1.0f + v*(1.0f - s))); | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         dx[i] = ggml_silu_backward_f32(x[i], dy[i]); | ||||
|     } | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, const ggml_fp16_t * x, const ggml_fp16_t * dy) { | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         dx[i] = ggml_silu_backward_f16(x[i], dy[i]); | ||||
|     } | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) { | ||||
| #ifndef GGML_USE_ACCELERATE | ||||
|     ggml_float sum = 0.0; | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         sum += (ggml_float)x[i]; | ||||
|     } | ||||
|     *s = (float)sum; | ||||
| #else | ||||
|     vDSP_sve(x, 1, s, n); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_sum_f32_ggf(const int n, ggml_float * s, const float * x) { | ||||
|     ggml_float sum = 0.0; | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         sum += (ggml_float)x[i]; | ||||
|     } | ||||
|     *s = sum; | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_sum_f16_ggf(const int n, float * s, const ggml_fp16_t * x) { | ||||
|     float sum = 0.0f; | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         sum += GGML_FP16_TO_FP32(x[i]); | ||||
|     } | ||||
|     *s = sum; | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_sum_bf16_ggf(const int n, float * s, const ggml_bf16_t * x) { | ||||
|     float sum = 0.0f; | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         sum += GGML_BF16_TO_FP32(x[i]); | ||||
|     } | ||||
|     *s = sum; | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_max_f32(const int n, float * s, const float * x) { | ||||
| #ifndef GGML_USE_ACCELERATE | ||||
|     float max = -INFINITY; | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         max = MAX(max, x[i]); | ||||
|     } | ||||
|     *s = max; | ||||
| #else | ||||
|     vDSP_maxv(x, 1, s, n); | ||||
| #endif | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x) { | ||||
|     ggml_vec_norm_f32(n, s, x); | ||||
|     *s = 1.f/(*s); | ||||
| } | ||||
|  | ||||
| inline static void ggml_vec_argmax_f32(const int n, int * s, const float * x) { | ||||
|     float max = -INFINITY; | ||||
|     int idx = 0; | ||||
|     for (int i = 0; i < n; ++i) { | ||||
|         max = MAX(max, x[i]); | ||||
|         if (max == x[i]) { idx = i; } | ||||
|     } | ||||
|     *s = idx; | ||||
| } | ||||
|  | ||||
| #ifdef __cplusplus | ||||
| } | ||||
| #endif | ||||
		Reference in New Issue
	
	Block a user
	 cmdr2
					cmdr2