mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-05 09:36:52 +00:00
ggml : add ggml_set_rows
Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using indices from 'c'. ref: #8366
This commit is contained in:
committed by
Georgi Gerganov
parent
7b50d589a8
commit
c1a581a10b
@@ -936,6 +936,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"TRANSPOSE",
|
||||
"GET_ROWS",
|
||||
"GET_ROWS_BACK",
|
||||
"SET_ROWS",
|
||||
"DIAG",
|
||||
"DIAG_MASK_INF",
|
||||
"DIAG_MASK_ZERO",
|
||||
@@ -986,7 +987,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||
"OPT_STEP_ADAMW",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
@@ -1032,6 +1033,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"transpose(x)",
|
||||
"get_rows(x)",
|
||||
"get_rows_back(x)",
|
||||
"set_rows(x)",
|
||||
"diag(x)",
|
||||
"diag_mask_inf(x)",
|
||||
"diag_mask_zero(x)",
|
||||
@@ -1082,7 +1084,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"adamw(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
|
||||
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
@@ -3395,6 +3397,28 @@ struct ggml_tensor * ggml_get_rows_back(
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_set_rows
|
||||
|
||||
struct ggml_tensor * ggml_set_rows(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c) {
|
||||
GGML_ASSERT(b->ne[2] == c->ne[1]);
|
||||
GGML_ASSERT(c->ne[3] == 1);
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(c->type == GGML_TYPE_I32);
|
||||
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_SET_ROWS;
|
||||
result->src[0] = b;
|
||||
result->src[1] = c;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_diag
|
||||
|
||||
struct ggml_tensor * ggml_diag(
|
||||
|
||||
Reference in New Issue
Block a user