Merge branch 'master' into compilade/mamba2

This commit is contained in:
Francis Couture-Harpin
2025-07-02 02:39:04 -04:00
157 changed files with 7288 additions and 3177 deletions

View File

@@ -382,6 +382,8 @@ struct test_case {
return 0;
}
virtual bool run_whole_graph() { return false; }
ggml_cgraph * gf = nullptr;
ggml_cgraph * gb = nullptr;
@@ -574,7 +576,7 @@ struct test_case {
GGML_UNUSED(index);
};
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud, run_whole_graph() ? out : nullptr);
if (!cmp_ok) {
printf("compare failed ");
@@ -1104,6 +1106,107 @@ struct test_unary : public test_case {
};
// GGML_OP_GLU
struct test_glu : public test_case {
const ggml_glu_op op;
const ggml_type type;
const std::array<int64_t, 4> ne_a;
int v; // view (1 : non-contiguous a)
bool swapped;
std::string vars() override {
return VARS_TO_STR4(type, ne_a, v, swapped);
}
test_glu(ggml_glu_op op,
ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
int v = 0,
bool swapped = false)
: op(op), type(type), ne_a(ne_a), v(v), swapped(swapped) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a;
if (v & 1) {
auto ne = ne_a; ne[0] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view_of_a");
} else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_name(a, "a");
}
ggml_tensor * out = ggml_glu(ctx, a, op, swapped);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
// test extended range of values to check for NaNs in GELU
init_tensor_uniform(t, -150.f, 150.f);
}
}
};
struct test_glu_split : public test_case {
const ggml_glu_op op;
const ggml_type type;
const std::array<int64_t, 4> ne_a;
int v; // view (1 : non-contiguous a)
std::string vars() override {
return VARS_TO_STR3(type, ne_a, v) + ",split";
}
test_glu_split(ggml_glu_op op,
ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
int v = 0)
: op(op), type(type), ne_a(ne_a), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a;
ggml_tensor * b;
if (v & 1) {
auto ne = ne_a; ne[0] *= 3;
a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
ggml_set_name(a, "view_of_a");
b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(b, "b");
b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
ggml_set_name(a, "view_of_b");
} else {
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_name(a, "a");
b = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_set_name(b, "b");
}
ggml_tensor * out = ggml_glu_split(ctx, a, b, op);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
// test extended range of values to check for NaNs in GELU
init_tensor_uniform(t, -150.f, 150.f);
}
}
};
// GGML_OP_GET_ROWS
struct test_get_rows : public test_case {
const ggml_type type;
@@ -1213,6 +1316,76 @@ struct test_get_rows_back : public test_case {
}
};
// GGML_OP_SET_ROWS
struct test_set_rows : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const std::array<int, 2> nr23; // broadcast only dims 2 and 3
const int r; // rows to set
const bool v; // view (non-contiguous src1)
std::string vars() override {
return VARS_TO_STR5(type, ne, nr23, r, v);
}
test_set_rows(ggml_type type,
std::array<int64_t, 4> ne,
std::array<int, 2> nr23,
int r, bool v = false)
: type(type), ne(ne), nr23(nr23), r(r), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * dst = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
ggml_set_name(dst, "dst");
ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], r, ne[2]*nr23[0], ne[3]*nr23[1]);
ggml_set_name(src, "src");
ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, r, ne[2], ne[3]);
ggml_set_name(row_idxs, "row_idxs");
if (v) {
src = ggml_view_4d(ctx, src, ne[0], r/2, ne[2]*nr23[0], ne[3]*nr23[1], src->nb[1], src->nb[2], src->nb[3], 0);
row_idxs = ggml_view_3d(ctx, row_idxs, r/2, ne[2], ne[3], row_idxs->nb[1], row_idxs->nb[2], 0);
ggml_set_name(row_idxs, "view_of_rows");
}
ggml_tensor * out = ggml_set_rows(ctx, dst, src, row_idxs);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
std::random_device rd;
std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I64) {
if (ggml_is_view_op(t->op)) {
continue;
}
for (int i2 = 0; i2 < t->ne[2]; i2++) {
for (int i1 = 0; i1 < t->ne[1]; i1++) {
// generate a shuffled subset of row indices
std::vector<int64_t> data(ne[1]);
for (int i = 0; i < ne[1]; i++) {
data[i] = i;
}
std::shuffle(data.begin(), data.end(), rng);
data.resize(t->ne[0]);
const size_t offs = i1*t->nb[1] + i2*t->nb[2];
ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
}
}
} else {
init_tensor_uniform(t);
}
}
}
};
// GGML_OP_ARGMAX
struct test_argmax : public test_case {
const ggml_type type;
@@ -1826,6 +1999,63 @@ struct test_rms_norm_back : public test_case {
}
};
// GGML_OP_RMS_NORM + GGML_OP_MUL
struct test_rms_norm_mul : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const float eps;
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "RMS_NORM_MUL";
}
bool run_whole_graph() override { return true; }
std::string vars() override {
return VARS_TO_STR3(type, ne, eps);
}
test_rms_norm_mul(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {64, 5, 4, 3},
float eps = 1e-6f)
: type(type), ne(ne), eps(eps) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_set_param(b);
ggml_set_name(b, "b");
// Use a and b early, so we don't end up with an OP_NONE between rms_norm and mul
a = ggml_add(ctx, a, b);
ggml_tensor * out = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
ggml_set_name(out, "out");
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
init_tensor_uniform(t, -10.f, 10.f);
}
}
double max_nmse_err() override {
return 1e-6;
}
float grad_eps() override {
return 1.0f;
}
bool grad_precise() override {
return true;
}
};
// GGML_OP_SSM_CONV
struct test_ssm_conv : public test_case {
const ggml_type type;
@@ -3096,28 +3326,28 @@ struct test_upscale : public test_case {
}
};
// GGML_OP_UPSCALE (ext)
struct test_upscale_ext : public test_case {
// GGML_OP_UPSCALE (via ggml_interpolate)
struct test_interpolate : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
const std::array<int64_t, 4> ne_tgt;
const ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST;
const uint32_t mode = GGML_SCALE_MODE_NEAREST;
std::string vars() override {
return VARS_TO_STR4(type, ne, ne_tgt, mode);
}
test_upscale_ext(ggml_type type = GGML_TYPE_F32,
test_interpolate(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {2, 5, 7, 11},
std::array<int64_t, 4> ne_tgt = {5, 7, 11, 13},
ggml_scale_mode mode = GGML_SCALE_MODE_NEAREST)
uint32_t mode = GGML_SCALE_MODE_NEAREST)
: type(type), ne(ne), ne_tgt(ne_tgt), mode(mode) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");
ggml_tensor * out = ggml_upscale_ext(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode);
ggml_tensor * out = ggml_interpolate(ctx, a, ne_tgt[0], ne_tgt[1],ne_tgt[2], ne_tgt[3], mode);
ggml_set_name(out, "out");
return out;
@@ -3696,6 +3926,7 @@ struct test_llama : public test_llm {
static constexpr float attn_factor = 1.0f;
static constexpr float beta_fast = 32.0f;
static constexpr float beta_slow = 1.0f;
bool fused;
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
@@ -3711,7 +3942,9 @@ struct test_llama : public test_llm {
return 2e-3;
}
test_llama(int n_tokens = 1)
bool run_whole_graph() override { return fused; }
test_llama(int n_tokens = 1, bool fused = false)
: test_llm({
/*n_vocab =*/ 32000,
/*n_embd =*/ 3200,
@@ -3723,7 +3956,9 @@ struct test_llama : public test_llm {
/*f_norm_eps =*/ 0.f,
/*f_norm_rms_eps =*/ 1e-5f,
/*n_tokens =*/ n_tokens,
}) {
})
, fused(fused)
{
}
ggml_tensor * build_graph(ggml_context * ctx) override {
@@ -3990,6 +4225,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}
// glu ops
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
for (int v : {0, 1}) {
for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
for (bool swapped : {false, true}) {
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));
}
test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v));
test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v));
}
}
}
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
for (ggml_type type : all_types) {
for (int b : {1, 7}) {
@@ -4014,6 +4264,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
}
test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, { 1, 8, 1, 3 }, { 1, 1 }, 2, false));
for (ggml_type type : all_types) {
for (int b : {1, 7}) {
for (bool v : {false, true}) {
test_cases.emplace_back(new test_set_rows(type, { 256, 5, b, 3 }, { 1, 1, }, 1, v));
test_cases.emplace_back(new test_set_rows(type, { 256, 11, 1, b }, { 2, 3, }, 7, v));
test_cases.emplace_back(new test_set_rows(type, { 3*ggml_blck_size(type), 3, b, 1 }, { 2, 3, }, 2, v));
if (ggml_blck_size(type) == 1) {
test_cases.emplace_back(new test_set_rows(type, { 31, 3, b, 1 }, { 2, 3, }, 2, v));
test_cases.emplace_back(new test_set_rows(type, { 33, 5, 1, b }, { 2, 3, }, 1, v));
}
}
}
}
for (ggml_type type_input : {GGML_TYPE_F32}) {
for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
for (int k0 : {1, 3}) {
@@ -4249,6 +4516,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
@@ -4283,39 +4553,45 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
#if 1
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
// test cases without permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 2}));
std::vector<int> ks = { 256 };
if (ggml_blck_size(type_a) == 1) {
ks.push_back(4);
}
for (auto k : ks) {
// test cases without permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {1, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {1, 1}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 1}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {1, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {3, 2}, {2, 2}));
// test cases with permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
// test cases with permutation
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {2, 3}, {1, 1}, {0, 3, 2, 1}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 1, 3, 2}));
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {2, 3}, {1, 1}, {0, 3, 2, 1}));
}
// test cases with large ne00/ne10 to cover stream-k fixup
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 1024, {3, 2}, {1, 1}));
@@ -4363,8 +4639,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (auto nr : {1,4}) {
for (uint32_t m = 0; m < 2; ++m) {
for (uint32_t k = 0; k < 2; ++k) {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
for (ggml_type type: {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 1056 + m, 1, 128 + k, {bs, 1}, {nr, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(type, GGML_TYPE_F32, 128 + m, 1, 1056 + k, {bs, 1}, {nr, 1}, {0, 1, 2, 3}, true));
}
}
}
}
@@ -4376,6 +4654,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// this case is verified (pass) in Intel(R) Data Center GPU Max 1100 (sycl backend) and NV A30 (cuda backend)
// test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F16, 512, 262144, 9216, {1, 1}, {1, 1}));
// test large experts*tokens
for (bool b : {false, true}) {
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
}
for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
for (int n_mats : {4, 8}) {
@@ -4552,8 +4835,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) {
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode));
test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true));
test_cases.emplace_back(new test_upscale_ext(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode));
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode));
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {5, 7, 11, 13}, {2, 5, 7, 11}, mode));
}
test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS));
test_cases.emplace_back(new test_sum());
test_cases.emplace_back(new test_sum_rows());
@@ -4613,8 +4898,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
// these tests are disabled to save execution time, but they can be handy for debugging
#if 0
// these tests are disabled to save execution time, sbut they can be handy for debugging
test_cases.emplace_back(new test_llama(2, true));
test_cases.emplace_back(new test_llama(1));
test_cases.emplace_back(new test_llama(2));
test_cases.emplace_back(new test_falcon(1));