mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-03 09:22:01 +00:00
fix test_im2col_3d
This commit is contained in:
@@ -4054,10 +4054,10 @@ struct test_im2col_3d : public test_case {
|
||||
test_im2col_3d(ggml_type type_input = GGML_TYPE_F32, ggml_type type_kernel = GGML_TYPE_F16, ggml_type dst_type = GGML_TYPE_F32,
|
||||
std::array<int64_t, 4> ne_input = {10, 10, 10, 9}, // [OC*IC, KD, KH, KW]
|
||||
std::array<int64_t, 4> ne_kernel = {3, 3, 3, 1}, // [N*IC, ID, IH, IW]
|
||||
int64_t IC = 3,
|
||||
int s0 = 1, int s1 = 1, int s2 = 1,
|
||||
int p0 = 1, int p1 = 1, int p2 = 1,
|
||||
int d0 = 1, int d1 = 1, int d2 = 1,
|
||||
int64_t IC = 3)
|
||||
int d0 = 1, int d1 = 1, int d2 = 1)
|
||||
: type_input(type_input), type_kernel(type_kernel), dst_type(dst_type), ne_input(ne_input), ne_kernel(ne_kernel), s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), IC(IC) {}
|
||||
|
||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||
@@ -5689,7 +5689,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||
for (int d1 : {1, 3}) {
|
||||
for (int d2 : {1, 3}) {
|
||||
test_cases.emplace_back(new test_im2col_3d(
|
||||
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 3, 3}, {3, 3, 3, 3},
|
||||
GGML_TYPE_F32, GGML_TYPE_F32, GGML_TYPE_F32, {20, 20, 10, 3}, {3, 3, 3, 3},
|
||||
3, s0, s1, s2, p0, p1, p2, d0, d1, d2));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user