diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 2c6de09ab9..b9a8d37715 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1379,6 +1379,15 @@ extern "C" { struct ggml_tensor * b, // row indices struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape + // a TD [n_embd, ne1, ne2, ne3] + // b TS [n_embd, n_rows, ne02, ne03] | ne02 == ne2, ne03 == ne3 + // c I64 [n_rows, ne11, ne12, 1] | c[i] in [0, ne1) + // + // broadcast: + // ne2 % ne11 == 0 + // ne3 % ne12 == 0 + // + // return view(a) GGML_API struct ggml_tensor * ggml_set_rows( struct ggml_context * ctx, struct ggml_tensor * a, // destination diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 3dd57ccefc..1c8bdfbcef 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -4530,12 +4530,14 @@ static void ggml_compute_forward_set_rows_f32( GGML_TENSOR_BINARY_OP_LOCALS const int64_t nc = ne00; - const int64_t nr = ggml_nelements(src1); + const int64_t nr = ne01; assert(ne0 == nc); - assert(ne02 == ne11); - assert(nb00 == sizeof(float)); - assert(ggml_nrows(src0) == nr); + assert(ne2 == ne02); + assert(ne3 == ne03); + assert(src0->type == GGML_TYPE_F32); + assert(ne02 % ne11 == 0); + assert(ne03 % ne12 == 0); const int ith = params->ith; const int nth = params->nth; @@ -4547,17 +4549,22 @@ static void ggml_compute_forward_set_rows_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int64_t i = ir0; i < ir1; ++i) { - const int64_t i12 = i/(ne11*ne10); - const int64_t i11 = (i - i12*ne11*ne10)/ne10; - const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); - const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); + for (int64_t i03 = 0; i03 < ne03; ++i03) { + for (int64_t i02 = 0; i02 < ne02; ++i02) { + for (int64_t i = ir0; i < ir1; ++i) { + const int64_t i12 = i03%ne12; + const int64_t i11 = i02%ne11; + const int64_t i10 = i; - GGML_ASSERT(i01 >= 0 && i01 < ne1); + const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); - ggml_cpu_fp32_to_fp16( - (const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03), - (ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i11*nb2 + i12*nb3), nc); + GGML_ASSERT(i01 >= 0 && i01 < ne1); + + ggml_cpu_fp32_to_fp16( + (const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03), + (ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), nc); + } + } } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 4779565110..12d9ad70ac 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3410,12 +3410,20 @@ struct ggml_tensor * ggml_set_rows( struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_tensor * c) { - GGML_ASSERT(b->ne[2] == c->ne[1]); + GGML_ASSERT(a->ne[0] == b->ne[0]); + GGML_ASSERT(a->ne[2] == b->ne[2]); + GGML_ASSERT(a->ne[3] == b->ne[3]); + GGML_ASSERT(b->ne[1] == c->ne[0]); + GGML_ASSERT(b->ne[2] % c->ne[1] == 0); + GGML_ASSERT(b->ne[3] % c->ne[2] == 0); GGML_ASSERT(c->ne[3] == 1); - GGML_ASSERT(a->type == GGML_TYPE_F16); + GGML_ASSERT(a->type == GGML_TYPE_F16); // TODO: relax GGML_ASSERT(b->type == GGML_TYPE_F32); GGML_ASSERT(c->type == GGML_TYPE_I64); + GGML_ASSERT(ggml_is_contiguous_rows(a)); + GGML_ASSERT(ggml_is_contiguous_rows(b)); + struct ggml_tensor * result = ggml_view_tensor(ctx, a); result->op = GGML_OP_SET_ROWS;