mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-18 11:46:58 +00:00
62 lines
2.0 KiB
Common Lisp
62 lines
2.0 KiB
Common Lisp
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
|
|
|
|
__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
|
|
|
__kernel void mul_mat_f16_f32_image(
|
|
__read_only image2d_t A_img,
|
|
__read_only image2d_t B_img,
|
|
__global float* C_buf,
|
|
const ulong c_offset,
|
|
const int M,
|
|
const int N,
|
|
const int K
|
|
) {
|
|
const int n_4_idx = get_global_id(0);
|
|
const int m_idx = get_global_id(1);
|
|
|
|
const int n_base = n_4_idx << 2;
|
|
|
|
if (n_base >= N || m_idx >= M) {
|
|
return;
|
|
}
|
|
|
|
float4 c_vals = (float4)(0.0f);
|
|
const int K_4 = (K + 3) / 4;
|
|
|
|
for (int k_4_idx = 0; k_4_idx < K_4; ++k_4_idx) {
|
|
int k_base = k_4_idx << 2;
|
|
|
|
float4 a_vals = convert_float4(read_imageh(A_img, SAMPLER, (int2)(k_4_idx, m_idx)));
|
|
|
|
if (k_base < K) {
|
|
float4 b0 = convert_float4(read_imageh(B_img, SAMPLER, (int2)(n_4_idx, k_base + 0)));
|
|
c_vals = mad(a_vals.x, b0, c_vals);
|
|
}
|
|
if (k_base + 1 < K) {
|
|
float4 b1 = convert_float4(read_imageh(B_img, SAMPLER, (int2)(n_4_idx, k_base + 1)));
|
|
c_vals = mad(a_vals.y, b1, c_vals);
|
|
}
|
|
if (k_base + 2 < K) {
|
|
float4 b2 = convert_float4(read_imageh(B_img, SAMPLER, (int2)(n_4_idx, k_base + 2)));
|
|
c_vals = mad(a_vals.z, b2, c_vals);
|
|
}
|
|
if (k_base + 3 < K) {
|
|
float4 b3 = convert_float4(read_imageh(B_img, SAMPLER, (int2)(n_4_idx, k_base + 3)));
|
|
c_vals = mad(a_vals.w, b3, c_vals);
|
|
}
|
|
}
|
|
|
|
__global float* C = (__global float*)((__global char*)C_buf + c_offset);
|
|
|
|
if (n_base + 3 < N) {
|
|
C[(n_base + 0) * M + m_idx] = c_vals.x;
|
|
C[(n_base + 1) * M + m_idx] = c_vals.y;
|
|
C[(n_base + 2) * M + m_idx] = c_vals.z;
|
|
C[(n_base + 3) * M + m_idx] = c_vals.w;
|
|
} else {
|
|
if (n_base < N) C[n_base * M + m_idx] = c_vals.x;
|
|
if (n_base + 1 < N) C[(n_base + 1) * M + m_idx] = c_vals.y;
|
|
if (n_base + 2 < N) C[(n_base + 2) * M + m_idx] = c_vals.z;
|
|
}
|
|
}
|