mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	Merge branch 'master' into compilade/refactor-kv-cache
This commit is contained in:
		@@ -28,10 +28,11 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    std::string result2;
 | 
			
		||||
 | 
			
		||||
    // init
 | 
			
		||||
    llama_model * model;
 | 
			
		||||
    llama_context * ctx;
 | 
			
		||||
    llama_init_result llama_init = llama_init_from_gpt_params(params);
 | 
			
		||||
 | 
			
		||||
    llama_model * model = llama_init.model;
 | 
			
		||||
    llama_context * ctx = llama_init.context;
 | 
			
		||||
 | 
			
		||||
    std::tie(model, ctx) = llama_init_from_gpt_params(params);
 | 
			
		||||
    if (model == nullptr || ctx == nullptr) {
 | 
			
		||||
        fprintf(stderr, "%s : failed to init\n", __func__);
 | 
			
		||||
        return 1;
 | 
			
		||||
@@ -47,7 +48,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    // save state (rng, logits, embedding and kv_cache) to file
 | 
			
		||||
    {
 | 
			
		||||
        std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
 | 
			
		||||
        const size_t written = llama_state_get_data(ctx, state_mem.data());
 | 
			
		||||
        const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size());
 | 
			
		||||
 | 
			
		||||
        FILE *fp_write = fopen("dump_state.bin", "wb");
 | 
			
		||||
        fwrite(state_mem.data(), 1, written, fp_write);
 | 
			
		||||
@@ -99,13 +100,16 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    // load state (rng, logits, embedding and kv_cache) from file
 | 
			
		||||
    {
 | 
			
		||||
        std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
 | 
			
		||||
        std::vector<uint8_t> state_mem;
 | 
			
		||||
 | 
			
		||||
        FILE * fp_read = fopen("dump_state.bin", "rb");
 | 
			
		||||
        fseek(fp_read, 0, SEEK_END);
 | 
			
		||||
        state_mem.resize(ftell(fp_read));
 | 
			
		||||
        fseek(fp_read, 0, SEEK_SET);
 | 
			
		||||
        const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
 | 
			
		||||
        fclose(fp_read);
 | 
			
		||||
 | 
			
		||||
        if (read != llama_state_set_data(ctx2, state_mem.data())) {
 | 
			
		||||
        if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) {
 | 
			
		||||
            fprintf(stderr, "\n%s : failed to read state\n", __func__);
 | 
			
		||||
            llama_free(ctx2);
 | 
			
		||||
            llama_free_model(model);
 | 
			
		||||
@@ -159,13 +163,16 @@ int main(int argc, char ** argv) {
 | 
			
		||||
 | 
			
		||||
    // load state (rng, logits, embedding and kv_cache) from file
 | 
			
		||||
    {
 | 
			
		||||
        std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
 | 
			
		||||
        std::vector<uint8_t> state_mem;
 | 
			
		||||
 | 
			
		||||
        FILE * fp_read = fopen("dump_state.bin", "rb");
 | 
			
		||||
        fseek(fp_read, 0, SEEK_END);
 | 
			
		||||
        state_mem.resize(ftell(fp_read));
 | 
			
		||||
        fseek(fp_read, 0, SEEK_SET);
 | 
			
		||||
        const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
 | 
			
		||||
        fclose(fp_read);
 | 
			
		||||
 | 
			
		||||
        if (read != llama_state_set_data(ctx3, state_mem.data())) {
 | 
			
		||||
        if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) {
 | 
			
		||||
            fprintf(stderr, "\n%s : failed to read state\n", __func__);
 | 
			
		||||
            llama_free(ctx3);
 | 
			
		||||
            llama_free_model(model);
 | 
			
		||||
@@ -182,7 +189,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
    {
 | 
			
		||||
        // save kv of seq 0
 | 
			
		||||
        std::vector<uint8_t> seq_store(llama_state_seq_get_size(ctx3, 0));
 | 
			
		||||
        const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), 0);
 | 
			
		||||
        const size_t ncopy = llama_state_seq_get_data(ctx3, seq_store.data(), seq_store.size(), 0);
 | 
			
		||||
        if (ncopy != seq_store.size()) {
 | 
			
		||||
            fprintf(stderr, "\n%s : seq copy data length %zd does not match expected length %zd\n", __func__, ncopy, seq_store.size());
 | 
			
		||||
            llama_free(ctx3);
 | 
			
		||||
@@ -196,7 +203,7 @@ int main(int argc, char ** argv) {
 | 
			
		||||
        fprintf(stderr, "%s : kv cache cleared\n", __func__);
 | 
			
		||||
 | 
			
		||||
        // restore kv into seq 1
 | 
			
		||||
        const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), 1);
 | 
			
		||||
        const size_t nset = llama_state_seq_set_data(ctx3, seq_store.data(), seq_store.size(), 1);
 | 
			
		||||
        if (nset != seq_store.size()) {
 | 
			
		||||
            fprintf(stderr, "\n%s : seq set data length %zd does not match expected length %zd\n", __func__, nset, seq_store.size());
 | 
			
		||||
            llama_free(ctx3);
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user