mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
llama : add gpt-oss (#15091)
* oai moe * compat with new checkpoint * add attn sink impl * add rope scaling yarn * logits match with latest transformers code * wip chat template * rm trailing space * use ggml_scale_bias * rm redundant is_swa_all * convert interleaved gate_up * graph : fix activation function to match reference (#7) * vocab : handle o200k_harmony special tokens * ggml : add attention sinks support (#1) * llama : add attn sinks * ggml : add attn sinks * cuda : add attn sinks * vulkan : add support for sinks in softmax remove unnecessary return * ggml : add fused swiglu_oai op (#11) * ggml : add fused swiglu_oai op * Update ggml/src/ggml-cpu/ops.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * update CUDA impl * cont : metal impl * add vulkan impl * test-backend-ops : more test cases, clean up * llama : remove unfused impl * remove extra lines --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> --------- Co-authored-by: slaren <slarengh@gmail.com> * repack mxfp4 upon conversion * clean up a bit * enable thinking * add quick hack to render only some special tokens * fix bf16 conversion * remove vocab hack * webui ok * support chat parsing for gpt-oss * fix webui * direct mapping mxfp4, FINALLY * force using mxfp4 * properly use lazy tensor * ggml : add mxfp4 ggml : use e8m0 conversion instead of powf Co-authored-by: Diego Devesa <slarengh@gmail.com> change kvalues_mxfp4 table to match e2m1 (#6) metal : remove quantization for now (not used) cuda : fix disabled CUDA graphs due to ffn moe bias vulkan : add support for mxfp4 cont : add cm2 dequant * ggml : add ggml_add_id (#13) * ggml : add ggml_add_id * add cuda impl * llama : add weight support check for add_id * perf opt * add vulkan impl * rename cuda files * add metal impl * allow in-place ggml_add_id * llama : keep biases on CPU with --cpu-moe * llama : fix compile error ggml-ci * cuda : add fallback for __nv_cvt_e8m0_to_bf16raw ggml-ci * cleanup ggml-ci * sycl : fix supports_op for MXFP4 ggml-ci * fix Unknown reasoning format * ggml-cpu : fix AVX build ggml-ci * fix hip build ggml-ci * cuda : add mxfp4 dequantization support for cuBLAS ggml-ci * ggml-cpu : fix mxfp4 fallback definitions for some architectures ggml-ci * cuda : fix version required for __nv_cvt_e8m0_to_bf16raw --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co> Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
@@ -296,6 +296,7 @@ static std::string var_to_str(ggml_scale_mode mode) {
|
||||
#define VARS_TO_STR10(a, b, c, d, e, f, g, h, i, j) VAR_TO_STR(a) + "," + VARS_TO_STR9(b, c, d, e, f, g, h, i, j)
|
||||
#define VARS_TO_STR11(a, b, c, d, e, f, g, h, i, j, k) VAR_TO_STR(a) + "," + VARS_TO_STR10(b, c, d, e, f, g, h, i, j, k)
|
||||
#define VARS_TO_STR12(a, b, c, d, e, f, g, h, i, j, k, l) VAR_TO_STR(a) + "," + VARS_TO_STR11(b, c, d, e, f, g, h, i, j, k, l)
|
||||
#define VARS_TO_STR13(a, b, c, d, e, f, g, h, i, j, k, l, m) VAR_TO_STR(a) + "," + VARS_TO_STR12(b, c, d, e, f, g, h, i, j, k, l, m)
|
||||
|
||||
#ifdef GGML_USE_SYCL
|
||||
static bool inline _isinf(float f) {
|
||||
@@ -1886,6 +1887,66 @@ struct test_glu_split : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
struct test_swiglu_oai : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne_a;
|
||||
int v; // view (1 : non-contiguous a)
|
||||
float alpha;
|
||||
float limit;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR5(type, ne_a, v, alpha, limit);
|
||||
}
|
||||
|
||||
test_swiglu_oai(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
|
||||
int v = 0,
|
||||
float alpha = 1.702f,
|
||||
float limit = 7.0f)
|
||||
: type(type), ne_a(ne_a), v(v), alpha(alpha), limit(limit) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a;
|
||||
ggml_tensor * b;
|
||||
if (v & 1) {
|
||||
auto ne = ne_a; ne[0] *= 3;
|
||||
a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
|
||||
ggml_set_name(a, "view_of_a");
|
||||
|
||||
b = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||
ggml_set_param(b);
|
||||
ggml_set_name(b, "b");
|
||||
|
||||
b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
|
||||
ggml_set_name(a, "view_of_b");
|
||||
} else {
|
||||
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_set_param(a);
|
||||
ggml_set_name(a, "a");
|
||||
|
||||
b = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||
ggml_set_param(b);
|
||||
ggml_set_name(b, "b");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_swiglu_oai(ctx, a, b, alpha, limit);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
// test extended range of values to check for NaNs in GELU
|
||||
init_tensor_uniform(t, -150.f, 150.f);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_GET_ROWS
|
||||
struct test_get_rows : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -2483,6 +2544,68 @@ struct test_bin_bcast : public test_case {
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ADD_ID
|
||||
struct test_add_id : public test_case {
|
||||
const ggml_type type_a;
|
||||
const ggml_type type_b;
|
||||
const int64_t n_embd;
|
||||
const int64_t n_experts;
|
||||
const int64_t n_experts_used;
|
||||
const int64_t n_token;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR6(type_a, type_b, n_embd, n_experts, n_experts_used, n_token);
|
||||
}
|
||||
|
||||
size_t op_size(ggml_tensor * t) override {
|
||||
return ggml_nbytes(t) + ggml_nbytes(t->src[0]) + ggml_nbytes(t->src[2]);
|
||||
}
|
||||
|
||||
test_add_id(ggml_type type_a = GGML_TYPE_F32,
|
||||
ggml_type type_b = GGML_TYPE_F32,
|
||||
int64_t n_embd = 128,
|
||||
int64_t n_experts = 16,
|
||||
int64_t n_experts_used = 8,
|
||||
int64_t n_token = 10)
|
||||
: type_a(type_a), type_b(type_b), n_embd(n_embd),
|
||||
n_experts(n_experts), n_experts_used(n_experts_used), n_token(n_token) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_3d(ctx, type_a, n_embd, n_experts_used, n_token);
|
||||
ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, n_embd, n_experts);
|
||||
ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_experts, n_token);
|
||||
if (n_experts_used != n_experts) {
|
||||
ids = ggml_view_2d(ctx, ids, n_experts_used, n_token, ids->nb[1], 0);
|
||||
ggml_set_name(ids, "view_of_ids");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_add_id(ctx, a, b, ids);
|
||||
ggml_set_name(out, "out");
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
if (ggml_is_view_op(t->op)) { continue; }
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
// ids
|
||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||
std::vector<int32_t> data(t->ne[0]);
|
||||
for (int i = 0; i < t->ne[0]; i++) {
|
||||
data[i] = i % n_experts;
|
||||
}
|
||||
std::shuffle(data.begin(), data.end(), rng);
|
||||
ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
|
||||
}
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// GGML_OP_ADD1
|
||||
struct test_add1 : public test_case {
|
||||
const ggml_type type;
|
||||
@@ -3122,11 +3245,11 @@ struct test_mul_mat_id : public test_case {
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (t->type == GGML_TYPE_I32) {
|
||||
if (ggml_is_view_op(t->op)) { continue; }
|
||||
std::random_device rd;
|
||||
std::default_random_engine rng(rd());
|
||||
// ids
|
||||
for (int64_t r = 0; r < ggml_nrows(t); r++) {
|
||||
std::vector<int32_t> data(t->ne[0]);
|
||||
@@ -3447,13 +3570,14 @@ struct test_soft_max : public test_case {
|
||||
const ggml_type type;
|
||||
const std::array<int64_t, 4> ne;
|
||||
const bool mask;
|
||||
const bool sinks;
|
||||
const ggml_type m_prec;
|
||||
const std::array<int64_t, 2> nr23; // broadcast only dims 2 and 3
|
||||
const float scale;
|
||||
const float max_bias;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR7(type, ne, mask, m_prec, nr23, scale, max_bias);
|
||||
return VARS_TO_STR8(type, ne, mask, sinks, m_prec, nr23, scale, max_bias);
|
||||
}
|
||||
|
||||
// the 1024 test with bias occasionally fails:
|
||||
@@ -3465,11 +3589,12 @@ struct test_soft_max : public test_case {
|
||||
test_soft_max(ggml_type type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne = {10, 5, 4, 3},
|
||||
bool mask = false,
|
||||
bool sinks = false,
|
||||
ggml_type m_prec = GGML_TYPE_F32,
|
||||
std::array<int64_t, 2> nr23 = {1, 1},
|
||||
float scale = 1.0f,
|
||||
float max_bias = 0.0f)
|
||||
: type(type), ne(ne), mask(mask), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
|
||||
: type(type), ne(ne), mask(mask), sinks(sinks), m_prec(m_prec), nr23(nr23), scale(scale), max_bias(max_bias) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
|
||||
@@ -3482,7 +3607,14 @@ struct test_soft_max : public test_case {
|
||||
ggml_set_name(mask, "mask");
|
||||
}
|
||||
|
||||
ggml_tensor * sinks = nullptr;
|
||||
if (this->sinks) {
|
||||
sinks = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[2]*nr23[0]);
|
||||
ggml_set_name(sinks, "sinks");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
|
||||
ggml_soft_max_add_sinks(out, sinks);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
@@ -4433,6 +4565,7 @@ struct test_flash_attn_ext : public test_case {
|
||||
const int64_t nb; // batch size
|
||||
|
||||
const bool mask; // use mask
|
||||
const bool sinks; // use sinks
|
||||
|
||||
const float max_bias; // ALiBi
|
||||
const float logit_softcap; // Gemma 2
|
||||
@@ -4442,7 +4575,7 @@ struct test_flash_attn_ext : public test_case {
|
||||
std::array<int32_t, 4> permute;
|
||||
|
||||
std::string vars() override {
|
||||
return VARS_TO_STR12(hsk, hsv, nh, nr23, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute);
|
||||
return VARS_TO_STR13(hsk, hsv, nh, nr23, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, permute);
|
||||
}
|
||||
|
||||
double max_nmse_err() override {
|
||||
@@ -4457,9 +4590,9 @@ struct test_flash_attn_ext : public test_case {
|
||||
}
|
||||
|
||||
test_flash_attn_ext(int64_t hsk = 128, int64_t hsv = 128, int64_t nh = 32, std::array<int64_t, 2> nr23 = {1, 1}, int64_t kv = 96, int64_t nb = 8,
|
||||
bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
|
||||
bool mask = true, bool sinks = false, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32,
|
||||
ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3})
|
||||
: hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
|
||||
: hsk(hsk), hsv(hsv), nh(nh), nr23(nr23), kv(kv), nb(nb), mask(mask), sinks(sinks), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV));
|
||||
@@ -4499,13 +4632,31 @@ struct test_flash_attn_ext : public test_case {
|
||||
ggml_set_name(m, "m");
|
||||
}
|
||||
|
||||
ggml_tensor * s = nullptr;
|
||||
if (sinks) {
|
||||
s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, q->ne[2]);
|
||||
ggml_set_name(s, "s");
|
||||
}
|
||||
|
||||
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hsk), max_bias, logit_softcap);
|
||||
ggml_flash_attn_ext_set_prec(out, prec);
|
||||
ggml_flash_attn_ext_add_sinks(out, s);
|
||||
ggml_flash_attn_ext_set_prec (out, prec);
|
||||
ggml_set_name(out, "out");
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void initialize_tensors(ggml_context * ctx) override {
|
||||
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||
if (strcmp(t->name, "s") == 0) {
|
||||
// make the sink values more noticable in order to trigger a test failure when the implementation is wrong
|
||||
init_tensor_uniform(t, -10.0f, 10.0f);
|
||||
} else {
|
||||
init_tensor_uniform(t);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool grad_precise() override {
|
||||
return true;
|
||||
}
|
||||
@@ -5038,6 +5189,7 @@ static const ggml_type all_types[] = {
|
||||
GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
|
||||
GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
|
||||
GGML_TYPE_Q8_0,
|
||||
GGML_TYPE_MXFP4,
|
||||
GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
|
||||
GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
|
||||
GGML_TYPE_Q6_K,
|
||||
@@ -5053,6 +5205,7 @@ static const ggml_type base_types[] = {
|
||||
GGML_TYPE_Q4_0,
|
||||
GGML_TYPE_Q4_1, // for I8MM tests
|
||||
GGML_TYPE_Q4_K,
|
||||
GGML_TYPE_MXFP4, // TODO: or "other"
|
||||
GGML_TYPE_IQ2_XXS
|
||||
};
|
||||
|
||||
@@ -5089,6 +5242,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||
for (int v : {0, 1}) {
|
||||
for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
|
||||
if (op == GGML_GLU_OP_SWIGLU_OAI) {
|
||||
// SWIGLU_OAI is handled separately
|
||||
continue;
|
||||
}
|
||||
|
||||
for (bool swapped : {false, true}) {
|
||||
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));
|
||||
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));
|
||||
@@ -5100,6 +5258,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
|
||||
for (int v : {0, 1}) {
|
||||
for (float alpha : {.5f, 1.702f}) {
|
||||
for (float limit : {2.0f, 7.0f}) {
|
||||
test_cases.emplace_back(new test_swiglu_oai(GGML_TYPE_F32, { 128, 2, 2, 2 }, v, alpha, limit));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
|
||||
for (ggml_type type : all_types) {
|
||||
for (int b : {1, 7}) {
|
||||
@@ -5668,6 +5834,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
}
|
||||
|
||||
// add_id
|
||||
for (ggml_type type_a : {GGML_TYPE_F32}) {
|
||||
for (ggml_type type_b : {GGML_TYPE_F32}) {
|
||||
for (int n_mats : {4, 8}) {
|
||||
for (int n_used : {1, 2, 4}) {
|
||||
for (int n_embd : {32, 129}) {
|
||||
for (int n_token : {1, 32, 129}) {
|
||||
test_cases.emplace_back(new test_add_id(type_a, type_b, n_embd, n_mats, n_used, n_token));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||
test_cases.emplace_back(new test_sqr(type));
|
||||
test_cases.emplace_back(new test_sqrt(type));
|
||||
@@ -5697,38 +5878,40 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
}
|
||||
#endif
|
||||
for (bool mask : {false, true}) {
|
||||
for (float max_bias : {0.0f, 8.0f}) {
|
||||
if (!mask && max_bias > 0.0f) continue;
|
||||
for (float scale : {1.0f, 0.1f}) {
|
||||
for (int64_t ne0 : {16, 1024}) {
|
||||
for (int64_t ne1 : {16, 1024}) {
|
||||
if (mask) {
|
||||
for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {1, 1}, scale, max_bias));
|
||||
for (bool sinks : {false, true}) {
|
||||
for (float max_bias : {0.0f, 8.0f}) {
|
||||
if (!mask && max_bias > 0.0f) continue;
|
||||
for (float scale : {1.0f, 0.1f}) {
|
||||
for (int64_t ne0 : {16, 1024}) {
|
||||
for (int64_t ne1 : {16, 1024}) {
|
||||
if (mask) {
|
||||
for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {1, 1}, scale, max_bias));
|
||||
|
||||
if (ne0 <= 32 && ne1 <= 32) {
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, m_prec, {3, 1}, scale, max_bias));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, m_prec, {2, 3}, scale, max_bias));
|
||||
if (ne0 <= 32 && ne1 <= 32) {
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 3}, mask, sinks, m_prec, {3, 1}, scale, max_bias));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, m_prec, {2, 3}, scale, max_bias));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
/* The precision of mask here doesn't matter as boolean mask is false */
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, sinks, GGML_TYPE_F32, {1, 1}, scale, max_bias));
|
||||
}
|
||||
} else {
|
||||
/* The precision of mask here doesn't matter as boolean mask is false */
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, GGML_TYPE_F32, {1, 1}, scale, max_bias));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, true, false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, false, GGML_TYPE_F16, {1, 1}, 0.1f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f));
|
||||
|
||||
for (float max_bias : {0.0f, 8.0f}) {
|
||||
for (float scale : {1.0f, 0.1f}) {
|
||||
@@ -5832,27 +6015,29 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
if (hsk == 576 && hsv != 512) continue; // DeepSeek MLA
|
||||
|
||||
for (bool mask : { true, false } ) {
|
||||
for (float max_bias : { 0.0f, 8.0f }) {
|
||||
if (!mask && max_bias > 0.0f) continue;
|
||||
for (float logit_softcap : {0.0f, 10.0f}) {
|
||||
if (hsk != 128 && logit_softcap != 0.0f) continue;
|
||||
for (int nh : { 4, }) {
|
||||
for (int nr3 : { 1, 3, }) {
|
||||
if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
|
||||
for (int nr2 : { 1, 4, 16 }) {
|
||||
if (nr2 == 16 && hsk != 128) continue;
|
||||
for (int kv : { 512, 1024, }) {
|
||||
if (nr2 != 1 && kv != 512) continue;
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV));
|
||||
// run fewer test cases permuted
|
||||
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
||||
for (bool sinks : { true, false } ) {
|
||||
for (float max_bias : { 0.0f, 8.0f }) {
|
||||
if (!mask && max_bias > 0.0f) continue;
|
||||
for (float logit_softcap : {0.0f, 10.0f}) {
|
||||
if (hsk != 128 && logit_softcap != 0.0f) continue;
|
||||
for (int nh : { 4, }) {
|
||||
for (int nr3 : { 1, 3, }) {
|
||||
if (hsk > 64 && nr3 > 1) continue; // skip broadcast for large head sizes
|
||||
for (int nr2 : { 1, 4, 16 }) {
|
||||
if (nr2 == 16 && hsk != 128) continue;
|
||||
for (int kv : { 512, 1024, }) {
|
||||
if (nr2 != 1 && kv != 512) continue;
|
||||
for (int nb : { 1, 3, 32, 35, }) {
|
||||
for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) {
|
||||
if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue;
|
||||
for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
|
||||
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV));
|
||||
// run fewer test cases permuted
|
||||
if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(
|
||||
hsk, hsv, nh, {nr2, nr3}, kv, nb, mask, sinks, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3}));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5937,14 +6122,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
|
||||
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
|
||||
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {1024, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 1024, 10, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {256, 256, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {64, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {77, 64, 20, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
|
||||
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 10, 1, 1}));
|
||||
test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
|
||||
@@ -5976,7 +6161,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
for (int kv : { 4096, 8192, 16384, }) {
|
||||
for (int hs : { 64, 128, }) {
|
||||
for (int nr : { 1, 4, }) {
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
|
||||
test_cases.emplace_back(new test_flash_attn_ext(hs, hs, 8, {nr, 1}, kv, 1, true, false, 0, 0, GGML_PREC_F32, GGML_TYPE_F16));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -5988,6 +6173,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
|
||||
|
||||
test_cases.emplace_back(new test_mean(GGML_TYPE_F32, {256, 256, 3, 1}));
|
||||
|
||||
|
||||
for (int n_token : {1, 512}) {
|
||||
test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 128, 4, n_token));
|
||||
test_cases.emplace_back(new test_add_id(GGML_TYPE_F32, GGML_TYPE_F32, 2880, 32, 4, n_token));
|
||||
}
|
||||
|
||||
return test_cases;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user