mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-06 09:46:50 +00:00
finetune: SGD optimizer, more CLI args (#13873)
* examples/finetune -opt SGD (stochastic gradient descent) memory opt
add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating
m, v tensors.
support finetune.cpp arg -opt SGD (or sgd). (default adamw as before)
llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch)
when using SGD instead of 19gb (55 sec/epoch) using adamw.
(wikipedia 100 lines finetune)
(
using the same GPU memory, adamw can only do before OOM 512
batch/context, reaching:
train: [███████▉] data=0000140/0000140 loss=0.02575±0.00099 acc=99.52±0.03% t=00:00:47 ETA=00:00:00
val: [███████▉] data=0000008/0000008 loss=4.76565±0.28810 acc=41.46±0.77% t=00:00:00 ETA=00:00:00
SGD is superior, though it converges slower, with max before OOM 1728
batch/context (esp see the better validation perf):
train: [███████▉] data=0000039/0000039 loss=0.00371±0.00010 acc=99.96±0.01% t=00:00:41 ETA=00:00:00
val: [███████▉] data=0000003/0000003 loss=5.11406±0.76034 acc=48.01±0.69% t=00:00:01 ETA=00:00:00
)
note: when finetuning long enough (or w/ enough -lr),
validation accuracy *eventually* drops ('catastrophic forgetting')
-lr-half (halflife) option useful for SGD to avoid oscillation or
super slow underdamped learning (makes setting -lr more forgiving).
terminal -lr for now is set by lr-halvings i.e. if you want at most
1/8 the inital -lr you set -lr-halvings 3.
note: objective loss not directly comparable between adamw, sgd? -
check perplexity or accuracy or consider relative improvements
for convergence
new finetune args -wd 1e-9 to enable weight decay in sgd or adamw,
and max -epochs N (default 2 as before)
cache (1 - wd*alpha) in 'adamw' opt struct -
no noticeable perf benefit, disabled (still done
for new SGD though)
since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params
would probably be able to change between SGD and AdamW with each epoch
but would need to use adamw for the first (unconfirmed - no cmdline arg
to set such a policy yet)
test-opt checks adamw as before and now sgd (except for a few disabled
tests for sgd only; probably just needs logging values and adding
alternate reference values); tolerance on the 'regression'
test is broader for sgd (so we don't need many more epochs)
* Vulkan: Implement GGML_OP_OPT_STEP_SGD
* tests: Fix OPT_STEP_SGD test-backend-ops
* SGD op param store weight-decay and not 1-alpha*wd
* minor + cosmetic changes
* fix vulkan sgd
* try CI fix
---------
Co-authored-by: 0cc4m <picard12@live.de>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
@@ -10330,6 +10330,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
|
||||
|
||||
const float alpha = adamw_params_ptr[0];
|
||||
const float beta1 = adamw_params_ptr[1];
|
||||
const float beta2 = adamw_params_ptr[2];
|
||||
@@ -10337,7 +10338,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
||||
const float wd = adamw_params_ptr[4];
|
||||
const float beta1h = adamw_params_ptr[5];
|
||||
const float beta2h = adamw_params_ptr[6];
|
||||
|
||||
const float keep = 1.f - alpha * wd;
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
@@ -10360,7 +10361,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
|
||||
// The weight decay is applied independently of the Adam momenta m and v.
|
||||
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
||||
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
||||
w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
|
||||
w[i00] = w[i00] * keep - alpha * mh / vh;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -10382,3 +10383,63 @@ void ggml_compute_forward_opt_step_adamw(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src0_grad = dst->src[1];
|
||||
const ggml_tensor * sgd_params = dst->src[2];
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
||||
GGML_ASSERT(ggml_nelements(sgd_params) == 2);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1) / nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr * ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
// using adamw param subset we care about - alpha, wd - could have a separate struct
|
||||
const float * sgd_params_ptr = ggml_get_data_f32(sgd_params);
|
||||
const float alpha = sgd_params_ptr[0];
|
||||
const float keep = 1.f - alpha * sgd_params_ptr[1];
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir / (ne02 * ne01);
|
||||
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
|
||||
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
|
||||
|
||||
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
|
||||
|
||||
float * w = (float *) ((char *) src0->data + offset); // weight
|
||||
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
|
||||
|
||||
for (int i00 = 0; i00 < ne00; ++i00) {
|
||||
w[i00] = w[i00] * keep - alpha * g[i00];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_opt_step_sgd_f32(params, dst);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error - sgd is F32 only");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user