mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
cuda : fix multi-seq, quantized FA
ggml-ci
This commit is contained in:
@@ -745,10 +745,14 @@ void launch_fattn(
|
||||
size_t nb23 = V ? V->nb[3] : nb13;
|
||||
|
||||
if (need_f16_K && K->type != GGML_TYPE_F16) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(K));
|
||||
K_f16.alloc(ggml_nelements(K));
|
||||
const int64_t n_seq = K->ne[3];
|
||||
const int64_t n_eps = (K->nb[3]/ggml_type_size(K->type))*ggml_blck_size(K->type); // elements per sequence
|
||||
|
||||
K_f16.alloc(n_seq*n_eps);
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type);
|
||||
to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream);
|
||||
for (int s = 0; s < n_seq; ++s) {
|
||||
to_fp16(K_data + s*K->nb[3], K_f16.ptr + s*n_eps, n_eps, main_stream);
|
||||
}
|
||||
K_data = (char *) K_f16.ptr;
|
||||
|
||||
const size_t bs = ggml_blck_size(K->type);
|
||||
@@ -760,10 +764,14 @@ void launch_fattn(
|
||||
}
|
||||
|
||||
if (V && need_f16_V && V->type != GGML_TYPE_F16) {
|
||||
GGML_ASSERT(ggml_is_contiguously_allocated(V));
|
||||
V_f16.alloc(ggml_nelements(V));
|
||||
const int64_t n_seq = V->ne[3];
|
||||
const int64_t n_eps = (V->nb[3]/ggml_type_size(V->type))*ggml_blck_size(V->type);
|
||||
|
||||
V_f16.alloc(n_seq*n_eps);
|
||||
to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type);
|
||||
to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream);
|
||||
for (int s = 0; s < n_seq; ++s) {
|
||||
to_fp16(V_data + s*V->nb[3], V_f16.ptr + s*n_eps, n_eps, main_stream);
|
||||
}
|
||||
V_data = (char *) V_f16.ptr;
|
||||
|
||||
const size_t bs = ggml_blck_size(V->type);
|
||||
|
||||
@@ -5525,6 +5525,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
test_cases.emplace_back(new test_timestep_embedding());
|
||||
test_cases.emplace_back(new test_leaky_relu());
|
||||
|
||||
test_cases.emplace_back(new test_flash_attn_ext(128, 128, 4, {1, 3}, 512, 128, true, 0.0f, 0.0f, GGML_PREC_DEFAULT, GGML_TYPE_Q8_0));
|
||||
|
||||
for (int hsk : { 64, 80, 128, 192, 256, 576 }) {
|
||||
for (int hsv : { 64, 80, 128, 192, 256, 512 }) {
|
||||
if (hsk != 192 && hsk != 576 && hsk != hsv) continue;
|
||||
|
||||
Reference in New Issue
Block a user