mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-29 08:41:22 +00:00 
			
		
		
		
	tests : add non-cont K,V FA tests
ggml-ci
This commit is contained in:
		| @@ -4366,26 +4366,32 @@ struct test_flash_attn_ext : public test_case { | ||||
|         const int64_t hsk_padded = GGML_PAD(hsk, ggml_blck_size(type_KV)); | ||||
|         const int64_t hsv_padded = GGML_PAD(hsv, ggml_blck_size(type_KV)); | ||||
|  | ||||
|         auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3) -> ggml_tensor * { | ||||
|         auto const &create_permuted = [&](ggml_type type, int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, bool is_view) -> ggml_tensor * { | ||||
|             int64_t ne[4] = {ne0, ne1, ne2, ne3}; | ||||
|             int64_t ne_perm[4]; | ||||
|             for (int i = 0; i < 4; ++i) { | ||||
|                 ne_perm[permute[i]] = ne[i]; | ||||
|             } | ||||
|             ggml_tensor * t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]); | ||||
|             ggml_tensor * t; | ||||
|             if (is_view) { | ||||
|                 ggml_tensor * t0 = ggml_new_tensor_4d(ctx, type, ne_perm[0], 2*ne_perm[1], ne_perm[2], ne_perm[3]); | ||||
|                 t = ggml_view_4d(ctx, t0, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3], t0->nb[1], t0->nb[2], t0->nb[3], 0); | ||||
|             } else { | ||||
|                 t = ggml_new_tensor_4d(ctx, type, ne_perm[0], ne_perm[1], ne_perm[2], ne_perm[3]); | ||||
|             } | ||||
|             if (permute != std::array<int32_t, 4>{0, 1, 2, 3}) { | ||||
|                 t = ggml_permute(ctx, t, permute[0], permute[1], permute[2], permute[3]); | ||||
|             } | ||||
|             return t; | ||||
|         }; | ||||
|  | ||||
|         ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1]); | ||||
|         ggml_tensor * q = create_permuted(GGML_TYPE_F32, hsk_padded, nb, nh*nr23[0], nr23[1], false); | ||||
|         ggml_set_name(q, "q"); | ||||
|  | ||||
|         ggml_tensor * k = create_permuted(type_KV,       hsk_padded, kv, nh,         nr23[1]); | ||||
|         ggml_tensor * k = create_permuted(type_KV,       hsk_padded, kv, nh,         nr23[1], true); // the K tensor is usually a view of the K cache | ||||
|         ggml_set_name(k, "k"); | ||||
|  | ||||
|         ggml_tensor * v = create_permuted(type_KV,       hsv_padded, kv, nh,         nr23[1]); | ||||
|         ggml_tensor * v = create_permuted(type_KV,       hsv_padded, kv, nh,         nr23[1], true); // the V tensor is usually a view of the V cache | ||||
|         ggml_set_name(v, "v"); | ||||
|  | ||||
|         ggml_tensor * m = nullptr; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov