mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-07 09:57:00 +00:00
ggml : ggml_set_rows support broadcast
This commit is contained in:
@@ -3410,12 +3410,20 @@ struct ggml_tensor * ggml_set_rows(
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_tensor * c) {
|
||||
GGML_ASSERT(b->ne[2] == c->ne[1]);
|
||||
GGML_ASSERT(a->ne[0] == b->ne[0]);
|
||||
GGML_ASSERT(a->ne[2] == b->ne[2]);
|
||||
GGML_ASSERT(a->ne[3] == b->ne[3]);
|
||||
GGML_ASSERT(b->ne[1] == c->ne[0]);
|
||||
GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
|
||||
GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
|
||||
GGML_ASSERT(c->ne[3] == 1);
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F16);
|
||||
GGML_ASSERT(a->type == GGML_TYPE_F16); // TODO: relax
|
||||
GGML_ASSERT(b->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(c->type == GGML_TYPE_I64);
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(a));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(b));
|
||||
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_SET_ROWS;
|
||||
|
||||
Reference in New Issue
Block a user