mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-03 09:22:01 +00:00
fix cpu im2col_3d
This commit is contained in:
@@ -7094,7 +7094,7 @@ static void ggml_compute_forward_im2col_3d_f16(
|
||||
for (int64_t iic = ith; iic < IC; iic += nth) {
|
||||
|
||||
// micro kernel
|
||||
ggml_fp16_t * dst_data = wdata + (in*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
|
||||
ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
|
||||
const float * const src_data = (float *) src1->data + (in*IC + iic)*ID_IH_IW; // [ID, IH, IW]
|
||||
|
||||
for (int64_t ikd = 0; ikd < KD; ikd++) {
|
||||
@@ -7104,7 +7104,7 @@ static void ggml_compute_forward_im2col_3d_f16(
|
||||
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
||||
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
||||
|
||||
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
||||
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
||||
} else {
|
||||
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(src_data[iid*IH_IW + iih*IW + iiw]);
|
||||
@@ -7186,7 +7186,7 @@ static void ggml_compute_forward_im2col_3d_f32(
|
||||
for (int64_t iic = ith; iic < IC; iic += nth) {
|
||||
|
||||
// micro kernel
|
||||
float * dst_data = wdata + (in*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
|
||||
float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
|
||||
const float * const src_data = (float *) src1->data + (in*IC + iic)*ID_IH_IW; // [ID, IH, IW]
|
||||
|
||||
for (int64_t ikd = 0; ikd < KD; ikd++) {
|
||||
@@ -7196,7 +7196,7 @@ static void ggml_compute_forward_im2col_3d_f32(
|
||||
const int64_t iih = ioh*s1 + ikh*d1 - p1;
|
||||
const int64_t iid = iod*s2 + ikd*d2 - p2;
|
||||
|
||||
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
|
||||
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
|
||||
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
|
||||
} else {
|
||||
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = src_data[iid*IH_IW + iih*IW + iiw];
|
||||
|
||||
Reference in New Issue
Block a user