mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-29 08:41:22 +00:00
finetune: SGD optimizer, more CLI args (#13873)
* examples/finetune -opt SGD (stochastic gradient descent) memory opt
add unit tested GGML_OPT_OPTIMIZER_SGD to ggml - avoids allocating
m, v tensors.
support finetune.cpp arg -opt SGD (or sgd). (default adamw as before)
llama 3.2-1b-F32 result: observed 11gb gpu ram (41 sec/epoch)
when using SGD instead of 19gb (55 sec/epoch) using adamw.
(wikipedia 100 lines finetune)
(
using the same GPU memory, adamw can only do before OOM 512
batch/context, reaching:
train: [███████▉] data=0000140/0000140 loss=0.02575±0.00099 acc=99.52±0.03% t=00:00:47 ETA=00:00:00
val: [███████▉] data=0000008/0000008 loss=4.76565±0.28810 acc=41.46±0.77% t=00:00:00 ETA=00:00:00
SGD is superior, though it converges slower, with max before OOM 1728
batch/context (esp see the better validation perf):
train: [███████▉] data=0000039/0000039 loss=0.00371±0.00010 acc=99.96±0.01% t=00:00:41 ETA=00:00:00
val: [███████▉] data=0000003/0000003 loss=5.11406±0.76034 acc=48.01±0.69% t=00:00:01 ETA=00:00:00
)
note: when finetuning long enough (or w/ enough -lr),
validation accuracy *eventually* drops ('catastrophic forgetting')
-lr-half (halflife) option useful for SGD to avoid oscillation or
super slow underdamped learning (makes setting -lr more forgiving).
terminal -lr for now is set by lr-halvings i.e. if you want at most
1/8 the inital -lr you set -lr-halvings 3.
note: objective loss not directly comparable between adamw, sgd? -
check perplexity or accuracy or consider relative improvements
for convergence
new finetune args -wd 1e-9 to enable weight decay in sgd or adamw,
and max -epochs N (default 2 as before)
cache (1 - wd*alpha) in 'adamw' opt struct -
no noticeable perf benefit, disabled (still done
for new SGD though)
since opt. memory is pre-allocated, the ggml_opt_get_optimizer_params
would probably be able to change between SGD and AdamW with each epoch
but would need to use adamw for the first (unconfirmed - no cmdline arg
to set such a policy yet)
test-opt checks adamw as before and now sgd (except for a few disabled
tests for sgd only; probably just needs logging values and adding
alternate reference values); tolerance on the 'regression'
test is broader for sgd (so we don't need many more epochs)
* Vulkan: Implement GGML_OP_OPT_STEP_SGD
* tests: Fix OPT_STEP_SGD test-backend-ops
* SGD op param store weight-decay and not 1-alpha*wd
* minor + cosmetic changes
* fix vulkan sgd
* try CI fix
---------
Co-authored-by: 0cc4m <picard12@live.de>
Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
@@ -510,6 +510,7 @@ struct vk_device_struct {
|
||||
vk_pipeline pipeline_rwkv_wkv6_f32;
|
||||
vk_pipeline pipeline_rwkv_wkv7_f32;
|
||||
vk_pipeline pipeline_opt_step_adamw_f32;
|
||||
vk_pipeline pipeline_opt_step_sgd_f32;
|
||||
vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
|
||||
vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
|
||||
vk_pipeline pipeline_conv2d_dw_whcn_f32;
|
||||
@@ -3123,6 +3124,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||
|
||||
// conv2d
|
||||
for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) {
|
||||
uint32_t conv2d_WG_SIZE = 256;
|
||||
@@ -7193,6 +7196,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||
return ctx->device->pipeline_opt_step_adamw_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_opt_step_sgd_f32;
|
||||
}
|
||||
return nullptr;
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||
return ctx->device->pipeline_leaky_relu_f32;
|
||||
@@ -7692,6 +7700,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||
ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz);
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
} else if (op == GGML_OP_OPT_STEP_SGD) {
|
||||
// OPT_STEP_SGD works on src0, it does not need dst
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements);
|
||||
} else if (use_src2) {
|
||||
ggml_vk_sync_buffers(subctx);
|
||||
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
|
||||
@@ -8045,6 +8057,12 @@ static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& su
|
||||
);
|
||||
}
|
||||
|
||||
static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) {
|
||||
const size_t n = ggml_nelements(dst->src[0]);
|
||||
|
||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun);
|
||||
}
|
||||
|
||||
static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||
int * op_params = (int *)dst->op_params;
|
||||
|
||||
@@ -9598,6 +9616,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
break;
|
||||
default:
|
||||
std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl;
|
||||
@@ -9662,6 +9681,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_CONV_2D:
|
||||
case GGML_OP_CONV_2D_DW:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
{
|
||||
// These operations all go through ggml_vk_op_f32, so short-circuit and
|
||||
// do the only thing needed for the dryrun.
|
||||
@@ -9911,6 +9931,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun);
|
||||
|
||||
break;
|
||||
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun);
|
||||
|
||||
break;
|
||||
default:
|
||||
return false;
|
||||
@@ -10014,8 +10039,8 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
case GGML_OP_REPEAT:
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
buf = tensor->buffer;
|
||||
|
||||
break;
|
||||
case GGML_OP_UNARY:
|
||||
switch (ggml_get_unary_op(tensor)) {
|
||||
@@ -11154,6 +11179,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_SIN:
|
||||
case GGML_OP_COS:
|
||||
case GGML_OP_CLAMP:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
case GGML_OP_OPT_STEP_SGD:
|
||||
return op->src[0]->type == GGML_TYPE_F32;
|
||||
case GGML_OP_UPSCALE:
|
||||
case GGML_OP_ACC:
|
||||
@@ -11175,8 +11203,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||
case GGML_OP_POOL_2D:
|
||||
case GGML_OP_RWKV_WKV6:
|
||||
case GGML_OP_RWKV_WKV7:
|
||||
case GGML_OP_LEAKY_RELU:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
return true;
|
||||
case GGML_OP_CONV_TRANSPOSE_1D:
|
||||
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
|
||||
@@ -11774,6 +11800,10 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
|
||||
src_clone[0]->flags = src0->flags;
|
||||
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
|
||||
src_clone[2], src_clone[3], src_clone[4]);
|
||||
} else if (tensor->op == GGML_OP_OPT_STEP_SGD) {
|
||||
src_clone[0]->flags = src0->flags;
|
||||
tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1],
|
||||
src_clone[2]);
|
||||
}
|
||||
else {
|
||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||
|
||||
Reference in New Issue
Block a user