mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-29 08:41:22 +00:00 
			
		
		
		
	 5cdb27e091
			
		
	
	5cdb27e091
	
	
	
		
			
			* examples/finetune -opt SGD (stochastic gradient descent) memory opt
add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating
m, v tensors.
support finetune.cpp arg -opt SGD (or sgd). (default adamw as before)
llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch)
when using SGD instead of 19gb (55 sec/epoch) using adamw.
(wikipedia 100 lines finetune)
(
using the same GPU memory, adamw can only do before OOM 512
batch/context, reaching:
train: [███████▉] data=0000140/0000140 loss=0.02575±0.00099 acc=99.52±0.03% t=00:00:47 ETA=00:00:00
val:   [███████▉] data=0000008/0000008 loss=4.76565±0.28810 acc=41.46±0.77% t=00:00:00 ETA=00:00:00
SGD is superior, though it converges slower, with max before OOM 1728
batch/context (esp see the better validation perf):
train: [███████▉] data=0000039/0000039 loss=0.00371±0.00010 acc=99.96±0.01% t=00:00:41 ETA=00:00:00
val:   [███████▉] data=0000003/0000003 loss=5.11406±0.76034 acc=48.01±0.69% t=00:00:01 ETA=00:00:00
)
note: when finetuning long enough (or w/ enough -lr),
validation accuracy *eventually* drops ('catastrophic forgetting')
-lr-half (halflife) option useful for SGD to avoid oscillation or
super slow underdamped learning (makes setting -lr more forgiving).
terminal -lr for now is set by lr-halvings i.e. if you want at most
1/8 the inital -lr you set -lr-halvings 3.
note: objective loss not directly comparable between adamw, sgd? -
check perplexity or accuracy or consider relative improvements
for convergence
new finetune args -wd 1e-9 to enable weight decay in sgd or adamw,
and max -epochs N (default 2 as before)
cache (1 - wd*alpha) in 'adamw' opt struct -
no noticeable perf benefit, disabled (still done
for new SGD though)
since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params
would probably be able to change between SGD and AdamW with each epoch
but would need to use adamw for the first (unconfirmed - no cmdline arg
to set such a policy yet)
test-opt checks adamw as before and now sgd (except for a few disabled
tests for sgd only; probably just needs logging values and adding
alternate reference values);  tolerance on the 'regression'
test is broader for sgd (so we don't need many more epochs)
* Vulkan: Implement GGML_OP_OPT_STEP_SGD
* tests: Fix OPT_STEP_SGD test-backend-ops
* SGD op param store weight-decay and not 1-alpha*wd
* minor + cosmetic changes
* fix vulkan sgd
* try CI fix
---------
Co-authored-by: 0cc4m <picard12@live.de>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
		
	
		
			
				
	
	
		
			97 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
			
		
		
	
	
			97 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			C++
		
	
	
	
	
	
| #include "arg.h"
 | |
| #include "common.h"
 | |
| #include "log.h"
 | |
| #include "llama.h"
 | |
| 
 | |
| #include <cmath>
 | |
| #include <cstdio>
 | |
| #include <cstring>
 | |
| #include <ctime>
 | |
| #include <vector>
 | |
| 
 | |
| #if defined(_MSC_VER)
 | |
| #pragma warning(disable: 4244 4267)  // possible loss of data
 | |
| #endif
 | |
| 
 | |
| int main(int argc, char ** argv) {
 | |
|     common_params params;
 | |
|     params.escape = false;
 | |
| 
 | |
|     if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_FINETUNE)) {
 | |
|         return 1;
 | |
|     }
 | |
| 
 | |
|     if (params.use_mmap) {
 | |
|         LOG_INF("%s: force disabling memory mapping because it would result in-read-only pointers to the weights\n",
 | |
|                 __func__);
 | |
|         params.use_mmap = false;
 | |
|     }
 | |
|     if (params.cache_type_k != GGML_TYPE_F32) {
 | |
|         LOG_INF("%s: force changing k cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
 | |
|         params.cache_type_k = GGML_TYPE_F32;
 | |
|     }
 | |
|     if (params.cache_type_v != GGML_TYPE_F32) {
 | |
|         LOG_INF("%s: force changing v cache type to f32 due to a lack of f16 support for OUT_PROD\n", __func__);
 | |
|         params.cache_type_v = GGML_TYPE_F32;
 | |
|     }
 | |
| 
 | |
|     common_init();
 | |
|     llama_backend_init();
 | |
|     llama_numa_init(params.numa);
 | |
|     // load the model and apply lora adapter, if any
 | |
|     common_init_result   llama_init = common_init_from_params(params);
 | |
|     llama_model_ptr    & model      = llama_init.model;
 | |
|     llama_context_ptr  & ctx        = llama_init.context;
 | |
| 
 | |
|     if (model == NULL) {
 | |
|         LOG_ERR("%s: unable to load model\n", __func__);
 | |
|         return 1;
 | |
|     }
 | |
| 
 | |
|     // print system information
 | |
|     {
 | |
|         LOG_INF("\n");
 | |
|         LOG_INF("%s\n", common_params_get_system_info(params).c_str());
 | |
|     }
 | |
| 
 | |
|     std::vector<llama_token> tokens  = common_tokenize(ctx.get(), params.prompt, true);
 | |
|     ggml_opt_dataset_t       dataset = common_opt_dataset_init(ctx.get(), tokens, llama_n_ctx(ctx.get()) / 2);
 | |
| 
 | |
|     struct lr_opt & lr = params.lr;
 | |
|     LOG_INF("-optimizer %s -lr0 %.2g -wd %.2g -lr-min %.2g -min-epochs %.2g -epochs %d -period %.2g -val %.2g\n",
 | |
|             ggml_opt_optimizer_name(params.optimizer), (double) lr.lr0, (double) lr.wd, (double) lr.lr_min, (double) lr.decay_epochs,
 | |
|             (unsigned) lr.epochs, (double) params.n_batch / params.n_ubatch, (double) params.val_split);
 | |
| 
 | |
|     struct llama_opt_params lopt_params{
 | |
|         /*n_ctx_train     =*/0,
 | |
|         /*param_filter    =*/llama_opt_param_filter_all,
 | |
|         /*param_filter_ud =*/nullptr,
 | |
|         /*get_opt_pars    =*/common_opt_lr_pars,
 | |
|         /*get_opt_pars_ud =*/¶ms.lr,
 | |
|         /*optimizer_type  =*/params.optimizer,
 | |
|     };
 | |
|     llama_opt_init(ctx.get(), model.get(), lopt_params);
 | |
| 
 | |
|     const int64_t idata_split = ggml_opt_dataset_ndata(dataset) * (1.0f - params.val_split);
 | |
| 
 | |
|     ggml_opt_result_t result_train = ggml_opt_result_init();
 | |
|     ggml_opt_result_t result_eval  = ggml_opt_result_init();
 | |
| 
 | |
|     for (lr.epoch = 0; lr.epoch < lr.epochs; ++lr.epoch) {
 | |
|         llama_opt_epoch(ctx.get(), dataset, result_train, result_eval, idata_split,
 | |
|                         ggml_opt_epoch_callback_progress_bar, ggml_opt_epoch_callback_progress_bar);
 | |
|         fprintf(stderr, "\n");
 | |
| 
 | |
|         ggml_opt_result_reset(result_train);
 | |
|         ggml_opt_result_reset(result_eval);
 | |
|     }
 | |
|     ggml_opt_result_free(result_train);
 | |
|     ggml_opt_result_free(result_eval);
 | |
| 
 | |
|     llama_model_save_to_file(model.get(), params.out_file.c_str());
 | |
| 
 | |
|     llama_backend_free();
 | |
| 
 | |
|     return 0;
 | |
| }
 |