ggml : ggml_set_rows support broadcast

This commit is contained in:
Georgi Gerganov
2025-06-22 10:28:07 +03:00
parent 313a444b22
commit df71c803b4
3 changed files with 39 additions and 15 deletions

View File

@@ -3410,12 +3410,20 @@ struct ggml_tensor * ggml_set_rows(
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c) {
GGML_ASSERT(b->ne[2] == c->ne[1]);
GGML_ASSERT(a->ne[0] == b->ne[0]);
GGML_ASSERT(a->ne[2] == b->ne[2]);
GGML_ASSERT(a->ne[3] == b->ne[3]);
GGML_ASSERT(b->ne[1] == c->ne[0]);
GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
GGML_ASSERT(c->ne[3] == 1);
GGML_ASSERT(a->type == GGML_TYPE_F16);
GGML_ASSERT(a->type == GGML_TYPE_F16); // TODO: relax
GGML_ASSERT(b->type == GGML_TYPE_F32);
GGML_ASSERT(c->type == GGML_TYPE_I64);
GGML_ASSERT(ggml_is_contiguous_rows(a));
GGML_ASSERT(ggml_is_contiguous_rows(b));
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
result->op = GGML_OP_SET_ROWS;