diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 94ab1ec0f5..be505748af 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -50,6 +50,7 @@ #include "ggml-cuda/upscale.cuh" #include "ggml-cuda/wkv.cuh" #include "ggml-cuda/gla.cuh" +#include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml.h" @@ -2416,6 +2417,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SET_ROWS: ggml_cuda_op_set_rows(ctx, dst); break; + case GGML_OP_SET: + ggml_cuda_op_set(ctx, dst); + break; case GGML_OP_DUP: ggml_cuda_dup(ctx, dst); break; @@ -3842,6 +3846,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g op->src[0]->type == GGML_TYPE_F32 && (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); } break; + case GGML_OP_SET: + { + const ggml_type t = op->type; + return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) && + t == op->src[0]->type && + t == op->src[1]->type; + } break; case GGML_OP_CPY: { ggml_type src0_type = op->src[0]->type; diff --git a/ggml/src/ggml-cuda/set.cu b/ggml/src/ggml-cuda/set.cu new file mode 100644 index 0000000000..04bfe07ba0 --- /dev/null +++ b/ggml/src/ggml-cuda/set.cu @@ -0,0 +1,39 @@ +#include "set.cuh" +#include "cpy.cuh" + +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)); + GGML_ASSERT(src1->type == src0->type); + GGML_ASSERT(dst ->type == src0->type); + + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + + const size_t nb1 = ((int32_t *) dst->op_params)[0]; + const size_t nb2 = ((int32_t *) dst->op_params)[1]; + const size_t nb3 = ((int32_t *) dst->op_params)[2]; + const size_t offset = ((int32_t *) dst->op_params)[3]; + const bool inplace= (bool) ((int32_t *) dst->op_params)[4]; + + if (!inplace) { + ggml_cuda_cpy(ctx, src0, dst); + } + + ggml_tensor dst_view = *dst; + dst_view.data = (void *)((char *)dst->data + offset); + dst_view.ne[0] = src1->ne[0]; + dst_view.ne[1] = src1->ne[1]; + dst_view.ne[2] = src1->ne[2]; + dst_view.ne[3] = src1->ne[3]; + + dst_view.nb[0] = ggml_element_size(dst); + dst_view.nb[1] = nb1; + dst_view.nb[2] = nb2; + dst_view.nb[3] = nb3; + + ggml_cuda_cpy(ctx, src1, &dst_view); +} diff --git a/ggml/src/ggml-cuda/set.cuh b/ggml/src/ggml-cuda/set.cuh new file mode 100644 index 0000000000..dd09529f3e --- /dev/null +++ b/ggml/src/ggml-cuda/set.cuh @@ -0,0 +1,7 @@ +#pragma once + +#include "common.cuh" + +#define CUDA_SET_BLOCK_SIZE 256 + +void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);