mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	load starcoder weight
This commit is contained in:
		
							
								
								
									
										69
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										69
									
								
								llama.cpp
									
									
									
									
									
								
							@@ -937,6 +937,7 @@ struct llama_hparams {
 | 
			
		||||
    uint32_t n_layer     = 32;
 | 
			
		||||
    uint32_t n_rot       = 64;
 | 
			
		||||
    uint32_t n_ff        = 11008;
 | 
			
		||||
    uint32_t n_positions = -1;    // StarCoder
 | 
			
		||||
 | 
			
		||||
    float f_norm_eps     = 1e-5;
 | 
			
		||||
    float f_norm_rms_eps = 1e-5;
 | 
			
		||||
@@ -1068,6 +1069,7 @@ struct llama_model {
 | 
			
		||||
    llama_vocab   vocab;
 | 
			
		||||
 | 
			
		||||
    struct ggml_tensor * tok_embeddings;
 | 
			
		||||
    struct ggml_tensor * pos_embeddings;
 | 
			
		||||
 | 
			
		||||
    struct ggml_tensor * output_norm;
 | 
			
		||||
    struct ggml_tensor * output_norm_b;
 | 
			
		||||
@@ -2184,6 +2186,73 @@ static void llm_load_tensors(
 | 
			
		||||
                        layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, backend_split);
 | 
			
		||||
                        layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split);
 | 
			
		||||
 | 
			
		||||
                        if (backend == GGML_BACKEND_GPU) {
 | 
			
		||||
                            vram_weights +=
 | 
			
		||||
                                ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) +
 | 
			
		||||
                                ggml_nbytes(layer.wqkv)      + ggml_nbytes(layer.wo)          +
 | 
			
		||||
                                ggml_nbytes(layer.w2)        + ggml_nbytes(layer.w3);
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
                } break;
 | 
			
		||||
            case LLM_ARCH_STARCODER:
 | 
			
		||||
                {
 | 
			
		||||
                    model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
 | 
			
		||||
                    model.pos_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_POS_EMBD, "weight"), {n_embd, hparams.n_positions}, GGML_BACKEND_CPU);
 | 
			
		||||
 | 
			
		||||
                    // output
 | 
			
		||||
                    {
 | 
			
		||||
                        ggml_backend backend_norm;
 | 
			
		||||
                        ggml_backend backend_output;
 | 
			
		||||
 | 
			
		||||
                        if (n_gpu_layers > int(n_layer)) {
 | 
			
		||||
                            // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
 | 
			
		||||
                            // on Windows however this is detrimental unless everything is on the GPU
 | 
			
		||||
#ifndef _WIN32
 | 
			
		||||
                            backend_norm = low_vram ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
 | 
			
		||||
#else
 | 
			
		||||
                            backend_norm = low_vram || n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
 | 
			
		||||
#endif // _WIN32
 | 
			
		||||
 | 
			
		||||
                            backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
 | 
			
		||||
                        } else {
 | 
			
		||||
                            backend_norm   = GGML_BACKEND_CPU;
 | 
			
		||||
                            backend_output = GGML_BACKEND_CPU;
 | 
			
		||||
                        }
 | 
			
		||||
 | 
			
		||||
                        model.output_norm   = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd},          backend_norm);
 | 
			
		||||
                        model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd},          backend_norm);
 | 
			
		||||
                        model.output        = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, backend_output);
 | 
			
		||||
 | 
			
		||||
                        if (backend_norm == GGML_BACKEND_GPU) {
 | 
			
		||||
                            vram_weights += ggml_nbytes(model.output_norm);
 | 
			
		||||
                            vram_weights += ggml_nbytes(model.output_norm_b);
 | 
			
		||||
                        }
 | 
			
		||||
                        if (backend_output == GGML_BACKEND_GPU_SPLIT) {
 | 
			
		||||
                            vram_weights += ggml_nbytes(model.output);
 | 
			
		||||
                        }
 | 
			
		||||
                    }
 | 
			
		||||
 | 
			
		||||
                    const uint32_t n_ff = hparams.n_ff;
 | 
			
		||||
 | 
			
		||||
                    const int i_gpu_start = n_layer - n_gpu_layers;
 | 
			
		||||
 | 
			
		||||
                    model.layers.resize(n_layer);
 | 
			
		||||
 | 
			
		||||
                    for (uint32_t i = 0; i < n_layer; ++i) {
 | 
			
		||||
                        const ggml_backend backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
 | 
			
		||||
                        const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
 | 
			
		||||
 | 
			
		||||
                        auto & layer = model.layers[i];
 | 
			
		||||
 | 
			
		||||
                        layer.attn_norm   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, backend);
 | 
			
		||||
                        layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM,   "bias", i),   {n_embd}, backend);
 | 
			
		||||
 | 
			
		||||
                        layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd_gqa}, backend_split);
 | 
			
		||||
                        layer.wo   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd},                backend_split);
 | 
			
		||||
 | 
			
		||||
                        layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, backend_split);
 | 
			
		||||
                        layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, backend_split);
 | 
			
		||||
 | 
			
		||||
                        if (backend == GGML_BACKEND_GPU) {
 | 
			
		||||
                            vram_weights +=
 | 
			
		||||
                                ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) +
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user