mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-14 11:07:10 +00:00
ggml : ggml_set_rows support broadcast
This commit is contained in:
@@ -4530,12 +4530,14 @@ static void ggml_compute_forward_set_rows_f32(
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const int64_t nc = ne00;
|
||||
const int64_t nr = ggml_nelements(src1);
|
||||
const int64_t nr = ne01;
|
||||
|
||||
assert(ne0 == nc);
|
||||
assert(ne02 == ne11);
|
||||
assert(nb00 == sizeof(float));
|
||||
assert(ggml_nrows(src0) == nr);
|
||||
assert(ne2 == ne02);
|
||||
assert(ne3 == ne03);
|
||||
assert(src0->type == GGML_TYPE_F32);
|
||||
assert(ne02 % ne11 == 0);
|
||||
assert(ne03 % ne12 == 0);
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
@@ -4547,17 +4549,22 @@ static void ggml_compute_forward_set_rows_f32(
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
for (int64_t i = ir0; i < ir1; ++i) {
|
||||
const int64_t i12 = i/(ne11*ne10);
|
||||
const int64_t i11 = (i - i12*ne11*ne10)/ne10;
|
||||
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
|
||||
const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||
for (int64_t i03 = 0; i03 < ne03; ++i03) {
|
||||
for (int64_t i02 = 0; i02 < ne02; ++i02) {
|
||||
for (int64_t i = ir0; i < ir1; ++i) {
|
||||
const int64_t i12 = i03%ne12;
|
||||
const int64_t i11 = i02%ne11;
|
||||
const int64_t i10 = i;
|
||||
|
||||
GGML_ASSERT(i01 >= 0 && i01 < ne1);
|
||||
const int64_t i01 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
|
||||
|
||||
ggml_cpu_fp32_to_fp16(
|
||||
(const float *) ((char *) src0->data + i10*nb01 + i11*nb02 + i12*nb03),
|
||||
(ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i11*nb2 + i12*nb3), nc);
|
||||
GGML_ASSERT(i01 >= 0 && i01 < ne1);
|
||||
|
||||
ggml_cpu_fp32_to_fp16(
|
||||
(const float *) ((char *) src0->data + i*nb01 + i02*nb02 + i03*nb03),
|
||||
(ggml_fp16_t *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3), nc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user