ggml : add ggml_set_rows

Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using
indices from 'c'.

ref: #8366
This commit is contained in:
Radoslav Gerganov
2025-06-19 11:04:23 +03:00
committed by Georgi Gerganov
parent 7b50d589a8
commit c1a581a10b
5 changed files with 98 additions and 2 deletions

View File

@@ -936,6 +936,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"TRANSPOSE",
"GET_ROWS",
"GET_ROWS_BACK",
"SET_ROWS",
"DIAG",
"DIAG_MASK_INF",
"DIAG_MASK_ZERO",
@@ -986,7 +987,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW",
};
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
@@ -1032,6 +1033,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"transpose(x)",
"get_rows(x)",
"get_rows_back(x)",
"set_rows(x)",
"diag(x)",
"diag_mask_inf(x)",
"diag_mask_zero(x)",
@@ -1082,7 +1084,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)",
};
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -3395,6 +3397,28 @@ struct ggml_tensor * ggml_get_rows_back(
return result;
}
// ggml_set_rows
struct ggml_tensor * ggml_set_rows(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c) {
GGML_ASSERT(b->ne[2] == c->ne[1]);
GGML_ASSERT(c->ne[3] == 1);
GGML_ASSERT(a->type == GGML_TYPE_F16);
GGML_ASSERT(b->type == GGML_TYPE_F32);
GGML_ASSERT(c->type == GGML_TYPE_I32);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
result->op = GGML_OP_SET_ROWS;
result->src[0] = b;
result->src[1] = c;
return result;
}
// ggml_diag
struct ggml_tensor * ggml_diag(