metal : various optimizations + refactoring (#16446)

* metal : ssm_scan minor opts

* metal : get_rows optimize

* metal : cpy optimize

* metal : ssm_conv opt

* metal : ssm_scan simplify

* metal : ssm_Scan opt
This commit is contained in:
Georgi Gerganov
2025-10-07 08:21:40 +03:00
committed by GitHub
parent 3df2244df4
commit 8ae32dc9ec
5 changed files with 258 additions and 452 deletions

View File

@@ -338,7 +338,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
char base[256]; char base[256];
char name[256]; char name[256];
snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); const char * suffix = "";
if (op->src[1]->ne[0] % 4 == 0) {
suffix = "_4";
}
snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix);
snprintf(name, 256, "%s", base); snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
@@ -352,15 +358,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar
} }
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
char base[256]; char base[256];
char name[256]; char name[256];
if (op->src[3]->ne[0] == 1) { const int nsg = (ne00 + 31)/32;
snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type));
} else {
snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type));
} snprintf(name, 256, "%s_nsg=%d", base, nsg);
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) { if (res) {
@@ -369,7 +375,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg);
return res; return res;
} }

View File

@@ -776,9 +776,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
}; };
} }
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
{ return true;
return op->ne[3] == 1;
}
case GGML_OP_SET_ROWS: case GGML_OP_SET_ROWS:
{ {
if (op->src[0]->type != GGML_TYPE_F32) { if (op->src[0]->type != GGML_TYPE_F32) {

View File

@@ -178,6 +178,7 @@ typedef struct {
} ggml_metal_kargs_clamp; } ggml_metal_kargs_clamp;
typedef struct { typedef struct {
int64_t nk0;
int64_t ne00; int64_t ne00;
int64_t ne01; int64_t ne01;
int64_t ne02; int64_t ne02;
@@ -572,32 +573,45 @@ typedef struct {
int64_t n_seq_tokens; int64_t n_seq_tokens;
int64_t n_seqs; int64_t n_seqs;
uint64_t s_off; uint64_t s_off;
uint64_t nb00;
uint64_t nb01; uint64_t nb01;
uint64_t nb02; uint64_t nb02;
uint64_t nb03; uint64_t nb03;
uint64_t nb10;
uint64_t nb11; uint64_t nb11;
uint64_t nb12; uint64_t nb12;
uint64_t ns12;
uint64_t nb13; uint64_t nb13;
uint64_t nb20;
uint64_t nb21; uint64_t nb21;
uint64_t ns21;
uint64_t nb22; uint64_t nb22;
int64_t ne30;
uint64_t nb31; uint64_t nb31;
uint64_t nb41; uint64_t nb41;
uint64_t nb42; uint64_t nb42;
uint64_t ns42;
uint64_t nb43; uint64_t nb43;
uint64_t nb51; uint64_t nb51;
uint64_t nb52; uint64_t nb52;
uint64_t ns52;
uint64_t nb53; uint64_t nb53;
uint64_t nb0;
} ggml_metal_kargs_ssm_scan; } ggml_metal_kargs_ssm_scan;
typedef struct { typedef struct {
int64_t ne00; int32_t ne00t;
int32_t ne00;
uint64_t nb01; uint64_t nb01;
uint64_t nb02; uint64_t nb02;
int64_t ne10; uint64_t nb03;
int32_t ne10;
uint64_t nb10; uint64_t nb10;
uint64_t nb11; uint64_t nb11;
uint64_t nb12;
uint64_t nb1; uint64_t nb1;
uint64_t nb2; uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_get_rows; } ggml_metal_kargs_get_rows;
typedef struct { typedef struct {

View File

@@ -577,6 +577,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
ggml_metal_kargs_cpy args = { ggml_metal_kargs_cpy args = {
/*.nk0 =*/ ne00,
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
/*.ne01 =*/ ne01, /*.ne01 =*/ ne01,
/*.ne02 =*/ ne02, /*.ne02 =*/ ne02,
@@ -906,23 +907,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type);
ggml_metal_kargs_get_rows args = { ggml_metal_kargs_get_rows args = {
/*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00,
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
/*.nb01 =*/ nb01, /*.nb01 =*/ nb01,
/*.nb02 =*/ nb02, /*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10, /*.ne10 =*/ ne10,
/*.nb10 =*/ nb10, /*.nb10 =*/ nb10,
/*.nb11 =*/ nb11, /*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb1 =*/ nb1, /*.nb1 =*/ nb1,
/*.nb2 =*/ nb2, /*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
}; };
const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const int nw0 = (args.ne00t + nth - 1)/nth;
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1); ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1);
return 1; return 1;
} }
@@ -1172,25 +1181,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
/*.n_seq_tokens =*/ n_seq_tokens, /*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs, /*.n_seqs =*/ n_seqs,
/*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float), /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float),
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01, /*.nb01 =*/ nb01,
/*.nb02 =*/ nb02, /*.nb02 =*/ nb02,
/*.nb03 =*/ nb03, /*.nb03 =*/ nb03,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11, /*.nb11 =*/ nb11,
/*.nb12 =*/ nb12, /*.nb12 =*/ nb12,
/*.ns12 =*/ nb12/nb10,
/*.nb13 =*/ nb13, /*.nb13 =*/ nb13,
/*.nb20 =*/ nb20,
/*.nb21 =*/ nb21, /*.nb21 =*/ nb21,
/*.ns21 =*/ nb21/nb20,
/*.nb22 =*/ nb22, /*.nb22 =*/ nb22,
/*.ne30 =*/ ne30,
/*.nb31 =*/ nb31, /*.nb31 =*/ nb31,
/*.nb41 =*/ nb41, /*.nb41 =*/ nb41,
/*.nb42 =*/ nb42, /*.nb42 =*/ nb42,
/*.ns42 =*/ nb42/nb40,
/*.nb43 =*/ nb43, /*.nb43 =*/ nb43,
/*.nb51 =*/ nb51, /*.nb51 =*/ nb51,
/*.nb52 =*/ nb52, /*.nb52 =*/ nb52,
/*.ns52 =*/ nb52/nb50,
/*.nb53 =*/ nb53, /*.nb53 =*/ nb53,
/*.nb0 =*/ nb0,
}; };
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op);
GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const size_t sms = ggml_metal_pipeline_get_smem(pipeline); const size_t sms = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
@@ -1206,13 +1226,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0); ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0);
if (ne30 == 1) {
// Mamba-2
ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1);
} else {
GGML_ASSERT(d_inner == 1);
ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1);
}
return 1; return 1;
} }
@@ -1273,26 +1287,23 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0);
// TODO: support int64_t nk0 = ne00;
//const int32_t nk00 = ne00/ggml_blck_size(op->type); if (ggml_is_quantized(op->src[0]->type)) {
const int32_t nk00 = ne00; nk0 = ne00/16;
} else if (ggml_is_quantized(op->type)) {
int nth = 32; // SIMD width nk0 = ne00/ggml_blck_size(op->type);
while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
} }
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
// when rows are small, we can batch them together in a single threadgroup // when rows are small, we can batch them together in a single threadgroup
int nrptg = 1; int nrptg = 1;
// TODO: relax this constraint in the future // TODO: relax this constraint in the future
if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) { if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) {
if (nth > nk00) { if (nth > nk0) {
nrptg = (nth + nk00 - 1)/nk00; nrptg = (nth + nk0 - 1)/nk0;
nth = nk00; nth = nk0;
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nrptg--; nrptg--;
@@ -1300,10 +1311,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
} }
} }
nth = std::min(nth, nk00); nth = std::min<int>(nth, nk0);
ggml_metal_kargs_cpy args = { ggml_metal_kargs_cpy args = {
/*.ne00 =*/ nk00, /*.nk0 =*/ nk0,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01, /*.ne01 =*/ ne01,
/*.ne02 =*/ ne02, /*.ne02 =*/ ne02,
/*.ne03 =*/ ne03, /*.ne03 =*/ ne03,
@@ -1321,12 +1333,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
/*.nb3 =*/ nb3, /*.nb3 =*/ nb3,
}; };
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1); ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1);
return 1; return 1;
} }

View File

@@ -2032,7 +2032,38 @@ kernel void kernel_ssm_conv_f32_f32(
x[0] = sumf; x[0] = sumf;
} }
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part kernel void kernel_ssm_conv_f32_f32_4(
constant ggml_metal_kargs_ssm_conv & args,
device const void * src0,
device const void * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t ir = tgpig.x;
const int64_t i2 = tgpig.y;
const int64_t i3 = tgpig.z;
const int64_t nc = args.ne10;
//const int64_t ncs = args.ne00;
//const int64_t nr = args.ne01;
//const int64_t n_t = args.ne1;
//const int64_t n_s = args.ne2;
device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02);
device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11);
device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2);
float sumf = 0.0f;
for (int64_t i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}
x[0] = sumf;
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
kernel void kernel_ssm_scan_f32( kernel void kernel_ssm_scan_f32(
constant ggml_metal_kargs_ssm_scan & args, constant ggml_metal_kargs_ssm_scan & args,
device const void * src0, device const void * src0,
@@ -2045,218 +2076,87 @@ kernel void kernel_ssm_scan_f32(
device float * dst, device float * dst,
threadgroup float * shared [[threadgroup(0)]], threadgroup float * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]], ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]], ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) { uint3 tgpg[[threadgroups_per_grid]]) {
constexpr short NW = N_SIMDWIDTH;
const int64_t i0 = tpitg.x; shared[tpitg.x] = 0.0f;
const int64_t i1 = 0;
const int64_t ir = tgpig.x; // current head
const int64_t i3 = tgpig.y; // current seq
const uint64_t nb00 = sizeof(float); const int32_t i0 = tpitg.x;
const uint64_t nb10 = sizeof(float); const int32_t i1 = tgpig.x;
const uint64_t nb20 = sizeof(float); const int32_t ir = tgpig.y; // current head
const int32_t i3 = tgpig.z; // current seq
const int64_t nc = args.d_state; const int32_t nc = args.d_state;
const int64_t nr = args.d_inner; const int32_t nr = args.d_inner;
const int64_t nh = args.n_head; const int32_t nh = args.n_head;
const int64_t ng = args.n_group; const int32_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens; const int32_t n_t = args.n_seq_tokens;
const int64_t s_off = args.s_off; const int32_t s_off = args.s_off;
device const int32_t * ids = (device const int32_t *) src6; device const int32_t * ids = (device const int32_t *) src6;
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
const int64_t g = ir / (nh / ng); // repeat_interleave const int32_t i = i0 + i1*nc;
const int32_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i]; float s0 = s0_buff[i];
float s = s_buff[i]; float s = 0.0f;
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh}
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) { const float A0 = A[i0%args.ne30];
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns}
const float x_dt = x[0] * dt_soft_plus; device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns}
device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns}
device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns}
const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns}
s = state;
// Parallel sum: This relies on the fact that this kernel will be for (int i2 = 0; i2 < n_t; i2 += sgptg) {
// dispatched with each threadgroup having (d_state, 1, 1) threads which
// are subdivided into SIMD groups of size `sgptg`. The goal is to
// compute y = sum({state * C[i] for i in range(d_state)}).
// To parallelize this effectively, we first use simd_sum over each SIMD
// group to compute the sum of each SIMD group, then place the result in
// the SIMD group's indexed bucket in the shared memory. We then sum
// over the individual group sums to compute the final sum.
// Computed for each thread
float sumf = state * C[i0];
// Sum the threads in the simd group => simd sum
sumf = simd_sum(sumf);
if (sgptg > 1) {
// Once per simd group, place the group sum into the shared buffer
if (tiisg == 0) {
shared[sgitg] = sumf;
}
// Wait for all threads in the threadgroup to reach this point. This
// ensures that all elements of the shared buffer are populated with the
// sum of the individual simd groups.
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// For simd group 0 at indices < num simd groups, extract the shared for (int t = 0; t < sgptg && i2 + t < n_t; t++) {
// simd sum const float dt0 = dt[0];
sumf = 0.0f; const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0;
if (sgitg == 0) { const float x_dt = x[0] * dtsp;
if (tiisg < sgptg) { const float dA = exp(dtsp * A0);
sumf = shared[tiisg];
} s = (s0 * dA) + (B[i0] * x_dt);
sumf = simd_sum(sumf);
const float sumf = simd_sum(s * C[i0]);
if (tiisg == 0) { if (tiisg == 0) {
y[0] = sumf; shared[t*NW + sgitg] = sumf;
}
}
} else if (tiisg == 0) {
y[0] = sumf;
} }
// recurse // recurse
s0 = s; s0 = s;
x += args.ns12;
dt += args.ns21;
B += args.ns42;
C += args.ns52;
} }
// Assign the final state to the output buffer
s_buff[i] = s;
}
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part
kernel void kernel_ssm_scan_group_f32(
constant ggml_metal_kargs_ssm_scan & args,
device const void * src0,
device const void * src1,
device const void * src2,
device const void * src3,
device const void * src4,
device const void * src5,
device const void * src6,
device float * dst,
threadgroup float * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgptg[[simdgroups_per_threadgroup]],
uint3 tgpg[[threadgroups_per_grid]]) {
const int64_t i0 = tpitg.x;
const int64_t i1 = tgpig.x;
const int64_t ir = tgpig.y; // current head
const int64_t i3 = tgpig.z; // current seq
const uint64_t nb00 = sizeof(float);
const uint64_t nb10 = sizeof(float);
const uint64_t nb20 = sizeof(float);
const int64_t nc = args.d_state;
const int64_t nr = args.d_inner;
const int64_t nh = args.n_head;
const int64_t ng = args.n_group;
const int64_t n_t = args.n_seq_tokens;
const int64_t s_off = args.s_off;
device const int32_t * ids = (device const int32_t *) src6;
device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
const int64_t i = i0 + i1*nc;
const int64_t g = ir / (nh / ng); // repeat_interleave
float s0 = s0_buff[i];
float s = s_buff[i];
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13);
device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22);
device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43);
device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53);
device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00);
for (int64_t i2 = 0; i2 < n_t; ++i2) {
device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns}
device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns}
device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns}
device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns}
device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns}
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
const float x_dt = x[0] * dt_soft_plus;
const float dA = exp(dt_soft_plus * A[0]);
const float state = (s0 * dA) + (B[i0] * x_dt);
s = state;
// Parallel sum: This relies on the fact that this kernel will be
// dispatched with each threadgroup having (d_state, 1, 1) threads which
// are subdivided into SIMD groups of size `sgptg`. The goal is to
// compute y = sum({state * C[i] for i in range(d_state)}).
// To parallelize this effectively, we first use simd_sum over each SIMD
// group to compute the sum of each SIMD group, then place the result in
// the SIMD group's indexed bucket in the shared memory. We then sum
// over the individual group sums to compute the final sum.
// Computed for each thread
float sumf = state * C[i0];
// Sum the threads in the simd group => simd sum
sumf = simd_sum(sumf);
// Once per simd group, place the group sum into the shared buffer
if (tiisg == 0) {
shared[sgitg] = sumf;
}
// Wait for all threads in the threadgroup to reach this point. This
// ensures that all elements of the shared buffer are populated with the
// sum of the individual simd groups.
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
// For simd group 0 at indices < num simd groups, extract the shared const float sumf = simd_sum(shared[sgitg*NW + tiisg]);
// simd sum
sumf = 0.0f; if (tiisg == 0 && i2 + sgitg < n_t) {
if (sgitg == 0) { y[sgitg*nh*nr] = sumf;
if (tiisg < sgptg) {
sumf = shared[tiisg];
}
sumf = simd_sum(sumf);
if (tiisg == 0) {
y[0] = sumf;
}
} }
// recurse y += sgptg*nh*nr;
s0 = s;
} }
// Assign the final state to the output buffer
s_buff[i] = s; s_buff[i] = s;
} }
@@ -5770,21 +5670,17 @@ kernel void kernel_flash_attn_ext_vec_reduce(
} }
template<typename T0, typename T1> template<typename T0, typename T1>
kernel void kernel_cpy( kernel void kernel_cpy_t_t(
constant ggml_metal_kargs_cpy & args, constant ggml_metal_kargs_cpy & args,
device const char * src0, device const char * src0,
device char * dst, device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]], ushort tiitg[[thread_index_in_threadgroup]],
ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) {
ushort3 tptg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2]; const int i03 = tgpig[2];
const int i02 = tgpig[1]; const int i02 = tgpig[1];
const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x; const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
if (i01 >= args.ne01) {
return;
}
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
@@ -5795,190 +5691,70 @@ kernel void kernel_cpy(
device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) { for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) {
device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
dst_data[i00] = (T1) src[0]; dst_data[i00] = (T1) src[0];
break;
} }
} }
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t; typedef decltype(kernel_cpy_t_t<float, float>) kernel_cpy_t;
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>; template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>; template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy<float, int32_t>; template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy<int32_t, float>; template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<int32_t, float>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>; template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<float, bfloat>;
#endif #endif
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>; template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<half, float>;
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>; template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t<half, half>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy<bfloat, float>; template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, float>;
template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>; template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t<bfloat, bfloat>;
#endif #endif
// TODO: templetify these kernels template<short QK,
kernel void kernel_cpy_f32_q8_0( typename block_q,
void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_cpy_f32_q(
constant ggml_metal_kargs_cpy & args, constant ggml_metal_kargs_cpy & args,
device const char * src0, device const char * src0,
device char * dst, device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]], ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) { ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2]; const int i03 = tgpig[2];
const int i02 = tgpig[1]; const int i02 = tgpig[1];
const int i01 = tgpig[0]; const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0; const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK;
device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) { for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00);
quantize_q8_0(src, dst_data[i00/QK8_0]); quantize_func(src, dst_data[i00]);
break;
} }
} }
kernel void kernel_cpy_f32_q4_0( typedef decltype(kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>) cpy_f_q_t;
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK8_0, block_q8_0, quantize_q8_0>;
template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_0, block_q4_0, quantize_q4_0>;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_1, block_q4_1, quantize_q4_1>;
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_0, block_q5_0, quantize_q5_0>;
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK5_1, block_q5_1, quantize_q5_1>;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0; template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q<QK4_NL, block_iq4_nl, quantize_iq4_nl>;
device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_q4_0(src, dst_data[i00/QK4_0]);
}
}
kernel void kernel_cpy_f32_q4_1(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1;
device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_q4_1(src, dst_data[i00/QK4_1]);
}
}
kernel void kernel_cpy_f32_q5_0(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0;
device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_q5_0(src, dst_data[i00/QK5_0]);
}
}
kernel void kernel_cpy_f32_q5_1(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1;
device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_q5_1(src, dst_data[i00/QK5_1]);
}
}
kernel void kernel_cpy_f32_iq4_nl(
constant ggml_metal_kargs_cpy & args,
device const char * src0,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2];
const int i02 = tgpig[1];
const int i01 = tgpig[0];
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL;
device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
}
}
template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)> template<typename T4x4, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
kernel void kernel_cpy_q_f32( kernel void kernel_cpy_q_f32(
@@ -5986,11 +5762,12 @@ kernel void kernel_cpy_q_f32(
device const char * src0, device const char * src0,
device char * dst, device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]], ushort tiitg[[thread_index_in_threadgroup]],
ushort3 ntg[[threads_per_threadgroup]]) { ushort3 ntg[[threads_per_threadgroup]]) {
const int i03 = tgpig[2]; const int i03 = tgpig[2];
const int i02 = tgpig[1]; const int i02 = tgpig[1];
const int i01 = tgpig[0]; const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0];
const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0;
const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
@@ -6002,10 +5779,12 @@ kernel void kernel_cpy_q_f32(
device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01);
device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) {
T4x4 temp; T4x4 temp;
dequantize_func(src_data + i00/nl, i00%nl, temp); dequantize_func(src_data + i00/nl, i00%nl, temp);
dst_data[i00] = temp; dst_data[i00] = temp;
break;
} }
} }
@@ -7767,64 +7546,58 @@ kernel void kernel_get_rows_q(
constant ggml_metal_kargs_get_rows & args, constant ggml_metal_kargs_get_rows & args,
device const void * src0, device const void * src0,
device const void * src1, device const void * src1,
device float * dst, device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]], ushort tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) { ushort3 ntg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x; const int32_t iw0 = tgpig.x/args.ne10;
const int64_t i11 = tgpig.y; const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int64_t i02 = i11; const int32_t i02 = i11;
const int32_t i03 = i12;
for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) { auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
float4x4 temp; float4x4 temp;
dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp); dequantize_func(psrc + ind/nl, ind%nl, temp);
*(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp; pdst[ind] = temp;
break;
} }
} }
template<typename T> template<typename T0, typename T>
kernel void kernel_get_rows_f( kernel void kernel_get_rows_f(
constant ggml_metal_kargs_get_rows & args, constant ggml_metal_kargs_get_rows & args,
device const void * src0, device const void * src0,
device const void * src1, device const void * src1,
device float * dst, device void * dst,
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]], ushort tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) { ushort3 ntg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x; const int32_t iw0 = tgpig.x/args.ne10;
const int64_t i11 = tgpig.y; const int32_t i10 = tgpig.x%args.ne10;
const int32_t i11 = tgpig.y;
const int32_t i12 = tgpig.z;
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0];
const int64_t i02 = i11; const int32_t i02 = i11;
const int32_t i03 = i12;
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01);
(( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1);
((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
}
}
kernel void kernel_get_rows_i32( for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) {
constant ggml_metal_kargs_get_rows & args, pdst[ind] = psrc[ind];
device const void * src0,
device const void * src1,
device int32_t * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint3 tptg [[threads_per_threadgroup]]) {
const int64_t i10 = tgpig.x;
const int64_t i11 = tgpig.y;
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; break;
const int64_t i02 = i11;
for (int ind = tiitg; ind < args.ne00; ind += tptg.x) {
(( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] =
((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind];
} }
} }
@@ -8310,12 +8083,13 @@ kernel void kernel_mul_mm_id(
// get rows // get rows
// //
typedef decltype(kernel_get_rows_f<float>) get_rows_f_t; typedef decltype(kernel_get_rows_f<float, float>) get_rows_f_t;
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>; template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float, float>;
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half, float>;
template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f<int32_t, int32_t>;
#if defined(GGML_METAL_HAS_BF16) #if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat>; template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f<bfloat, float>;
#endif #endif
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t; typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;