mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	tests: Fix OPT_STEP_SGD test-backend-ops
This commit is contained in:
		@@ -1006,8 +1006,9 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
 | 
				
			|||||||
    "CROSS_ENTROPY_LOSS",
 | 
					    "CROSS_ENTROPY_LOSS",
 | 
				
			||||||
    "CROSS_ENTROPY_LOSS_BACK",
 | 
					    "CROSS_ENTROPY_LOSS_BACK",
 | 
				
			||||||
    "OPT_STEP_ADAMW",
 | 
					    "OPT_STEP_ADAMW",
 | 
				
			||||||
    "GLU",
 | 
					 | 
				
			||||||
    "OPT_STEP_SGD",
 | 
					    "OPT_STEP_SGD",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    "GLU",
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
 | 
					static_assert(GGML_OP_COUNT == 87, "GGML_OP_COUNT != 87");
 | 
				
			||||||
@@ -1106,8 +1107,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
 | 
				
			|||||||
    "cross_entropy_loss(x,y)",
 | 
					    "cross_entropy_loss(x,y)",
 | 
				
			||||||
    "cross_entropy_loss_back(x,y)",
 | 
					    "cross_entropy_loss_back(x,y)",
 | 
				
			||||||
    "adamw(x)",
 | 
					    "adamw(x)",
 | 
				
			||||||
    "glu(x)",
 | 
					 | 
				
			||||||
    "sgd(x)",
 | 
					    "sgd(x)",
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    "glu(x)",
 | 
				
			||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 | 
					static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -5110,7 +5110,7 @@ static const ggml_type other_types[] = {
 | 
				
			|||||||
};
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
 | 
					// Test cases for evaluation: should try to cover edge cases while using small input sizes to keep the runtime low
 | 
				
			||||||
static std::vector<std::unique_ptr<test_case>> make_test_cases_eval(bool test_sgd = true) {
 | 
					static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
 | 
				
			||||||
    std::vector<std::unique_ptr<test_case>> test_cases;
 | 
					    std::vector<std::unique_ptr<test_case>> test_cases;
 | 
				
			||||||
    std::default_random_engine rng(0);
 | 
					    std::default_random_engine rng(0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -5912,7 +5912,6 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval(bool test_sg
 | 
				
			|||||||
    test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));
 | 
					    test_cases.emplace_back(new test_cross_entropy_loss_back(GGML_TYPE_F32, {30000, 1, 1, 1}));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
 | 
					    test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
 | 
				
			||||||
    if (test_sgd)
 | 
					 | 
				
			||||||
    test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, { 10, 5, 4, 3 }));
 | 
					    test_cases.emplace_back(new test_opt_step_sgd(GGML_TYPE_F32, { 10, 5, 4, 3 }));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#if 0
 | 
					#if 0
 | 
				
			||||||
@@ -6051,10 +6050,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    char const* name = ggml_backend_name(backend);
 | 
					 | 
				
			||||||
    bool const vulkan = strstr(name, "ulkan");
 | 
					 | 
				
			||||||
    bool const sgd = !vulkan;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (mode == MODE_TEST) {
 | 
					    if (mode == MODE_TEST) {
 | 
				
			||||||
        auto test_cases = make_test_cases_eval();
 | 
					        auto test_cases = make_test_cases_eval();
 | 
				
			||||||
        filter_test_cases(test_cases, params_filter);
 | 
					        filter_test_cases(test_cases, params_filter);
 | 
				
			||||||
@@ -6080,7 +6075,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 | 
				
			|||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (mode == MODE_GRAD) {
 | 
					    if (mode == MODE_GRAD) {
 | 
				
			||||||
        auto test_cases = make_test_cases_eval(sgd);
 | 
					        auto test_cases = make_test_cases_eval();
 | 
				
			||||||
        filter_test_cases(test_cases, params_filter);
 | 
					        filter_test_cases(test_cases, params_filter);
 | 
				
			||||||
        size_t n_ok = 0;
 | 
					        size_t n_ok = 0;
 | 
				
			||||||
        for (auto & test : test_cases) {
 | 
					        for (auto & test : test_cases) {
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user