mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-06 09:46:50 +00:00
@@ -7207,6 +7207,148 @@ void ggml_compute_forward_conv_2d(
|
||||
ggml_compute_forward_conv_2d_impl(params, src0, src1, dst, src0->type);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_3d
|
||||
|
||||
static void ggml_compute_forward_conv_3d_impl(const ggml_compute_params * params,
|
||||
const ggml_tensor * kernel,
|
||||
const ggml_tensor * src,
|
||||
ggml_tensor * dst,
|
||||
ggml_type kernel_type) {
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous(kernel));
|
||||
GGML_ASSERT(kernel_type == GGML_TYPE_F16 || kernel_type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(kernel->type == kernel_type);
|
||||
|
||||
const ggml_type_traits * traits = ggml_get_type_traits(kernel_type);
|
||||
|
||||
const int32_t s0 = dst->op_params[0];
|
||||
const int32_t s1 = dst->op_params[1];
|
||||
const int32_t s2 = dst->op_params[2];
|
||||
const int32_t p0 = dst->op_params[3];
|
||||
const int32_t p1 = dst->op_params[4];
|
||||
const int32_t p2 = dst->op_params[5];
|
||||
const int32_t d0 = dst->op_params[6];
|
||||
const int32_t d1 = dst->op_params[7];
|
||||
const int32_t d2 = dst->op_params[8];
|
||||
const int32_t c = dst->op_params[9];
|
||||
const int32_t n = dst->op_params[10];
|
||||
const int32_t oc = dst->op_params[11];
|
||||
|
||||
const int64_t src_w = src->ne[0];
|
||||
const int64_t src_h = src->ne[1];
|
||||
const int64_t src_d = src->ne[2];
|
||||
const int64_t knl_w = kernel->ne[0];
|
||||
const int64_t knl_h = kernel->ne[1];
|
||||
const int64_t knl_d = kernel->ne[2];
|
||||
const int64_t dst_w = dst->ne[0];
|
||||
const int64_t dst_h = dst->ne[1];
|
||||
const int64_t dst_d = dst->ne[2];
|
||||
|
||||
const float * src_data = (float *) src->data;
|
||||
void * knl_data = kernel->data;
|
||||
float * dst_data = (float *) dst->data;
|
||||
|
||||
const int64_t knl_n_per_channel = knl_w * knl_h * knl_d;
|
||||
const int64_t knl_n_total = knl_n_per_channel * c;
|
||||
const int64_t patch_total = n * dst_w * dst_h * dst_d;
|
||||
|
||||
const int64_t space_per_patch = knl_n_total * traits->type_size + oc * sizeof(float);
|
||||
const int64_t batch_size = params->wsize / space_per_patch;
|
||||
const int64_t patches_per_batch = batch_size > 8 ? (batch_size / 8) * 8 : batch_size;
|
||||
const int64_t batch_n = (patch_total + patches_per_batch - 1) / patches_per_batch;
|
||||
|
||||
GGML_ASSERT(patches_per_batch > 0 && batch_size >= 1);
|
||||
|
||||
void * tmp = params->wdata;
|
||||
|
||||
for (int64_t batch_i = 0; batch_i < batch_n; ++batch_i) {
|
||||
const int64_t patch_start_batch = batch_i * patches_per_batch;
|
||||
const int64_t patch_end_batch = std::min(patch_start_batch + patches_per_batch, patch_total);
|
||||
const int64_t patch_n_in_batch = patch_end_batch - patch_start_batch;
|
||||
|
||||
const int64_t patch_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
||||
const int64_t patch_start = patch_start_batch + params->ith * patch_per_thread;
|
||||
const int64_t patch_end = std::min(patch_start + patch_per_thread, patch_end_batch);
|
||||
|
||||
for (int64_t p = patch_start; p < patch_end; ++p) {
|
||||
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
||||
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
||||
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
||||
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
||||
const int64_t dst_y = p_in_depth / dst_w;
|
||||
const int64_t dst_x = p_in_depth % dst_w;
|
||||
|
||||
char * dst_row = (char *) tmp + (p % patches_per_batch) * knl_n_total * traits->type_size;
|
||||
|
||||
for (int64_t ic = 0; ic < c; ++ic) {
|
||||
for (int64_t kz = 0; kz < knl_d; ++kz) {
|
||||
for (int64_t ky = 0; ky < knl_h; ++ky) {
|
||||
for (int64_t kx = 0; kx < knl_w; ++kx) {
|
||||
const int64_t sz = dst_z * s2 + kz * d2 - p2;
|
||||
const int64_t sy = dst_y * s1 + ky * d1 - p1;
|
||||
const int64_t sx = dst_x * s0 + kx * d0 - p0;
|
||||
|
||||
int64_t dst_idx = ic * knl_n_per_channel + kz * (knl_h * knl_w) + ky * knl_w + kx;
|
||||
|
||||
float src_val;
|
||||
if (sz < 0 || sz >= src_d || sy < 0 || sy >= src_h || sx < 0 || sx >= src_w) {
|
||||
src_val = 0.0f;
|
||||
} else {
|
||||
const int64_t cn_idx = batch_idx * c + ic;
|
||||
const float * src_ptr = (const float *)((const char *)src_data + sx*src->nb[0] + sy*src->nb[1] + sz*src->nb[2] + cn_idx*src->nb[3]);
|
||||
src_val = *src_ptr;
|
||||
}
|
||||
|
||||
char * element_ptr = dst_row + dst_idx * traits->type_size;
|
||||
if (kernel_type == GGML_TYPE_F32) {
|
||||
*(float *)element_ptr = src_val;
|
||||
} else if (kernel_type == GGML_TYPE_F16) {
|
||||
*(ggml_fp16_t *)element_ptr = GGML_CPU_FP32_TO_FP16(src_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
float * gemm_output = (float *) ((char *) tmp + patches_per_batch * knl_n_total * traits->type_size);
|
||||
ggml_call_mul_mat(kernel_type, params, patch_n_in_batch, oc, knl_n_total, tmp, knl_data, gemm_output);
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
|
||||
const int64_t permute_per_thread = (patch_n_in_batch + params->nth - 1) / params->nth;
|
||||
const int64_t permute_start = params->ith * permute_per_thread;
|
||||
const int64_t permute_end = std::min(permute_start + permute_per_thread, patch_n_in_batch);
|
||||
|
||||
for (int64_t i = permute_start; i < permute_end; ++i) {
|
||||
const int64_t p = patch_start_batch + i;
|
||||
const int64_t p_in_batch = p % (dst_w * dst_h * dst_d);
|
||||
const int64_t p_in_depth = p_in_batch % (dst_w * dst_h);
|
||||
const int64_t batch_idx = p / (dst_w * dst_h * dst_d);
|
||||
const int64_t dst_z = p_in_batch / (dst_w * dst_h);
|
||||
const int64_t dst_y = p_in_depth / dst_w;
|
||||
const int64_t dst_x = p_in_depth % dst_w;
|
||||
|
||||
for (int64_t ioc = 0; ioc < oc; ++ioc) {
|
||||
const float value = gemm_output[i * oc + ioc];
|
||||
const int64_t ocn_idx = batch_idx * oc + ioc;
|
||||
float * dst_ptr = (float *)((char *)dst_data + dst_x*dst->nb[0] + dst_y*dst->nb[1] + dst_z*dst->nb[2] + ocn_idx*dst->nb[3]);
|
||||
*dst_ptr = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_compute_forward_conv_3d(
|
||||
const ggml_compute_params * params,
|
||||
ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
ggml_compute_forward_conv_3d_impl(params, src0, src1, dst, src0->type);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_conv_transpose_2d
|
||||
|
||||
void ggml_compute_forward_conv_transpose_2d(
|
||||
|
||||
Reference in New Issue
Block a user