mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	cont : gate the ggml_set_rows usage with env var
ggml-ci
This commit is contained in:
		@@ -130,6 +130,13 @@ llama_kv_cache_unified::llama_kv_cache_unified(
 | 
			
		||||
 | 
			
		||||
    const char * LLAMA_KV_CACHE_DEBUG = getenv("LLAMA_KV_CACHE_DEBUG");
 | 
			
		||||
    debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0;
 | 
			
		||||
 | 
			
		||||
    const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
 | 
			
		||||
    supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
 | 
			
		||||
 | 
			
		||||
    if (!supports_set_rows) {
 | 
			
		||||
        LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void llama_kv_cache_unified::clear(bool data) {
 | 
			
		||||
@@ -751,15 +758,21 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_
 | 
			
		||||
 | 
			
		||||
    auto * k = layers[ikv].k;
 | 
			
		||||
 | 
			
		||||
    const int64_t n_embd_k_gqa = k->ne[0];
 | 
			
		||||
    const int64_t n_tokens = k_cur->ne[2];
 | 
			
		||||
 | 
			
		||||
    if (kv_idxs) {
 | 
			
		||||
        return ggml_set_rows(ctx, k, ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens), kv_idxs);
 | 
			
		||||
    k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens);
 | 
			
		||||
 | 
			
		||||
    if (kv_idxs && supports_set_rows) {
 | 
			
		||||
        return ggml_set_rows(ctx, k, k_cur, kv_idxs);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // TODO: fallback to old ggml_cpy() method for backwards compatibility
 | 
			
		||||
    //       will be removed when ggml_set_rows() is adopted by all backends
 | 
			
		||||
 | 
			
		||||
    ggml_tensor * k_view = ggml_view_1d(ctx, k,
 | 
			
		||||
            n_tokens*hparams.n_embd_k_gqa(il),
 | 
			
		||||
            ggml_row_size(k->type, hparams.n_embd_k_gqa(il))*head_cur);
 | 
			
		||||
            n_tokens*n_embd_k_gqa,
 | 
			
		||||
            ggml_row_size(k->type, n_embd_k_gqa)*head_cur);
 | 
			
		||||
 | 
			
		||||
    return ggml_cpy(ctx, k_cur, k_view);
 | 
			
		||||
}
 | 
			
		||||
@@ -769,37 +782,43 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
 | 
			
		||||
 | 
			
		||||
    auto * v = layers[ikv].v;
 | 
			
		||||
 | 
			
		||||
    const int64_t n_embd_v_gqa = v->ne[0];
 | 
			
		||||
    const int64_t n_tokens = v_cur->ne[2];
 | 
			
		||||
 | 
			
		||||
    v_cur = ggml_reshape_2d(ctx, v_cur, hparams.n_embd_v_gqa(il), n_tokens);
 | 
			
		||||
    v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens);
 | 
			
		||||
 | 
			
		||||
    if (kv_idxs && supports_set_rows) {
 | 
			
		||||
        if (!v_trans) {
 | 
			
		||||
            return ggml_set_rows(ctx, v, v_cur, kv_idxs);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // note: the V cache is transposed when not using flash attention
 | 
			
		||||
        v_cur = ggml_transpose(ctx, v_cur);
 | 
			
		||||
 | 
			
		||||
        // the row becomes a single element and we repeat the KV indices d_head times
 | 
			
		||||
        ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
 | 
			
		||||
 | 
			
		||||
        v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
 | 
			
		||||
 | 
			
		||||
        // TODO: this repeat can be avoided if ggml_set_rows() supports broadcast
 | 
			
		||||
        kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1);
 | 
			
		||||
 | 
			
		||||
        return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // TODO: fallback to old ggml_cpy() method for backwards compatibility
 | 
			
		||||
    //       will be removed when ggml_set_rows() is adopted by all backends
 | 
			
		||||
 | 
			
		||||
    ggml_tensor * v_view = nullptr;
 | 
			
		||||
 | 
			
		||||
    if (!v_trans) {
 | 
			
		||||
        if (kv_idxs) {
 | 
			
		||||
            return ggml_set_rows(ctx, v, v_cur, kv_idxs);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        v_view = ggml_view_1d(ctx, v,
 | 
			
		||||
                n_tokens*hparams.n_embd_v_gqa(il),
 | 
			
		||||
                ggml_row_size(v->type, hparams.n_embd_v_gqa(il))*head_cur);
 | 
			
		||||
                n_tokens*n_embd_v_gqa,
 | 
			
		||||
                ggml_row_size(v->type, n_embd_v_gqa)*head_cur);
 | 
			
		||||
    } else {
 | 
			
		||||
        v_cur = ggml_transpose(ctx, v_cur);
 | 
			
		||||
 | 
			
		||||
        // note: the V cache is transposed when not using flash attention
 | 
			
		||||
        if (kv_idxs) {
 | 
			
		||||
            // the row becomes a single element and we repeat the KV indices d_head times
 | 
			
		||||
            // TODO: this seems not very optimal - can we do something better?
 | 
			
		||||
            v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]);
 | 
			
		||||
 | 
			
		||||
            v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]);
 | 
			
		||||
 | 
			
		||||
            kv_idxs = ggml_repeat_4d(ctx, kv_idxs, v_cur->ne[1], v_cur->ne[2], 1, 1);
 | 
			
		||||
 | 
			
		||||
            return ggml_set_rows(ctx, v_view, v_cur, kv_idxs);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        v_view = ggml_view_2d(ctx, v, n_tokens, hparams.n_embd_v_gqa(il),
 | 
			
		||||
        v_view = ggml_view_2d(ctx, v, n_tokens, n_embd_v_gqa,
 | 
			
		||||
                (v->ne[1])*ggml_element_size(v),
 | 
			
		||||
                (head_cur)*ggml_element_size(v));
 | 
			
		||||
    }
 | 
			
		||||
@@ -808,6 +827,10 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void llama_kv_cache_unified::set_input_kv_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, uint32_t head_cur) const {
 | 
			
		||||
    if (!supports_set_rows) {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const uint32_t n_tokens = ubatch->n_tokens;
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer));
 | 
			
		||||
 
 | 
			
		||||
@@ -158,8 +158,13 @@ private:
 | 
			
		||||
    // SWA
 | 
			
		||||
    const uint32_t n_swa = 0;
 | 
			
		||||
 | 
			
		||||
    // env: LLAMA_KV_CACHE_DEBUG
 | 
			
		||||
    int debug = 0;
 | 
			
		||||
 | 
			
		||||
    // env: LLAMA_SET_ROWS (temporary)
 | 
			
		||||
    // ref: https://github.com/ggml-org/llama.cpp/pull/14285
 | 
			
		||||
    int supports_set_rows = false;
 | 
			
		||||
 | 
			
		||||
    const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE;
 | 
			
		||||
 | 
			
		||||
    std::vector<ggml_context_ptr>        ctxs;
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user