mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	llama : add llm_build_norm helper function
ggml-ci
This commit is contained in:
		
							
								
								
									
										407
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										407
									
								
								llama.cpp
									
									
									
									
									
								
							@@ -972,7 +972,7 @@ struct llama_mlock {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
typedef void (*offload_func_t)(struct ggml_tensor * tensor);
 | 
					typedef void (*offload_func_t)(struct ggml_tensor * tensor);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static void ggml_offload_nop(struct ggml_tensor * tensor) { // don't offload by default
 | 
					static void ggml_offload_nop(struct ggml_tensor * tensor) {
 | 
				
			||||||
    (void) tensor;
 | 
					    (void) tensor;
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -3093,6 +3093,42 @@ static bool llama_model_load(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
using llm_build_cb = std::function<void(struct ggml_tensor * cur, const char * name, int nl)>;
 | 
					using llm_build_cb = std::function<void(struct ggml_tensor * cur, const char * name, int nl)>;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					enum llm_norm_type {
 | 
				
			||||||
 | 
					    LLM_NORM,
 | 
				
			||||||
 | 
					    LLM_NORM_RMS,
 | 
				
			||||||
 | 
					};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					static struct ggml_tensor * llm_build_norm(
 | 
				
			||||||
 | 
					        struct ggml_context * ctx,
 | 
				
			||||||
 | 
					         struct ggml_tensor * cur,
 | 
				
			||||||
 | 
					         struct ggml_tensor * mw,
 | 
				
			||||||
 | 
					         struct ggml_tensor * mb,
 | 
				
			||||||
 | 
					              llm_norm_type   type,
 | 
				
			||||||
 | 
					                      float   eps,
 | 
				
			||||||
 | 
					         const llm_build_cb & cb,
 | 
				
			||||||
 | 
					                        int   il) {
 | 
				
			||||||
 | 
					    switch (type) {
 | 
				
			||||||
 | 
					        case LLM_NORM:     cur = ggml_norm    (ctx, cur, eps); break;
 | 
				
			||||||
 | 
					        case LLM_NORM_RMS: cur = ggml_rms_norm(ctx, cur, eps); break;
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					    if (mw || mb) {
 | 
				
			||||||
 | 
					        cb(cur, "norm", il);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (mw) {
 | 
				
			||||||
 | 
					        cur = ggml_mul(ctx, cur, mw);
 | 
				
			||||||
 | 
					        if (mb) {
 | 
				
			||||||
 | 
					            cb(cur, "norm_w", il);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if (mb) {
 | 
				
			||||||
 | 
					        cur = ggml_add(ctx, cur, mb);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return cur;
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static struct ggml_cgraph * llm_build_llama(
 | 
					static struct ggml_cgraph * llm_build_llama(
 | 
				
			||||||
        llama_context  & lctx,
 | 
					        llama_context  & lctx,
 | 
				
			||||||
    const llama_batch  & batch,
 | 
					    const llama_batch  & batch,
 | 
				
			||||||
@@ -3192,14 +3228,11 @@ static struct ggml_cgraph * llm_build_llama(
 | 
				
			|||||||
        struct ggml_tensor * inpSA = inpL;
 | 
					        struct ggml_tensor * inpSA = inpL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // norm
 | 
					        // norm
 | 
				
			||||||
        {
 | 
					        cur = llm_build_norm(ctx0, inpL,
 | 
				
			||||||
            cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
 | 
					                model.layers[il].attn_norm,
 | 
				
			||||||
            cb(cur, "rms_norm_0", il);
 | 
					                NULL,
 | 
				
			||||||
 | 
					                LLM_NORM_RMS, norm_rms_eps, cb, il);
 | 
				
			||||||
            // cur = cur*attn_norm(broadcasted)
 | 
					        cb(cur, "attn_norm", il);
 | 
				
			||||||
            cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
 | 
					 | 
				
			||||||
            cb(cur, "attn_norm_0", il);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // self-attention
 | 
					        // self-attention
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
@@ -3307,15 +3340,11 @@ static struct ggml_cgraph * llm_build_llama(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        // feed-forward network
 | 
					        // feed-forward network
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            // norm
 | 
					            cur = llm_build_norm(ctx0, inpFF,
 | 
				
			||||||
            {
 | 
					                    model.layers[il].ffn_norm,
 | 
				
			||||||
                cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
 | 
					                    NULL,
 | 
				
			||||||
                cb(cur, "rms_norm_1", il);
 | 
					                    LLM_NORM_RMS, norm_rms_eps, cb, il);
 | 
				
			||||||
 | 
					 | 
				
			||||||
                // cur = cur*ffn_norm(broadcasted)
 | 
					 | 
				
			||||||
                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
 | 
					 | 
				
			||||||
            cb(cur, "ffn_norm", il);
 | 
					            cb(cur, "ffn_norm", il);
 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
 | 
					            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
 | 
				
			||||||
                    model.layers[il].w3,
 | 
					                    model.layers[il].w3,
 | 
				
			||||||
@@ -3349,15 +3378,11 @@ static struct ggml_cgraph * llm_build_llama(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    cur = inpL;
 | 
					    cur = inpL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // norm
 | 
					    cur = llm_build_norm(ctx0, cur,
 | 
				
			||||||
    {
 | 
					            model.output_norm,
 | 
				
			||||||
        cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
 | 
					            NULL,
 | 
				
			||||||
        cb(cur, "rms_norm_2", -1);
 | 
					            LLM_NORM_RMS, norm_rms_eps, cb, -1);
 | 
				
			||||||
 | 
					 | 
				
			||||||
        // cur = cur*norm(broadcasted)
 | 
					 | 
				
			||||||
        cur = ggml_mul(ctx0, cur, model.output_norm);
 | 
					 | 
				
			||||||
    cb(cur, "result_norm", -1);
 | 
					    cb(cur, "result_norm", -1);
 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // lm_head
 | 
					    // lm_head
 | 
				
			||||||
    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
					    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
				
			||||||
@@ -3466,15 +3491,11 @@ static struct ggml_cgraph * llm_build_baichaun(
 | 
				
			|||||||
    for (int il = 0; il < n_layer; ++il) {
 | 
					    for (int il = 0; il < n_layer; ++il) {
 | 
				
			||||||
        struct ggml_tensor * inpSA = inpL;
 | 
					        struct ggml_tensor * inpSA = inpL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // norm
 | 
					        cur = llm_build_norm(ctx0, inpL,
 | 
				
			||||||
        {
 | 
					                model.layers[il].attn_norm,
 | 
				
			||||||
            cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
 | 
					                NULL,
 | 
				
			||||||
            cb(cur, "rms_norm_0", il);
 | 
					                LLM_NORM_RMS, norm_rms_eps, cb, il);
 | 
				
			||||||
 | 
					        cb(cur, "attn_norm", il);
 | 
				
			||||||
            // cur = cur*attn_norm(broadcasted)
 | 
					 | 
				
			||||||
            cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
 | 
					 | 
				
			||||||
            cb(cur, "attn_norm_0", il);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // self-attention
 | 
					        // self-attention
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
@@ -3600,15 +3621,11 @@ static struct ggml_cgraph * llm_build_baichaun(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        // feed-forward network
 | 
					        // feed-forward network
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            // norm
 | 
					            cur = llm_build_norm(ctx0, inpFF,
 | 
				
			||||||
            {
 | 
					                    model.layers[il].ffn_norm,
 | 
				
			||||||
                cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
 | 
					                    NULL,
 | 
				
			||||||
                cb(cur, "rms_norm_1", il);
 | 
					                    LLM_NORM_RMS, norm_rms_eps, cb, il);
 | 
				
			||||||
 | 
					 | 
				
			||||||
                // cur = cur*ffn_norm(broadcasted)
 | 
					 | 
				
			||||||
                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
 | 
					 | 
				
			||||||
            cb(cur, "ffn_norm", il);
 | 
					            cb(cur, "ffn_norm", il);
 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
 | 
					            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
 | 
				
			||||||
                    model.layers[il].w3,
 | 
					                    model.layers[il].w3,
 | 
				
			||||||
@@ -3763,27 +3780,21 @@ static struct ggml_cgraph * llm_build_falcon(
 | 
				
			|||||||
        struct ggml_tensor * attn_norm;
 | 
					        struct ggml_tensor * attn_norm;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // self-attention
 | 
					        // self-attention
 | 
				
			||||||
        // TODO: refactor into common function (shared with LLaMA)
 | 
					 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            attn_norm = ggml_norm(ctx0, inpL, norm_eps);
 | 
					            attn_norm = llm_build_norm(ctx0, inpL,
 | 
				
			||||||
            cb(attn_norm, "attn_norm_0", il);
 | 
					                    model.layers[il].attn_norm,
 | 
				
			||||||
 | 
					                    model.layers[il].attn_norm_b,
 | 
				
			||||||
 | 
					                    LLM_NORM, norm_eps, cb, il);
 | 
				
			||||||
 | 
					            cb(attn_norm, "attn_norm", il);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            attn_norm = ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm);
 | 
					            if (model.layers[il].attn_norm_2) {
 | 
				
			||||||
            cb(attn_norm, "attn_norm_0_w", il);
 | 
					                // Falcon-40B
 | 
				
			||||||
 | 
					                cur = llm_build_norm(ctx0, attn_norm,
 | 
				
			||||||
            attn_norm = ggml_add(ctx0, attn_norm, model.layers[il].attn_norm_b);
 | 
					                        model.layers[il].attn_norm_2,
 | 
				
			||||||
            cb(attn_norm, "attn_norm_0_wb", il);
 | 
					                        model.layers[il].attn_norm_2_b,
 | 
				
			||||||
 | 
					                        LLM_NORM, norm_eps, cb, il);
 | 
				
			||||||
            if (model.layers[il].attn_norm_2) { // Falcon-40B
 | 
					 | 
				
			||||||
                cur = ggml_norm(ctx0, inpL, norm_eps);
 | 
					 | 
				
			||||||
                cb(cur, "attn_norm_2", il);
 | 
					                cb(cur, "attn_norm_2", il);
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
                cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm_2);
 | 
					 | 
				
			||||||
                cb(cur, "attn_norm_2_w", il);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_2_b);
 | 
					 | 
				
			||||||
                cb(cur, "attn_norm_2_wb", il);
 | 
					 | 
				
			||||||
            } else { // Falcon 7B
 | 
					 | 
				
			||||||
                cur = attn_norm;
 | 
					                cur = attn_norm;
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -3925,16 +3936,11 @@ static struct ggml_cgraph * llm_build_falcon(
 | 
				
			|||||||
    cur = inpL;
 | 
					    cur = inpL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // norm
 | 
					    // norm
 | 
				
			||||||
    {
 | 
					    cur = llm_build_norm(ctx0, cur,
 | 
				
			||||||
        cur = ggml_norm(ctx0, cur, norm_eps);
 | 
					            model.output_norm,
 | 
				
			||||||
        cb(cur, "out_norm_0", -1);
 | 
					            model.output_norm_b,
 | 
				
			||||||
 | 
					            LLM_NORM, norm_eps, cb, -1);
 | 
				
			||||||
        cur = ggml_mul(ctx0, cur, model.output_norm);
 | 
					 | 
				
			||||||
        cb(cur, "out_norm_0_w", -1);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        cur = ggml_add(ctx0, cur, model.output_norm_b);
 | 
					 | 
				
			||||||
    cb(cur, "result_norm", -1);
 | 
					    cb(cur, "result_norm", -1);
 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
					    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
				
			||||||
    cb(cur, "result_output", -1);
 | 
					    cb(cur, "result_output", -1);
 | 
				
			||||||
@@ -4024,17 +4030,11 @@ static struct ggml_cgraph * llm_build_starcoder(
 | 
				
			|||||||
    cb(inpL, "inpL", -1);
 | 
					    cb(inpL, "inpL", -1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for (int il = 0; il < n_layer; ++il) {
 | 
					    for (int il = 0; il < n_layer; ++il) {
 | 
				
			||||||
        {
 | 
					        cur = llm_build_norm(ctx0, inpL,
 | 
				
			||||||
            // Norm
 | 
					                model.layers[il].attn_norm,
 | 
				
			||||||
            cur = ggml_norm(ctx0, inpL, norm_eps);
 | 
					                model.layers[il].attn_norm_b,
 | 
				
			||||||
            cb(cur, "attn_norm_0", il);
 | 
					                LLM_NORM, norm_eps, cb, il);
 | 
				
			||||||
 | 
					        cb(cur, "attn_norm", il);
 | 
				
			||||||
            cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
 | 
					 | 
				
			||||||
            cb(cur, "attn_norm_0_w", il);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b);
 | 
					 | 
				
			||||||
            cb(cur, "attn_norm_0_wb", il);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            // Self Attention
 | 
					            // Self Attention
 | 
				
			||||||
@@ -4130,17 +4130,11 @@ static struct ggml_cgraph * llm_build_starcoder(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        // FF
 | 
					        // FF
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            // Norm
 | 
					            cur = llm_build_norm(ctx0, inpFF,
 | 
				
			||||||
            {
 | 
					                    model.layers[il].ffn_norm,
 | 
				
			||||||
                cur = ggml_norm(ctx0, inpFF, norm_eps);
 | 
					                    model.layers[il].ffn_norm_b,
 | 
				
			||||||
                cb(cur, "ffn_norm_0", il);
 | 
					                    LLM_NORM, norm_eps, cb, il);
 | 
				
			||||||
 | 
					            cb(cur, "ffn_norm", il);
 | 
				
			||||||
                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
 | 
					 | 
				
			||||||
                cb(cur, "ffn_norm_0_w", il);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                cur = ggml_add(ctx0, cur, model.layers[il].ffn_norm_b);
 | 
					 | 
				
			||||||
                cb(cur, "ffn_norm_0_wb", il);
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
 | 
					            cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3);
 | 
				
			||||||
            cb(cur, "result_w3", il);
 | 
					            cb(cur, "result_w3", il);
 | 
				
			||||||
@@ -4161,17 +4155,11 @@ static struct ggml_cgraph * llm_build_starcoder(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Output Norm
 | 
					    cur = llm_build_norm(ctx0, inpL,
 | 
				
			||||||
    {
 | 
					            model.output_norm,
 | 
				
			||||||
        cur = ggml_norm(ctx0, inpL, norm_eps);
 | 
					            model.output_norm_b,
 | 
				
			||||||
        cb(cur, "out_norm_0", -1);
 | 
					            LLM_NORM, norm_eps, cb, -1);
 | 
				
			||||||
 | 
					 | 
				
			||||||
        cur = ggml_mul(ctx0, cur, model.output_norm);
 | 
					 | 
				
			||||||
        cb(cur, "out_norm_0_w", -1);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        cur = ggml_add(ctx0, cur, model.output_norm_b);
 | 
					 | 
				
			||||||
    cb(cur, "result_norm", -1);
 | 
					    cb(cur, "result_norm", -1);
 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
					    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
				
			||||||
    cb(cur, "result_output", -1);
 | 
					    cb(cur, "result_output", -1);
 | 
				
			||||||
@@ -4271,16 +4259,11 @@ static struct ggml_cgraph * llm_build_persimmon(
 | 
				
			|||||||
    for (int il = 0; il < n_layer; ++il) {
 | 
					    for (int il = 0; il < n_layer; ++il) {
 | 
				
			||||||
        struct ggml_tensor * residual = inpL;
 | 
					        struct ggml_tensor * residual = inpL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        {
 | 
					        cur = llm_build_norm(ctx0, inpL,
 | 
				
			||||||
            cur = ggml_norm(ctx0, inpL, norm_eps);
 | 
					                model.layers[il].attn_norm,
 | 
				
			||||||
            cb(cur, "attn_norm_0", il);
 | 
					                model.layers[il].attn_norm_b,
 | 
				
			||||||
 | 
					                LLM_NORM, norm_eps, cb, il);
 | 
				
			||||||
            cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
 | 
					        cb(cur, "attn_norm", il);
 | 
				
			||||||
            cb(cur, "attn_norm_0_w", il);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b);
 | 
					 | 
				
			||||||
            cb(cur, "attn_norm_0_wb", il);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // self attention
 | 
					        // self attention
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
@@ -4316,22 +4299,16 @@ static struct ggml_cgraph * llm_build_persimmon(
 | 
				
			|||||||
            cb(tmpk, "tmpk", il);
 | 
					            cb(tmpk, "tmpk", il);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            // Q/K Layernorm
 | 
					            // Q/K Layernorm
 | 
				
			||||||
            tmpq = ggml_norm(ctx0, tmpq, norm_eps);
 | 
					            tmpq = llm_build_norm(ctx0, tmpq,
 | 
				
			||||||
 | 
					                    model.layers[il].attn_q_norm,
 | 
				
			||||||
 | 
					                    model.layers[il].attn_q_norm_b,
 | 
				
			||||||
 | 
					                    LLM_NORM, norm_eps, cb, il);
 | 
				
			||||||
            cb(tmpq, "tmpq", il);
 | 
					            cb(tmpq, "tmpq", il);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            tmpq = ggml_mul(ctx0, tmpq, model.layers[il].attn_q_norm);
 | 
					            tmpk = llm_build_norm(ctx0, tmpk,
 | 
				
			||||||
            cb(tmpq, "tmpq", il);
 | 
					                    model.layers[il].attn_k_norm,
 | 
				
			||||||
 | 
					                    model.layers[il].attn_k_norm_b,
 | 
				
			||||||
            tmpq = ggml_add(ctx0, tmpq, model.layers[il].attn_q_norm_b);
 | 
					                    LLM_NORM, norm_eps, cb, il);
 | 
				
			||||||
            cb(tmpq, "tmpq", il);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            tmpk = ggml_norm(ctx0, tmpk, norm_eps);
 | 
					 | 
				
			||||||
            cb(tmpk, "tmpk", il);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            tmpk =  ggml_mul(ctx0, tmpk, model.layers[il].attn_k_norm);
 | 
					 | 
				
			||||||
            cb(tmpk, "tmpk", il);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            tmpk =  ggml_add(ctx0, tmpk, model.layers[il].attn_k_norm_b);
 | 
					 | 
				
			||||||
            cb(tmpk, "tmpk", il);
 | 
					            cb(tmpk, "tmpk", il);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            // RoPE the first n_rot of q/k, pass the other half, and concat.
 | 
					            // RoPE the first n_rot of q/k, pass the other half, and concat.
 | 
				
			||||||
@@ -4480,17 +4457,11 @@ static struct ggml_cgraph * llm_build_persimmon(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            // MLP
 | 
					            // MLP
 | 
				
			||||||
            {
 | 
					            cur = llm_build_norm(ctx0, inpFF,
 | 
				
			||||||
                // Norm
 | 
					                    model.layers[il].ffn_norm,
 | 
				
			||||||
                cur = ggml_norm(ctx0, inpFF, norm_eps);
 | 
					                    model.layers[il].ffn_norm_b,
 | 
				
			||||||
                cb(cur, "ffn_norm_0", il);
 | 
					                    LLM_NORM, norm_eps, cb, il);
 | 
				
			||||||
 | 
					            cb(cur, "ffn_norm", il);
 | 
				
			||||||
                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
 | 
					 | 
				
			||||||
                cb(cur, "ffn_norm_0_w", il);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                cur = ggml_add(ctx0, cur, model.layers[il].ffn_norm_b);
 | 
					 | 
				
			||||||
                cb(cur, "ffn_norm_0_wb", il);
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
 | 
					            cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
 | 
				
			||||||
            cb(cur, "result_w3", il);
 | 
					            cb(cur, "result_w3", il);
 | 
				
			||||||
@@ -4519,16 +4490,11 @@ static struct ggml_cgraph * llm_build_persimmon(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    cur = inpL;
 | 
					    cur = inpL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    {
 | 
					    cur = llm_build_norm(ctx0, cur,
 | 
				
			||||||
        cur = ggml_norm(ctx0, cur, norm_eps);
 | 
					            model.output_norm,
 | 
				
			||||||
        cb(cur, "out_norm_0", -1);
 | 
					            model.output_norm_b,
 | 
				
			||||||
 | 
					            LLM_NORM, norm_eps, cb, -1);
 | 
				
			||||||
        cur = ggml_mul(ctx0, cur, model.output_norm);
 | 
					 | 
				
			||||||
        cb(cur, "out_norm_0_w", -1);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        cur = ggml_add(ctx0, cur, model.output_norm_b);
 | 
					 | 
				
			||||||
    cb(cur, "result_norm", -1);
 | 
					    cb(cur, "result_norm", -1);
 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
					    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
				
			||||||
    cb(cur, "result_output", -1);
 | 
					    cb(cur, "result_output", -1);
 | 
				
			||||||
@@ -4609,15 +4575,11 @@ static struct ggml_cgraph * llm_build_refact(
 | 
				
			|||||||
    for (int il = 0; il < n_layer; ++il) {
 | 
					    for (int il = 0; il < n_layer; ++il) {
 | 
				
			||||||
        struct ggml_tensor * inpSA = inpL;
 | 
					        struct ggml_tensor * inpSA = inpL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // norm
 | 
					        cur = llm_build_norm(ctx0, inpL,
 | 
				
			||||||
        {
 | 
					                model.layers[il].attn_norm,
 | 
				
			||||||
            cur = ggml_rms_norm(ctx0, inpL, norm_rms_eps);
 | 
					                NULL,
 | 
				
			||||||
            cb(cur, "rms_norm_0", il);
 | 
					                LLM_NORM_RMS, norm_rms_eps, cb, il);
 | 
				
			||||||
 | 
					        cb(cur, "attn_norm", il);
 | 
				
			||||||
            // cur = cur*attn_norm(broadcasted)
 | 
					 | 
				
			||||||
            cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
 | 
					 | 
				
			||||||
            cb(cur, "attn_norm_0", il);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // self-attention
 | 
					        // self-attention
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
@@ -4719,15 +4681,11 @@ static struct ggml_cgraph * llm_build_refact(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        // feed-forward network
 | 
					        // feed-forward network
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            // norm
 | 
					            cur = llm_build_norm(ctx0, inpFF,
 | 
				
			||||||
            {
 | 
					                    model.layers[il].ffn_norm,
 | 
				
			||||||
                cur = ggml_rms_norm(ctx0, inpFF, norm_rms_eps);
 | 
					                    NULL,
 | 
				
			||||||
                cb(cur, "rms_norm_1", il);
 | 
					                    LLM_NORM_RMS, norm_rms_eps, cb, il);
 | 
				
			||||||
 | 
					 | 
				
			||||||
                // cur = cur*ffn_norm(broadcasted)
 | 
					 | 
				
			||||||
                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
 | 
					 | 
				
			||||||
            cb(cur, "ffn_norm", il);
 | 
					            cb(cur, "ffn_norm", il);
 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
 | 
					            struct ggml_tensor * tmp = ggml_mul_mat(ctx0,
 | 
				
			||||||
                    model.layers[il].w3,
 | 
					                    model.layers[il].w3,
 | 
				
			||||||
@@ -4761,15 +4719,11 @@ static struct ggml_cgraph * llm_build_refact(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    cur = inpL;
 | 
					    cur = inpL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // norm
 | 
					    cur = llm_build_norm(ctx0, cur,
 | 
				
			||||||
    {
 | 
					            model.output_norm,
 | 
				
			||||||
        cur = ggml_rms_norm(ctx0, cur, norm_rms_eps);
 | 
					            NULL,
 | 
				
			||||||
        cb(cur, "rms_norm_2", -1);
 | 
					            LLM_NORM_RMS, norm_rms_eps, cb, -1);
 | 
				
			||||||
 | 
					 | 
				
			||||||
        // cur = cur*norm(broadcasted)
 | 
					 | 
				
			||||||
        cur = ggml_mul(ctx0, cur, model.output_norm);
 | 
					 | 
				
			||||||
    cb(cur, "result_norm", -1);
 | 
					    cb(cur, "result_norm", -1);
 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // lm_head
 | 
					    // lm_head
 | 
				
			||||||
    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
					    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
				
			||||||
@@ -4851,30 +4805,18 @@ static struct ggml_cgraph * llm_build_bloom(
 | 
				
			|||||||
    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
 | 
					    struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
 | 
				
			||||||
    cb(KQ_mask, "KQ_mask", -1);
 | 
					    cb(KQ_mask, "KQ_mask", -1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // norm
 | 
					    inpL = llm_build_norm(ctx0, embd,
 | 
				
			||||||
    {
 | 
					            model.tok_norm,
 | 
				
			||||||
        inpL = ggml_norm(ctx0, embd, norm_eps);
 | 
					            model.tok_norm_b,
 | 
				
			||||||
 | 
					            LLM_NORM, norm_eps, cb, -1);
 | 
				
			||||||
    cb(inpL, "inp_norm", -1);
 | 
					    cb(inpL, "inp_norm", -1);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        inpL = ggml_mul(ctx0, inpL, model.tok_norm);
 | 
					 | 
				
			||||||
        cb(inpL, "inp_norm_w", -1);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        inpL = ggml_add (ctx0, inpL, model.tok_norm_b);
 | 
					 | 
				
			||||||
        cb(inpL, "inp_norm_wb", -1);
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for (int il = 0; il < n_layer; ++il) {
 | 
					    for (int il = 0; il < n_layer; ++il) {
 | 
				
			||||||
        {
 | 
					        cur = llm_build_norm(ctx0, inpL,
 | 
				
			||||||
            // Norm
 | 
					                model.layers[il].attn_norm,
 | 
				
			||||||
            cur = ggml_norm(ctx0, inpL, norm_eps);
 | 
					                model.layers[il].attn_norm_b,
 | 
				
			||||||
            cb(cur, "attn_norm_0", il);
 | 
					                LLM_NORM, norm_eps, cb, il);
 | 
				
			||||||
 | 
					        cb(cur, "attn_norm", il);
 | 
				
			||||||
            cur = ggml_mul(ctx0, cur, model.layers[il].attn_norm);
 | 
					 | 
				
			||||||
            cb(cur, "attn_norm_0_w", il);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            cur = ggml_add(ctx0, cur, model.layers[il].attn_norm_b);
 | 
					 | 
				
			||||||
            cb(cur, "attn_norm_0_wb", il);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            // Self Attention
 | 
					            // Self Attention
 | 
				
			||||||
@@ -4984,17 +4926,11 @@ static struct ggml_cgraph * llm_build_bloom(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        // FF
 | 
					        // FF
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            // Norm
 | 
					            cur = llm_build_norm(ctx0, inpFF,
 | 
				
			||||||
            {
 | 
					                    model.layers[il].ffn_norm,
 | 
				
			||||||
                cur = ggml_norm(ctx0, inpFF, norm_eps);
 | 
					                    model.layers[il].ffn_norm_b,
 | 
				
			||||||
                cb(cur, "ffn_norm_0", il);
 | 
					                    LLM_NORM, norm_eps, cb, il);
 | 
				
			||||||
 | 
					            cb(cur, "ffn_norm", il);
 | 
				
			||||||
                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
 | 
					 | 
				
			||||||
                cb(cur, "ffn_norm_0_w", il);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                cur = ggml_add(ctx0, cur, model.layers[il].ffn_norm_b);
 | 
					 | 
				
			||||||
                cb(cur, "ffn_norm_0_wb", il);
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
 | 
					            cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
 | 
				
			||||||
            cb(cur, "result_w3", il);
 | 
					            cb(cur, "result_w3", il);
 | 
				
			||||||
@@ -5016,17 +4952,11 @@ static struct ggml_cgraph * llm_build_bloom(
 | 
				
			|||||||
        cb(inpL, "inpFF_+_result_w2", il);
 | 
					        cb(inpL, "inpFF_+_result_w2", il);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // Output Norm
 | 
					    cur = llm_build_norm(ctx0, inpL,
 | 
				
			||||||
    {
 | 
					            model.output_norm,
 | 
				
			||||||
        cur = ggml_norm(ctx0, inpL, norm_eps);
 | 
					            model.output_norm_b,
 | 
				
			||||||
        cb(cur, "out_norm_0", -1);
 | 
					            LLM_NORM, norm_eps, cb, -1);
 | 
				
			||||||
 | 
					 | 
				
			||||||
        cur = ggml_mul(ctx0, cur, model.output_norm);
 | 
					 | 
				
			||||||
        cb(cur, "out_norm_0_w", -1);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        cur = ggml_add(ctx0, cur, model.output_norm_b);
 | 
					 | 
				
			||||||
    cb(cur, "result_norm", -1);
 | 
					    cb(cur, "result_norm", -1);
 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
					    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
				
			||||||
    cb(cur, "result_output", -1);
 | 
					    cb(cur, "result_output", -1);
 | 
				
			||||||
@@ -5109,18 +5039,15 @@ static struct ggml_cgraph * llm_build_mpt(
 | 
				
			|||||||
    for (int il = 0; il < n_layer; ++il) {
 | 
					    for (int il = 0; il < n_layer; ++il) {
 | 
				
			||||||
        struct ggml_tensor * attn_norm;
 | 
					        struct ggml_tensor * attn_norm;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attn_norm = llm_build_norm(ctx0, inpL,
 | 
				
			||||||
 | 
					                model.layers[il].attn_norm,
 | 
				
			||||||
 | 
					                NULL,
 | 
				
			||||||
 | 
					                LLM_NORM, norm_eps, cb, il);
 | 
				
			||||||
 | 
					        cb(attn_norm, "attn_norm", il);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // self-attention
 | 
					        // self-attention
 | 
				
			||||||
        // TODO: refactor into common function (shared with LLaMA)
 | 
					 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            attn_norm = ggml_norm(ctx0, inpL, norm_eps);
 | 
					 | 
				
			||||||
            cb(attn_norm, "attn_norm_0", il);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            attn_norm = ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm);
 | 
					 | 
				
			||||||
            cb(attn_norm, "attn_norm_0_w", il);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            if (1) {
 | 
					 | 
				
			||||||
            cur = attn_norm;
 | 
					            cur = attn_norm;
 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            // compute QKV
 | 
					            // compute QKV
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -5230,14 +5157,11 @@ static struct ggml_cgraph * llm_build_mpt(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        // feed forward
 | 
					        // feed forward
 | 
				
			||||||
        {
 | 
					        {
 | 
				
			||||||
            // Norm
 | 
					            cur = llm_build_norm(ctx0, attn_out,
 | 
				
			||||||
            {
 | 
					                    model.layers[il].ffn_norm,
 | 
				
			||||||
                cur = ggml_norm(ctx0, attn_out, norm_eps);
 | 
					                    NULL,
 | 
				
			||||||
                cb(cur, "ffn_norm_0", il);
 | 
					                    LLM_NORM, norm_eps, cb, il);
 | 
				
			||||||
 | 
					            cb(cur, "ffn_norm", il);
 | 
				
			||||||
                cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm);
 | 
					 | 
				
			||||||
                cb(cur, "ffn_norm_0_w", il);
 | 
					 | 
				
			||||||
            }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
 | 
					            cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur);
 | 
				
			||||||
            cb(cur, "result_w3", il);
 | 
					            cb(cur, "result_w3", il);
 | 
				
			||||||
@@ -5258,14 +5182,11 @@ static struct ggml_cgraph * llm_build_mpt(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    cur = inpL;
 | 
					    cur = inpL;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // norm
 | 
					    cur = llm_build_norm(ctx0, cur,
 | 
				
			||||||
    {
 | 
					            model.output_norm,
 | 
				
			||||||
        cur = ggml_norm(ctx0, cur, norm_eps);
 | 
					            NULL,
 | 
				
			||||||
        cb(cur, "out_norm_0", -1);
 | 
					            LLM_NORM, norm_eps, cb, -1);
 | 
				
			||||||
 | 
					 | 
				
			||||||
        cur = ggml_mul(ctx0, cur, model.output_norm);
 | 
					 | 
				
			||||||
    cb(cur, "result_norm", -1);
 | 
					    cb(cur, "result_norm", -1);
 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
					    cur = ggml_mul_mat(ctx0, model.output, cur);
 | 
				
			||||||
    cb(cur, "result_output", -1);
 | 
					    cb(cur, "result_output", -1);
 | 
				
			||||||
@@ -5378,15 +5299,12 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
 | 
				
			|||||||
    { "inp_norm_w",                 OFFLOAD_FUNC_NR  },
 | 
					    { "inp_norm_w",                 OFFLOAD_FUNC_NR  },
 | 
				
			||||||
    { "inp_norm_wb",                OFFLOAD_FUNC_NR  },
 | 
					    { "inp_norm_wb",                OFFLOAD_FUNC_NR  },
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    { "rms_norm_0",                 OFFLOAD_FUNC     },
 | 
					    { "norm",                       OFFLOAD_FUNC     },
 | 
				
			||||||
 | 
					    { "norm_w",                     OFFLOAD_FUNC     },
 | 
				
			||||||
    { "attn_norm_0",                OFFLOAD_FUNC     },
 | 
					    { "norm_wb",                    OFFLOAD_FUNC     },
 | 
				
			||||||
    { "attn_norm_0_w",              OFFLOAD_FUNC     },
 | 
					 | 
				
			||||||
    { "attn_norm_0_wb",             OFFLOAD_FUNC     },
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    { "attn_norm",                  OFFLOAD_FUNC     },
 | 
				
			||||||
    { "attn_norm_2",                OFFLOAD_FUNC     },
 | 
					    { "attn_norm_2",                OFFLOAD_FUNC     },
 | 
				
			||||||
    { "attn_norm_2_w",              OFFLOAD_FUNC     },
 | 
					 | 
				
			||||||
    { "attn_norm_2_wb",             OFFLOAD_FUNC     },
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    { "wqkv",                       OFFLOAD_FUNC_KQ  },
 | 
					    { "wqkv",                       OFFLOAD_FUNC_KQ  },
 | 
				
			||||||
    { "bqkv",                       OFFLOAD_FUNC_KQ  },
 | 
					    { "bqkv",                       OFFLOAD_FUNC_KQ  },
 | 
				
			||||||
@@ -5614,20 +5532,19 @@ static struct ggml_cgraph * llama_build_graph(
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        static const std::unordered_map<llm_offload_func_e, std::string, std::hash<int>> k_offload_func_name = {
 | 
					        static const std::unordered_map<llm_offload_func_e, std::string, std::hash<int>> k_offload_func_name = {
 | 
				
			||||||
            { OFFLOAD_FUNC_NOP, "CPU" },
 | 
					            { OFFLOAD_FUNC_NOP, "CPU" },
 | 
				
			||||||
 | 
					            { OFFLOAD_FUNC_OUT, "CPU" },
 | 
				
			||||||
#ifdef GGML_USE_CUBLAS
 | 
					#ifdef GGML_USE_CUBLAS
 | 
				
			||||||
            { OFFLOAD_FUNC,     "GPU (CUDA)" },
 | 
					            { OFFLOAD_FUNC,     "GPU (CUDA)" },
 | 
				
			||||||
            { OFFLOAD_FUNC_KQ,  "GPU (CUDA) KQ" },
 | 
					            { OFFLOAD_FUNC_KQ,  "GPU (CUDA) KQ" },
 | 
				
			||||||
            { OFFLOAD_FUNC_V,   "GPU (CUDA) V" },
 | 
					            { OFFLOAD_FUNC_V,   "GPU (CUDA) V" },
 | 
				
			||||||
            { OFFLOAD_FUNC_NR,  "GPU (CUDA) NR" },
 | 
					            { OFFLOAD_FUNC_NR,  "GPU (CUDA) NR" },
 | 
				
			||||||
            { OFFLOAD_FUNC_EMB, "GPU (CUDA) EMB" },
 | 
					            { OFFLOAD_FUNC_EMB, "GPU (CUDA) EMB" },
 | 
				
			||||||
            { OFFLOAD_FUNC_OUT, "GPU (CUDA) OUT" },
 | 
					 | 
				
			||||||
#else
 | 
					#else
 | 
				
			||||||
            { OFFLOAD_FUNC,     "CPU" },
 | 
					            { OFFLOAD_FUNC,     "CPU" },
 | 
				
			||||||
            { OFFLOAD_FUNC_KQ,  "CPU" },
 | 
					            { OFFLOAD_FUNC_KQ,  "CPU" },
 | 
				
			||||||
            { OFFLOAD_FUNC_V,   "CPU" },
 | 
					            { OFFLOAD_FUNC_V,   "CPU" },
 | 
				
			||||||
            { OFFLOAD_FUNC_NR,  "CPU" },
 | 
					            { OFFLOAD_FUNC_NR,  "CPU" },
 | 
				
			||||||
            { OFFLOAD_FUNC_EMB, "CPU" },
 | 
					            { OFFLOAD_FUNC_EMB, "CPU" },
 | 
				
			||||||
            { OFFLOAD_FUNC_OUT, "CPU" },
 | 
					 | 
				
			||||||
#endif // GGML_USE_CUBLAS
 | 
					#endif // GGML_USE_CUBLAS
 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user