avoid ggml_conv_3d conflict

This commit is contained in:
leejet
2025-08-30 03:28:07 +08:00
parent d30e07dbb3
commit df05913bc4
3 changed files with 5 additions and 5 deletions

View File

@@ -1977,7 +1977,7 @@ extern "C" {
int d0, // dilation dimension 0 int d0, // dilation dimension 0
int d1); // dilation dimension 1 int d1); // dilation dimension 1
GGML_API struct ggml_tensor * ggml_conv_3d( GGML_API struct ggml_tensor * ggml_conv_3d_direct(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC] struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
struct ggml_tensor * b, // input [W, H, D, C * N] struct ggml_tensor * b, // input [W, H, D, C * N]

View File

@@ -4568,9 +4568,9 @@ struct ggml_tensor * ggml_conv_2d_direct(
return result; return result;
} }
// ggml_conv_3d // ggml_conv_3d_direct
struct ggml_tensor * ggml_conv_3d( struct ggml_tensor * ggml_conv_3d_direct(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b, struct ggml_tensor * b,

View File

@@ -4196,7 +4196,7 @@ struct test_conv_3d : public test_case {
return (uint64_t)N * OC * OD * OH * OW * (2 * IC * KD * KH * KW - 1); return (uint64_t)N * OC * OD * OH * OW * (2 * IC * KD * KH * KW - 1);
} }
test_conv_3d( test_conv_3d_direct(
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW,
int64_t OC, int64_t KD, int64_t KH, int64_t KW, int64_t OC, int64_t KD, int64_t KH, int64_t KW,
int s0, int s1, int s2, int s0, int s1, int s2,
@@ -4221,7 +4221,7 @@ struct test_conv_3d : public test_case {
ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel); ggml_tensor * kernel = ggml_new_tensor(ctx, type_kernel, 4, ne_kernel);
ggml_set_name(kernel, "kernel"); ggml_set_name(kernel, "kernel");
ggml_tensor * out = ggml_conv_3d(ctx, kernel, input, s0, s1, s2, p0, p1, p2, d0, d1, d2, (int)IC, (int)N, (int)OC); ggml_tensor * out = ggml_conv_3d_direct(ctx, kernel, input, s0, s1, s2, p0, p1, p2, d0, d1, d2, (int)IC, (int)N, (int)OC);
ggml_set_name(out, "out"); ggml_set_name(out, "out");
return out; return out;
} }