mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CUDA: add fused rms norm (#14800)
This commit is contained in:
		| @@ -2641,6 +2641,7 @@ struct test_rms_norm_mul_add : public test_case { | ||||
|     const ggml_type type; | ||||
|     const std::array<int64_t, 4> ne; | ||||
|     const float eps; | ||||
|     const bool broadcast; | ||||
|  | ||||
|     std::string op_desc(ggml_tensor * t) override { | ||||
|         GGML_UNUSED(t); | ||||
| @@ -2650,18 +2651,21 @@ struct test_rms_norm_mul_add : public test_case { | ||||
|     bool run_whole_graph() override { return true; } | ||||
|  | ||||
|     std::string vars() override { | ||||
|         return VARS_TO_STR3(type, ne, eps); | ||||
|         return VARS_TO_STR4(type, ne, eps, broadcast); | ||||
|     } | ||||
|  | ||||
|     test_rms_norm_mul_add(ggml_type type = GGML_TYPE_F32, | ||||
|             std::array<int64_t, 4> ne = {64, 5, 4, 3}, | ||||
|             float eps = 1e-6f) | ||||
|         : type(type), ne(ne), eps(eps) {} | ||||
|             float eps = 1e-6f, bool broadcast = false) | ||||
|         : type(type), ne(ne), eps(eps), broadcast(broadcast) {} | ||||
|  | ||||
|     ggml_tensor * build_graph(ggml_context * ctx) override { | ||||
|         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); | ||||
|         std::array<int64_t, 4> broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4}; | ||||
|  | ||||
|         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data()); | ||||
|         ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); | ||||
|         ggml_tensor * c = ggml_new_tensor(ctx, type, 4, ne.data()); | ||||
|  | ||||
|         ggml_set_param(a); | ||||
|         ggml_set_name(a, "a"); | ||||
|         ggml_set_param(b); | ||||
| @@ -5354,6 +5358,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() { | ||||
|     } | ||||
|     for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) { | ||||
|         test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); | ||||
|         test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); | ||||
|     } | ||||
|  | ||||
|     test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f)); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Aman Gupta
					Aman Gupta