diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e61b097833..0eefc0b137 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -318,6 +318,44 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_librar return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->op == GGML_OP_CUMSUM); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_cumsum_blk_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_ASSERT(op->op == GGML_OP_CUMSUM); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_cumsum_add_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(!op->src[1] || op->src[1]->type == GGML_TYPE_F16 || op->src[1]->type == GGML_TYPE_F32); diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 5539abda33..39ee6e3427 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -113,6 +113,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_blk (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cumsum_add (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 741b1a44db..e073d7af16 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -870,6 +870,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_SUM_ROWS: + case GGML_OP_CUMSUM: case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index dd889cd90d..0fae97029f 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -612,6 +612,45 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_sum_rows; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t net0; + int64_t net1; + int64_t net2; + int64_t net3; + uint64_t nbt0; + uint64_t nbt1; + uint64_t nbt2; + uint64_t nbt3; + bool outb; +} ggml_metal_kargs_cumsum_blk; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t net0; + int64_t net1; + int64_t net2; + int64_t net3; + uint64_t nbt0; + uint64_t nbt1; + uint64_t nbt2; + uint64_t nbt3; +} ggml_metal_kargs_cumsum_add; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 0c1714fdbc..973fc9e747 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -311,6 +311,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_sum_rows(ctx, idx); } break; + case GGML_OP_CUMSUM: + { + n_fuse = ggml_metal_op_cumsum(ctx, idx); + } break; case GGML_OP_SOFT_MAX: { n_fuse = ggml_metal_op_soft_max(ctx, idx); @@ -539,7 +543,7 @@ int ggml_metal_op_repeat(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_repeat(lib, op->type); @@ -585,7 +589,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); @@ -694,7 +698,7 @@ int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); float scale; float bias; @@ -733,7 +737,7 @@ int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); float min; float max; @@ -772,7 +776,7 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); int64_t n = ggml_nelements(op); @@ -802,7 +806,7 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); if (op->src[1]) { GGML_ASSERT(ggml_are_same_shape(op->src[0], op->src[1])); @@ -834,18 +838,6 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { const int32_t nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00/2); - //[encoder setComputePipelineState:pipeline]; - //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - //if (src1) { - // [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - //} else { - // [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - //} - //[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - //[encoder setBytes:&args length:sizeof(args) atIndex:3]; - - //[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); @@ -907,7 +899,7 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ggml_metal_kargs_sum_rows args = { /*.ne00 =*/ ne00, @@ -941,14 +933,6 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { const size_t smem = ggml_metal_pipeline_get_smem(pipeline); - //[encoder setComputePipelineState:pipeline]; - //[encoder setBytes:&args length:sizeof(args) atIndex:0]; - //[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; - //[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - //[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; - - //[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); @@ -961,6 +945,149 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_cumsum(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline_blk = ggml_metal_library_get_pipeline_cumsum_blk(lib, op); + + int nth = 1; + while (nth < ne00 && 2*nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline_blk)) { + nth *= 2; + } + + GGML_ASSERT(ne00 <= nth*nth); + + const int64_t net0 = (ne00 + nth - 1) / nth; + const int64_t net1 = ne01; + const int64_t net2 = ne02; + const int64_t net3 = ne03; + + const uint64_t nbt0 = sizeof(float); + const uint64_t nbt1 = net0*nbt0; + const uint64_t nbt2 = net1*nbt1; + const uint64_t nbt3 = net2*nbt2; + + const size_t smem = GGML_PAD(32*sizeof(float), 16); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_tmp = bid_dst; + bid_tmp.offs += ggml_nbytes(op); + + { + ggml_metal_kargs_cumsum_blk args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.net0 =*/ net0, + /*.net1 =*/ net1, + /*.net2 =*/ net2, + /*.net3 =*/ net3, + /*.nbt0 =*/ nbt0, + /*.nbt1 =*/ nbt1, + /*.nbt2 =*/ nbt2, + /*.nbt3 =*/ nbt3, + /*.outb =*/ ne00 > nth, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline_blk); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 2); + ggml_metal_encoder_set_buffer (enc, bid_dst, 3); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1); + } + + if (ne00 > nth) { + ggml_metal_op_concurrency_reset(ctx); + + { + ggml_metal_kargs_cumsum_blk args = { + /*.ne00 =*/ net0, + /*.ne01 =*/ net1, + /*.ne02 =*/ net2, + /*.ne03 =*/ net3, + /*.nb00 =*/ nbt0, + /*.nb01 =*/ nbt1, + /*.nb02 =*/ nbt2, + /*.nb03 =*/ nbt3, + /*.net0 =*/ net0, + /*.net1 =*/ net1, + /*.net2 =*/ net2, + /*.net3 =*/ net3, + /*.nbt0 =*/ nbt0, + /*.nbt1 =*/ nbt1, + /*.nbt2 =*/ nbt2, + /*.nbt3 =*/ nbt3, + /*.outb =*/ false, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline_blk); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 1); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 2); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 3); + + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, net1, net2, net3, nth, 1, 1); + } + + ggml_metal_op_concurrency_reset(ctx); + + { + ggml_metal_pipeline_t pipeline_add = ggml_metal_library_get_pipeline_cumsum_add(lib, op); + + ggml_metal_kargs_cumsum_add args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.net0 =*/ net0, + /*.net1 =*/ net1, + /*.net2 =*/ net2, + /*.net3 =*/ net3, + /*.nbt0 =*/ nbt0, + /*.nbt1 =*/ nbt1, + /*.nbt2 =*/ nbt2, + /*.nbt3 =*/ nbt3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline_add); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_tmp, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, net0*ne01, ne02, ne03, nth, 1, 1); + } + } + + return 1; +} + int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -972,7 +1099,7 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); @@ -1017,7 +1144,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type); @@ -1081,7 +1208,7 @@ int ggml_metal_op_soft_max(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); float scale; float max_bias; @@ -1169,7 +1296,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ggml_metal_kargs_ssm_conv args = { /*.ne00 =*/ ne00, @@ -1224,7 +1351,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne6, op->src[6], ne); GGML_TENSOR_LOCALS(uint64_t, nb6, op->src[6], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); const ggml_tensor * src3 = op->src[3]; const ggml_tensor * src4 = op->src[4]; @@ -1310,7 +1437,7 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); const int64_t B = op->op == GGML_OP_RWKV_WKV6 ? op->src[5]->ne[1] : op->src[6]->ne[1]; const int64_t T = op->src[0]->ne[2]; @@ -1351,7 +1478,7 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); @@ -1424,7 +1551,7 @@ int ggml_metal_op_pool_2d(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); const int32_t * opts = op->op_params; ggml_op_pool op_pool = (ggml_op_pool) opts[0]; @@ -1488,7 +1615,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); GGML_ASSERT(ne00 == ne10); @@ -1729,7 +1856,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); // src2 = ids GGML_ASSERT(op->src[2]->type == GGML_TYPE_I32); @@ -2689,7 +2816,7 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); float eps; memcpy(&eps, op->op_params, sizeof(float)); @@ -2737,7 +2864,7 @@ int ggml_metal_op_group_norm(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); const int32_t ngrp = ((const int32_t *) op->op_params)[0]; @@ -2792,7 +2919,7 @@ int ggml_metal_op_norm(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); float eps; memcpy(&eps, op->op_params, sizeof(float)); @@ -2928,7 +3055,7 @@ int ggml_metal_op_rope(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); // make sure we have one or more position id(ne10) per token(ne02) GGML_ASSERT(ne10 % ne02 == 0); @@ -3022,7 +3149,7 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); const int32_t s0 = ((const int32_t *)(op->op_params))[0]; const int32_t s1 = ((const int32_t *)(op->op_params))[1]; @@ -3172,7 +3299,7 @@ int ggml_metal_op_conv_transpose_1d(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); const int32_t s0 = ((const int32_t *)(op->op_params))[0]; @@ -3217,7 +3344,7 @@ int ggml_metal_op_conv_transpose_2d(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); const int32_t s0 = ((const int32_t *)(op->op_params))[0]; @@ -3271,7 +3398,7 @@ int ggml_metal_op_upscale(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); const float sf0 = (float)ne0/op->src[0]->ne[0]; const float sf1 = (float)ne1/op->src[0]->ne[1]; @@ -3324,7 +3451,7 @@ int ggml_metal_op_pad(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ggml_metal_kargs_pad args = { /*.ne00 =*/ ne00, @@ -3368,7 +3495,7 @@ int ggml_metal_op_pad_reflect_1d(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ggml_metal_kargs_pad_reflect_1d args = { /*.ne00 =*/ ne00, @@ -3412,7 +3539,7 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_t enc = ctx->enc; GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); float start; float step; @@ -3430,12 +3557,6 @@ int ggml_metal_op_arange(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_arange(lib, op); - //[encoder setComputePipelineState:pipeline]; - //[encoder setBuffer:id_dst offset:offs_dst atIndex:0]; - //[encoder setBytes:&args length:sizeof(args) atIndex:1]; - - //[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1); @@ -3454,7 +3575,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); const int dim = op->op_params[0]; const int max_period = op->op_params[1]; @@ -3488,7 +3609,7 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ggml_metal_kargs_argmax args = { /*.ne00 = */ ne00, @@ -3529,7 +3650,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_argsort(lib, op); @@ -3539,7 +3660,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { nth *= 2; } - const int nptg = (ne00 + nth - 1)/nth; + const int npr = (ne00 + nth - 1)/nth; // Metal kernels require the buffer size to be multiple of 16 bytes // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength @@ -3551,7 +3672,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_buffer_id bid_tmp = bid_dst; bid_tmp.offs += ggml_nbytes(op); - if ((int) ceil(std::log(nptg) / std::log(2)) % 2 == 1) { + if ((int) ceil(std::log(npr) / std::log(2)) % 2 == 1) { std::swap(bid_dst, bid_tmp); } @@ -3573,7 +3694,7 @@ int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, nptg*ne01, ne02, ne03, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, npr*ne01, ne02, ne03, nth, 1, 1); ggml_metal_pipeline_t pipeline_merge = ggml_metal_library_get_pipeline_argsort_merge(lib, op); @@ -3626,7 +3747,7 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); float slope; memcpy(&slope, op->op_params, sizeof(float)); @@ -3662,7 +3783,7 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); @@ -3698,7 +3819,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 3cf400dc45..332e550ee7 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -52,6 +52,7 @@ int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_cumsum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 35f07f3e71..f6033ddc97 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -197,6 +197,7 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ res += ggml_metal_op_flash_attn_ext_extra_blk(tensor); res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); } break; + case GGML_OP_CUMSUM: case GGML_OP_ARGSORT: { res *= 2; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 8afc7318f6..eabb22165d 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1832,6 +1832,117 @@ typedef decltype(kernel_sum_rows) kernel_sum_rows_t; template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template +kernel void kernel_cumsum_blk( + constant ggml_metal_kargs_cumsum_blk & args, + device const char * src0, + device char * tmp, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int ib = tgpig[0]/args.ne01; + + const int i00 = ib*ntg.x; + const int i01 = tgpig[0]%args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; + + device const float * src0_row = (device const float *) (src0 + + args.nb01*i01 + + args.nb02*i02 + + args.nb03*i03); + + threadgroup float * shmem_f32 = (threadgroup float *) shmem; + + float v = 0.0f; + + if (i00 + tpitg.x < args.ne00) { + v = src0_row[i00 + tpitg.x]; + } + + float s = simd_prefix_inclusive_sum(v); + + if (tiisg == N_SIMDWIDTH - 1) { + shmem_f32[sgitg] = s; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + shmem_f32[tiisg] = simd_prefix_exclusive_sum(shmem_f32[tiisg]); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + s += shmem_f32[sgitg]; + + device float * dst_row = (device float *) dst + + args.ne00*i01 + + args.ne00*args.ne01*i02 + + args.ne00*args.ne01*args.ne02*i03; + + if (i00 + tpitg.x < args.ne00) { + dst_row[i00 + tpitg.x] = s; + } + + if (args.outb && tpitg.x == ntg.x - 1) { + device float * tmp_row = (device float *) tmp + + args.net0*i01 + + args.net0*args.net1*i02 + + args.net0*args.net1*args.net2*i03; + + tmp_row[ib] = s; + } +} + +typedef decltype(kernel_cumsum_blk) kernel_cumsum_blk_t; + +template [[host_name("kernel_cumsum_blk_f32")]] kernel kernel_cumsum_blk_t kernel_cumsum_blk; + +template +kernel void kernel_cumsum_add( + constant ggml_metal_kargs_cumsum_add & args, + device const char * tmp, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int ib = tgpig[0]/args.ne01; + + if (ib == 0) { + return; + } + + const int i00 = ib*ntg.x; + const int i01 = tgpig[0]%args.ne01; + const int i02 = tgpig[1]; + const int i03 = tgpig[2]; + + device const float * tmp_row = (device const float *) (tmp + + args.nbt1*i01 + + args.nbt2*i02 + + args.nbt3*i03); + + device float * dst_row = (device float *) dst + + args.ne00*i01 + + args.ne00*args.ne01*i02 + + args.ne00*args.ne01*args.ne02*i03; + + if (i00 + tpitg.x < args.ne00) { + dst_row[i00 + tpitg.x] += tmp_row[ib - 1]; + } +} + +typedef decltype(kernel_cumsum_add) kernel_cumsum_add_t; + +template [[host_name("kernel_cumsum_add_f32")]] kernel kernel_cumsum_add_t kernel_cumsum_add; + template kernel void kernel_soft_max( constant ggml_metal_kargs_soft_max & args, @@ -4543,7 +4654,7 @@ typedef void (argsort_t)( constant ggml_metal_kargs_argsort & args, device const char * src0, device int32_t * dst, - threadgroup int32_t * smem_i32 [[threadgroup(0)]], + threadgroup int32_t * shmem_i32 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]); @@ -4553,7 +4664,7 @@ kernel void kernel_argsort_f32_i32( constant ggml_metal_kargs_argsort & args, device const char * src0, device int32_t * dst, - threadgroup int32_t * smem_i32 [[threadgroup(0)]], + threadgroup int32_t * shmem_i32 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { @@ -4565,10 +4676,10 @@ kernel void kernel_argsort_f32_i32( const int i02 = tgpig[1]; const int i03 = tgpig[2]; - device const float * x_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03); + device const float * src0_row = (device const float *) (src0 + args.nb01*i01 + args.nb02*i02 + args.nb03*i03); // initialize indices - smem_i32[col] = i00 + col; + shmem_i32[col] = i00 + col; threadgroup_barrier(mem_flags::mem_threadgroup); @@ -4577,20 +4688,20 @@ kernel void kernel_argsort_f32_i32( int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (smem_i32[col] >= args.ne00 || - (smem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ? - x_row[smem_i32[col]] > x_row[smem_i32[ixj]] : - x_row[smem_i32[col]] < x_row[smem_i32[ixj]])) + if (shmem_i32[col] >= args.ne00 || + (shmem_i32[ixj] < args.ne00 && (order == GGML_SORT_ORDER_ASC ? + src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]] : + src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]])) ) { - SWAP(smem_i32[col], smem_i32[ixj]); + SWAP(shmem_i32[col], shmem_i32[ixj]); } } else { - if (smem_i32[ixj] >= args.ne00 || - (smem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ? - x_row[smem_i32[col]] < x_row[smem_i32[ixj]] : - x_row[smem_i32[col]] > x_row[smem_i32[ixj]])) + if (shmem_i32[ixj] >= args.ne00 || + (shmem_i32[col] < args.ne00 && (order == GGML_SORT_ORDER_ASC ? + src0_row[shmem_i32[col]] < src0_row[shmem_i32[ixj]] : + src0_row[shmem_i32[col]] > src0_row[shmem_i32[ixj]])) ) { - SWAP(smem_i32[col], smem_i32[ixj]); + SWAP(shmem_i32[col], shmem_i32[ixj]); } } } @@ -4603,7 +4714,7 @@ kernel void kernel_argsort_f32_i32( if (i00 + col < args.ne00) { dst += i00 + args.ne00*i01 + args.ne00*args.ne01*i02 + args.ne00*args.ne01*args.ne02*i03; - dst[col] = smem_i32[col]; + dst[col] = shmem_i32[col]; } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index a87190e9f4..267bead8c4 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7558,7 +7558,20 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_arange()); test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); - test_cases.emplace_back(new test_cumsum()); + + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 512, 5, 4, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1023, 5, 4, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 1024, 5, 4, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2047, 5, 4, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 5, 4, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 201*1204, 1, 1, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 312*1205, 1, 1, 1 })); test_cases.emplace_back(new test_xielu());