mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	Add --rope-scale parameter (#2544)
* common.cpp : Add --rope-scale parameter * README.md : Add info about using linear rope scaling
This commit is contained in:
		| @@ -194,6 +194,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||||||
|                 break; |                 break; | ||||||
|             } |             } | ||||||
|             params.rope_freq_scale = std::stof(argv[i]); |             params.rope_freq_scale = std::stof(argv[i]); | ||||||
|  |         } else if (arg == "--rope-scale") { | ||||||
|  |             if (++i >= argc) { | ||||||
|  |                 invalid_param = true; | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             params.rope_freq_scale = 1.0f/std::stof(argv[i]); | ||||||
|         } else if (arg == "--memory-f32") { |         } else if (arg == "--memory-f32") { | ||||||
|             params.memory_f16 = false; |             params.memory_f16 = false; | ||||||
|         } else if (arg == "--top-p") { |         } else if (arg == "--top-p") { | ||||||
| @@ -564,8 +570,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | |||||||
|     fprintf(stdout, "  --cfg-negative-prompt PROMPT \n"); |     fprintf(stdout, "  --cfg-negative-prompt PROMPT \n"); | ||||||
|     fprintf(stdout, "                        negative prompt to use for guidance. (default: empty)\n"); |     fprintf(stdout, "                        negative prompt to use for guidance. (default: empty)\n"); | ||||||
|     fprintf(stdout, "  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); |     fprintf(stdout, "  --cfg-scale N         strength of guidance (default: %f, 1.0 = disable)\n", params.cfg_scale); | ||||||
|     fprintf(stdout, "  --rope-freq-base N    RoPE base frequency (default: %.1f)\n", params.rope_freq_base); |     fprintf(stdout, "  --rope-scale N        RoPE context linear scaling factor, inverse of --rope-freq-scale (default: %g)\n", 1.0f/params.rope_freq_scale); | ||||||
|     fprintf(stdout, "  --rope-freq-scale N   RoPE frequency scaling factor (default: %g)\n", params.rope_freq_scale); |     fprintf(stdout, "  --rope-freq-base N    RoPE base frequency, used by NTK-aware scaling (default: %.1f)\n", params.rope_freq_base); | ||||||
|  |     fprintf(stdout, "  --rope-freq-scale N   RoPE frequency linear scaling factor, inverse of --rope-scale (default: %g)\n", params.rope_freq_scale); | ||||||
|     fprintf(stdout, "  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); |     fprintf(stdout, "  --ignore-eos          ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); | ||||||
|     fprintf(stdout, "  --no-penalize-nl      do not penalize newline token\n"); |     fprintf(stdout, "  --no-penalize-nl      do not penalize newline token\n"); | ||||||
|     fprintf(stdout, "  --memory-f32          use f32 instead of f16 for memory key+value (default: disabled)\n"); |     fprintf(stdout, "  --memory-f32          use f32 instead of f16 for memory key+value (default: disabled)\n"); | ||||||
|   | |||||||
| @@ -140,6 +140,12 @@ The `--ctx-size` option allows you to set the size of the prompt context used by | |||||||
|  |  | ||||||
| -   `-c N, --ctx-size N`: Set the size of the prompt context (default: 512). The LLaMA models were built with a context of 2048, which will yield the best results on longer input/inference. However, increasing the context size beyond 2048 may lead to unpredictable results. | -   `-c N, --ctx-size N`: Set the size of the prompt context (default: 512). The LLaMA models were built with a context of 2048, which will yield the best results on longer input/inference. However, increasing the context size beyond 2048 may lead to unpredictable results. | ||||||
|  |  | ||||||
|  | ### Extended Context Size | ||||||
|  |  | ||||||
|  | Some fine-tuned models have extened the context length by scaling RoPE. For example, if the original pretrained model have a context length (max sequence length) of 4096 (4k) and the fine-tuned model have 32k. That is a scaling factor of 8, and should work by setting the above `--ctx-size` to 32768 (32k) and `--rope-scale` to 8. | ||||||
|  |  | ||||||
|  | - `--rope-scale N`: Where N is the linear scaling factor used by the fine-tuned model. | ||||||
|  |  | ||||||
| ### Keep Prompt | ### Keep Prompt | ||||||
|  |  | ||||||
| The `--keep` option allows users to retain the original prompt when the model runs out of context, ensuring a connection to the initial instruction or conversation topic is maintained. | The `--keep` option allows users to retain the original prompt when the model runs out of context, ensuring a connection to the initial instruction or conversation topic is maintained. | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user