fix cpu im2col_3d

This commit is contained in:
leejet
2025-08-30 11:25:07 +08:00
parent f6278c832f
commit c9b9fabe08

View File

@@ -7094,7 +7094,7 @@ static void ggml_compute_forward_im2col_3d_f16(
for (int64_t iic = ith; iic < IC; iic += nth) { for (int64_t iic = ith; iic < IC; iic += nth) {
// micro kernel // micro kernel
ggml_fp16_t * dst_data = wdata + (in*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW] ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
const float * const src_data = (float *) src1->data + (in*IC + iic)*ID_IH_IW; // [ID, IH, IW] const float * const src_data = (float *) src1->data + (in*IC + iic)*ID_IH_IW; // [ID, IH, IW]
for (int64_t ikd = 0; ikd < KD; ikd++) { for (int64_t ikd = 0; ikd < KD; ikd++) {
@@ -7104,7 +7104,7 @@ static void ggml_compute_forward_im2col_3d_f16(
const int64_t iih = ioh*s1 + ikh*d1 - p1; const int64_t iih = ioh*s1 + ikh*d1 - p1;
const int64_t iid = iod*s2 + ikd*d2 - p2; const int64_t iid = iod*s2 + ikd*d2 - p2;
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
} else { } else {
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iid*IH_IW + iih*IW + iiw]); dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iid*IH_IW + iih*IW + iiw]);
@@ -7186,7 +7186,7 @@ static void ggml_compute_forward_im2col_3d_f32(
for (int64_t iic = ith; iic < IC; iic += nth) { for (int64_t iic = ith; iic < IC; iic += nth) {
// micro kernel // micro kernel
float * dst_data = wdata + (in*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW] float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
const float * const src_data = (float *) src1->data + (in*IC + iic)*ID_IH_IW; // [ID, IH, IW] const float * const src_data = (float *) src1->data + (in*IC + iic)*ID_IH_IW; // [ID, IH, IW]
for (int64_t ikd = 0; ikd < KD; ikd++) { for (int64_t ikd = 0; ikd < KD; ikd++) {
@@ -7196,7 +7196,7 @@ static void ggml_compute_forward_im2col_3d_f32(
const int64_t iih = ioh*s1 + ikh*d1 - p1; const int64_t iih = ioh*s1 + ikh*d1 - p1;
const int64_t iid = iod*s2 + ikd*d2 - p2; const int64_t iid = iod*s2 + ikd*d2 - p2;
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) { if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0; dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
} else { } else {
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = src_data[iid*IH_IW + iih*IW + iiw]; dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = src_data[iid*IH_IW + iih*IW + iiw];