mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-02 09:12:03 +00:00
tests: Fix OPT_STEP_SGD test-backend-ops
This commit is contained in:
@@ -1006,8 +1006,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"CROSS_ENTROPY_LOSS",
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
"OPT_STEP_ADAMW",
|
||||
"GLU",
|
||||
"OPT_STEP_SGD",
|
||||
|
||||
"GLU",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
|
||||
@@ -1106,8 +1107,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"cross_entropy_loss(x,y)",
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
"adamw(x)",
|
||||
"glu(x)",
|
||||
"sgd(x)",
|
||||
|
||||
"glu(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -5110,7 +5110,7 @@ static const ggml_type other_types[] = {
|
||||
};
|
||||
|
||||
// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
|
||||
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval(bool test_sgd = true) {
|
||||
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
std::vector<std::unique_ptr<test_case>> test_cases;
|
||||
std::default_random_engine rng(0);
|
||||
|
||||
@@ -5912,7 +5912,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval(bool test_sg
|
||||
test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));
|
||||
|
||||
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
|
||||
if (test_sgd)
|
||||
test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, { 10, 5, 4, 3 }));
|
||||
|
||||
#if 0
|
||||
@@ -6051,10 +6050,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
}
|
||||
};
|
||||
|
||||
char const* name = ggml_backend_name(backend);
|
||||
bool const vulkan = strstr(name, "ulkan");
|
||||
bool const sgd = !vulkan;
|
||||
|
||||
if (mode == MODE_TEST) {
|
||||
auto test_cases = make_test_cases_eval();
|
||||
filter_test_cases(test_cases, params_filter);
|
||||
@@ -6080,7 +6075,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||
}
|
||||
|
||||
if (mode == MODE_GRAD) {
|
||||
auto test_cases = make_test_cases_eval(sgd);
|
||||
auto test_cases = make_test_cases_eval();
|
||||
filter_test_cases(test_cases, params_filter);
|
||||
size_t n_ok = 0;
|
||||
for (auto & test : test_cases) {
|
||||
|
||||
Reference in New Issue
Block a user