diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 42f5b555fe..c4a26b7d10 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7095,7 +7095,7 @@ static void ggml_compute_forward_im2col_3d_f16( // micro kernel ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW] - const float * const src_data = (float *) src1->data + (in*IC + iic)*ID_IH_IW; // [ID, IH, IW] + const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW] for (int64_t ikd = 0; ikd < KD; ikd++) { for (int64_t ikh = 0; ikh < KH; ikh++) { @@ -7107,7 +7107,8 @@ static void ggml_compute_forward_im2col_3d_f16( if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; } else { - dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iid*IH_IW + iih*IW + iiw]); + const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW] + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s); } } } @@ -7187,7 +7188,7 @@ static void ggml_compute_forward_im2col_3d_f32( // micro kernel float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW] - const float * const src_data = (float *) src1->data + (in*IC + iic)*ID_IH_IW; // [ID, IH, IW] + const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW] for (int64_t ikd = 0; ikd < KD; ikd++) { for (int64_t ikh = 0; ikh < KH; ikh++) { @@ -7199,7 +7200,8 @@ static void ggml_compute_forward_im2col_3d_f32( if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) { dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; } else { - dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = src_data[iid*IH_IW + iih*IW + iiw]; + const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW] + dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s; } } }