mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	gguf : calculate n_mult
This commit is contained in:
		@@ -514,16 +514,30 @@ struct ggml_context * ctx_data = NULL;
 | 
				
			|||||||
    return gguf_get_arr_n(gguf_ctx, i);
 | 
					    return gguf_get_arr_n(gguf_ctx, i);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    int find_n_mult(const int n_ff, const int n_embd) {
 | 
				
			||||||
 | 
					        int n_mults[3] = {8192, 1, -1};
 | 
				
			||||||
 | 
					        for (int i = 0; i < 3; ++i) {
 | 
				
			||||||
 | 
					            int calc_ff = (((8 * n_embd) / 3 + n_mults[i] - 1) / n_mults[i]) * n_mults[i];
 | 
				
			||||||
 | 
					            if (calc_ff == n_ff) {
 | 
				
			||||||
 | 
					                return n_mults[i];
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        throw std::runtime_error(format("failed to find n_mult for n_ff = %d and n_emb = %d\n", n_ff, n_embd));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    void read_hparams() {
 | 
					    void read_hparams() {
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // TODO make keysconstants in header
 | 
					        // TODO make keysconstants in header
 | 
				
			||||||
        // TODO: read all hparams from file
 | 
					        // TODO: read all hparams from file
 | 
				
			||||||
        hparams.n_vocab = read_n_vocab();
 | 
					        hparams.n_vocab = read_n_vocab();
 | 
				
			||||||
 | 
					        hparams.n_ctx   = read_u32("llama.context_length");
 | 
				
			||||||
        hparams.n_embd  = read_u32("llama.embedding_length");
 | 
					        hparams.n_embd  = read_u32("llama.embedding_length");
 | 
				
			||||||
        //hparams.n_mult  = file.read_u32();
 | 
					        uint32_t n_ff    = read_u32("llama.feed_forward_length");
 | 
				
			||||||
 | 
					        hparams.n_mult  = find_n_mult(n_ff, hparams.n_embd);
 | 
				
			||||||
        hparams.n_head  = read_u32("llama.attention.head_count");
 | 
					        hparams.n_head  = read_u32("llama.attention.head_count");
 | 
				
			||||||
        hparams.n_layer = read_u32("llama.layer_count");
 | 
					        hparams.n_layer = read_u32("llama.layer_count");
 | 
				
			||||||
        //hparams.n_rot   = file.read_u32();
 | 
					        hparams.n_rot   = hparams.n_embd / hparams.n_head;
 | 
				
			||||||
        //hparams.ftype   = (enum llama_ftype) file.read_u32();
 | 
					        //hparams.ftype   = (enum llama_ftype) file.read_u32();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // LLaMAv2
 | 
					        // LLaMAv2
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user