tests: Fix OPT_STEP_SGD test-backend-ops

This commit is contained in:
0cc4m
2025-07-20 07:22:28 +00:00
parent 9d0312425e
commit 2ec70c964b
2 changed files with 7 additions and 10 deletions

View File

@@ -1006,8 +1006,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
"OPT_STEP_ADAMW",
"GLU",
"OPT_STEP_SGD",
"GLU",
};
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
@@ -1106,8 +1107,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)",
"adamw(x)",
"glu(x)",
"sgd(x)",
"glu(x)",
};
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

View File

@@ -5110,7 +5110,7 @@ static const ggml_type other_types[] = {
};
// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval(bool test_sgd = true) {
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
std::vector<std::unique_ptr<test_case>> test_cases;
std::default_random_engine rng(0);
@@ -5912,7 +5912,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval(bool test_sg
test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
if (test_sgd)
test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, { 10, 5, 4, 3 }));
#if 0
@@ -6051,10 +6050,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
};
char const* name = ggml_backend_name(backend);
bool const vulkan = strstr(name, "ulkan");
bool const sgd = !vulkan;
if (mode == MODE_TEST) {
auto test_cases = make_test_cases_eval();
filter_test_cases(test_cases, params_filter);
@@ -6080,7 +6075,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
}
if (mode == MODE_GRAD) {
auto test_cases = make_test_cases_eval(sgd);
auto test_cases = make_test_cases_eval();
filter_test_cases(test_cases, params_filter);
size_t n_ok = 0;
for (auto & test : test_cases) {