mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-05 09:36:52 +00:00
cuda: add SET operation support (#16804)
* feat(cuda): add GGML_OP_SET support Implement CUDA kernel for SET operation with f32 support. All tests passing (14598/14598). * cuda(set): add I32 support; keep F32 * refactor(cuda): use ggml_cuda_cpy to unify SET operator logic and remove code duplication * Update ggml/src/ggml-cuda/ggml-cuda.cu Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * Update ggml/src/ggml-cuda/set.cu Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
@@ -50,6 +50,7 @@
|
|||||||
#include "ggml-cuda/upscale.cuh"
|
#include "ggml-cuda/upscale.cuh"
|
||||||
#include "ggml-cuda/wkv.cuh"
|
#include "ggml-cuda/wkv.cuh"
|
||||||
#include "ggml-cuda/gla.cuh"
|
#include "ggml-cuda/gla.cuh"
|
||||||
|
#include "ggml-cuda/set.cuh"
|
||||||
#include "ggml-cuda/set-rows.cuh"
|
#include "ggml-cuda/set-rows.cuh"
|
||||||
#include "ggml-cuda/pad_reflect_1d.cuh"
|
#include "ggml-cuda/pad_reflect_1d.cuh"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
@@ -2416,6 +2417,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
|||||||
case GGML_OP_SET_ROWS:
|
case GGML_OP_SET_ROWS:
|
||||||
ggml_cuda_op_set_rows(ctx, dst);
|
ggml_cuda_op_set_rows(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_SET:
|
||||||
|
ggml_cuda_op_set(ctx, dst);
|
||||||
|
break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
ggml_cuda_dup(ctx, dst);
|
ggml_cuda_dup(ctx, dst);
|
||||||
break;
|
break;
|
||||||
@@ -3842,6 +3846,13 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||||||
op->src[0]->type == GGML_TYPE_F32 &&
|
op->src[0]->type == GGML_TYPE_F32 &&
|
||||||
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
|
(op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_SET:
|
||||||
|
{
|
||||||
|
const ggml_type t = op->type;
|
||||||
|
return (t == GGML_TYPE_F32 || t == GGML_TYPE_I32) &&
|
||||||
|
t == op->src[0]->type &&
|
||||||
|
t == op->src[1]->type;
|
||||||
|
} break;
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
{
|
{
|
||||||
ggml_type src0_type = op->src[0]->type;
|
ggml_type src0_type = op->src[0]->type;
|
||||||
|
|||||||
39
ggml/src/ggml-cuda/set.cu
Normal file
39
ggml/src/ggml-cuda/set.cu
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
#include "set.cuh"
|
||||||
|
#include "cpy.cuh"
|
||||||
|
|
||||||
|
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
|
||||||
|
GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
|
||||||
|
GGML_ASSERT(src1->type == src0->type);
|
||||||
|
GGML_ASSERT(dst ->type == src0->type);
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src1));
|
||||||
|
|
||||||
|
const size_t nb1 = ((int32_t *) dst->op_params)[0];
|
||||||
|
const size_t nb2 = ((int32_t *) dst->op_params)[1];
|
||||||
|
const size_t nb3 = ((int32_t *) dst->op_params)[2];
|
||||||
|
const size_t offset = ((int32_t *) dst->op_params)[3];
|
||||||
|
const bool inplace= (bool) ((int32_t *) dst->op_params)[4];
|
||||||
|
|
||||||
|
if (!inplace) {
|
||||||
|
ggml_cuda_cpy(ctx, src0, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor dst_view = *dst;
|
||||||
|
dst_view.data = (void *)((char *)dst->data + offset);
|
||||||
|
dst_view.ne[0] = src1->ne[0];
|
||||||
|
dst_view.ne[1] = src1->ne[1];
|
||||||
|
dst_view.ne[2] = src1->ne[2];
|
||||||
|
dst_view.ne[3] = src1->ne[3];
|
||||||
|
|
||||||
|
dst_view.nb[0] = ggml_element_size(dst);
|
||||||
|
dst_view.nb[1] = nb1;
|
||||||
|
dst_view.nb[2] = nb2;
|
||||||
|
dst_view.nb[3] = nb3;
|
||||||
|
|
||||||
|
ggml_cuda_cpy(ctx, src1, &dst_view);
|
||||||
|
}
|
||||||
7
ggml/src/ggml-cuda/set.cuh
Normal file
7
ggml/src/ggml-cuda/set.cuh
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "common.cuh"
|
||||||
|
|
||||||
|
#define CUDA_SET_BLOCK_SIZE 256
|
||||||
|
|
||||||
|
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
Reference in New Issue
Block a user