ggml : ggml_set_rows support broadcast

This commit is contained in:
Georgi Gerganov
2025-06-22 10:28:07 +03:00
parent 313a444b22
commit df71c803b4
3 changed files with 39 additions and 15 deletions

View File

@@ -4530,12 +4530,14 @@ static void ggml_compute_forward_set_rows_f32(
GGML_TENSOR_BINARY_OP_LOCALS
const int64_t nc = ne00;
const int64_t nr = ggml_nelements(src1);
const int64_t nr = ne01;
assert(ne0 == nc);
assert(ne02 == ne11);
assert(nb00 == sizeof(float));
assert(ggml_nrows(src0) == nr);
assert(ne2 == ne02);
assert(ne3 == ne03);
assert(src0->type == GGML_TYPE_F32);
assert(ne02 % ne11 == 0);
assert(ne03 % ne12 == 0);
const int ith = params->ith;
const int nth = params->nth;
@@ -4547,17 +4549,22 @@ static void ggml_compute_forward_set_rows_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int64_t i = ir0; i < ir1; ++i) {
const int64_t i12 = i/(ne11*ne10);
const int64_t i11 = (i - i12*ne11*ne10)/ne10;
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
for (int64_t i03 = 0; i03 < ne03; ++i03) {
for (int64_t i02 = 0; i02 < ne02; ++i02) {
for (int64_t i = ir0; i < ir1; ++i) {
const int64_t i12 = i03%ne12;
const int64_t i11 = i02%ne11;
const int64_t i10 = i;
GGML_ASSERT(i01 >= 0 && i01 < ne1);
const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
ggml_cpu_fp32_to_fp16(
(const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03),
(ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i11*nb2 + i12*nb3), nc);
GGML_ASSERT(i01 >= 0 && i01 < ne1);
ggml_cpu_fp32_to_fp16(
(const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
(ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), nc);
}
}
}
}