mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-05 09:36:52 +00:00
add conv3d support
This commit is contained in:
@@ -1870,6 +1870,41 @@ extern "C" {
|
||||
int d0, // dilation dimension 0
|
||||
int d1); // dilation dimension 1
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_im2col_3d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int64_t IC,
|
||||
int s0, // stride width
|
||||
int s1, // stride height
|
||||
int s2, // stride depth
|
||||
int p0, // padding width
|
||||
int p1, // padding height
|
||||
int p2, // padding depth
|
||||
int d0, // dilation width
|
||||
int d1, // dilation height
|
||||
int d2, // dilation depth
|
||||
enum ggml_type dst_type);
|
||||
|
||||
// a: [OC*IC, KD, KH, KW]
|
||||
// b: [N*IC, ID, IH, IW]
|
||||
// result: [N*OC, OD, OH, OW]
|
||||
GGML_API struct ggml_tensor * ggml_conv_3d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int64_t IC,
|
||||
int s0, // stride width
|
||||
int s1, // stride height
|
||||
int s2, // stride depth
|
||||
int p0, // padding width
|
||||
int p1, // padding height
|
||||
int p2, // padding depth
|
||||
int d0, // dilation width
|
||||
int d1, // dilation height
|
||||
int d2 // dilation depth
|
||||
);
|
||||
|
||||
// kernel size is a->ne[0] x a->ne[1]
|
||||
// stride is equal to kernel size
|
||||
// padding is zero
|
||||
|
||||
@@ -4361,6 +4361,88 @@ struct ggml_tensor * ggml_conv_2d(
|
||||
return result;
|
||||
}
|
||||
|
||||
// a: [OC*IC, KD, KH, KW]
|
||||
// b: [N*IC, ID, IH, IW]
|
||||
// result: [N*OD, OH, OW, IC * KD * KH * KW]
|
||||
struct ggml_tensor * ggml_im2col_3d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int64_t IC,
|
||||
int s0, // stride width
|
||||
int s1, // stride height
|
||||
int s2, // stride depth
|
||||
int p0, // padding width
|
||||
int p1, // padding height
|
||||
int p2, // padding depth
|
||||
int d0, // dilation width
|
||||
int d1, // dilation height
|
||||
int d2, // dilation depth
|
||||
enum ggml_type dst_type) {
|
||||
|
||||
const int64_t N = b->ne[3] / IC;
|
||||
const int64_t ID = b->ne[2];
|
||||
const int64_t IH = b->ne[1];
|
||||
const int64_t IW = b->ne[0];
|
||||
|
||||
const int64_t OC = a->ne[3] / IC;
|
||||
const int64_t KD = a->ne[2];
|
||||
const int64_t KH = a->ne[1];
|
||||
const int64_t KW = a->ne[0];
|
||||
const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2);
|
||||
|
||||
struct ggml_tensor* img = ggml_reshape_4d(ctx, b, IW*IH, ID, IC, N); // [N, IC, ID, IH * IW]
|
||||
img = ggml_cont(ctx, ggml_permute(ctx, img, 2, 0, 1, 3)); // [N, IH*IW, IC, ID]
|
||||
img = ggml_reshape_3d(ctx, b, ID, IC, IW*IH*N); // [N*IH*IW, IC, ID]
|
||||
|
||||
a = ggml_reshape_3d(ctx, a, KD, IC, OC*KW*KH); // [OC*KW*KH, IC, KD]
|
||||
img = ggml_im2col(ctx, a, img, s2, 1, p2, 0, d2, 1, false, GGML_TYPE_F32); // [N*IH*IW, OD, IC*KD]
|
||||
|
||||
img = ggml_reshape_4d(ctx, img, IC*KD, OD, IW*IH, N); // [N, IH*IW, OD, IC*KD]
|
||||
img = ggml_cont(ctx, ggml_permute(ctx, img, 1, 2, 0, 3)); // [N, OD, IC*KD, IH*IW]
|
||||
img = ggml_reshape_4d(ctx, img, IW, IH, IC*KD, OD*N); // [N*OD, IC*KD, IH, IW]
|
||||
|
||||
a = ggml_reshape_4d(ctx, a, KW, KH, IC*KD, OC); // [OC, KD*IC, KH, KW]
|
||||
|
||||
img = ggml_im2col(ctx, a, img, s0, s1, p0, p1, d0, d1, true, dst_type); // [N * OD, OH, OW, IC * KD * KH * KW]
|
||||
return img;
|
||||
}
|
||||
|
||||
// a: [OC*IC, KD, KH, KW]
|
||||
// b: [N*IC, ID, IH, IW]
|
||||
// result: [N*OC, OD, OH, OW]
|
||||
struct ggml_tensor * ggml_conv_3d(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
int64_t IC,
|
||||
int s0, // stride width
|
||||
int s1, // stride height
|
||||
int s2, // stride depth
|
||||
int p0, // padding width
|
||||
int p1, // padding height
|
||||
int p2, // padding depth
|
||||
int d0, // dilation width
|
||||
int d1, // dilation height
|
||||
int d2 // dilation depth
|
||||
) {
|
||||
struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW]
|
||||
|
||||
int64_t OC = a->ne[3] / IC;
|
||||
int64_t N = b->ne[3] / IC;
|
||||
struct ggml_tensor * result =
|
||||
ggml_mul_mat(ctx,
|
||||
ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW]
|
||||
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC)); // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW]
|
||||
|
||||
int64_t OD = im2col->ne[3] / N;
|
||||
result = ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW]
|
||||
result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW]
|
||||
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW]
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// ggml_conv_2d_sk_p0
|
||||
|
||||
struct ggml_tensor * ggml_conv_2d_sk_p0(
|
||||
|
||||
Reference in New Issue
Block a user