mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-29 08:41:22 +00:00 
			
		
		
		
	vulkan: implement more backpropagation operators (#11914)
* vulkan: implement GGML_OP_ROPE_BACK * vulkan: implement GGML_OP_RMS_NORM_BACK * vulkan: implement GGML_OP_SILU_BACK * vulkan: implement GGML_OP_SOFTMAX_BACK
This commit is contained in:
		| @@ -241,15 +241,18 @@ struct vk_device_struct { | ||||
|     vk_pipeline pipeline_norm_f32; | ||||
|     vk_pipeline pipeline_group_norm_f32; | ||||
|     vk_pipeline pipeline_rms_norm_f32; | ||||
|     vk_pipeline pipeline_rms_norm_back_f32; | ||||
|     vk_pipeline pipeline_gelu_f32; | ||||
|     vk_pipeline pipeline_gelu_quick_f32; | ||||
|     vk_pipeline pipeline_silu_f32; | ||||
|     vk_pipeline pipeline_silu_back_f32; | ||||
|     vk_pipeline pipeline_relu_f32; | ||||
|     vk_pipeline pipeline_leaky_relu_f32; | ||||
|     vk_pipeline pipeline_tanh_f32; | ||||
|     vk_pipeline pipeline_diag_mask_inf_f32; | ||||
|     vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; | ||||
|     vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; | ||||
|     vk_pipeline pipeline_soft_max_back_f32; | ||||
|     vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; | ||||
|     vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; | ||||
|     vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; | ||||
| @@ -504,6 +507,7 @@ struct vk_op_rope_push_constants { | ||||
|     uint32_t s1; | ||||
|     uint32_t s2; | ||||
|     int32_t sections[4]; | ||||
|     uint32_t is_back; | ||||
| }; | ||||
|  | ||||
| struct vk_op_soft_max_push_constants { | ||||
| @@ -2121,6 +2125,7 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); | ||||
|  | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); | ||||
| @@ -2180,6 +2185,7 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); | ||||
| @@ -2190,6 +2196,7 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); | ||||
|  | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); | ||||
| @@ -5283,6 +5290,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const | ||||
|     case GGML_OP_CONT: | ||||
|     case GGML_OP_DUP: | ||||
|         return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type); | ||||
|     case GGML_OP_SILU_BACK: | ||||
|         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { | ||||
|             return ctx->device->pipeline_silu_back_f32; | ||||
|         } | ||||
|         return nullptr; | ||||
|     case GGML_OP_NORM: | ||||
|         if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { | ||||
|             return ctx->device->pipeline_norm_f32; | ||||
| @@ -5298,6 +5310,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const | ||||
|             return ctx->device->pipeline_rms_norm_f32; | ||||
|         } | ||||
|         return nullptr; | ||||
|     case GGML_OP_RMS_NORM_BACK: | ||||
|         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { | ||||
|             return ctx->device->pipeline_rms_norm_back_f32; | ||||
|         } | ||||
|         return nullptr; | ||||
|     case GGML_OP_UNARY: | ||||
|         switch (ggml_get_unary_op(dst)) { | ||||
|             case GGML_UNARY_OP_SILU: | ||||
| @@ -5344,7 +5361,13 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const | ||||
|             return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16; | ||||
|         } | ||||
|         return nullptr; | ||||
|     case GGML_OP_SOFT_MAX_BACK: | ||||
|         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { | ||||
|             return ctx->device->pipeline_soft_max_back_f32; | ||||
|         } | ||||
|         return nullptr; | ||||
|     case GGML_OP_ROPE: | ||||
|     case GGML_OP_ROPE_BACK: | ||||
|         { | ||||
|             const int mode = ((const int32_t *) dst->op_params)[2]; | ||||
|             const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; | ||||
| @@ -5672,7 +5695,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co | ||||
|     switch (op) { | ||||
|     case GGML_OP_NORM: | ||||
|     case GGML_OP_RMS_NORM: | ||||
|     case GGML_OP_RMS_NORM_BACK: | ||||
|     case GGML_OP_SOFT_MAX: | ||||
|     case GGML_OP_SOFT_MAX_BACK: | ||||
|     case GGML_OP_SUM_ROWS: | ||||
|     case GGML_OP_ARGMAX: | ||||
|         { | ||||
| @@ -5696,6 +5721,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co | ||||
|         } break; | ||||
|     case GGML_OP_DIAG_MASK_INF: | ||||
|     case GGML_OP_ROPE: | ||||
|     case GGML_OP_ROPE_BACK: | ||||
|         elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; | ||||
|         break; | ||||
|     case GGML_OP_GET_ROWS: | ||||
| @@ -5791,7 +5817,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co | ||||
|  | ||||
|         ggml_vk_sync_buffers(subctx); | ||||
|         ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); | ||||
|     } else if (op == GGML_OP_ROPE) { | ||||
|     } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { | ||||
|         // Empty src2 is possible in rope, but the shader needs a buffer | ||||
|         vk_subbuffer subbuf_z; | ||||
|         if (use_src2) { | ||||
| @@ -6313,6 +6339,10 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const | ||||
|     }, dryrun); | ||||
| } | ||||
|  | ||||
| static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { | ||||
|     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); | ||||
| } | ||||
|  | ||||
| static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { | ||||
|     float * op_params = (float *)dst->op_params; | ||||
|  | ||||
| @@ -6335,6 +6365,11 @@ static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, | ||||
|     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); | ||||
| } | ||||
|  | ||||
| static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { | ||||
|     float * op_params = (float *)dst->op_params; | ||||
|     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); | ||||
| } | ||||
|  | ||||
| static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { | ||||
|     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); | ||||
| } | ||||
| @@ -6370,7 +6405,12 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, | ||||
|     }, dryrun); | ||||
| } | ||||
|  | ||||
| static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { | ||||
| static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { | ||||
|     float * op_params = (float *)dst->op_params; | ||||
|     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun); | ||||
| } | ||||
|  | ||||
| static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { | ||||
|     const int n_dims        = ((int32_t *) dst->op_params)[1]; | ||||
|     const int mode          = ((int32_t *) dst->op_params)[2]; | ||||
|     // const int n_ctx         = ((int32_t *) dst->op_params)[3]; | ||||
| @@ -6398,7 +6438,7 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, cons | ||||
|         (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], | ||||
|         freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, | ||||
|         src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, | ||||
|         sections[0], sections[1], sections[2], sections[3], | ||||
|         sections[0], sections[1], sections[2], sections[3], backprop | ||||
|     }, dryrun); | ||||
| } | ||||
|  | ||||
| @@ -7319,12 +7359,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod | ||||
|     case GGML_OP_CPY: | ||||
|     case GGML_OP_CONT: | ||||
|     case GGML_OP_DUP: | ||||
|     case GGML_OP_SILU_BACK: | ||||
|     case GGML_OP_NORM: | ||||
|     case GGML_OP_GROUP_NORM: | ||||
|     case GGML_OP_RMS_NORM: | ||||
|     case GGML_OP_RMS_NORM_BACK: | ||||
|     case GGML_OP_DIAG_MASK_INF: | ||||
|     case GGML_OP_SOFT_MAX: | ||||
|     case GGML_OP_SOFT_MAX_BACK: | ||||
|     case GGML_OP_ROPE: | ||||
|     case GGML_OP_ROPE_BACK: | ||||
|     case GGML_OP_MUL_MAT: | ||||
|     case GGML_OP_MUL_MAT_ID: | ||||
|     case GGML_OP_ARGSORT: | ||||
| @@ -7377,13 +7421,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod | ||||
|         case GGML_OP_CPY: | ||||
|         case GGML_OP_CONT: | ||||
|         case GGML_OP_DUP: | ||||
|         case GGML_OP_SILU_BACK: | ||||
|         case GGML_OP_NORM: | ||||
|         case GGML_OP_GROUP_NORM: | ||||
|         case GGML_OP_RMS_NORM: | ||||
|         case GGML_OP_RMS_NORM_BACK: | ||||
|         case GGML_OP_UNARY: | ||||
|         case GGML_OP_DIAG_MASK_INF: | ||||
|         case GGML_OP_SOFT_MAX: | ||||
|         case GGML_OP_SOFT_MAX_BACK: | ||||
|         case GGML_OP_ROPE: | ||||
|         case GGML_OP_ROPE_BACK: | ||||
|         case GGML_OP_ARGSORT: | ||||
|         case GGML_OP_SUM: | ||||
|         case GGML_OP_SUM_ROWS: | ||||
| @@ -7475,6 +7523,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod | ||||
|     case GGML_OP_DUP: | ||||
|         ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun); | ||||
|  | ||||
|         break; | ||||
|     case GGML_OP_SILU_BACK: | ||||
|         ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun); | ||||
|  | ||||
|         break; | ||||
|     case GGML_OP_NORM: | ||||
|         ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun); | ||||
| @@ -7487,6 +7539,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod | ||||
|     case GGML_OP_RMS_NORM: | ||||
|         ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun); | ||||
|  | ||||
|         break; | ||||
|     case GGML_OP_RMS_NORM_BACK: | ||||
|         ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); | ||||
|  | ||||
|         break; | ||||
|     case GGML_OP_UNARY: | ||||
|         switch (ggml_get_unary_op(node)) { | ||||
| @@ -7508,9 +7564,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod | ||||
|     case GGML_OP_SOFT_MAX: | ||||
|         ggml_vk_soft_max(ctx, compute_ctx, src0, src1, node, dryrun); | ||||
|  | ||||
|         break; | ||||
|     case GGML_OP_SOFT_MAX_BACK: | ||||
|         ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun); | ||||
|  | ||||
|         break; | ||||
|     case GGML_OP_ROPE: | ||||
|         ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, dryrun); | ||||
|         ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun); | ||||
|  | ||||
|         break; | ||||
|     case GGML_OP_ROPE_BACK: | ||||
|         ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun); | ||||
|  | ||||
|         break; | ||||
|     case GGML_OP_ARGSORT: | ||||
| @@ -7636,12 +7700,16 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor * | ||||
|     case GGML_OP_CPY: | ||||
|     case GGML_OP_CONT: | ||||
|     case GGML_OP_DUP: | ||||
|     case GGML_OP_SILU_BACK: | ||||
|     case GGML_OP_NORM: | ||||
|     case GGML_OP_GROUP_NORM: | ||||
|     case GGML_OP_RMS_NORM: | ||||
|     case GGML_OP_RMS_NORM_BACK: | ||||
|     case GGML_OP_DIAG_MASK_INF: | ||||
|     case GGML_OP_SOFT_MAX: | ||||
|     case GGML_OP_SOFT_MAX_BACK: | ||||
|     case GGML_OP_ROPE: | ||||
|     case GGML_OP_ROPE_BACK: | ||||
|     case GGML_OP_RESHAPE: | ||||
|     case GGML_OP_VIEW: | ||||
|     case GGML_OP_PERMUTE: | ||||
| @@ -8560,6 +8628,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm | ||||
|         case GGML_OP_REPEAT_BACK: | ||||
|             return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; | ||||
|         case GGML_OP_ROPE: | ||||
|         case GGML_OP_ROPE_BACK: | ||||
|         case GGML_OP_NONE: | ||||
|         case GGML_OP_RESHAPE: | ||||
|         case GGML_OP_VIEW: | ||||
| @@ -8576,6 +8645,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm | ||||
|         case GGML_OP_MUL: | ||||
|         case GGML_OP_DIV: | ||||
|         case GGML_OP_CONCAT: | ||||
|         case GGML_OP_SILU_BACK: | ||||
|         case GGML_OP_RMS_NORM_BACK: | ||||
|         case GGML_OP_UPSCALE: | ||||
|         case GGML_OP_SCALE: | ||||
|         case GGML_OP_SQR: | ||||
| @@ -8585,6 +8656,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm | ||||
|         case GGML_OP_PAD: | ||||
|         case GGML_OP_DIAG_MASK_INF: | ||||
|         case GGML_OP_SOFT_MAX: | ||||
|         case GGML_OP_SOFT_MAX_BACK: | ||||
|         case GGML_OP_ARGSORT: | ||||
|         case GGML_OP_SUM: | ||||
|         case GGML_OP_SUM_ROWS: | ||||
| @@ -8976,15 +9048,22 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { | ||||
|         tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], *(int *)tensor->op_params, ((float *)tensor->op_params)[1]); | ||||
|     } else if (tensor->op == GGML_OP_RMS_NORM) { | ||||
|         tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); | ||||
|     } else if (tensor->op == GGML_OP_RMS_NORM_BACK) { | ||||
|         const float eps = ((float *) tensor->op_params)[0]; | ||||
|         tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); | ||||
|     } else if (tensor->op == GGML_OP_SILU_BACK) { | ||||
|         tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]); | ||||
|     } else if (tensor->op == GGML_OP_SOFT_MAX) { | ||||
|         if (src1 != nullptr) { | ||||
|             tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); | ||||
|         } else { | ||||
|             tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); | ||||
|         } | ||||
|     } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) { | ||||
|         tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); | ||||
|     } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { | ||||
|         tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], *(int *)tensor->op_params); | ||||
|     } else if (tensor->op == GGML_OP_ROPE) { | ||||
|     } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) { | ||||
|         const int n_dims      = ((int32_t *) tensor->op_params)[1]; | ||||
|         const int mode        = ((int32_t *) tensor->op_params)[2]; | ||||
|         //const int n_ctx_ggml       = ((int32_t *) tensor->op_params)[3]; | ||||
| @@ -8997,9 +9076,17 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) { | ||||
|         const float beta_slow       = ((float *) tensor->op_params)[10]; | ||||
|         if (mode & GGML_ROPE_TYPE_MROPE) { | ||||
|             int32_t *sections = ((int32_t *) tensor->op_params) + 11; | ||||
|             tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); | ||||
|             if (tensor->op == GGML_OP_ROPE) { | ||||
|                 tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); | ||||
|             } else { | ||||
|                 tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); | ||||
|             } | ||||
|         } else { | ||||
|             tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); | ||||
|             if (tensor->op == GGML_OP_ROPE) { | ||||
|                 tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); | ||||
|             } else { | ||||
|                 tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); | ||||
|             } | ||||
|         } | ||||
|     } else if (tensor->op == GGML_OP_UNARY) { | ||||
|         switch (ggml_get_unary_op(tensor)) { | ||||
|   | ||||
							
								
								
									
										55
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| #version 450 | ||||
|  | ||||
| #include "generic_head.comp" | ||||
| #include "types.comp" | ||||
|  | ||||
| #extension GL_EXT_control_flow_attributes : enable | ||||
| #define BLOCK_SIZE 512 | ||||
|  | ||||
| layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; | ||||
|  | ||||
| layout (binding = 0) readonly buffer G {A_TYPE data_a[];}; | ||||
| layout (binding = 1) readonly buffer X {B_TYPE data_b[];}; | ||||
| layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; | ||||
|  | ||||
| shared FLOAT_TYPE sum_xx[BLOCK_SIZE]; | ||||
| shared FLOAT_TYPE sum_xg[BLOCK_SIZE]; | ||||
|  | ||||
| void main() { | ||||
|     const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; | ||||
|     const uint tid = gl_LocalInvocationID.x; | ||||
|  | ||||
|     // Compute derivative of x[i]/norm(x) = g[i]/norm(x) - x[i] dot(x,g)/KX / norm(x)^1.5 | ||||
|  | ||||
|     // partial sums for thread in warp | ||||
|     sum_xx[tid] = FLOAT_TYPE(0.0f); | ||||
|     sum_xg[tid] = FLOAT_TYPE(0.0f); | ||||
|  | ||||
|     [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { | ||||
|         const FLOAT_TYPE gi = FLOAT_TYPE(data_a[row*p.KX + col]); | ||||
|         const FLOAT_TYPE xi = FLOAT_TYPE(data_b[row*p.KX + col]); | ||||
|         sum_xx[tid] += xi * xi; | ||||
|         sum_xg[tid] += xi * gi; | ||||
|     } | ||||
|  | ||||
|     // sum up partial sums and write back result | ||||
|     barrier(); | ||||
|     [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { | ||||
|         if (tid < s) { | ||||
|             sum_xx[tid] += sum_xx[tid + s]; | ||||
|             sum_xg[tid] += sum_xg[tid + s]; | ||||
|         } | ||||
|         barrier(); | ||||
|     } | ||||
|  | ||||
|     const FLOAT_TYPE eps = FLOAT_TYPE(p.param1); | ||||
|     const FLOAT_TYPE mean = sum_xx[0] / FLOAT_TYPE(p.KX); | ||||
|     const FLOAT_TYPE scale_g = inversesqrt(mean + eps); | ||||
|     const FLOAT_TYPE scale_x = -scale_g * sum_xg[0] / (sum_xx[0] + FLOAT_TYPE(p.KX) * eps); | ||||
|  | ||||
|     [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { | ||||
|         data_d[row*p.KX + col] = D_TYPE( | ||||
|             scale_g * FLOAT_TYPE(data_a[row*p.KX + col]) + | ||||
|             scale_x * FLOAT_TYPE(data_b[row*p.KX + col])); | ||||
|     } | ||||
| } | ||||
| @@ -29,6 +29,7 @@ layout (push_constant) uniform parameter { | ||||
|     uint s1; | ||||
|     uint s2; | ||||
|     int sections[4]; | ||||
|     uint is_back; | ||||
| } p; | ||||
|  | ||||
| float rope_yarn_ramp(const float low, const float high, const uint i0) { | ||||
| @@ -48,6 +49,10 @@ void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out | ||||
|         // Get n-d magnitude scaling corrected for interpolation | ||||
|         mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); | ||||
|     } | ||||
|     // Backprogagation uses inverted rotation | ||||
|     if (p.is_back != 0) { | ||||
|         theta = -theta; | ||||
|     } | ||||
|     cos_theta = cos(theta) * mscale; | ||||
|     sin_theta = sin(theta) * mscale; | ||||
| } | ||||
|   | ||||
							
								
								
									
										26
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,26 @@ | ||||
| #version 450 | ||||
|  | ||||
| #include "generic_head.comp" | ||||
| #include "types.comp" | ||||
|  | ||||
| #extension GL_EXT_control_flow_attributes : enable | ||||
|  | ||||
| layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; | ||||
|  | ||||
| layout (binding = 0) readonly buffer G {A_TYPE data_g[];}; | ||||
| layout (binding = 1) readonly buffer X {B_TYPE data_x[];}; | ||||
| layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; | ||||
|  | ||||
| void main() { | ||||
|     const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; | ||||
|  | ||||
|     if (i >= p.KX) { | ||||
|         return; | ||||
|     } | ||||
|  | ||||
|     // Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2 | ||||
|  | ||||
|     const float xi = float(data_x[i]); | ||||
|     const float s = 1.0f / (1.0f + exp(-xi)); | ||||
|     data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s))); | ||||
| } | ||||
							
								
								
									
										50
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
| #version 450 | ||||
|  | ||||
| #extension GL_EXT_control_flow_attributes : enable | ||||
|  | ||||
| #include "generic_head.comp" | ||||
| #include "types.comp" | ||||
|  | ||||
| layout(constant_id = 0) const uint BLOCK_SIZE = 32; | ||||
| layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; | ||||
|  | ||||
| // In this shader Y = softmax(X) and X is not provided as input. | ||||
|  | ||||
| layout (binding = 0) readonly buffer G {A_TYPE data_g[];}; | ||||
| layout (binding = 1) readonly buffer Y {B_TYPE data_y[];}; | ||||
| layout (binding = 2) buffer D {D_TYPE data_d[];}; | ||||
|  | ||||
| shared FLOAT_TYPE sum_yg[BLOCK_SIZE]; | ||||
|  | ||||
| void main() { | ||||
|     const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; | ||||
|     const uint tid = gl_LocalInvocationID.x; | ||||
|  | ||||
|     FLOAT_TYPE scale = p.param1; | ||||
|  | ||||
|     // partial sums for thread in warp | ||||
|     sum_yg[tid] = FLOAT_TYPE(0.0f); | ||||
|  | ||||
|     [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { | ||||
|         const FLOAT_TYPE gi = FLOAT_TYPE(data_g[row*p.KX + col]); | ||||
|         const FLOAT_TYPE yi = FLOAT_TYPE(data_y[row*p.KX + col]); | ||||
|         sum_yg[tid] += yi * gi; | ||||
|     } | ||||
|  | ||||
|     // sum up partial sums and write back result | ||||
|     barrier(); | ||||
|     [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { | ||||
|         if (tid < s) { | ||||
|             sum_yg[tid] += sum_yg[tid + s]; | ||||
|         } | ||||
|         barrier(); | ||||
|     } | ||||
|  | ||||
|     const FLOAT_TYPE dot_yg = sum_yg[0]; | ||||
|  | ||||
|     [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { | ||||
|         data_d[row*p.KX + col] = D_TYPE(scale | ||||
|             * (FLOAT_TYPE(data_g[row*p.KX + col]) - dot_yg) | ||||
|             * FLOAT_TYPE(data_y[row*p.KX + col])); | ||||
|     } | ||||
| } | ||||
| @@ -427,6 +427,7 @@ void process_shaders() { | ||||
|     string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); | ||||
|     string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); | ||||
|     string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); | ||||
|     string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); | ||||
|  | ||||
|     string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); | ||||
|     string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); | ||||
| @@ -477,6 +478,7 @@ void process_shaders() { | ||||
|     string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); | ||||
|     string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); | ||||
|     string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); | ||||
|     string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); | ||||
|     string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); | ||||
|     string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); | ||||
|     string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); | ||||
| @@ -485,6 +487,7 @@ void process_shaders() { | ||||
|  | ||||
|     string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); | ||||
|     string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); | ||||
|     string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); | ||||
|  | ||||
|     string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); | ||||
|     string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Rémy O
					Rémy O