mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	ggml : rewrite silu and softmax for cpu (#7154)
This change upstreams llamafile's vectorized expf() functions. This lets us compute softmax and silu more accurately than the short[65536] lookup table that GGML previously used to make this operation go faster. We can support aarch64 and sse2+ with the worst case rounding error of 2ulp. It makes make -j8 tests && ./tests/test-backend-ops -o SOFT_MAX -b CPU perf go 1.5x faster for SSE2+FMA, 1.9x faster for AVX2+FMA and 2.1x on AVX512
This commit is contained in:
		
							
								
								
									
										476
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										476
									
								
								ggml.c
									
									
									
									
									
								
							@@ -165,9 +165,6 @@ void ggml_print_backtrace(void) {
 | 
			
		||||
#define GGML_DEBUG 0
 | 
			
		||||
#define GGML_GELU_FP16
 | 
			
		||||
#define GGML_GELU_QUICK_FP16
 | 
			
		||||
#define GGML_SILU_FP16
 | 
			
		||||
// #define GGML_CROSS_ENTROPY_EXP_FP16
 | 
			
		||||
// #define GGML_FLASH_ATTN_EXP_FP16
 | 
			
		||||
 | 
			
		||||
#define GGML_SOFT_MAX_UNROLL 4
 | 
			
		||||
#define GGML_VEC_DOT_UNROLL  2
 | 
			
		||||
@@ -318,12 +315,6 @@ static ggml_fp16_t ggml_table_gelu_f16[1 << 16];
 | 
			
		||||
// precomputed quick gelu table for f16 (128 KB)
 | 
			
		||||
static ggml_fp16_t ggml_table_gelu_quick_f16[1 << 16];
 | 
			
		||||
 | 
			
		||||
// precomputed silu table for f16 (128 KB)
 | 
			
		||||
static ggml_fp16_t ggml_table_silu_f16[1 << 16];
 | 
			
		||||
 | 
			
		||||
// precomputed exp table for f16 (128 KB)
 | 
			
		||||
static ggml_fp16_t ggml_table_exp_f16[1 << 16];
 | 
			
		||||
 | 
			
		||||
// precomputed f32 table for f16 (256 KB) (ggml-impl.h)
 | 
			
		||||
float ggml_table_f32_f16[1 << 16];
 | 
			
		||||
 | 
			
		||||
@@ -2085,52 +2076,291 @@ inline static float ggml_silu_f32(float x) {
 | 
			
		||||
    return x/(1.0f + expf(-x));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {
 | 
			
		||||
//    const uint16_t * i16 = (const uint16_t *) x;
 | 
			
		||||
//    for (int i = 0; i < n; ++i) {
 | 
			
		||||
//        y[i] = ggml_table_silu_f16[i16[i]];
 | 
			
		||||
//    }
 | 
			
		||||
//}
 | 
			
		||||
#if defined(__ARM_NEON)
 | 
			
		||||
 | 
			
		||||
#ifdef GGML_SILU_FP16
 | 
			
		||||
inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
 | 
			
		||||
    uint16_t t;
 | 
			
		||||
    for (int i = 0; i < n; ++i) {
 | 
			
		||||
        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
 | 
			
		||||
        memcpy(&t, &fp16, sizeof(uint16_t));
 | 
			
		||||
        y[i] = GGML_FP16_TO_FP32(ggml_table_silu_f16[t]);
 | 
			
		||||
    }
 | 
			
		||||
// adapted from arm limited optimized routine
 | 
			
		||||
// the maximum error is 1.45358 plus 0.5 ulps
 | 
			
		||||
// numbers above 88.38 will flush to infinity
 | 
			
		||||
// numbers beneath -103.97 will flush to zero
 | 
			
		||||
inline static float32x4_t ggml_v_expf(float32x4_t x) {
 | 
			
		||||
    const float32x4_t r = vdupq_n_f32(0x1.8p23f);
 | 
			
		||||
    const float32x4_t z = vfmaq_f32(r, x, vdupq_n_f32(0x1.715476p+0f));
 | 
			
		||||
    const float32x4_t n = vsubq_f32(z, r);
 | 
			
		||||
    const float32x4_t b = vfmsq_f32(vfmsq_f32(x, n, vdupq_n_f32(0x1.62e4p-1f)), n,
 | 
			
		||||
                                    vdupq_n_f32(0x1.7f7d1cp-20f));
 | 
			
		||||
    const uint32x4_t e = vshlq_n_u32(vreinterpretq_u32_f32(z), 23);
 | 
			
		||||
    const float32x4_t k = vreinterpretq_f32_u32(vaddq_u32(e, vreinterpretq_u32_f32(vdupq_n_f32(1))));
 | 
			
		||||
    const uint32x4_t c = vcagtq_f32(n, vdupq_n_f32(126));
 | 
			
		||||
    const float32x4_t u = vmulq_f32(b, b);
 | 
			
		||||
    const float32x4_t j = vfmaq_f32(
 | 
			
		||||
        vmulq_f32(vdupq_n_f32(0x1.ffffecp-1f), b),
 | 
			
		||||
        vfmaq_f32(vfmaq_f32(vdupq_n_f32(0x1.fffdb6p-2f), vdupq_n_f32(0x1.555e66p-3f), b),
 | 
			
		||||
                  vfmaq_f32(vdupq_n_f32(0x1.573e2ep-5f), vdupq_n_f32(0x1.0e4020p-7f), b), u), u);
 | 
			
		||||
    if (!vpaddd_u64(vreinterpretq_u64_u32(c)))
 | 
			
		||||
        return vfmaq_f32(k, j, k);
 | 
			
		||||
    const uint32x4_t d = vandq_u32(vclezq_f32(n), vdupq_n_u32(0x82000000));
 | 
			
		||||
    const float32x4_t s1 = vreinterpretq_f32_u32(vaddq_u32(d, vdupq_n_u32(0x7f000000)));
 | 
			
		||||
    const float32x4_t s2 = vreinterpretq_f32_u32(vsubq_u32(e, d));
 | 
			
		||||
    return vbslq_f32(vcagtq_f32(n, vdupq_n_f32(192)), vmulq_f32(s1, s1),
 | 
			
		||||
                     vbslq_f32(c, vmulq_f32(vfmaq_f32(s2, s2, j), s1), vfmaq_f32(k, k, j)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// computes silu x/(1+exp(-x)) in single precision vector
 | 
			
		||||
inline static float32x4_t ggml_v_silu(float32x4_t x) {
 | 
			
		||||
    const float32x4_t one = vdupq_n_f32(1.0f);
 | 
			
		||||
    const float32x4_t zero = vdupq_n_f32(0.0f);
 | 
			
		||||
    const float32x4_t neg_x = vsubq_f32(zero, x);
 | 
			
		||||
    const float32x4_t exp_neg_x = ggml_v_expf(neg_x);
 | 
			
		||||
    const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
 | 
			
		||||
    return vdivq_f32(x, one_plus_exp_neg_x);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#elif defined(__AVX512F__) && defined(__AVX512DQ__)
 | 
			
		||||
 | 
			
		||||
// adapted from arm limited optimized routine
 | 
			
		||||
// the maximum error is 1.45358 plus 0.5 ulps
 | 
			
		||||
// numbers above 88.38 will flush to infinity
 | 
			
		||||
// numbers beneath -103.97 will flush to zero
 | 
			
		||||
inline static __m512 ggml_v_expf(__m512 x) {
 | 
			
		||||
  const __m512 r = _mm512_set1_ps(0x1.8p23f);
 | 
			
		||||
  const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
 | 
			
		||||
  const __m512 n = _mm512_sub_ps(z, r);
 | 
			
		||||
  const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
 | 
			
		||||
                                    _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
 | 
			
		||||
  const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
 | 
			
		||||
  const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
 | 
			
		||||
  const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
 | 
			
		||||
  const __m512 u = _mm512_mul_ps(b, b);
 | 
			
		||||
  const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
 | 
			
		||||
                                                                   _mm512_set1_ps(0x1.573e2ep-5f)), u,
 | 
			
		||||
                                                   _mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
 | 
			
		||||
                                                                   _mm512_set1_ps(0x1.fffdb6p-2f))),
 | 
			
		||||
                                   u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
 | 
			
		||||
  if (_mm512_kortestz(c, c))
 | 
			
		||||
    return _mm512_fmadd_ps(j, k, k);
 | 
			
		||||
  const __m512i g = _mm512_and_si512(
 | 
			
		||||
      _mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
 | 
			
		||||
      _mm512_set1_epi32(0x82000000u));
 | 
			
		||||
  const __m512 s1 =
 | 
			
		||||
      _mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
 | 
			
		||||
  const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
 | 
			
		||||
  const __mmask16 d =
 | 
			
		||||
      _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
 | 
			
		||||
  return _mm512_mask_blend_ps(
 | 
			
		||||
      d, _mm512_mask_blend_ps(
 | 
			
		||||
          c, _mm512_fmadd_ps(k, j, k),
 | 
			
		||||
          _mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
 | 
			
		||||
      _mm512_mul_ps(s1, s1));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// computes silu x/(1+exp(-x)) in single precision vector
 | 
			
		||||
inline static __m512 ggml_v_silu(__m512 x) {
 | 
			
		||||
    const __m512 one = _mm512_set1_ps(1);
 | 
			
		||||
    const __m512 zero = _mm512_setzero_ps();
 | 
			
		||||
    const __m512 neg_x = _mm512_sub_ps(zero, x);
 | 
			
		||||
    const __m512 exp_neg_x = ggml_v_expf(neg_x);
 | 
			
		||||
    const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
 | 
			
		||||
    return _mm512_div_ps(x, one_plus_exp_neg_x);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#elif defined(__AVX2__) && defined(__FMA__)
 | 
			
		||||
 | 
			
		||||
// adapted from arm limited optimized routine
 | 
			
		||||
// the maximum error is 1.45358 plus 0.5 ulps
 | 
			
		||||
// numbers above 88.38 will flush to infinity
 | 
			
		||||
// numbers beneath -103.97 will flush to zero
 | 
			
		||||
inline static __m256 ggml_v_expf(__m256 x) {
 | 
			
		||||
  const __m256 r = _mm256_set1_ps(0x1.8p23f);
 | 
			
		||||
  const __m256 z = _mm256_fmadd_ps(x, _mm256_set1_ps(0x1.715476p+0f), r);
 | 
			
		||||
  const __m256 n = _mm256_sub_ps(z, r);
 | 
			
		||||
  const __m256 b = _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.7f7d1cp-20f),
 | 
			
		||||
                                    _mm256_fnmadd_ps(n, _mm256_set1_ps(0x1.62e4p-1f), x));
 | 
			
		||||
  const __m256i e = _mm256_slli_epi32(_mm256_castps_si256(z), 23);
 | 
			
		||||
  const __m256 k = _mm256_castsi256_ps(
 | 
			
		||||
      _mm256_add_epi32(e, _mm256_castps_si256(_mm256_set1_ps(1))));
 | 
			
		||||
  const __m256i c = _mm256_castps_si256(
 | 
			
		||||
      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
 | 
			
		||||
                    _mm256_set1_ps(126), _CMP_GT_OQ));
 | 
			
		||||
  const __m256 u = _mm256_mul_ps(b, b);
 | 
			
		||||
  const __m256 j = _mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_fmadd_ps(_mm256_set1_ps(0x1.0e4020p-7f), b,
 | 
			
		||||
                                                                   _mm256_set1_ps(0x1.573e2ep-5f)), u,
 | 
			
		||||
                                                   _mm256_fmadd_ps(_mm256_set1_ps(0x1.555e66p-3f), b,
 | 
			
		||||
                                                                   _mm256_set1_ps(0x1.fffdb6p-2f))),
 | 
			
		||||
                                   u, _mm256_mul_ps(_mm256_set1_ps(0x1.ffffecp-1f), b));
 | 
			
		||||
  if (!_mm256_movemask_ps(_mm256_castsi256_ps(c)))
 | 
			
		||||
    return _mm256_fmadd_ps(j, k, k);
 | 
			
		||||
  const __m256i g = _mm256_and_si256(
 | 
			
		||||
      _mm256_castps_si256(_mm256_cmp_ps(n, _mm256_setzero_ps(), _CMP_LE_OQ)),
 | 
			
		||||
      _mm256_set1_epi32(0x82000000u));
 | 
			
		||||
  const __m256 s1 =
 | 
			
		||||
      _mm256_castsi256_ps(_mm256_add_epi32(g, _mm256_set1_epi32(0x7f000000u)));
 | 
			
		||||
  const __m256 s2 = _mm256_castsi256_ps(_mm256_sub_epi32(e, g));
 | 
			
		||||
  const __m256i d = _mm256_castps_si256(
 | 
			
		||||
      _mm256_cmp_ps(_mm256_andnot_ps(_mm256_set1_ps(-0.f), n),
 | 
			
		||||
                    _mm256_set1_ps(192), _CMP_GT_OQ));
 | 
			
		||||
  return _mm256_or_ps(
 | 
			
		||||
      _mm256_and_ps(_mm256_castsi256_ps(d), _mm256_mul_ps(s1, s1)),
 | 
			
		||||
      _mm256_andnot_ps(
 | 
			
		||||
          _mm256_castsi256_ps(d),
 | 
			
		||||
          _mm256_or_ps(
 | 
			
		||||
              _mm256_and_ps(_mm256_castsi256_ps(c),
 | 
			
		||||
                            _mm256_mul_ps(_mm256_fmadd_ps(s2, j, s2), s1)),
 | 
			
		||||
              _mm256_andnot_ps(_mm256_castsi256_ps(c), _mm256_fmadd_ps(k, j, k)))));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// computes silu x/(1+exp(-x)) in single precision vector
 | 
			
		||||
inline static __m256 ggml_v_silu(__m256 x) {
 | 
			
		||||
    const __m256 one = _mm256_set1_ps(1);
 | 
			
		||||
    const __m256 zero = _mm256_setzero_ps();
 | 
			
		||||
    const __m256 neg_x = _mm256_sub_ps(zero, x);
 | 
			
		||||
    const __m256 exp_neg_x = ggml_v_expf(neg_x);
 | 
			
		||||
    const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
 | 
			
		||||
    return _mm256_div_ps(x, one_plus_exp_neg_x);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#elif defined(__SSE2__) // __AVX2__ / __ARM_NEON
 | 
			
		||||
 | 
			
		||||
#if defined(__FMA__)
 | 
			
		||||
#define MADD128(x, y, z) _mm_fmadd_ps(x, y, z)
 | 
			
		||||
#define NMADD128(x, y, z) _mm_fnmadd_ps(x, y, z)
 | 
			
		||||
#else
 | 
			
		||||
inline static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
 | 
			
		||||
    for (int i = 0; i < n; ++i) {
 | 
			
		||||
#define MADD128(x, y, z) _mm_add_ps(_mm_mul_ps(x, y), z)
 | 
			
		||||
#define NMADD128(x, y, z) _mm_sub_ps(z, _mm_mul_ps(x, y))
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
// adapted from arm limited optimized routine
 | 
			
		||||
// the maximum error is 1.45358 plus 0.5 ulps
 | 
			
		||||
// numbers above 88.38 will flush to infinity
 | 
			
		||||
// numbers beneath -103.97 will flush to zero
 | 
			
		||||
inline static __m128 ggml_v_expf(__m128 x) {
 | 
			
		||||
    const __m128 r = _mm_set1_ps(0x1.8p23f);
 | 
			
		||||
    const __m128 z = MADD128(x, _mm_set1_ps(0x1.715476p+0f), r);
 | 
			
		||||
    const __m128 n = _mm_sub_ps(z, r);
 | 
			
		||||
    const __m128 b =
 | 
			
		||||
        NMADD128(n, _mm_set1_ps(0x1.7f7d1cp-20f), NMADD128(n, _mm_set1_ps(0x1.62e4p-1f), x));
 | 
			
		||||
    const __m128i e = _mm_slli_epi32(_mm_castps_si128(z), 23);
 | 
			
		||||
    const __m128 k = _mm_castsi128_ps(_mm_add_epi32(e, _mm_castps_si128(_mm_set1_ps(1))));
 | 
			
		||||
    const __m128i c =
 | 
			
		||||
        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(126)));
 | 
			
		||||
    const __m128 u = _mm_mul_ps(b, b);
 | 
			
		||||
    const __m128 j =
 | 
			
		||||
        MADD128(MADD128(MADD128(_mm_set1_ps(0x1.0e4020p-7f), b, _mm_set1_ps(0x1.573e2ep-5f)), u,
 | 
			
		||||
                        MADD128(_mm_set1_ps(0x1.555e66p-3f), b, _mm_set1_ps(0x1.fffdb6p-2f))),
 | 
			
		||||
                u, _mm_mul_ps(_mm_set1_ps(0x1.ffffecp-1f), b));
 | 
			
		||||
    if (!_mm_movemask_epi8(c))
 | 
			
		||||
        return MADD128(j, k, k);
 | 
			
		||||
    const __m128i g = _mm_and_si128(_mm_castps_si128(_mm_cmple_ps(n, _mm_setzero_ps())),
 | 
			
		||||
                                    _mm_set1_epi32(0x82000000u));
 | 
			
		||||
    const __m128 s1 = _mm_castsi128_ps(_mm_add_epi32(g, _mm_set1_epi32(0x7f000000u)));
 | 
			
		||||
    const __m128 s2 = _mm_castsi128_ps(_mm_sub_epi32(e, g));
 | 
			
		||||
    const __m128i d =
 | 
			
		||||
        _mm_castps_si128(_mm_cmpgt_ps(_mm_andnot_ps(_mm_set1_ps(-0.f), n), _mm_set1_ps(192)));
 | 
			
		||||
    return _mm_or_ps(
 | 
			
		||||
        _mm_and_ps(_mm_castsi128_ps(d), _mm_mul_ps(s1, s1)),
 | 
			
		||||
        _mm_andnot_ps(_mm_castsi128_ps(d),
 | 
			
		||||
                      _mm_or_ps(_mm_and_ps(_mm_castsi128_ps(c), _mm_mul_ps(MADD128(s2, j, s2), s1)),
 | 
			
		||||
                                _mm_andnot_ps(_mm_castsi128_ps(c), MADD128(k, j, k)))));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// computes silu x/(1+exp(-x)) in single precision vector
 | 
			
		||||
inline static __m128 ggml_v_silu(__m128 x) {
 | 
			
		||||
    const __m128 one = _mm_set1_ps(1);
 | 
			
		||||
    const __m128 zero = _mm_setzero_ps();
 | 
			
		||||
    const __m128 neg_x = _mm_sub_ps(zero, x);
 | 
			
		||||
    const __m128 exp_neg_x = ggml_v_expf(neg_x);
 | 
			
		||||
    const __m128 one_plus_exp_neg_x = _mm_add_ps(one, exp_neg_x);
 | 
			
		||||
    return _mm_div_ps(x, one_plus_exp_neg_x);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#endif // __ARM_NEON / __AVX2__ / __SSE2__
 | 
			
		||||
 | 
			
		||||
static void ggml_vec_silu_f32(const int n, float * y, const float * x) {
 | 
			
		||||
    int i = 0;
 | 
			
		||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
 | 
			
		||||
    for (; i + 15 < n; i += 16) {
 | 
			
		||||
        _mm512_storeu_ps(y + i, ggml_v_silu(_mm512_loadu_ps(x + i)));
 | 
			
		||||
    }
 | 
			
		||||
#elif defined(__AVX2__) && defined(__FMA__)
 | 
			
		||||
    for (; i + 7 < n; i += 8) {
 | 
			
		||||
        _mm256_storeu_ps(y + i, ggml_v_silu(_mm256_loadu_ps(x + i)));
 | 
			
		||||
    }
 | 
			
		||||
#elif defined(__SSE2__)
 | 
			
		||||
    for (; i + 3 < n; i += 4) {
 | 
			
		||||
        _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
 | 
			
		||||
    }
 | 
			
		||||
#elif defined(__ARM_NEON)
 | 
			
		||||
    for (; i + 3 < n; i += 4) {
 | 
			
		||||
        vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
    for (; i < n; ++i) {
 | 
			
		||||
        y[i] = ggml_silu_f32(x[i]);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
 | 
			
		||||
    int i = 0;
 | 
			
		||||
    ggml_float sum = 0;
 | 
			
		||||
#if defined(__AVX512F__) && defined(__AVX512DQ__)
 | 
			
		||||
    for (; i + 15 < n; i += 16) {
 | 
			
		||||
        __m512 val = ggml_v_expf(_mm512_sub_ps(_mm512_loadu_ps(x + i),
 | 
			
		||||
                                               _mm512_set1_ps(max)));
 | 
			
		||||
        _mm512_storeu_ps(y + i, val);
 | 
			
		||||
        sum += (ggml_float)_mm512_reduce_add_ps(val);
 | 
			
		||||
    }
 | 
			
		||||
#elif defined(__AVX2__) && defined(__FMA__)
 | 
			
		||||
    for (; i + 7 < n; i += 8) {
 | 
			
		||||
        __m256 val = ggml_v_expf(_mm256_sub_ps(_mm256_loadu_ps(x + i),
 | 
			
		||||
                                               _mm256_set1_ps(max)));
 | 
			
		||||
        _mm256_storeu_ps(y + i, val);
 | 
			
		||||
        __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1),
 | 
			
		||||
                                 _mm256_castps256_ps128(val));
 | 
			
		||||
        val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2));
 | 
			
		||||
        val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2));
 | 
			
		||||
        sum += (ggml_float)_mm_cvtss_f32(val2);
 | 
			
		||||
    }
 | 
			
		||||
#elif defined(__SSE2__)
 | 
			
		||||
    for (; i + 3 < n; i += 4) {
 | 
			
		||||
        __m128 val = ggml_v_expf(_mm_sub_ps(_mm_loadu_ps(x + i),
 | 
			
		||||
                                            _mm_set1_ps(max)));
 | 
			
		||||
        _mm_storeu_ps(y + i, val);
 | 
			
		||||
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
 | 
			
		||||
        val = _mm_add_ps(val, _mm_movehl_ps(val, val));
 | 
			
		||||
        val = _mm_add_ss(val, _mm_movehdup_ps(val));
 | 
			
		||||
#else
 | 
			
		||||
        __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1));
 | 
			
		||||
        val = _mm_add_ps(val, tmp);
 | 
			
		||||
        tmp = _mm_movehl_ps(tmp, val);
 | 
			
		||||
        val = _mm_add_ss(val, tmp);
 | 
			
		||||
#endif
 | 
			
		||||
        sum += (ggml_float)_mm_cvtss_f32(val);
 | 
			
		||||
    }
 | 
			
		||||
#elif defined(__ARM_NEON)
 | 
			
		||||
    for (; i + 3 < n; i += 4) {
 | 
			
		||||
        float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),
 | 
			
		||||
                                                vdupq_n_f32(max)));
 | 
			
		||||
        vst1q_f32(y + i, val);
 | 
			
		||||
        sum += (ggml_float)vaddvq_f32(val);
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
    for (; i < n; ++i) {
 | 
			
		||||
        float val = expf(x[i] - max);
 | 
			
		||||
        sum += (ggml_float)val;
 | 
			
		||||
        y[i] = val;
 | 
			
		||||
    }
 | 
			
		||||
    return sum;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline static float ggml_silu_backward_f32(float x, float dy) {
 | 
			
		||||
    const float s = 1.0f/(1.0f + expf(-x));
 | 
			
		||||
    return dy*s*(1.0f + x*(1.0f - s));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#ifdef GGML_SILU_FP16
 | 
			
		||||
inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
 | 
			
		||||
    for (int i = 0; i < n; ++i) {
 | 
			
		||||
        // we did not use x[i] to compute forward silu but its f16 equivalent
 | 
			
		||||
        // take derivative at f16 of x[i]:
 | 
			
		||||
        ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
 | 
			
		||||
        float usedx = GGML_FP16_TO_FP32(fp16);
 | 
			
		||||
        dx[i] = ggml_silu_backward_f32(usedx, dy[i]);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
#else
 | 
			
		||||
inline static void ggml_vec_silu_backward_f32(const int n, float * dx, const float * x, const float * dy) {
 | 
			
		||||
    for (int i = 0; i < n; ++i) {
 | 
			
		||||
        dx[i] = ggml_silu_backward_f32(x[i], dy[i]);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
 | 
			
		||||
#ifndef GGML_USE_ACCELERATE
 | 
			
		||||
@@ -2922,8 +3152,6 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
 | 
			
		||||
                float f = ggml_table_f32_f16[i] = GGML_COMPUTE_FP16_TO_FP32(u.fp16);
 | 
			
		||||
                ggml_table_gelu_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_f32(f));
 | 
			
		||||
                ggml_table_gelu_quick_f16[i] = GGML_FP32_TO_FP16(ggml_gelu_quick_f32(f));
 | 
			
		||||
                ggml_table_silu_f16[i] = GGML_FP32_TO_FP16(ggml_silu_f32(f));
 | 
			
		||||
                ggml_table_exp_f16[i] = GGML_FP32_TO_FP16(expf(f));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
 | 
			
		||||
@@ -13600,22 +13828,7 @@ static void ggml_compute_forward_soft_max_f32(
 | 
			
		||||
        float max = -INFINITY;
 | 
			
		||||
        ggml_vec_max_f32(nc, &max, wp);
 | 
			
		||||
 | 
			
		||||
        ggml_float sum = 0.0;
 | 
			
		||||
 | 
			
		||||
        uint16_t scvt;
 | 
			
		||||
        for (int i = 0; i < nc; i++) {
 | 
			
		||||
            if (wp[i] == -INFINITY) {
 | 
			
		||||
                dp[i] = 0.0f;
 | 
			
		||||
            } else {
 | 
			
		||||
                // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
 | 
			
		||||
                ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
 | 
			
		||||
                memcpy(&scvt, &s, sizeof(scvt));
 | 
			
		||||
                const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
 | 
			
		||||
                sum += (ggml_float)val;
 | 
			
		||||
                dp[i] = val;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
 | 
			
		||||
        assert(sum > 0.0);
 | 
			
		||||
 | 
			
		||||
        sum = 1.0/sum;
 | 
			
		||||
@@ -15374,37 +15587,7 @@ static void ggml_compute_forward_flash_attn_f32(
 | 
			
		||||
                vvexpf(S, S, &Mup);
 | 
			
		||||
                ggml_vec_sum_f32(Mup, &sum, S);
 | 
			
		||||
#else
 | 
			
		||||
                uint16_t   scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
 | 
			
		||||
                ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
 | 
			
		||||
 | 
			
		||||
                for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
 | 
			
		||||
                    if (i >= masked_begin) {
 | 
			
		||||
                        break;
 | 
			
		||||
                    }
 | 
			
		||||
                    float * SS = S + i;
 | 
			
		||||
 | 
			
		||||
                    for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
 | 
			
		||||
                        if (i + j >= masked_begin) {
 | 
			
		||||
                            break;
 | 
			
		||||
                        } else if (SS[j] == -INFINITY) {
 | 
			
		||||
                            SS[j] = 0.0f;
 | 
			
		||||
                        } else {
 | 
			
		||||
#ifndef GGML_FLASH_ATTN_EXP_FP16
 | 
			
		||||
                            const float val = expf(SS[j] - max);
 | 
			
		||||
#else
 | 
			
		||||
                            ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
 | 
			
		||||
                            memcpy(&scvt[j], &s, sizeof(uint16_t));
 | 
			
		||||
                            const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
 | 
			
		||||
#endif
 | 
			
		||||
                            sump[j] += (ggml_float)val;
 | 
			
		||||
                            SS[j] = val;
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
 | 
			
		||||
                    sum += sump[i];
 | 
			
		||||
                }
 | 
			
		||||
                sum = ggml_vec_soft_max_f32(Mup, S, S, max);
 | 
			
		||||
#endif
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
@@ -15586,28 +15769,7 @@ static void ggml_compute_forward_flash_attn_f16(
 | 
			
		||||
                vvexpf(S, S, &Mup);
 | 
			
		||||
                ggml_vec_sum_f32(Mup, &sum, S);
 | 
			
		||||
#else
 | 
			
		||||
                uint16_t   scvt[GGML_SOFT_MAX_UNROLL];
 | 
			
		||||
                ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
 | 
			
		||||
 | 
			
		||||
                for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
 | 
			
		||||
                    float * SS = S + i;
 | 
			
		||||
 | 
			
		||||
                    for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
 | 
			
		||||
                        if (SS[j] == -INFINITY) {
 | 
			
		||||
                            SS[j] = 0.0f;
 | 
			
		||||
                        } else {
 | 
			
		||||
                            ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
 | 
			
		||||
                            memcpy(&scvt[j], &s, sizeof(uint16_t));
 | 
			
		||||
                            const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
 | 
			
		||||
                            sump[j] += (ggml_float)val;
 | 
			
		||||
                            SS[j] = val;
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
 | 
			
		||||
                    sum += sump[i];
 | 
			
		||||
                }
 | 
			
		||||
                sum = ggml_vec_soft_max_f32(Mup, S, S, max);
 | 
			
		||||
#endif
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
@@ -16234,38 +16396,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
 | 
			
		||||
                        vvexpf(SM, SM, &Mup);
 | 
			
		||||
                        ggml_vec_sum_f32(Mup, &sum, SM);
 | 
			
		||||
#else
 | 
			
		||||
                        uint16_t   scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
 | 
			
		||||
                        ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
 | 
			
		||||
 | 
			
		||||
                        for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
 | 
			
		||||
                            if (i >= masked_begin) {
 | 
			
		||||
                                break;
 | 
			
		||||
                            }
 | 
			
		||||
                            float * SR =  S + i;
 | 
			
		||||
                            float * SW = SM + i;
 | 
			
		||||
 | 
			
		||||
                            for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
 | 
			
		||||
                                if (i + j >= masked_begin) {
 | 
			
		||||
                                    break;
 | 
			
		||||
                                } else if (SR[j] == -INFINITY) {
 | 
			
		||||
                                    SW[j] = 0.0f;
 | 
			
		||||
                                } else {
 | 
			
		||||
#ifndef GGML_FLASH_ATTN_EXP_FP16
 | 
			
		||||
                                    const float val = expf(SR[j] - max);
 | 
			
		||||
#else
 | 
			
		||||
                                    ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
 | 
			
		||||
                                    memcpy(&scvt[j], &s, sizeof(uint16_t));
 | 
			
		||||
                                    const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt[j]]);
 | 
			
		||||
#endif
 | 
			
		||||
                                    sump[j] += (ggml_float)val;
 | 
			
		||||
                                    SW[j] = val;
 | 
			
		||||
                                }
 | 
			
		||||
                            }
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
 | 
			
		||||
                            sum += sump[i];
 | 
			
		||||
                        }
 | 
			
		||||
                        sum = ggml_vec_soft_max_f32(Mup, SM, S, max);
 | 
			
		||||
#endif
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
@@ -17291,35 +17422,15 @@ static void ggml_compute_forward_cross_entropy_loss_f32(
 | 
			
		||||
            assert(!isnan(s1[i]));
 | 
			
		||||
        }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
        // soft_max
 | 
			
		||||
        ggml_float sum = 0.0;
 | 
			
		||||
        {
 | 
			
		||||
            float max = -INFINITY;
 | 
			
		||||
            ggml_vec_max_f32(nc, &max, s0);
 | 
			
		||||
 | 
			
		||||
            uint16_t scvt; UNUSED(scvt);
 | 
			
		||||
            for (int i = 0; i < nc; i++) {
 | 
			
		||||
                if (s0[i] == -INFINITY) {
 | 
			
		||||
                    st[i] = 0.0f;
 | 
			
		||||
                } else {
 | 
			
		||||
#ifndef GGML_CROSS_ENTROPY_EXP_FP16
 | 
			
		||||
                    const float s = s0[i] - max;
 | 
			
		||||
                    const float val = expf(s);
 | 
			
		||||
#else
 | 
			
		||||
                    ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
 | 
			
		||||
                    memcpy(&scvt, &s, sizeof(scvt));
 | 
			
		||||
                    const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
 | 
			
		||||
#endif
 | 
			
		||||
                    sum += (ggml_float)val;
 | 
			
		||||
                    st[i] = val;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            assert(sum > 0.0);
 | 
			
		||||
            // sum = 1.0/sum;
 | 
			
		||||
        }
 | 
			
		||||
        // avoid log(0) by rescaling from [0..1] to [eps..1]
 | 
			
		||||
        float max = -INFINITY;
 | 
			
		||||
        ggml_vec_max_f32(nc, &max, s0);
 | 
			
		||||
        ggml_float sum = ggml_vec_soft_max_f32(nc, st, s0, max);
 | 
			
		||||
        assert(sum > 0.0);
 | 
			
		||||
        sum = (1.0 - eps) / sum;
 | 
			
		||||
 | 
			
		||||
        // avoid log(0) by rescaling from [0..1] to [eps..1]
 | 
			
		||||
        ggml_vec_scale_f32(nc, st, sum);
 | 
			
		||||
        ggml_vec_add1_f32(nc, st, st, eps);
 | 
			
		||||
        ggml_vec_log_f32(nc, st, st);
 | 
			
		||||
@@ -17409,32 +17520,11 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
        // soft_max
 | 
			
		||||
        ggml_float sum = 0.0;
 | 
			
		||||
        {
 | 
			
		||||
            float max = -INFINITY;
 | 
			
		||||
            ggml_vec_max_f32(nc, &max, s0);
 | 
			
		||||
 | 
			
		||||
            uint16_t scvt; UNUSED(scvt);
 | 
			
		||||
            for (int i = 0; i < nc; i++) {
 | 
			
		||||
                if (s0[i] == -INFINITY) {
 | 
			
		||||
                    ds0[i] = 0.0f;
 | 
			
		||||
                } else {
 | 
			
		||||
#ifndef GGML_CROSS_ENTROPY_EXP_FP16
 | 
			
		||||
                    const float s = s0[i] - max;
 | 
			
		||||
                    const float val = expf(s);
 | 
			
		||||
#else
 | 
			
		||||
                    ggml_fp16_t s = GGML_FP32_TO_FP16(s0[i] - max);
 | 
			
		||||
                    memcpy(&scvt, &s, sizeof(scvt));
 | 
			
		||||
                    const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
 | 
			
		||||
#endif
 | 
			
		||||
                    sum += (ggml_float)val;
 | 
			
		||||
                    ds0[i] = val;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            assert(sum > 0.0);
 | 
			
		||||
            sum = (1.0 - eps)/sum;
 | 
			
		||||
        }
 | 
			
		||||
        float max = -INFINITY;
 | 
			
		||||
        ggml_vec_max_f32(nc, &max, s0);
 | 
			
		||||
        ggml_float sum = ggml_vec_soft_max_f32(nc, ds0, s0, max);
 | 
			
		||||
        assert(sum > 0.0);
 | 
			
		||||
        sum = (1.0 - eps) / sum;
 | 
			
		||||
 | 
			
		||||
        // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
 | 
			
		||||
        ggml_vec_scale_f32(nc, ds0, sum);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user