mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CUDA: determine FA parallel blocks at runtime
This commit is contained in:
		| @@ -259,6 +259,10 @@ static std::string var_to_str(ggml_type type) { | ||||
|     return ggml_type_name(type); | ||||
| } | ||||
|  | ||||
| static std::string var_to_str(ggml_prec prec) { | ||||
|     return prec == GGML_PREC_F32 ? "f32" : "def"; | ||||
| } | ||||
|  | ||||
| static std::string var_to_str(ggml_op_pool pool) { | ||||
|     switch (pool) { | ||||
|         case GGML_OP_POOL_AVG:  return "avg"; | ||||
| @@ -3146,11 +3150,12 @@ struct test_flash_attn_ext : public test_case { | ||||
|     const float max_bias; // ALiBi | ||||
|     const float logit_softcap; // Gemma 2 | ||||
|  | ||||
|     const ggml_prec prec; | ||||
|     const ggml_type type_KV; | ||||
|     std::array<int32_t, 4> permute; | ||||
|  | ||||
|     std::string vars() override { | ||||
|         return VARS_TO_STR10(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, permute); | ||||
|         return VARS_TO_STR11(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, permute); | ||||
|     } | ||||
|  | ||||
|     double max_nmse_err() override { | ||||
| @@ -3165,9 +3170,9 @@ struct test_flash_attn_ext : public test_case { | ||||
|     } | ||||
|  | ||||
|     test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t nr = 1, int64_t kv = 96, int64_t nb = 8, | ||||
|                         bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_type type_KV = GGML_TYPE_F16, | ||||
|                         std::array<int32_t, 4> permute = {0, 1, 2, 3}) | ||||
|         : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), type_KV(type_KV), permute(permute) {} | ||||
|                         bool mask = true, float max_bias = 0.0f, float logit_softcap = 0.0f, ggml_prec prec = GGML_PREC_F32, | ||||
|                         ggml_type type_KV = GGML_TYPE_F16, std::array<int32_t, 4> permute = {0, 1, 2, 3}) | ||||
|         : hs(hs), nh(nh), nr(nr), kv(kv), nb(nb), mask(mask), max_bias(max_bias), logit_softcap(logit_softcap), prec(prec), type_KV(type_KV), permute(permute) {} | ||||
|  | ||||
|     ggml_tensor * build_graph(ggml_context * ctx) override { | ||||
|         const int64_t hs_padded = GGML_PAD(hs, ggml_blck_size(type_KV)); | ||||
| @@ -3201,6 +3206,7 @@ struct test_flash_attn_ext : public test_case { | ||||
|         } | ||||
|  | ||||
|         ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, m, 1.0f/sqrtf(hs), max_bias, logit_softcap); | ||||
|         ggml_flash_attn_ext_set_prec(out, prec); | ||||
|         ggml_set_name(out, "out"); | ||||
|  | ||||
|         return out; | ||||
| @@ -4308,11 +4314,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() { | ||||
|                             for (int kv : { 512, 1024, }) { | ||||
|                                 if (nr != 1 && kv != 512) continue; | ||||
|                                 for (int nb : { 1, 3, 32, 35, }) { | ||||
|                                     for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { | ||||
|                                         test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV)); | ||||
|                                         // run fewer test cases permuted | ||||
|                                         if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) { | ||||
|                                             test_cases.emplace_back(new test_flash_attn_ext(hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, type_KV, {0, 2, 1, 3})); | ||||
|                                     for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { | ||||
|                                         if (hs != 128 && prec == GGML_PREC_DEFAULT) continue; | ||||
|                                         for (ggml_type type_KV : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { | ||||
|                                             test_cases.emplace_back(new test_flash_attn_ext( | ||||
|                                                 hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV)); | ||||
|                                             // run fewer test cases permuted | ||||
|                                             if (mask == true && max_bias == 0.0f && logit_softcap == 0 && kv == 512) { | ||||
|                                                 test_cases.emplace_back(new test_flash_attn_ext( | ||||
|                                                     hs, nh, nr, kv, nb, mask, max_bias, logit_softcap, prec, type_KV, {0, 2, 1, 3})); | ||||
|                                             } | ||||
|                                         } | ||||
|                                     } | ||||
|                                 } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler