mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-04 09:32:00 +00:00 
			
		
		
		
	add input embeddings handling
This commit is contained in:
		
							
								
								
									
										329
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										329
									
								
								llama.cpp
									
									
									
									
									
								
							@@ -3424,6 +3424,331 @@ static struct ggml_cgraph * llm_build_falcon(
 | 
			
		||||
    return gf;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static struct ggml_cgraph * llm_build_starcoder(
 | 
			
		||||
         llama_context & lctx,
 | 
			
		||||
     const llama_token * tokens,
 | 
			
		||||
           const float * embd,
 | 
			
		||||
                   int   n_tokens,
 | 
			
		||||
                   int   n_past) {
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT((!tokens && embd) || (tokens && !embd)); // NOLINT
 | 
			
		||||
 | 
			
		||||
    const int N = n_tokens;
 | 
			
		||||
 | 
			
		||||
    const auto & model   = lctx.model;
 | 
			
		||||
    const auto & hparams = model.hparams;
 | 
			
		||||
 | 
			
		||||
    const auto & kv_self = lctx.kv_self;
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(!!kv_self.ctx);
 | 
			
		||||
 | 
			
		||||
    const int64_t n_embd      = hparams.n_embd;
 | 
			
		||||
    const int64_t n_layer     = hparams.n_layer;
 | 
			
		||||
    const int64_t n_ctx       = hparams.n_ctx;
 | 
			
		||||
    const int64_t n_head      = hparams.n_head;
 | 
			
		||||
    const int64_t n_head_kv   = hparams.n_head_kv;
 | 
			
		||||
    const int64_t n_embd_head = hparams.n_embd_head();
 | 
			
		||||
    const int64_t n_embd_gqa  = hparams.n_embd_gqa();
 | 
			
		||||
 | 
			
		||||
    GGML_ASSERT(n_embd_head == hparams.n_rot);
 | 
			
		||||
 | 
			
		||||
    const float freq_base  = hparams.rope_freq_base;
 | 
			
		||||
    const float freq_scale = hparams.rope_freq_scale;
 | 
			
		||||
    const float norm_eps   = hparams.f_norm_eps;
 | 
			
		||||
 | 
			
		||||
    const int n_gpu_layers = model.n_gpu_layers;
 | 
			
		||||
 | 
			
		||||
    auto & buf_compute = lctx.buf_compute;
 | 
			
		||||
 | 
			
		||||
    struct ggml_init_params params = {
 | 
			
		||||
        /*.mem_size   =*/ buf_compute.size,
 | 
			
		||||
        /*.mem_buffer =*/ buf_compute.data,
 | 
			
		||||
        /*.no_alloc   =*/ false,
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    params.no_alloc = true;
 | 
			
		||||
 | 
			
		||||
    struct ggml_context * ctx0 = ggml_init(params);
 | 
			
		||||
 | 
			
		||||
    ggml_cgraph * gf = ggml_new_graph(ctx0);
 | 
			
		||||
 | 
			
		||||
    struct ggml_tensor * cur;
 | 
			
		||||
    struct ggml_tensor * token;
 | 
			
		||||
    struct ggml_tensor * position;
 | 
			
		||||
    struct ggml_tensor * inpL;
 | 
			
		||||
 | 
			
		||||
    if (tokens) {
 | 
			
		||||
        struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
 | 
			
		||||
 | 
			
		||||
        ggml_allocr_alloc(lctx.alloc, inp_tokens);
 | 
			
		||||
        if (!ggml_allocr_is_measure(lctx.alloc)) {
 | 
			
		||||
            memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens));
 | 
			
		||||
        }
 | 
			
		||||
        ggml_set_name(inp_tokens, "inp_tokens");
 | 
			
		||||
 | 
			
		||||
        token = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens);
 | 
			
		||||
    } else {
 | 
			
		||||
#ifdef GGML_USE_MPI
 | 
			
		||||
        GGML_ASSERT(false && "not implemented");
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
        token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N);
 | 
			
		||||
 | 
			
		||||
        ggml_allocr_alloc(lctx.alloc, token);
 | 
			
		||||
        if (!ggml_allocr_is_measure(lctx.alloc)) {
 | 
			
		||||
            memcpy(token->data, embd, N * n_embd * ggml_element_size(inpL));
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    {
 | 
			
		||||
        // Compute position embeddings.
 | 
			
		||||
        struct ggml_tensor * inp_positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
 | 
			
		||||
        ggml_allocr_alloc(lctx.alloc, inp_positions);
 | 
			
		||||
        if (!ggml_allocr_is_measure(lctx.alloc)) {
 | 
			
		||||
            for (int i = 0; i < N; ++i) {
 | 
			
		||||
                ((int32_t *) inp_positions->data)[i] = n_past + i;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        ggml_set_name(inp_positions, "inp_positions");
 | 
			
		||||
 | 
			
		||||
        position = ggml_get_rows(ctx0, model.pos_embeddings, inp_positions);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    inpL = ggml_add(ctx0, token, position);
 | 
			
		||||
 | 
			
		||||
    const int i_gpu_start = n_layer - n_gpu_layers;
 | 
			
		||||
    (void) i_gpu_start;
 | 
			
		||||
 | 
			
		||||
    // offload functions set the tensor output backend to GPU
 | 
			
		||||
    // tensors are GPU-accelerated if any input or the output has been offloaded
 | 
			
		||||
    //
 | 
			
		||||
    // with the low VRAM option VRAM scratch is disabled in llama_load_model_internal
 | 
			
		||||
    // in that case ggml_cuda_assign_buffers has no effect
 | 
			
		||||
    offload_func_t offload_func_nr = llama_nop; // nr = non-repeating
 | 
			
		||||
    offload_func_t offload_func_kq = llama_nop;
 | 
			
		||||
    offload_func_t offload_func_v  = llama_nop;
 | 
			
		||||
 | 
			
		||||
#ifdef GGML_USE_CUBLAS
 | 
			
		||||
    if (n_gpu_layers > n_layer) {
 | 
			
		||||
        offload_func_nr = ggml_cuda_assign_buffers_no_alloc;
 | 
			
		||||
    }
 | 
			
		||||
    if (n_gpu_layers > n_layer + 1) {
 | 
			
		||||
        offload_func_v  = ggml_cuda_assign_buffers_no_alloc;
 | 
			
		||||
    }
 | 
			
		||||
    if (n_gpu_layers > n_layer + 2) {
 | 
			
		||||
        offload_func_kq = ggml_cuda_assign_buffers_no_alloc;
 | 
			
		||||
    }
 | 
			
		||||
#endif // GGML_USE_CUBLAS
 | 
			
		||||
 | 
			
		||||
    struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
 | 
			
		||||
    ggml_allocr_alloc(lctx.alloc, KQ_scale);
 | 
			
		||||
    if (!ggml_allocr_is_measure(lctx.alloc)) {
 | 
			
		||||
        ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head));
 | 
			
		||||
    }
 | 
			
		||||
    ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)");
 | 
			
		||||
 | 
			
		||||
    for (int il = 0; il < n_layer; ++il) {
 | 
			
		||||
        struct ggml_tensor * attn_norm;
 | 
			
		||||
 | 
			
		||||
        offload_func_t offload_func = llama_nop;
 | 
			
		||||
 | 
			
		||||
#ifdef GGML_USE_CUBLAS
 | 
			
		||||
        if (il >= i_gpu_start) {
 | 
			
		||||
            offload_func = ggml_cuda_assign_buffers_no_alloc;
 | 
			
		||||
        }
 | 
			
		||||
#endif // GGML_USE_CUBLAS
 | 
			
		||||
 | 
			
		||||
        // self-attention
 | 
			
		||||
        // TODO: refactor into common function (shared with LLaMA)
 | 
			
		||||
        {
 | 
			
		||||
            attn_norm = ggml_norm(ctx0, inpL, norm_eps);
 | 
			
		||||
            offload_func(attn_norm);
 | 
			
		||||
 | 
			
		||||
            attn_norm = ggml_add(ctx0,
 | 
			
		||||
                    ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm),
 | 
			
		||||
                    model.layers[il].attn_norm_b);
 | 
			
		||||
            offload_func(attn_norm->src[0]);
 | 
			
		||||
            offload_func(attn_norm);
 | 
			
		||||
 | 
			
		||||
            if (model.layers[il].attn_norm_2) { // Falcon-40B
 | 
			
		||||
                cur = ggml_norm(ctx0, inpL, norm_eps);
 | 
			
		||||
                offload_func(cur);
 | 
			
		||||
 | 
			
		||||
                cur = ggml_add(ctx0,
 | 
			
		||||
                        ggml_mul(ctx0, cur, model.layers[il].attn_norm_2),
 | 
			
		||||
                        model.layers[il].attn_norm_2_b);
 | 
			
		||||
                offload_func(cur->src[0]);
 | 
			
		||||
                offload_func(cur);
 | 
			
		||||
            } else { // Falcon 7B
 | 
			
		||||
                cur = attn_norm;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // compute QKV
 | 
			
		||||
 | 
			
		||||
            cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
 | 
			
		||||
            offload_func_kq(cur);
 | 
			
		||||
 | 
			
		||||
            // Note that the strides for Kcur, Vcur are set up so that the
 | 
			
		||||
            // resulting views are misaligned with the tensor's storage
 | 
			
		||||
            // (by applying the K/V offset we shift the tensor's original
 | 
			
		||||
            // view to stick out behind the viewed QKV tensor's allocated
 | 
			
		||||
            // memory, so to say). This is ok because no actual accesses
 | 
			
		||||
            // happen to that out-of-range memory, but it can require some
 | 
			
		||||
            // trickery when trying to accurately dump these views for
 | 
			
		||||
            // debugging.
 | 
			
		||||
 | 
			
		||||
            const size_t wsize = ggml_type_size(cur->type);
 | 
			
		||||
 | 
			
		||||
            // TODO: these 2 ggml_conts are technically not needed, but we add them until CUDA support for
 | 
			
		||||
            //       non-contiguous views is added for the rope operator
 | 
			
		||||
            struct ggml_tensor * tmpq = ggml_cont(ctx0, ggml_view_3d(
 | 
			
		||||
                ctx0, cur, n_embd_head, n_head, N,
 | 
			
		||||
                wsize * n_embd_head,
 | 
			
		||||
                wsize * n_embd_head * (n_head + 2 * n_head_kv),
 | 
			
		||||
                0));
 | 
			
		||||
            offload_func_kq(tmpq);
 | 
			
		||||
 | 
			
		||||
            struct ggml_tensor * tmpk = ggml_cont(ctx0, ggml_view_3d(
 | 
			
		||||
                ctx0, cur, n_embd_head, n_head_kv, N,
 | 
			
		||||
                wsize * n_embd_head,
 | 
			
		||||
                wsize * n_embd_head * (n_head + 2 * n_head_kv),
 | 
			
		||||
                wsize * n_embd_head *  n_head));
 | 
			
		||||
            offload_func_kq(tmpk);
 | 
			
		||||
 | 
			
		||||
            struct ggml_tensor * tmpv = ggml_view_3d(
 | 
			
		||||
                ctx0, cur, n_embd_head, n_head_kv, N,
 | 
			
		||||
                wsize * n_embd_head,
 | 
			
		||||
                wsize * n_embd_head * (n_head + 2 * n_head_kv),
 | 
			
		||||
                wsize * n_embd_head * (n_head +     n_head_kv));
 | 
			
		||||
            offload_func_v(tmpv);
 | 
			
		||||
 | 
			
		||||
            // using mode = 2 for neox mode
 | 
			
		||||
            struct ggml_tensor * Qcur = ggml_rope_custom_inplace(ctx0, tmpq, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
 | 
			
		||||
            offload_func_kq(Qcur);
 | 
			
		||||
            struct ggml_tensor * Kcur = ggml_rope_custom_inplace(ctx0, tmpk, n_past, n_embd_head, 2, 0, freq_base, freq_scale);
 | 
			
		||||
            offload_func_kq(Kcur);
 | 
			
		||||
 | 
			
		||||
            {
 | 
			
		||||
                struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, N));
 | 
			
		||||
                offload_func_v(Vcur);
 | 
			
		||||
                offload_func_v(Vcur->src[0]->src[0]);
 | 
			
		||||
                ggml_set_name(Vcur, "Vcur");
 | 
			
		||||
 | 
			
		||||
                struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + n_past));
 | 
			
		||||
                offload_func_kq(k);
 | 
			
		||||
                ggml_set_name(k, "k");
 | 
			
		||||
 | 
			
		||||
                struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd_gqa,
 | 
			
		||||
                        (   n_ctx)*ggml_element_size(kv_self.v),
 | 
			
		||||
                        (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + n_past*ggml_element_size(kv_self.v));
 | 
			
		||||
                offload_func_v(v);
 | 
			
		||||
 | 
			
		||||
                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k));
 | 
			
		||||
                ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v));
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3);
 | 
			
		||||
            offload_func_kq(Q);
 | 
			
		||||
            ggml_set_name(Q, "Q");
 | 
			
		||||
 | 
			
		||||
            struct ggml_tensor * K =
 | 
			
		||||
                ggml_view_3d(ctx0, kv_self.k,
 | 
			
		||||
                        n_embd_head, n_past + N, n_head_kv,
 | 
			
		||||
                        ggml_element_size(kv_self.k)*n_embd_gqa,
 | 
			
		||||
                        ggml_element_size(kv_self.k)*n_embd_head,
 | 
			
		||||
                        ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il);
 | 
			
		||||
            offload_func_kq(K);
 | 
			
		||||
            ggml_set_name(K, "K");
 | 
			
		||||
 | 
			
		||||
            struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
 | 
			
		||||
            offload_func_kq(KQ);
 | 
			
		||||
            ggml_set_name(KQ, "KQ");
 | 
			
		||||
 | 
			
		||||
            struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale);
 | 
			
		||||
            offload_func_kq(KQ_scaled);
 | 
			
		||||
            ggml_set_name(KQ_scaled, "KQ_scaled");
 | 
			
		||||
 | 
			
		||||
            struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);
 | 
			
		||||
            offload_func_kq(KQ_masked);
 | 
			
		||||
            ggml_set_name(KQ_masked, "KQ_masked");
 | 
			
		||||
 | 
			
		||||
            struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked);
 | 
			
		||||
            offload_func_v(KQ_soft_max);
 | 
			
		||||
            ggml_set_name(KQ_soft_max, "KQ_soft_max");
 | 
			
		||||
 | 
			
		||||
            struct ggml_tensor * V =
 | 
			
		||||
                ggml_view_3d(ctx0, kv_self.v,
 | 
			
		||||
                        n_past + N, n_embd_head, n_head_kv,
 | 
			
		||||
                        ggml_element_size(kv_self.v)*n_ctx,
 | 
			
		||||
                        ggml_element_size(kv_self.v)*n_ctx*n_embd_head,
 | 
			
		||||
                        ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il);
 | 
			
		||||
            offload_func_v(V);
 | 
			
		||||
            ggml_set_name(V, "V");
 | 
			
		||||
 | 
			
		||||
            struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
 | 
			
		||||
            offload_func_v(KQV);
 | 
			
		||||
            ggml_set_name(KQV, "KQV");
 | 
			
		||||
 | 
			
		||||
            struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
 | 
			
		||||
            offload_func_v(KQV_merged);
 | 
			
		||||
            ggml_set_name(KQV_merged, "KQV_merged");
 | 
			
		||||
 | 
			
		||||
            cur = ggml_cpy(ctx0, KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
 | 
			
		||||
            offload_func_v(cur);
 | 
			
		||||
            ggml_set_name(cur, "KQV_merged_contiguous");
 | 
			
		||||
 | 
			
		||||
            cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur);
 | 
			
		||||
            offload_func(cur);
 | 
			
		||||
            ggml_set_name(cur, "result_wo");
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        struct ggml_tensor * attn_out = cur;
 | 
			
		||||
 | 
			
		||||
        // feed forward
 | 
			
		||||
        {
 | 
			
		||||
            struct ggml_tensor * inpFF = attn_norm;
 | 
			
		||||
 | 
			
		||||
            cur = ggml_mul_mat(ctx0, model.layers[il].w3, inpFF);
 | 
			
		||||
            offload_func(cur);
 | 
			
		||||
 | 
			
		||||
            cur = ggml_gelu(ctx0, cur);
 | 
			
		||||
            offload_func(cur);
 | 
			
		||||
            cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);
 | 
			
		||||
            offload_func(cur);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        cur = ggml_add(ctx0, cur, attn_out);
 | 
			
		||||
        offload_func(cur);
 | 
			
		||||
        cur = ggml_add(ctx0, cur, inpL);
 | 
			
		||||
        offload_func(cur);
 | 
			
		||||
 | 
			
		||||
        // input for next layer
 | 
			
		||||
        inpL = cur;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    cur = inpL;
 | 
			
		||||
 | 
			
		||||
    // norm
 | 
			
		||||
    {
 | 
			
		||||
        cur = ggml_norm(ctx0, cur, norm_eps);
 | 
			
		||||
        offload_func_nr(cur);
 | 
			
		||||
 | 
			
		||||
        cur = ggml_add(ctx0,
 | 
			
		||||
                ggml_mul(ctx0, cur, model.output_norm),
 | 
			
		||||
                model.output_norm_b);
 | 
			
		||||
        ggml_set_name(cur, "result_norm");
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
			
		||||
    ggml_set_name(cur, "result_output");
 | 
			
		||||
 | 
			
		||||
    ggml_build_forward_expand(gf, cur);
 | 
			
		||||
 | 
			
		||||
    ggml_free(ctx0);
 | 
			
		||||
 | 
			
		||||
    return gf;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static struct ggml_cgraph * llama_build_graph(
 | 
			
		||||
         llama_context & lctx,
 | 
			
		||||
     const llama_token * tokens,
 | 
			
		||||
@@ -3447,6 +3772,10 @@ static struct ggml_cgraph * llama_build_graph(
 | 
			
		||||
            {
 | 
			
		||||
                result = llm_build_falcon(lctx, tokens, embd, n_tokens, n_past);
 | 
			
		||||
            } break;
 | 
			
		||||
        case LLM_ARCH_STARCODER:
 | 
			
		||||
            {
 | 
			
		||||
                result = llm_build_starcoder(lctx, tokens, embd, n_tokens, n_past);
 | 
			
		||||
            } break;
 | 
			
		||||
        default:
 | 
			
		||||
            GGML_ASSERT(false);
 | 
			
		||||
    };
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user