mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-27 08:21:30 +00:00
CUDA: use FP32 arithmetic for conv2d (#15683)
This commit is contained in:
@@ -82,7 +82,7 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
|
|||||||
int64_t n, c_out, out_y, out_x;
|
int64_t n, c_out, out_y, out_x;
|
||||||
Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
|
Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
|
||||||
|
|
||||||
T acc = 0;
|
float acc = 0.0f;
|
||||||
|
|
||||||
for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
|
for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
|
||||||
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
|
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
|
||||||
@@ -93,21 +93,15 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
|
|||||||
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
|
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
|
||||||
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
|
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
|
||||||
|
|
||||||
T input_val;
|
const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
|
||||||
if (std::is_same<T, half>::value) {
|
const float kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
|
||||||
input_val = __float2half(input[Layout::input_index(n, c_in, in_y, in_x, P)]);
|
|
||||||
} else {
|
|
||||||
input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
|
|
||||||
}
|
|
||||||
|
|
||||||
T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
|
|
||||||
acc += (input_val * kernel_val);
|
acc += (input_val * kernel_val);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// [N, OC, OH, OW]
|
// [N, OC, OH, OW]
|
||||||
output[Layout::output_index(n, c_out, out_y, out_x, P)] = (float) acc;
|
output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|||||||
Reference in New Issue
Block a user