Merge branch 'master' into compilade/refactor-kv-cache

This commit is contained in:
Francis Couture-Harpin
2024-06-01 11:51:41 -04:00
164 changed files with 3908 additions and 2232 deletions

173
ggml.c
View File

@@ -60,6 +60,9 @@
typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool;
typedef atomic_int atomic_flag;
#define ATOMIC_FLAG_INIT 0
static void atomic_store(atomic_int * ptr, LONG val) {
InterlockedExchange(ptr, val);
@@ -73,6 +76,12 @@ static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
return atomic_fetch_add(ptr, -(dec));
}
static atomic_bool atomic_flag_test_and_set(atomic_flag * ptr) {
return InterlockedExchange(ptr, 1);
}
static void atomic_flag_clear(atomic_flag * ptr) {
InterlockedExchange(ptr, 0);
}
typedef HANDLE pthread_t;
@@ -1567,11 +1576,11 @@ do { \
// F16 arithmetic is not supported by AVX, so we use F32 instead
#define GGML_F32Cx8 __m256
#define GGML_F32Cx8 __m256
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
float tmp[8];
for (int i = 0; i < 8; i++) {
@@ -1580,13 +1589,14 @@ static inline __m256 __lasx_f32cx8_load(ggml_fp16_t *x) {
return (__m256)__lasx_xvld(tmp, 0);
}
static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
float arr[8];
__lasx_xvst(y, arr, 0);
for (int i = 0; i < 8; i++)
for (int i = 0; i < 8; i++) {
x[i] = GGML_FP32_TO_FP16(arr[i]);
}
}
#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
@@ -1662,7 +1672,7 @@ static inline void __lasx_f32cx8_store(ggml_fp16_t *x, __m256 y) {
#define GGML_F16_STEP 32
#define GGML_F16_EPR 4
static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
static inline __m128 __lsx_f16x4_load(const ggml_fp16_t * x) {
float tmp[4];
tmp[0] = GGML_FP16_TO_FP32(x[0]);
@@ -1673,7 +1683,7 @@ static inline __m128 __lsx_f16x4_load(ggml_fp16_t *x) {
return __lsx_vld(tmp, 0);
}
static inline void __lsx_f16x4_store(ggml_fp16_t *x, __m128 y) {
static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
float arr[4];
__lsx_vst(y, arr, 0);
@@ -2306,32 +2316,27 @@ inline static __m512 ggml_v_expf(__m512 x) {
const __m512 r = _mm512_set1_ps(0x1.8p23f);
const __m512 z = _mm512_fmadd_ps(x, _mm512_set1_ps(0x1.715476p+0f), r);
const __m512 n = _mm512_sub_ps(z, r);
const __m512 b = _mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
const __m512i e = _mm512_slli_epi32(_mm512_castps_si512(z), 23);
const __m512 k = _mm512_castsi512_ps(_mm512_add_epi32(e, _mm512_castps_si512(_mm512_set1_ps(1))));
const __mmask16 c = _mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(126), _CMP_GT_OQ);
const __m512 u = _mm512_mul_ps(b, b);
const __m512 j = _mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
_mm512_set1_ps(0x1.573e2ep-5f)), u,
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
_mm512_set1_ps(0x1.fffdb6p-2f))),
u, _mm512_mul_ps(_mm512_set1_ps(0x1.ffffecp-1f), b));
if (_mm512_kortestz(c, c))
return _mm512_fmadd_ps(j, k, k);
const __m512i g = _mm512_and_si512(
_mm512_movm_epi32(_mm512_cmp_ps_mask(n, _mm512_setzero_ps(), _CMP_LE_OQ)),
_mm512_set1_epi32(0x82000000u));
const __m512 s1 =
_mm512_castsi512_ps(_mm512_add_epi32(g, _mm512_set1_epi32(0x7f000000u)));
const __m512 s2 = _mm512_castsi512_ps(_mm512_sub_epi32(e, g));
const __m512 b =
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.7f7d1cp-20f),
_mm512_fnmadd_ps(n, _mm512_set1_ps(0x1.62e4p-1f), x));
const __mmask16 d =
_mm512_cmp_ps_mask(_mm512_abs_ps(n), _mm512_set1_ps(192), _CMP_GT_OQ);
return _mm512_mask_blend_ps(
d, _mm512_mask_blend_ps(
c, _mm512_fmadd_ps(k, j, k),
_mm512_mul_ps(_mm512_fmadd_ps(s2, j, s2), s1)),
_mm512_mul_ps(s1, s1));
const __m512 u = _mm512_mul_ps(b, b);
const __m512 j = _mm512_fmadd_ps(
_mm512_fmadd_ps(_mm512_fmadd_ps(_mm512_set1_ps(0x1.0e4020p-7f), b,
_mm512_set1_ps(0x1.573e2ep-5f)),
u,
_mm512_fmadd_ps(_mm512_set1_ps(0x1.555e66p-3f), b,
_mm512_set1_ps(0x1.fffdb6p-2f))),
u,
_mm512_fmadd_ps(_mm512_set1_ps(0x1.ffffecp-1f), b, _mm512_set1_ps(1.0F)));
const __m512 res = _mm512_scalef_ps(j, n);
if (_mm512_kortestz(d, d))
return res;
const __m512 zero = _mm512_setzero_ps();
const __m512 alt = _mm512_mask_blend_ps(
_mm512_cmp_ps_mask(n, zero, _CMP_LE_OQ), _mm512_set1_ps(INFINITY), zero);
return _mm512_mask_blend_ps(d, res, alt);
}
// computes silu x/(1+exp(-x)) in single precision vector
@@ -2883,24 +2888,20 @@ struct ggml_state {
// global state
static struct ggml_state g_state;
static atomic_int g_state_barrier = 0;
static atomic_flag g_state_critical = ATOMIC_FLAG_INIT;
// barrier via spin lock
inline static void ggml_critical_section_start(void) {
int processing = atomic_fetch_add(&g_state_barrier, 1);
while (processing > 0) {
// wait for other threads to finish
atomic_fetch_sub(&g_state_barrier, 1);
sched_yield(); // TODO: reconsider this
processing = atomic_fetch_add(&g_state_barrier, 1);
while (atomic_flag_test_and_set(&g_state_critical)) {
// spin
sched_yield();
}
}
// TODO: make this somehow automatically executed
// some sort of "sentry" mechanism
inline static void ggml_critical_section_end(void) {
atomic_fetch_sub(&g_state_barrier, 1);
atomic_flag_clear(&g_state_critical);
}
#if defined(__gnu_linux__)
@@ -3216,7 +3217,11 @@ GGML_CALL bool ggml_is_contiguous(const struct ggml_tensor * tensor) {
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}
static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * tensor) {
GGML_CALL bool ggml_is_contiguous_0(const struct ggml_tensor * tensor) {
return ggml_is_contiguous(tensor);
}
GGML_CALL bool ggml_is_contiguous_1(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return
@@ -3225,6 +3230,14 @@ static inline bool ggml_is_contiguous_except_dim_1(const struct ggml_tensor * te
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}
GGML_CALL bool ggml_is_contiguous_2(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return
tensor->nb[0] == ggml_type_size(tensor->type) &&
tensor->nb[3] == tensor->nb[2]*tensor->ne[2];
}
GGML_CALL bool ggml_is_permuted(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
@@ -6392,6 +6405,16 @@ struct ggml_tensor * ggml_rope_custom_inplace(
);
}
struct ggml_tensor * ggml_rope_xpos_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int n_dims,
float base,
bool down) {
return ggml_rope_impl(ctx, a, b, NULL, n_dims, 0, 0, 0, 10000.0f, 1.0f, 0.0f, 1.0f, 0.0f, 0.0f, base, down, true);
}
// ggml_rope_back
struct ggml_tensor * ggml_rope_back(
@@ -11008,7 +11031,7 @@ static void ggml_compute_forward_concat_f32(
static void ggml_compute_forward_concat(
const struct ggml_compute_params * params,
struct ggml_tensor* dst) {
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
@@ -11401,8 +11424,8 @@ static void ggml_compute_forward_gelu_f32(
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11464,8 +11487,8 @@ static void ggml_compute_forward_gelu_quick_f32(
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11527,8 +11550,8 @@ static void ggml_compute_forward_silu_f32(
const struct ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
@@ -11639,9 +11662,9 @@ static void ggml_compute_forward_silu_back_f32(
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * grad = dst->src[1];
GGML_ASSERT(ggml_is_contiguous_except_dim_1(grad));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(src0));
GGML_ASSERT(ggml_is_contiguous_except_dim_1(dst));
GGML_ASSERT(ggml_is_contiguous_1(grad));
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_are_same_shape(src0, grad));
@@ -14339,7 +14362,7 @@ static void ggml_compute_forward_rope_f32(
int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.f/n_dims;
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
@@ -14388,7 +14411,7 @@ static void ggml_compute_forward_rope_f32(
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta) * sin_sign;
theta_base *= theta_scale;
theta_base *= theta_scale;
block_theta *= theta_scale;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14423,29 +14446,22 @@ static void ggml_compute_forward_rope_f32(
dst_data[1] = x0*sin_theta*zeta + x1*cos_theta*zeta;
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t ib = 0;
const int64_t i0 = ic/2;
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
sin_theta *= sin_sign;
sin_theta *= sin_sign;
theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
const float * const src = (float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -14524,7 +14540,7 @@ static void ggml_compute_forward_rope_f16(
int ir = 0;
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.f/n_dims;
float corr_dims[2];
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims);
@@ -14573,7 +14589,7 @@ static void ggml_compute_forward_rope_f16(
const float cos_block_theta = cosf(block_theta);
const float sin_block_theta = sinf(block_theta) * sin_sign;
theta_base *= theta_scale;
theta_base *= theta_scale;
block_theta *= theta_scale;
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
@@ -14604,29 +14620,22 @@ static void ggml_compute_forward_rope_f16(
dst_data[1] = GGML_FP32_TO_FP16(x0*sin_theta + x1*cos_theta);
}
} else {
// TODO: this might be wrong for ne0 != n_dims - need double check
// it seems we have to rope just the first n_dims elements and do nothing with the rest
// ref: https://github.com/ml-explore/mlx/blob/dc2edc762c797e3b8de50b1dad4dc0a131691033/benchmarks/python/llama_jax_bench.py#L11-L26
theta_base *= freq_scale;
// ref: https://github.com/jquesnelle/yarn/blob/master/scaled_rope/LlamaYaRNScaledRotaryEmbedding.py
for (int64_t ic = 0; ic < ne0; ic += 2) {
if (ic < n_dims) {
const int64_t ib = 0;
const int64_t i0 = ic/2;
// simplified from `(ib * n_dims + ic) * inv_ndims`
float cur_rot = inv_ndims * ic - ib;
float freq_factor = freq_factors ? freq_factors[ic/2] : 1.0f;
const float freq_factor = freq_factors ? freq_factors[i0] : 1.0f;
float cos_theta, sin_theta;
rope_yarn(
theta_base/freq_factor, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor,
theta_base/freq_factor, freq_scale, corr_dims, ic, ext_factor, attn_factor,
&cos_theta, &sin_theta
);
sin_theta *= sin_sign;
sin_theta *= sin_sign;
theta_base *= theta_scale;
const int64_t i0 = ib*n_dims + ic/2;
const ggml_fp16_t * const src = (ggml_fp16_t *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
ggml_fp16_t * dst_data = (ggml_fp16_t *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
@@ -22797,6 +22806,14 @@ int ggml_cpu_has_sycl(void) {
#endif
}
int ggml_cpu_has_rpc(void) {
#if defined(GGML_USE_RPC)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_gpublas(void) {
return ggml_cpu_has_cuda() || ggml_cpu_has_clblast() || ggml_cpu_has_vulkan() || ggml_cpu_has_kompute() ||
ggml_cpu_has_sycl();