mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-04 09:32:00 +00:00
finetune: SGD optimizer, more CLI args (#13873)
* examples/finetune -opt SGD (stochastic gradient descent) memory opt
add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating
m, v tensors.
support finetune.cpp arg -opt SGD (or sgd). (default adamw as before)
llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch)
when using SGD instead of 19gb (55 sec/epoch) using adamw.
(wikipedia 100 lines finetune)
(
using the same GPU memory, adamw can only do before OOM 512
batch/context, reaching:
train: [███████▉] data=0000140/0000140 loss=0.02575±0.00099 acc=99.52±0.03% t=00:00:47 ETA=00:00:00
val: [███████▉] data=0000008/0000008 loss=4.76565±0.28810 acc=41.46±0.77% t=00:00:00 ETA=00:00:00
SGD is superior, though it converges slower, with max before OOM 1728
batch/context (esp see the better validation perf):
train: [███████▉] data=0000039/0000039 loss=0.00371±0.00010 acc=99.96±0.01% t=00:00:41 ETA=00:00:00
val: [███████▉] data=0000003/0000003 loss=5.11406±0.76034 acc=48.01±0.69% t=00:00:01 ETA=00:00:00
)
note: when finetuning long enough (or w/ enough -lr),
validation accuracy *eventually* drops ('catastrophic forgetting')
-lr-half (halflife) option useful for SGD to avoid oscillation or
super slow underdamped learning (makes setting -lr more forgiving).
terminal -lr for now is set by lr-halvings i.e. if you want at most
1/8 the inital -lr you set -lr-halvings 3.
note: objective loss not directly comparable between adamw, sgd? -
check perplexity or accuracy or consider relative improvements
for convergence
new finetune args -wd 1e-9 to enable weight decay in sgd or adamw,
and max -epochs N (default 2 as before)
cache (1 - wd*alpha) in 'adamw' opt struct -
no noticeable perf benefit, disabled (still done
for new SGD though)
since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params
would probably be able to change between SGD and AdamW with each epoch
but would need to use adamw for the first (unconfirmed - no cmdline arg
to set such a policy yet)
test-opt checks adamw as before and now sgd (except for a few disabled
tests for sgd only; probably just needs logging values and adding
alternate reference values); tolerance on the 'regression'
test is broader for sgd (so we don't need many more epochs)
* Vulkan: Implement GGML_OP_OPT_STEP_SGD
* tests: Fix OPT_STEP_SGD test-backend-ops
* SGD op param store weight-decay and not 1-alpha*wd
* minor + cosmetic changes
* fix vulkan sgd
* try CI fix
---------
Co-authored-by: 0cc4m <picard12@live.de>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
@@ -64,9 +64,11 @@ struct ggml_opt_context {
|
||||
int32_t opt_i = 0;
|
||||
bool loss_per_datapoint = false;
|
||||
|
||||
ggml_opt_get_optimizer_params get_opt_pars = nullptr;
|
||||
void * get_opt_pars_ud = nullptr;
|
||||
struct ggml_tensor * adamw_params = nullptr;
|
||||
ggml_opt_get_optimizer_params get_opt_pars = nullptr;
|
||||
void * get_opt_pars_ud = nullptr;
|
||||
struct ggml_tensor * opt_step_params = nullptr; // Stores output of get_opt_pars.
|
||||
|
||||
enum ggml_opt_optimizer_type optimizer = GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||
};
|
||||
|
||||
struct ggml_opt_result {
|
||||
@@ -229,9 +231,13 @@ struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * us
|
||||
result.adamw.eps = 1e-8f;
|
||||
result.adamw.wd = 0.0f;
|
||||
|
||||
result.sgd.alpha = 1e-3f;
|
||||
result.sgd.wd = 0.0f;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
struct ggml_opt_optimizer_params ggml_opt_get_constant_optimizer_params(void * userdata) {
|
||||
return *((struct ggml_opt_optimizer_params *) userdata);
|
||||
}
|
||||
@@ -249,6 +255,7 @@ struct ggml_opt_params ggml_opt_default_params(
|
||||
/*opt_period =*/ 1,
|
||||
/*get_opt_pars =*/ ggml_opt_get_default_optimizer_params,
|
||||
/*get_opt_pars_ud =*/ nullptr,
|
||||
/*optimizer =*/ GGML_OPT_OPTIMIZER_TYPE_ADAMW,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -316,9 +323,14 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||
GGML_ASSERT(opt_ctx->ctx_compute && "no compute context set, either use static graphs or set one with ggml_opt_prepare_alloc");
|
||||
GGML_ASSERT((!opt_ctx->static_graphs || opt_ctx->inputs->data) && "when using static graphs the inputs must be allocated statically");
|
||||
|
||||
const enum ggml_opt_optimizer_type optimizer = opt_ctx->optimizer;
|
||||
|
||||
const bool accumulate = opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_GRAD &&
|
||||
!(opt_ctx->static_graphs && opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT && opt_ctx->opt_period == 1);
|
||||
|
||||
const bool need_momenta = opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT &&
|
||||
opt_ctx->optimizer == GGML_OPT_OPTIMIZER_TYPE_ADAMW;
|
||||
|
||||
ggml_set_input(opt_ctx->inputs);
|
||||
ggml_set_output(opt_ctx->outputs);
|
||||
|
||||
@@ -340,8 +352,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||
// - pred (if using static graphs)
|
||||
// - ncorrect (if using static graphs, 2 tensors).
|
||||
constexpr size_t n_loss = 1;
|
||||
const size_t tensors_per_param = (accumulate ? 1 : 0) +
|
||||
(opt_ctx->build_type_alloc == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0);
|
||||
const size_t tensors_per_param = (accumulate ? 1 : 0) + (need_momenta ? 2 : 0);
|
||||
const size_t tensors_const = opt_ctx->static_graphs ? 9 : 0;
|
||||
const size_t size_meta = (n_loss + tensors_per_param*n_param + tensors_const) * ggml_tensor_overhead();
|
||||
struct ggml_init_params params = {
|
||||
@@ -458,7 +469,7 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||
}
|
||||
}
|
||||
|
||||
if (opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
|
||||
if (need_momenta && opt_ctx->build_type_alloc >= GGML_OPT_BUILD_TYPE_OPT) {
|
||||
opt_ctx->grad_m.resize(n_nodes);
|
||||
opt_ctx->grad_v.resize(n_nodes);
|
||||
for (int i = 0; i < n_nodes; ++i) {
|
||||
@@ -492,23 +503,36 @@ static void ggml_opt_build(ggml_opt_context_t opt_ctx) {
|
||||
// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
|
||||
opt_ctx->gb_opt = ggml_graph_dup(opt_ctx->ctx_compute, opt_ctx->gb_grad, /*force_grads =*/ true);
|
||||
|
||||
opt_ctx->adamw_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, 7);
|
||||
ggml_set_input(opt_ctx->adamw_params);
|
||||
ggml_set_name(opt_ctx->adamw_params, "adamw_params");
|
||||
|
||||
opt_ctx->opt_step_params = ggml_new_tensor_1d(opt_ctx->ctx_cpu, GGML_TYPE_F32, need_momenta ? 7 : 2);
|
||||
ggml_tensor * adamw_params = opt_ctx->opt_step_params;
|
||||
ggml_set_input(adamw_params);
|
||||
const char * optimizer_name = ggml_opt_optimizer_name(opt_ctx->optimizer);
|
||||
ggml_format_name(adamw_params, "%s_params", optimizer_name);
|
||||
for (int i = opt_ctx->gf->n_nodes-1; i >= 0; --i) {
|
||||
struct ggml_tensor * node = opt_ctx->gb_opt->nodes[i];
|
||||
struct ggml_tensor * grad = ggml_graph_get_grad(opt_ctx->gb_opt, node);
|
||||
|
||||
if (grad && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
|
||||
struct ggml_tensor * m = opt_ctx->grad_m[i];
|
||||
struct ggml_tensor * v = opt_ctx->grad_v[i];
|
||||
struct ggml_tensor * opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, opt_ctx->adamw_params);
|
||||
|
||||
ggml_set_name(m, (std::string("AdamW m for ") + std::string(node->name)).c_str());
|
||||
ggml_set_name(v, (std::string("AdamW v for ") + std::string(node->name)).c_str());
|
||||
ggml_set_name(opt_step, (std::string("AdamW step for ") + std::string(node->name)).c_str());
|
||||
|
||||
struct ggml_tensor * m = nullptr;
|
||||
struct ggml_tensor * v = nullptr;
|
||||
if (need_momenta) {
|
||||
m = opt_ctx->grad_m[i];
|
||||
v = opt_ctx->grad_v[i];
|
||||
ggml_format_name(m, "AdamW m for %s", node->name);
|
||||
ggml_format_name(v, "AdamW v for %s", node->name);
|
||||
}
|
||||
struct ggml_tensor * opt_step;
|
||||
switch (optimizer) {
|
||||
case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
||||
opt_step = ggml_opt_step_adamw(opt_ctx->ctx_compute, node, grad, m, v, adamw_params);
|
||||
break;
|
||||
case GGML_OPT_OPTIMIZER_TYPE_SGD:
|
||||
opt_step = ggml_opt_step_sgd(opt_ctx->ctx_compute, node, grad, adamw_params);
|
||||
break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
ggml_format_name(opt_step, "%s step for %s", optimizer_name, node->name);
|
||||
ggml_build_forward_expand(opt_ctx->gb_opt, opt_step);
|
||||
}
|
||||
}
|
||||
@@ -534,6 +558,7 @@ ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) {
|
||||
result->opt_period = params.opt_period;
|
||||
result->get_opt_pars = params.get_opt_pars;
|
||||
result->get_opt_pars_ud = params.get_opt_pars_ud;
|
||||
result->optimizer = params.optimizer;
|
||||
|
||||
GGML_ASSERT(result->opt_period >= 1);
|
||||
|
||||
@@ -756,29 +781,43 @@ void ggml_opt_alloc(ggml_opt_context_t opt_ctx, bool backward) {
|
||||
void ggml_opt_eval(ggml_opt_context_t opt_ctx, ggml_opt_result_t result) {
|
||||
GGML_ASSERT(opt_ctx->eval_ready);
|
||||
if (opt_ctx->allocated_graph == opt_ctx->gb_opt) {
|
||||
struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
||||
const ggml_opt_optimizer_params & opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud);
|
||||
|
||||
GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
|
||||
switch (opt_ctx->optimizer) {
|
||||
case GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
|
||||
GGML_ASSERT(opt_pars.adamw.alpha > 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.eps >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.wd >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.adamw.wd <= 1.0f);
|
||||
|
||||
// beta1, beta2 after applying warmup
|
||||
const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
|
||||
const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
|
||||
// beta1, beta2 after applying warmup
|
||||
const float beta1h = 1.0f / (1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter));
|
||||
const float beta2h = 1.0f / (1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter));
|
||||
|
||||
float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params);
|
||||
adamw_par_data[0] = opt_pars.adamw.alpha;
|
||||
adamw_par_data[1] = opt_pars.adamw.beta1;
|
||||
adamw_par_data[2] = opt_pars.adamw.beta2;
|
||||
adamw_par_data[3] = opt_pars.adamw.eps;
|
||||
adamw_par_data[4] = opt_pars.adamw.wd;
|
||||
adamw_par_data[5] = beta1h;
|
||||
adamw_par_data[6] = beta2h;
|
||||
float * adamw_par_data = ggml_get_data_f32(opt_ctx->opt_step_params);
|
||||
adamw_par_data[0] = opt_pars.adamw.alpha;
|
||||
adamw_par_data[1] = opt_pars.adamw.beta1;
|
||||
adamw_par_data[2] = opt_pars.adamw.beta2;
|
||||
adamw_par_data[3] = opt_pars.adamw.eps;
|
||||
adamw_par_data[4] = opt_pars.adamw.wd;
|
||||
adamw_par_data[5] = beta1h;
|
||||
adamw_par_data[6] = beta2h;
|
||||
} break;
|
||||
case GGML_OPT_OPTIMIZER_TYPE_SGD: {
|
||||
GGML_ASSERT(opt_pars.sgd.alpha > 0.0f);
|
||||
GGML_ASSERT(opt_pars.sgd.wd >= 0.0f);
|
||||
GGML_ASSERT(opt_pars.sgd.wd <= 1.0f);
|
||||
float * sgd = ggml_get_data_f32(opt_ctx->opt_step_params);
|
||||
sgd[0] = opt_pars.sgd.alpha;
|
||||
sgd[1] = opt_pars.sgd.wd;
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
|
||||
ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy);
|
||||
@@ -963,6 +1002,7 @@ void ggml_opt_fit(
|
||||
ggml_tensor * outputs,
|
||||
ggml_opt_dataset_t dataset,
|
||||
enum ggml_opt_loss_type loss_type,
|
||||
enum ggml_opt_optimizer_type optimizer,
|
||||
ggml_opt_get_optimizer_params get_opt_pars,
|
||||
int64_t nepoch,
|
||||
int64_t nbatch_logical,
|
||||
@@ -993,6 +1033,7 @@ void ggml_opt_fit(
|
||||
params.opt_period = opt_period;
|
||||
params.get_opt_pars = get_opt_pars;
|
||||
params.get_opt_pars_ud = &epoch;
|
||||
params.optimizer = optimizer;
|
||||
ggml_opt_context_t opt_ctx = ggml_opt_init(params);
|
||||
|
||||
// Shuffling the data is generally useful but there is only a point if not all data is used in a single batch.
|
||||
@@ -1035,3 +1076,18 @@ void ggml_opt_fit(
|
||||
ggml_opt_result_free(result_train);
|
||||
ggml_opt_result_free(result_val);
|
||||
}
|
||||
|
||||
enum ggml_opt_optimizer_type ggml_opt_context_optimizer_type(ggml_opt_context_t c) {
|
||||
return c->optimizer;
|
||||
}
|
||||
|
||||
GGML_API const char * ggml_opt_optimizer_name(enum ggml_opt_optimizer_type o) {
|
||||
switch (o) {
|
||||
case GGML_OPT_OPTIMIZER_TYPE_ADAMW:
|
||||
return "adamw";
|
||||
case GGML_OPT_OPTIMIZER_TYPE_SGD:
|
||||
return "sgd";
|
||||
default:
|
||||
return "undefined";
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user