mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	metal: add support for opt_step_sgd (#16539)
* metal: add support for opt_step_sgd * add newline to pass EditorConfig check
This commit is contained in:
		| @@ -1519,3 +1519,22 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_ | ||||
|  | ||||
|     return res; | ||||
| } | ||||
|  | ||||
| ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { | ||||
|     assert(op->op == GGML_OP_OPT_STEP_SGD); | ||||
|  | ||||
|     char base[256]; | ||||
|     char name[256]; | ||||
|  | ||||
|     snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type)); | ||||
|     snprintf(name, 256, "%s", base); | ||||
|  | ||||
|     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); | ||||
|     if (res) { | ||||
|         return res; | ||||
|     } | ||||
|  | ||||
|     res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); | ||||
|  | ||||
|     return res; | ||||
| } | ||||
|   | ||||
| @@ -136,6 +136,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d    (ggml_me | ||||
| ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange            (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||||
| ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); | ||||
| ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw    (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||||
| ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd      (ggml_metal_library_t lib, const struct ggml_tensor * op); | ||||
|  | ||||
| ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( | ||||
|         ggml_metal_library_t lib, | ||||
|   | ||||
| @@ -800,6 +800,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te | ||||
|                 }; | ||||
|             } | ||||
|         case GGML_OP_OPT_STEP_ADAMW: | ||||
|         case GGML_OP_OPT_STEP_SGD: | ||||
|             return has_simdgroup_reduction; | ||||
|         default: | ||||
|             return false; | ||||
|   | ||||
| @@ -781,4 +781,8 @@ typedef struct { | ||||
|     int64_t  np; | ||||
| } ggml_metal_kargs_opt_step_adamw; | ||||
|  | ||||
| typedef struct { | ||||
|     int64_t  np; | ||||
| } ggml_metal_kargs_opt_step_sgd; | ||||
|  | ||||
| #endif // GGML_METAL_IMPL | ||||
|   | ||||
| @@ -418,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { | ||||
|             { | ||||
|                 n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx); | ||||
|             } break; | ||||
|         case GGML_OP_OPT_STEP_SGD: | ||||
|             { | ||||
|                 n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx); | ||||
|             } break; | ||||
|        default: | ||||
|             { | ||||
|                 GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op)); | ||||
| @@ -3469,3 +3473,37 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { | ||||
|  | ||||
|     return 1; | ||||
| } | ||||
|  | ||||
| int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { | ||||
|     ggml_tensor * op = ctx->node(idx); | ||||
|  | ||||
|     ggml_metal_library_t lib = ctx->lib; | ||||
|     ggml_metal_encoder_t enc = ctx->enc; | ||||
|  | ||||
|     GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); | ||||
|     GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); | ||||
|     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne); | ||||
|     GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb); | ||||
|  | ||||
|     ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); | ||||
|  | ||||
|     const int64_t np = ggml_nelements(op->src[0]); | ||||
|     ggml_metal_kargs_opt_step_sgd args = { | ||||
|         /*.np =*/ np, | ||||
|     }; | ||||
|  | ||||
|     int ida = 0; | ||||
|  | ||||
|     ggml_metal_encoder_set_pipeline(enc, pipeline); | ||||
|     ggml_metal_encoder_set_bytes   (enc, &args, sizeof(args), ida++); | ||||
|     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); | ||||
|     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); | ||||
|     ggml_metal_encoder_set_buffer  (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); | ||||
|  | ||||
|     const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); | ||||
|     const int64_t n = (np + nth - 1) / nth; | ||||
|  | ||||
|     ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1); | ||||
|  | ||||
|     return 1; | ||||
| } | ||||
|   | ||||
| @@ -80,6 +80,7 @@ int ggml_metal_op_argmax            (ggml_metal_op_t ctx, int idx); | ||||
| int ggml_metal_op_argsort           (ggml_metal_op_t ctx, int idx); | ||||
| int ggml_metal_op_leaky_relu        (ggml_metal_op_t ctx, int idx); | ||||
| int ggml_metal_op_opt_step_adamw    (ggml_metal_op_t ctx, int idx); | ||||
| int ggml_metal_op_opt_step_sgd      (ggml_metal_op_t ctx, int idx); | ||||
|  | ||||
| #ifdef __cplusplus | ||||
| } | ||||
|   | ||||
| @@ -8806,3 +8806,17 @@ kernel void kernel_opt_step_adamw_f32( | ||||
|  | ||||
|     x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh; | ||||
| } | ||||
|  | ||||
| kernel void kernel_opt_step_sgd_f32( | ||||
|         constant    ggml_metal_kargs_opt_step_sgd & args, | ||||
|         device       float * x, | ||||
|         device const float * g, | ||||
|         device const float * pars, | ||||
|         uint        gid[[thread_position_in_grid]]) { | ||||
|  | ||||
|     if (gid >= args.np) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Sam/Samuel
					Sam/Samuel