|
|
|
|
@@ -2032,7 +2032,38 @@ kernel void kernel_ssm_conv_f32_f32(
|
|
|
|
|
x[0] = sumf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part
|
|
|
|
|
kernel void kernel_ssm_conv_f32_f32_4(
|
|
|
|
|
constant ggml_metal_kargs_ssm_conv & args,
|
|
|
|
|
device const void * src0,
|
|
|
|
|
device const void * src1,
|
|
|
|
|
device float * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
uint3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
const int64_t ir = tgpig.x;
|
|
|
|
|
const int64_t i2 = tgpig.y;
|
|
|
|
|
const int64_t i3 = tgpig.z;
|
|
|
|
|
|
|
|
|
|
const int64_t nc = args.ne10;
|
|
|
|
|
//const int64_t ncs = args.ne00;
|
|
|
|
|
//const int64_t nr = args.ne01;
|
|
|
|
|
//const int64_t n_t = args.ne1;
|
|
|
|
|
//const int64_t n_s = args.ne2;
|
|
|
|
|
|
|
|
|
|
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
|
|
|
|
|
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
|
|
|
|
|
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
|
|
|
|
|
|
|
|
|
|
float sumf = 0.0f;
|
|
|
|
|
|
|
|
|
|
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
|
|
|
|
|
sumf += dot(s[i0], c[i0]);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
x[0] = sumf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
|
|
|
|
kernel void kernel_ssm_scan_f32(
|
|
|
|
|
constant ggml_metal_kargs_ssm_scan & args,
|
|
|
|
|
device const void * src0,
|
|
|
|
|
@@ -2044,219 +2075,88 @@ kernel void kernel_ssm_scan_f32(
|
|
|
|
|
device const void * src6,
|
|
|
|
|
device float * dst,
|
|
|
|
|
threadgroup float * shared [[threadgroup(0)]],
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgptg[[simdgroups_per_threadgroup]],
|
|
|
|
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgptg[[simdgroups_per_threadgroup]],
|
|
|
|
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
|
|
|
|
constexpr short NW = N_SIMDWIDTH;
|
|
|
|
|
|
|
|
|
|
const int64_t i0 = tpitg.x;
|
|
|
|
|
const int64_t i1 = 0;
|
|
|
|
|
const int64_t ir = tgpig.x; // current head
|
|
|
|
|
const int64_t i3 = tgpig.y; // current seq
|
|
|
|
|
shared[tpitg.x] = 0.0f;
|
|
|
|
|
|
|
|
|
|
const uint64_t nb00 = sizeof(float);
|
|
|
|
|
const uint64_t nb10 = sizeof(float);
|
|
|
|
|
const uint64_t nb20 = sizeof(float);
|
|
|
|
|
const int32_t i0 = tpitg.x;
|
|
|
|
|
const int32_t i1 = tgpig.x;
|
|
|
|
|
const int32_t ir = tgpig.y; // current head
|
|
|
|
|
const int32_t i3 = tgpig.z; // current seq
|
|
|
|
|
|
|
|
|
|
const int64_t nc = args.d_state;
|
|
|
|
|
const int64_t nr = args.d_inner;
|
|
|
|
|
const int64_t nh = args.n_head;
|
|
|
|
|
const int64_t ng = args.n_group;
|
|
|
|
|
const int64_t n_t = args.n_seq_tokens;
|
|
|
|
|
const int32_t nc = args.d_state;
|
|
|
|
|
const int32_t nr = args.d_inner;
|
|
|
|
|
const int32_t nh = args.n_head;
|
|
|
|
|
const int32_t ng = args.n_group;
|
|
|
|
|
const int32_t n_t = args.n_seq_tokens;
|
|
|
|
|
|
|
|
|
|
const int64_t s_off = args.s_off;
|
|
|
|
|
const int32_t s_off = args.s_off;
|
|
|
|
|
|
|
|
|
|
device const int32_t * ids = (device const int32_t *) src6;
|
|
|
|
|
|
|
|
|
|
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
|
|
|
|
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
|
|
|
|
const int64_t i = i0 + i1*nc;
|
|
|
|
|
const int64_t g = ir / (nh / ng); // repeat_interleave
|
|
|
|
|
|
|
|
|
|
const int32_t i = i0 + i1*nc;
|
|
|
|
|
const int32_t g = ir / (nh / ng); // repeat_interleave
|
|
|
|
|
|
|
|
|
|
float s0 = s0_buff[i];
|
|
|
|
|
float s = s_buff[i];
|
|
|
|
|
float s = 0.0f;
|
|
|
|
|
|
|
|
|
|
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31);
|
|
|
|
|
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
|
|
|
|
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
|
|
|
|
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
|
|
|
|
|
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
|
|
|
|
|
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
|
|
|
|
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
|
|
|
|
|
|
|
|
|
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
|
|
|
|
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
|
|
|
|
|
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
|
|
|
|
|
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
|
|
|
|
|
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
|
|
|
|
|
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
|
|
|
|
|
const float A0 = A[i0%args.ne30];
|
|
|
|
|
|
|
|
|
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
|
|
|
const float x_dt = x[0] * dt_soft_plus;
|
|
|
|
|
device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
|
|
|
|
|
device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
|
|
|
|
|
device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
|
|
|
|
|
device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
|
|
|
|
|
|
|
|
|
|
const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt);
|
|
|
|
|
s = state;
|
|
|
|
|
device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
|
|
|
|
|
|
|
|
|
|
// Parallel sum: This relies on the fact that this kernel will be
|
|
|
|
|
// dispatched with each threadgroup having (d_state, 1, 1) threads which
|
|
|
|
|
// are subdivided into SIMD groups of size `sgptg`. The goal is to
|
|
|
|
|
// compute y = sum({state * C[i] for i in range(d_state)}).
|
|
|
|
|
// To parallelize this effectively, we first use simd_sum over each SIMD
|
|
|
|
|
// group to compute the sum of each SIMD group, then place the result in
|
|
|
|
|
// the SIMD group's indexed bucket in the shared memory. We then sum
|
|
|
|
|
// over the individual group sums to compute the final sum.
|
|
|
|
|
|
|
|
|
|
// Computed for each thread
|
|
|
|
|
float sumf = state * C[i0];
|
|
|
|
|
|
|
|
|
|
// Sum the threads in the simd group => simd sum
|
|
|
|
|
sumf = simd_sum(sumf);
|
|
|
|
|
|
|
|
|
|
if (sgptg > 1) {
|
|
|
|
|
|
|
|
|
|
// Once per simd group, place the group sum into the shared buffer
|
|
|
|
|
if (tiisg == 0) {
|
|
|
|
|
shared[sgitg] = sumf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Wait for all threads in the threadgroup to reach this point. This
|
|
|
|
|
// ensures that all elements of the shared buffer are populated with the
|
|
|
|
|
// sum of the individual simd groups.
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
// For simd group 0 at indices < num simd groups, extract the shared
|
|
|
|
|
// simd sum
|
|
|
|
|
sumf = 0.0f;
|
|
|
|
|
if (sgitg == 0) {
|
|
|
|
|
if (tiisg < sgptg) {
|
|
|
|
|
sumf = shared[tiisg];
|
|
|
|
|
}
|
|
|
|
|
sumf = simd_sum(sumf);
|
|
|
|
|
if (tiisg == 0) {
|
|
|
|
|
y[0] = sumf;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} else if (tiisg == 0) {
|
|
|
|
|
y[0] = sumf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// recurse
|
|
|
|
|
s0 = s;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Assign the final state to the output buffer
|
|
|
|
|
s_buff[i] = s;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
|
|
|
|
|
kernel void kernel_ssm_scan_group_f32(
|
|
|
|
|
constant ggml_metal_kargs_ssm_scan & args,
|
|
|
|
|
device const void * src0,
|
|
|
|
|
device const void * src1,
|
|
|
|
|
device const void * src2,
|
|
|
|
|
device const void * src3,
|
|
|
|
|
device const void * src4,
|
|
|
|
|
device const void * src5,
|
|
|
|
|
device const void * src6,
|
|
|
|
|
device float * dst,
|
|
|
|
|
threadgroup float * shared [[threadgroup(0)]],
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
ushort sgitg[[simdgroup_index_in_threadgroup]],
|
|
|
|
|
ushort tiisg[[thread_index_in_simdgroup]],
|
|
|
|
|
ushort sgptg[[simdgroups_per_threadgroup]],
|
|
|
|
|
uint3 tgpg[[threadgroups_per_grid]]) {
|
|
|
|
|
|
|
|
|
|
const int64_t i0 = tpitg.x;
|
|
|
|
|
const int64_t i1 = tgpig.x;
|
|
|
|
|
const int64_t ir = tgpig.y; // current head
|
|
|
|
|
const int64_t i3 = tgpig.z; // current seq
|
|
|
|
|
|
|
|
|
|
const uint64_t nb00 = sizeof(float);
|
|
|
|
|
const uint64_t nb10 = sizeof(float);
|
|
|
|
|
const uint64_t nb20 = sizeof(float);
|
|
|
|
|
|
|
|
|
|
const int64_t nc = args.d_state;
|
|
|
|
|
const int64_t nr = args.d_inner;
|
|
|
|
|
const int64_t nh = args.n_head;
|
|
|
|
|
const int64_t ng = args.n_group;
|
|
|
|
|
const int64_t n_t = args.n_seq_tokens;
|
|
|
|
|
|
|
|
|
|
const int64_t s_off = args.s_off;
|
|
|
|
|
|
|
|
|
|
device const int32_t * ids = (device const int32_t *) src6;
|
|
|
|
|
|
|
|
|
|
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
|
|
|
|
|
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
|
|
|
|
|
const int64_t i = i0 + i1*nc;
|
|
|
|
|
const int64_t g = ir / (nh / ng); // repeat_interleave
|
|
|
|
|
float s0 = s0_buff[i];
|
|
|
|
|
float s = s_buff[i];
|
|
|
|
|
|
|
|
|
|
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
|
|
|
|
|
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
|
|
|
|
|
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
|
|
|
|
|
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
|
|
|
|
|
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
|
|
|
|
|
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
|
|
|
|
|
|
|
|
|
|
for (int64_t i2 = 0; i2 < n_t; ++i2) {
|
|
|
|
|
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
|
|
|
|
|
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
|
|
|
|
|
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
|
|
|
|
|
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
|
|
|
|
|
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
|
|
|
|
|
|
|
|
|
|
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
|
|
|
|
|
const float x_dt = x[0] * dt_soft_plus;
|
|
|
|
|
const float dA = exp(dt_soft_plus * A[0]);
|
|
|
|
|
|
|
|
|
|
const float state = (s0 * dA) + (B[i0] * x_dt);
|
|
|
|
|
s = state;
|
|
|
|
|
|
|
|
|
|
// Parallel sum: This relies on the fact that this kernel will be
|
|
|
|
|
// dispatched with each threadgroup having (d_state, 1, 1) threads which
|
|
|
|
|
// are subdivided into SIMD groups of size `sgptg`. The goal is to
|
|
|
|
|
// compute y = sum({state * C[i] for i in range(d_state)}).
|
|
|
|
|
// To parallelize this effectively, we first use simd_sum over each SIMD
|
|
|
|
|
// group to compute the sum of each SIMD group, then place the result in
|
|
|
|
|
// the SIMD group's indexed bucket in the shared memory. We then sum
|
|
|
|
|
// over the individual group sums to compute the final sum.
|
|
|
|
|
|
|
|
|
|
// Computed for each thread
|
|
|
|
|
float sumf = state * C[i0];
|
|
|
|
|
|
|
|
|
|
// Sum the threads in the simd group => simd sum
|
|
|
|
|
sumf = simd_sum(sumf);
|
|
|
|
|
|
|
|
|
|
// Once per simd group, place the group sum into the shared buffer
|
|
|
|
|
if (tiisg == 0) {
|
|
|
|
|
shared[sgitg] = sumf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Wait for all threads in the threadgroup to reach this point. This
|
|
|
|
|
// ensures that all elements of the shared buffer are populated with the
|
|
|
|
|
// sum of the individual simd groups.
|
|
|
|
|
for (int i2 = 0; i2 < n_t; i2 += sgptg) {
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
// For simd group 0 at indices < num simd groups, extract the shared
|
|
|
|
|
// simd sum
|
|
|
|
|
sumf = 0.0f;
|
|
|
|
|
if (sgitg == 0) {
|
|
|
|
|
if (tiisg < sgptg) {
|
|
|
|
|
sumf = shared[tiisg];
|
|
|
|
|
}
|
|
|
|
|
sumf = simd_sum(sumf);
|
|
|
|
|
for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
|
|
|
|
|
const float dt0 = dt[0];
|
|
|
|
|
const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
|
|
|
|
|
const float x_dt = x[0] * dtsp;
|
|
|
|
|
const float dA = exp(dtsp * A0);
|
|
|
|
|
|
|
|
|
|
s = (s0 * dA) + (B[i0] * x_dt);
|
|
|
|
|
|
|
|
|
|
const float sumf = simd_sum(s * C[i0]);
|
|
|
|
|
|
|
|
|
|
if (tiisg == 0) {
|
|
|
|
|
y[0] = sumf;
|
|
|
|
|
shared[t*NW + sgitg] = sumf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// recurse
|
|
|
|
|
s0 = s;
|
|
|
|
|
|
|
|
|
|
x += args.ns12;
|
|
|
|
|
dt += args.ns21;
|
|
|
|
|
B += args.ns42;
|
|
|
|
|
C += args.ns52;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// recurse
|
|
|
|
|
s0 = s;
|
|
|
|
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
|
|
|
|
|
|
|
|
const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
|
|
|
|
|
|
|
|
|
|
if (tiisg == 0 && i2 + sgitg < n_t) {
|
|
|
|
|
y[sgitg*nh*nr] = sumf;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
y += sgptg*nh*nr;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Assign the final state to the output buffer
|
|
|
|
|
s_buff[i] = s;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -5770,21 +5670,17 @@ kernel void kernel_flash_attn_ext_vec_reduce(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename T0, typename T1>
|
|
|
|
|
kernel void kernel_cpy(
|
|
|
|
|
kernel void kernel_cpy_t_t(
|
|
|
|
|
constant ggml_metal_kargs_cpy & args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
device char * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
ushort3 tptg[[threads_per_threadgroup]]) {
|
|
|
|
|
ushort tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
const int i03 = tgpig[2];
|
|
|
|
|
const int i02 = tgpig[1];
|
|
|
|
|
const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x;
|
|
|
|
|
|
|
|
|
|
if (i01 >= args.ne01) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
|
|
|
|
|
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
|
|
|
|
|
|
|
|
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
|
|
|
|
|
|
|
|
@@ -5795,190 +5691,70 @@ kernel void kernel_cpy(
|
|
|
|
|
|
|
|
|
|
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
|
|
|
|
|
|
|
|
for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) {
|
|
|
|
|
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
|
|
|
|
|
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
|
|
|
dst_data[i00] = (T1) src[0];
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
|
|
|
|
|
typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy<float, int32_t>;
|
|
|
|
|
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy<int32_t, float>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
|
|
|
|
|
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
|
|
|
|
|
#if defined(GGML_METAL_HAS_BF16)
|
|
|
|
|
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
|
|
|
|
|
#endif
|
|
|
|
|
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
|
|
|
|
|
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
|
|
|
|
|
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
|
|
|
|
|
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
|
|
|
|
|
#if defined(GGML_METAL_HAS_BF16)
|
|
|
|
|
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>;
|
|
|
|
|
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
|
|
|
|
|
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
|
|
|
|
|
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
// TODO: templetify these kernels
|
|
|
|
|
kernel void kernel_cpy_f32_q8_0(
|
|
|
|
|
template<short QK,
|
|
|
|
|
typename block_q,
|
|
|
|
|
void (*quantize_func)(device const float *, device block_q &)>
|
|
|
|
|
kernel void kernel_cpy_f32_q(
|
|
|
|
|
constant ggml_metal_kargs_cpy & args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
device char * dst,
|
|
|
|
|
device char * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
ushort tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
const int i03 = tgpig[2];
|
|
|
|
|
const int i02 = tgpig[1];
|
|
|
|
|
const int i01 = tgpig[0];
|
|
|
|
|
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
|
|
|
|
|
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
|
|
|
|
|
|
|
|
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
|
|
|
|
|
|
|
|
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
|
|
|
|
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
|
|
|
|
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
|
|
|
|
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0;
|
|
|
|
|
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
|
|
|
|
|
|
|
|
|
|
device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
|
|
|
device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
|
|
|
|
|
|
|
|
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
|
|
|
|
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
|
|
|
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
|
|
|
|
|
device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
|
|
|
|
|
|
|
|
|
|
quantize_q8_0(src, dst_data[i00/QK8_0]);
|
|
|
|
|
quantize_func(src, dst_data[i00]);
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel void kernel_cpy_f32_q4_0(
|
|
|
|
|
constant ggml_metal_kargs_cpy & args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
device char * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
const int i03 = tgpig[2];
|
|
|
|
|
const int i02 = tgpig[1];
|
|
|
|
|
const int i01 = tgpig[0];
|
|
|
|
|
typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
|
|
|
|
|
|
|
|
|
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
|
|
|
|
|
|
|
|
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
|
|
|
|
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
|
|
|
|
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
|
|
|
|
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0;
|
|
|
|
|
|
|
|
|
|
device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
|
|
|
|
|
|
|
|
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
|
|
|
|
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
|
|
|
|
|
|
|
|
quantize_q4_0(src, dst_data[i00/QK4_0]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel void kernel_cpy_f32_q4_1(
|
|
|
|
|
constant ggml_metal_kargs_cpy & args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
device char * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
const int i03 = tgpig[2];
|
|
|
|
|
const int i02 = tgpig[1];
|
|
|
|
|
const int i01 = tgpig[0];
|
|
|
|
|
|
|
|
|
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
|
|
|
|
|
|
|
|
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
|
|
|
|
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
|
|
|
|
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
|
|
|
|
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1;
|
|
|
|
|
|
|
|
|
|
device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
|
|
|
|
|
|
|
|
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
|
|
|
|
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
|
|
|
|
|
|
|
|
quantize_q4_1(src, dst_data[i00/QK4_1]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel void kernel_cpy_f32_q5_0(
|
|
|
|
|
constant ggml_metal_kargs_cpy & args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
device char * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
const int i03 = tgpig[2];
|
|
|
|
|
const int i02 = tgpig[1];
|
|
|
|
|
const int i01 = tgpig[0];
|
|
|
|
|
|
|
|
|
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
|
|
|
|
|
|
|
|
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
|
|
|
|
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
|
|
|
|
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
|
|
|
|
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0;
|
|
|
|
|
|
|
|
|
|
device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
|
|
|
|
|
|
|
|
for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
|
|
|
|
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
|
|
|
|
|
|
|
|
quantize_q5_0(src, dst_data[i00/QK5_0]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel void kernel_cpy_f32_q5_1(
|
|
|
|
|
constant ggml_metal_kargs_cpy & args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
device char * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
const int i03 = tgpig[2];
|
|
|
|
|
const int i02 = tgpig[1];
|
|
|
|
|
const int i01 = tgpig[0];
|
|
|
|
|
|
|
|
|
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
|
|
|
|
|
|
|
|
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
|
|
|
|
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
|
|
|
|
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
|
|
|
|
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1;
|
|
|
|
|
|
|
|
|
|
device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
|
|
|
|
|
|
|
|
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
|
|
|
|
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
|
|
|
|
|
|
|
|
quantize_q5_1(src, dst_data[i00/QK5_1]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
kernel void kernel_cpy_f32_iq4_nl(
|
|
|
|
|
constant ggml_metal_kargs_cpy & args,
|
|
|
|
|
device const char * src0,
|
|
|
|
|
device char * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
const int i03 = tgpig[2];
|
|
|
|
|
const int i02 = tgpig[1];
|
|
|
|
|
const int i01 = tgpig[0];
|
|
|
|
|
|
|
|
|
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
|
|
|
|
|
|
|
|
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
|
|
|
|
|
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
|
|
|
|
|
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
|
|
|
|
|
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL;
|
|
|
|
|
|
|
|
|
|
device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
|
|
|
|
|
|
|
|
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
|
|
|
|
|
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
|
|
|
|
|
|
|
|
|
|
quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
|
|
|
|
|
template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
|
|
|
|
|
|
|
|
|
|
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
|
|
|
|
kernel void kernel_cpy_q_f32(
|
|
|
|
|
@@ -5986,11 +5762,12 @@ kernel void kernel_cpy_q_f32(
|
|
|
|
|
device const char * src0,
|
|
|
|
|
device char * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort3 tpitg[[thread_position_in_threadgroup]],
|
|
|
|
|
ushort tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
ushort3 ntg[[threads_per_threadgroup]]) {
|
|
|
|
|
const int i03 = tgpig[2];
|
|
|
|
|
const int i02 = tgpig[1];
|
|
|
|
|
const int i01 = tgpig[0];
|
|
|
|
|
const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
|
|
|
|
|
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
|
|
|
|
|
|
|
|
|
|
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
|
|
|
|
|
|
|
|
|
|
@@ -6002,10 +5779,12 @@ kernel void kernel_cpy_q_f32(
|
|
|
|
|
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
|
|
|
|
|
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
|
|
|
|
|
|
|
|
|
for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) {
|
|
|
|
|
for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
|
|
|
|
|
T4x4 temp;
|
|
|
|
|
dequantize_func(src_data + i00/nl, i00%nl, temp);
|
|
|
|
|
dst_data[i00] = temp;
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -7765,66 +7544,60 @@ kernel void kernel_mul_mv_mxfp4_f32(
|
|
|
|
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
|
|
|
|
kernel void kernel_get_rows_q(
|
|
|
|
|
constant ggml_metal_kargs_get_rows & args,
|
|
|
|
|
device const void * src0,
|
|
|
|
|
device const void * src1,
|
|
|
|
|
device float * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
|
|
|
const int64_t i10 = tgpig.x;
|
|
|
|
|
const int64_t i11 = tgpig.y;
|
|
|
|
|
device const void * src0,
|
|
|
|
|
device const void * src1,
|
|
|
|
|
device void * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
ushort3 ntg [[threads_per_threadgroup]]) {
|
|
|
|
|
const int32_t iw0 = tgpig.x/args.ne10;
|
|
|
|
|
const int32_t i10 = tgpig.x%args.ne10;
|
|
|
|
|
const int32_t i11 = tgpig.y;
|
|
|
|
|
const int32_t i12 = tgpig.z;
|
|
|
|
|
|
|
|
|
|
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
|
|
|
|
|
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
|
|
|
|
|
|
|
|
|
|
const int64_t i02 = i11;
|
|
|
|
|
const int32_t i02 = i11;
|
|
|
|
|
const int32_t i03 = i12;
|
|
|
|
|
|
|
|
|
|
for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) {
|
|
|
|
|
auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
|
|
|
|
|
auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
|
|
|
|
|
|
|
|
|
|
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
|
|
|
|
|
float4x4 temp;
|
|
|
|
|
dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp);
|
|
|
|
|
*(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp;
|
|
|
|
|
dequantize_func(psrc + ind/nl, ind%nl, temp);
|
|
|
|
|
pdst[ind] = temp;
|
|
|
|
|
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename T>
|
|
|
|
|
template<typename T0, typename T>
|
|
|
|
|
kernel void kernel_get_rows_f(
|
|
|
|
|
constant ggml_metal_kargs_get_rows & args,
|
|
|
|
|
device const void * src0,
|
|
|
|
|
device const void * src1,
|
|
|
|
|
device float * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
|
|
|
const int64_t i10 = tgpig.x;
|
|
|
|
|
const int64_t i11 = tgpig.y;
|
|
|
|
|
device const void * src0,
|
|
|
|
|
device const void * src1,
|
|
|
|
|
device void * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
ushort tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
ushort3 ntg [[threads_per_threadgroup]]) {
|
|
|
|
|
const int32_t iw0 = tgpig.x/args.ne10;
|
|
|
|
|
const int32_t i10 = tgpig.x%args.ne10;
|
|
|
|
|
const int32_t i11 = tgpig.y;
|
|
|
|
|
const int32_t i12 = tgpig.z;
|
|
|
|
|
|
|
|
|
|
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
|
|
|
|
|
const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
|
|
|
|
|
|
|
|
|
|
const int64_t i02 = i11;
|
|
|
|
|
const int32_t i02 = i11;
|
|
|
|
|
const int32_t i03 = i12;
|
|
|
|
|
|
|
|
|
|
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
|
|
|
|
|
(( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
|
|
|
|
|
((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
|
|
|
|
|
auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
|
|
|
|
|
|
|
|
|
|
kernel void kernel_get_rows_i32(
|
|
|
|
|
constant ggml_metal_kargs_get_rows & args,
|
|
|
|
|
device const void * src0,
|
|
|
|
|
device const void * src1,
|
|
|
|
|
device int32_t * dst,
|
|
|
|
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
|
|
|
uint tiitg[[thread_index_in_threadgroup]],
|
|
|
|
|
uint3 tptg [[threads_per_threadgroup]]) {
|
|
|
|
|
const int64_t i10 = tgpig.x;
|
|
|
|
|
const int64_t i11 = tgpig.y;
|
|
|
|
|
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
|
|
|
|
|
pdst[ind] = psrc[ind];
|
|
|
|
|
|
|
|
|
|
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0];
|
|
|
|
|
|
|
|
|
|
const int64_t i02 = i11;
|
|
|
|
|
|
|
|
|
|
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
|
|
|
|
|
(( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
|
|
|
|
|
((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -8310,12 +8083,13 @@ kernel void kernel_mul_mm_id(
|
|
|
|
|
// get rows
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
|
|
|
|
|
typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
|
|
|
|
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
|
|
|
|
|
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
|
|
|
|
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
|
|
|
|
|
template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
|
|
|
|
|
#if defined(GGML_METAL_HAS_BF16)
|
|
|
|
|
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>;
|
|
|
|
|
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
|
|
|
|
|