mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Update and fix Vulkan soft_max and argsort implementations (#7237)
* Update and fix Vulkan softmax implementation * Update and fix Vulkan argsort implementation
This commit is contained in:
		
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										192
									
								
								ggml-vulkan.cpp
									
									
									
									
									
								
							
							
						
						
									
										192
									
								
								ggml-vulkan.cpp
									
									
									
									
									
								
							| @@ -294,7 +294,6 @@ struct vk_op_rope_neox_push_constants { | |||||||
| struct vk_op_soft_max_push_constants { | struct vk_op_soft_max_push_constants { | ||||||
|     uint32_t KX; |     uint32_t KX; | ||||||
|     uint32_t KY; |     uint32_t KY; | ||||||
|     uint32_t KZ; |  | ||||||
|     float scale; |     float scale; | ||||||
|     float max_bias; |     float max_bias; | ||||||
|     float m0; |     float m0; | ||||||
| @@ -304,7 +303,8 @@ struct vk_op_soft_max_push_constants { | |||||||
|  |  | ||||||
| struct vk_op_argsort_push_constants { | struct vk_op_argsort_push_constants { | ||||||
|     uint32_t ncols; |     uint32_t ncols; | ||||||
|     bool ascending; |     uint32_t ncols_pad; | ||||||
|  |     int32_t order; | ||||||
| }; | }; | ||||||
|  |  | ||||||
| // Allow pre-recording command buffers | // Allow pre-recording command buffers | ||||||
| @@ -1501,8 +1501,8 @@ static void ggml_vk_load_shaders(ggml_backend_vk_context * ctx) { | |||||||
|  |  | ||||||
|     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); |     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {512, 1, 1}, {}, 1); | ||||||
|  |  | ||||||
|     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); |     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); | ||||||
|     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); |     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, {}, 1); | ||||||
|  |  | ||||||
|     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f32, "rope_f32", rope_f32_len, rope_f32_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); | ||||||
|     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); |     ggml_vk_create_pipeline(ctx, ctx->device->pipeline_rope_f16, "rope_f16", rope_f16_len, rope_f16_data, "main", 3, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); | ||||||
| @@ -3752,7 +3752,7 @@ static void ggml_vk_op_repeat(ggml_backend_vk_context * ctx, vk_context * subctx | |||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
| static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op) { | static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op) { | ||||||
|     switch (op) { |     switch (op) { | ||||||
|     case GGML_OP_ADD: |     case GGML_OP_ADD: | ||||||
|         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { |         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { | ||||||
| @@ -3834,7 +3834,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const | |||||||
|         if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { |         if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { | ||||||
|             return ctx->device->pipeline_soft_max_f32; |             return ctx->device->pipeline_soft_max_f32; | ||||||
|         } |         } | ||||||
|         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { |         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { | ||||||
|             return ctx->device->pipeline_soft_max_f32_f16; |             return ctx->device->pipeline_soft_max_f32_f16; | ||||||
|         } |         } | ||||||
|         return nullptr; |         return nullptr; | ||||||
| @@ -3900,15 +3900,12 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) { | |||||||
| } | } | ||||||
|  |  | ||||||
| template<typename PC> | template<typename PC> | ||||||
| static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, const PC&& pc) { | static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, ggml_op op, const PC&& pc) { | ||||||
| #ifdef GGML_VULKAN_DEBUG | #ifdef GGML_VULKAN_DEBUG | ||||||
|     std::cerr << "ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; |     std::cerr << "ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", backend=" << src0->backend << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; | ||||||
|     if (src1 != nullptr) { |     if (src1 != nullptr) { | ||||||
|         std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; |         std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", backend=" << src1->backend << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; | ||||||
|     } |     } | ||||||
|     if (src2 != nullptr) { |  | ||||||
|         std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", backend=" << src2->backend << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3]; |  | ||||||
|     } |  | ||||||
|     std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")" << std::endl; |     std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", backend=" << dst->backend << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "), " << ggml_op_name(op) << ")" << std::endl; | ||||||
| #endif | #endif | ||||||
|     GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type))));  // NOLINT |     GGML_ASSERT(op == GGML_OP_GET_ROWS || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type))));  // NOLINT | ||||||
| @@ -3929,10 +3926,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c | |||||||
|     const uint64_t nb2  = dst->nb[2]; |     const uint64_t nb2  = dst->nb[2]; | ||||||
|     const uint64_t nb3  = dst->nb[3]; |     const uint64_t nb3  = dst->nb[3]; | ||||||
|  |  | ||||||
|     const bool use_src2 = src2 != nullptr; |     vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, dst, op); | ||||||
|     const uint64_t ne2 = use_src2 ? src2->ne[0] * src2->ne[1] : 0; |  | ||||||
|  |  | ||||||
|     vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op); |  | ||||||
|     ggml_vk_func_t op_func; |     ggml_vk_func_t op_func; | ||||||
|  |  | ||||||
|     if (pipeline == nullptr) { |     if (pipeline == nullptr) { | ||||||
| @@ -3955,18 +3949,15 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c | |||||||
|     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; |     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) dst->extra; | ||||||
|     ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; |     ggml_tensor_extra_gpu * extra_src0 = (ggml_tensor_extra_gpu *) src0->extra; | ||||||
|     ggml_tensor_extra_gpu * extra_src1 = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; |     ggml_tensor_extra_gpu * extra_src1 = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr; | ||||||
|     ggml_tensor_extra_gpu * extra_src2 = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr; |  | ||||||
|  |  | ||||||
|     vk_buffer d_X = nullptr; |     vk_buffer d_X = nullptr; | ||||||
|     size_t x_buf_offset = 0; |     size_t x_buf_offset = 0; | ||||||
|     vk_buffer d_Y = nullptr; |     vk_buffer d_Y = nullptr; | ||||||
|     size_t y_buf_offset = 0; |     size_t y_buf_offset = 0; | ||||||
|     vk_buffer d_Z = nullptr; |     vk_buffer d_Z = nullptr; | ||||||
|     size_t z_buf_offset = 0; |  | ||||||
|  |  | ||||||
|     bool src0_uma = false; |     bool src0_uma = false; | ||||||
|     bool src1_uma = false; |     bool src1_uma = false; | ||||||
|     bool src2_uma = false; |  | ||||||
|  |  | ||||||
|     if (ctx->device->uma) { |     if (ctx->device->uma) { | ||||||
|         ggml_vk_host_get(ctx, src0->data, d_X, x_buf_offset); |         ggml_vk_host_get(ctx, src0->data, d_X, x_buf_offset); | ||||||
| @@ -3975,15 +3966,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c | |||||||
|             ggml_vk_host_get(ctx, src1->data, d_Y, y_buf_offset); |             ggml_vk_host_get(ctx, src1->data, d_Y, y_buf_offset); | ||||||
|             src1_uma = d_Y != nullptr; |             src1_uma = d_Y != nullptr; | ||||||
|         } |         } | ||||||
|         if (use_src2) { |  | ||||||
|             ggml_vk_host_get(ctx, src1->data, d_Z, z_buf_offset); |  | ||||||
|             src2_uma = d_Z != nullptr; |  | ||||||
|         } |  | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0, ctx->device->properties.limits.minStorageBufferOffsetAlignment); |     uint64_t x_sz = ggml_vk_align_size(ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0, ctx->device->properties.limits.minStorageBufferOffsetAlignment); | ||||||
|     uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0; |     uint64_t y_sz = use_src1 ? ggml_vk_align_size(ggml_type_size(src1->type) * ne1, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0; | ||||||
|     uint64_t z_sz = use_src2 ? ggml_vk_align_size(ggml_type_size(src2->type) * ne2, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : 0; |  | ||||||
|     uint64_t d_sz = ggml_type_size(dst->type) * ne0; |     uint64_t d_sz = ggml_type_size(dst->type) * ne0; | ||||||
|  |  | ||||||
|     vk_buffer d_D = extra->buffer_gpu.lock(); |     vk_buffer d_D = extra->buffer_gpu.lock(); | ||||||
| @@ -4007,12 +3993,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c | |||||||
|         GGML_ASSERT(d_Y != nullptr); |         GGML_ASSERT(d_Y != nullptr); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     if (use_src2 && !src2_uma) { |  | ||||||
|         d_Z = extra_src2->buffer_gpu.lock(); |  | ||||||
|         z_buf_offset = extra_src2->offset; |  | ||||||
|         GGML_ASSERT(d_Z != nullptr); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if (op_supports_incontiguous) { |     if (op_supports_incontiguous) { | ||||||
|         x_sz = ggml_nbytes(src0); |         x_sz = ggml_nbytes(src0); | ||||||
|         y_sz = use_src1 ? ggml_nbytes(src1) : 0; |         y_sz = use_src1 ? ggml_nbytes(src1) : 0; | ||||||
| @@ -4048,6 +4028,9 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c | |||||||
|         case GGML_OP_GET_ROWS: |         case GGML_OP_GET_ROWS: | ||||||
|             elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; |             elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; | ||||||
|             break; |             break; | ||||||
|  |         case GGML_OP_ARGSORT: | ||||||
|  |             elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; | ||||||
|  |             break; | ||||||
|         default: |         default: | ||||||
|             elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; |             elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; | ||||||
|             break; |             break; | ||||||
| @@ -4066,7 +4049,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (op == GGML_OP_SOFT_MAX) { |         if (op == GGML_OP_SOFT_MAX) { | ||||||
|             // Empty src1 and src2 are possible on soft_max, but the shader needs buffers |             // Empty src1 is possible on soft_max, but the shader needs a buffer | ||||||
|             vk_subbuffer subbuf_y; |             vk_subbuffer subbuf_y; | ||||||
|             if (use_src1) { |             if (use_src1) { | ||||||
|                 subbuf_y = { d_Y, y_buf_offset, y_sz }; |                 subbuf_y = { d_Y, y_buf_offset, y_sz }; | ||||||
| @@ -4074,15 +4057,8 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c | |||||||
|                 subbuf_y = { d_X, 0, d_X->size }; |                 subbuf_y = { d_X, 0, d_X->size }; | ||||||
|             } |             } | ||||||
|  |  | ||||||
|             vk_subbuffer subbuf_z; |  | ||||||
|             if (use_src2) { |  | ||||||
|                 subbuf_z = { d_Z, z_buf_offset, z_sz }; |  | ||||||
|             } else { |  | ||||||
|                 subbuf_z = { d_X, 0, d_X->size }; |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|             ggml_vk_sync_buffers(subctx); |             ggml_vk_sync_buffers(subctx); | ||||||
|             ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); |             ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, subbuf_y, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); | ||||||
|         } else if (use_src1) { |         } else if (use_src1) { | ||||||
|             ggml_vk_sync_buffers(subctx); |             ggml_vk_sync_buffers(subctx); | ||||||
|             ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); |             ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { { d_X, x_buf_offset, x_sz }, { d_Y, y_buf_offset, y_sz }, { d_D, d_buf_offset, d_sz } }, sizeof(PC), &pc, elements); | ||||||
| @@ -4099,13 +4075,13 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c | |||||||
|         } |         } | ||||||
|     } else { |     } else { | ||||||
|         GGML_ASSERT(op != GGML_OP_SOFT_MAX); |         GGML_ASSERT(op != GGML_OP_SOFT_MAX); | ||||||
|  |         GGML_ASSERT(op != GGML_OP_ARGSORT); | ||||||
|  |  | ||||||
|         ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, ne02 * ne03); |         ggml_pipeline_allocate_descriptor_sets(ctx, pipeline, ne02 * ne03); | ||||||
|  |  | ||||||
|         switch (dst->op) { |         switch (dst->op) { | ||||||
|         case GGML_OP_NORM: |         case GGML_OP_NORM: | ||||||
|         case GGML_OP_RMS_NORM: |         case GGML_OP_RMS_NORM: | ||||||
|         case GGML_OP_SOFT_MAX: |  | ||||||
|             elements = { (uint32_t)ne01, 1, 1 }; |             elements = { (uint32_t)ne01, 1, 1 }; | ||||||
|             break; |             break; | ||||||
|         case GGML_OP_DIAG_MASK_INF: |         case GGML_OP_DIAG_MASK_INF: | ||||||
| @@ -4145,7 +4121,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context * subctx, c | |||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f }); |     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_REPEAT, { (uint32_t)ggml_nelements(src0), (uint32_t)ggml_nelements(src1), 0.0f, 0.0f }); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
| @@ -4153,7 +4129,7 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context * subctx, | |||||||
|     const uint32_t src1_type_size = ggml_type_size(src1->type); |     const uint32_t src1_type_size = ggml_type_size(src1->type); | ||||||
|     const uint32_t dst_type_size = ggml_type_size(dst->type); |     const uint32_t dst_type_size = ggml_type_size(dst->type); | ||||||
|  |  | ||||||
|     ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { |     ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_GET_ROWS, { | ||||||
|         (uint32_t)ggml_nelements(src0), |         (uint32_t)ggml_nelements(src0), | ||||||
|         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, |         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, | ||||||
|         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, |         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, | ||||||
| @@ -4168,7 +4144,7 @@ static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context * subctx, cons | |||||||
|     const uint32_t src1_type_size = ggml_type_size(src1->type); |     const uint32_t src1_type_size = ggml_type_size(src1->type); | ||||||
|     const uint32_t dst_type_size = ggml_type_size(dst->type); |     const uint32_t dst_type_size = ggml_type_size(dst->type); | ||||||
|  |  | ||||||
|     ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, { |     ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ADD, { | ||||||
|         (uint32_t)ggml_nelements(src0), |         (uint32_t)ggml_nelements(src0), | ||||||
|         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, |         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, | ||||||
|         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, |         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, | ||||||
| @@ -4183,7 +4159,7 @@ static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context * subctx, cons | |||||||
|     const uint32_t src1_type_size = ggml_type_size(src1->type); |     const uint32_t src1_type_size = ggml_type_size(src1->type); | ||||||
|     const uint32_t dst_type_size = ggml_type_size(dst->type); |     const uint32_t dst_type_size = ggml_type_size(dst->type); | ||||||
|  |  | ||||||
|     ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, { |     ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_MUL, { | ||||||
|         (uint32_t)ggml_nelements(src0), |         (uint32_t)ggml_nelements(src0), | ||||||
|         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, |         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, | ||||||
|         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, |         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, | ||||||
| @@ -4198,7 +4174,7 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context * subctx, co | |||||||
|     const uint32_t src0_type_size = ggml_type_size(src0->type); |     const uint32_t src0_type_size = ggml_type_size(src0->type); | ||||||
|     const uint32_t dst_type_size = ggml_type_size(dst->type); |     const uint32_t dst_type_size = ggml_type_size(dst->type); | ||||||
|  |  | ||||||
|     ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, { |     ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_SCALE, { | ||||||
|         (uint32_t)ggml_nelements(src0), |         (uint32_t)ggml_nelements(src0), | ||||||
|         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, |         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, | ||||||
|         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size, |         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size, | ||||||
| @@ -4211,7 +4187,7 @@ static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context * subctx, cons | |||||||
|     const uint32_t src0_type_size = ggml_type_size(src0->type); |     const uint32_t src0_type_size = ggml_type_size(src0->type); | ||||||
|     const uint32_t dst_type_size = ggml_type_size(dst->type); |     const uint32_t dst_type_size = ggml_type_size(dst->type); | ||||||
|  |  | ||||||
|     ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, { |     ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_SQR, { | ||||||
|         (uint32_t)ggml_nelements(src0), |         (uint32_t)ggml_nelements(src0), | ||||||
|         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, |         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, | ||||||
|         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size, |         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size, | ||||||
| @@ -4225,7 +4201,7 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context * subctx, co | |||||||
|     const uint32_t src0_type_size = ggml_type_size(src0->type); |     const uint32_t src0_type_size = ggml_type_size(src0->type); | ||||||
|     const uint32_t dst_type_size = ggml_type_size(dst->type); |     const uint32_t dst_type_size = ggml_type_size(dst->type); | ||||||
|  |  | ||||||
|     ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, { |     ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_CLAMP, { | ||||||
|         (uint32_t)ggml_nelements(src0), |         (uint32_t)ggml_nelements(src0), | ||||||
|         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, |         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, | ||||||
|         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size, |         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size, | ||||||
| @@ -4240,7 +4216,7 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, cons | |||||||
|     const uint32_t dst_type_size = ggml_type_size(dst->type); |     const uint32_t dst_type_size = ggml_type_size(dst->type); | ||||||
|     const uint32_t d_offset = (extra->offset % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size; |     const uint32_t d_offset = (extra->offset % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size; | ||||||
|  |  | ||||||
|     ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, { |     ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_CPY, { | ||||||
|         (uint32_t)ggml_nelements(src0), |         (uint32_t)ggml_nelements(src0), | ||||||
|         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, |         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, | ||||||
|         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size, |         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size, | ||||||
| @@ -4252,24 +4228,24 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context * subctx, cons | |||||||
| static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { | static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { | ||||||
|     float * op_params = (float *)dst->op_params; |     float * op_params = (float *)dst->op_params; | ||||||
|  |  | ||||||
|     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); |     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { | static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { | ||||||
|     float * op_params = (float *)dst->op_params; |     float * op_params = (float *)dst->op_params; | ||||||
|     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); |     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_RMS_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { | static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { | ||||||
|     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); |     ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { | static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { | ||||||
|     int32_t * op_params = (int32_t *)dst->op_params; |     int32_t * op_params = (int32_t *)dst->op_params; | ||||||
|     ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }); |     ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { | static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|     float * op_params = (float *)dst->op_params; |     float * op_params = (float *)dst->op_params; | ||||||
|  |  | ||||||
|     float scale = op_params[0]; |     float scale = op_params[0]; | ||||||
| @@ -4285,13 +4261,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx, | |||||||
|     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2); |     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2); | ||||||
|     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); |     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | ||||||
|  |  | ||||||
| #pragma message("TODO: src2 is no longer used in soft_max - should be removed and ALiBi calculation should be updated") |     ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_SOFT_MAX, { | ||||||
| #pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/7192") |  | ||||||
|  |  | ||||||
|     ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, { |  | ||||||
|         ncols, |         ncols, | ||||||
|         src1 != nullptr ? nrows_y : (uint32_t)0, |         src1 != nullptr ? nrows_y : (uint32_t)0, | ||||||
|         src2 != nullptr ? (uint32_t)1 : (uint32_t)0, |  | ||||||
|         scale, max_bias, |         scale, max_bias, | ||||||
|         m0, m1, |         m0, m1, | ||||||
|         n_head_log2, |         n_head_log2, | ||||||
| @@ -4321,15 +4293,39 @@ static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context * subctx, con | |||||||
|     if (is_neox) { |     if (is_neox) { | ||||||
|         const float theta_scale = powf(freq_base, -2.0f/n_dims); |         const float theta_scale = powf(freq_base, -2.0f/n_dims); | ||||||
|         const float inv_ndims = -1.0f / n_dims; |         const float inv_ndims = -1.0f / n_dims; | ||||||
|         ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f}, theta_scale, inv_ndims }); |         ggml_vk_op_f32<vk_op_rope_neox_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ROPE, { | ||||||
|  |             (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], | ||||||
|  |             freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f}, theta_scale, inv_ndims | ||||||
|  |         }); | ||||||
|     } else { |     } else { | ||||||
|         ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ROPE, { (uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1], freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f} }); |         ggml_vk_op_f32<vk_op_rope_push_constants>(ctx, subctx, src0, src1, dst, GGML_OP_ROPE, { | ||||||
|  |             (uint32_t)src0->ne[0], freq_scale, (uint32_t)src0->ne[1], | ||||||
|  |             freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1], 0.0f, 0.0f} | ||||||
|  |         }); | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { | static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context * subctx, const ggml_tensor * src0, ggml_tensor * dst) { | ||||||
|     int32_t * op_params = (int32_t *)dst->op_params; |     int32_t * op_params = (int32_t *)dst->op_params; | ||||||
|     ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { (uint32_t)src0->ne[0], ((ggml_sort_order) op_params[0]) == GGML_SORT_ORDER_ASC }); |  | ||||||
|  |     uint32_t ncols = src0->ne[0]; | ||||||
|  |  | ||||||
|  |     uint32_t ncols_pad = 1; | ||||||
|  |     while (ncols_pad < ncols) { | ||||||
|  |         ncols_pad *= 2; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     GGML_ASSERT(ncols_pad <= 1024); | ||||||
|  |  | ||||||
|  |     std::cerr << "ncols=" << ncols << " ncols_pad=" << ncols_pad << " ascending=" << op_params[0] << std::endl; | ||||||
|  |  | ||||||
|  |     std::cerr << ((ggml_sort_order) op_params[0]) << " " << GGML_SORT_ORDER_ASC << std::endl; | ||||||
|  |  | ||||||
|  |     ggml_vk_op_f32<vk_op_argsort_push_constants>(ctx, subctx, src0, nullptr, dst, GGML_OP_ARGSORT, { | ||||||
|  |         ncols, | ||||||
|  |         ncols_pad, | ||||||
|  |         op_params[0], | ||||||
|  |     }); | ||||||
| } | } | ||||||
|  |  | ||||||
| #ifdef GGML_VULKAN_RUN_TESTS | #ifdef GGML_VULKAN_RUN_TESTS | ||||||
| @@ -5432,7 +5428,6 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod | |||||||
|  |  | ||||||
|     const ggml_tensor * src0 = node->src[0]; |     const ggml_tensor * src0 = node->src[0]; | ||||||
|     const ggml_tensor * src1 = node->src[1]; |     const ggml_tensor * src1 = node->src[1]; | ||||||
|     const ggml_tensor * src2 = node->src[2]; |  | ||||||
|  |  | ||||||
|     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra; |     ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) node->extra; | ||||||
|  |  | ||||||
| @@ -5547,7 +5542,7 @@ static void ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * nod | |||||||
|  |  | ||||||
|         break; |         break; | ||||||
|     case GGML_OP_SOFT_MAX: |     case GGML_OP_SOFT_MAX: | ||||||
|         ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, src2, node); |         ggml_vk_soft_max(ctx, ctx->compute_ctx, src0, src1, node); | ||||||
|  |  | ||||||
|         break; |         break; | ||||||
|     case GGML_OP_ROPE: |     case GGML_OP_ROPE: | ||||||
| @@ -6548,7 +6543,7 @@ static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector<c | |||||||
| } | } | ||||||
|  |  | ||||||
| static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) { | static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) { | ||||||
|     if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) { |     if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|     i0 = std::max(i0, 5); |     i0 = std::max(i0, 5); | ||||||
| @@ -6569,6 +6564,8 @@ static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * d | |||||||
|                     val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); |                     val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); | ||||||
|                 } else if (tensor->type == GGML_TYPE_F16) { |                 } else if (tensor->type == GGML_TYPE_F16) { | ||||||
|                     val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); |                     val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); | ||||||
|  |                 } else if (tensor->type == GGML_TYPE_I32) { | ||||||
|  |                     val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); | ||||||
|                 } else { |                 } else { | ||||||
|                     GGML_ASSERT(false); |                     GGML_ASSERT(false); | ||||||
|                 } |                 } | ||||||
| @@ -6671,7 +6668,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ | |||||||
|  |  | ||||||
|     ggml_tensor * src0 = tensor->src[0]; |     ggml_tensor * src0 = tensor->src[0]; | ||||||
|     ggml_tensor * src1 = tensor->src[1]; |     ggml_tensor * src1 = tensor->src[1]; | ||||||
|     ggml_tensor * src2 = tensor->src[2]; |  | ||||||
|  |  | ||||||
|     struct ggml_init_params iparams = { |     struct ggml_init_params iparams = { | ||||||
|         /*.mem_size   =*/ 1024*1024*1024, |         /*.mem_size   =*/ 1024*1024*1024, | ||||||
| @@ -6798,66 +6794,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ | |||||||
|  |  | ||||||
|         ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src1", src1_clone); |         ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src1", src1_clone); | ||||||
|     } |     } | ||||||
|     if (src2 != nullptr) { |  | ||||||
|         src2_clone = ggml_dup_tensor(ggml_ctx, src2); |  | ||||||
|  |  | ||||||
|         src2_size = ggml_nbytes(src2); |  | ||||||
|  |  | ||||||
|         src2_buffer = malloc(src2_size); |  | ||||||
|         src2_clone->data = src2_buffer; |  | ||||||
|         if (src2->backend == GGML_BACKEND_TYPE_CPU) { |  | ||||||
|             memcpy(src2_clone->data, src2->data, src2_size); |  | ||||||
|             memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); |  | ||||||
|         } else if (src2->backend == GGML_BACKEND_TYPE_GPU) { |  | ||||||
|             ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src2->extra; |  | ||||||
|             vk_buffer buf = extra->buffer_gpu.lock(); |  | ||||||
|             uint64_t offset = extra->offset; |  | ||||||
|             if (!ggml_is_contiguous(src2) && ggml_vk_dim01_contiguous(src2)) { |  | ||||||
|                 for (int i3 = 0; i3 < src2->ne[3]; i3++) { |  | ||||||
|                     for (int i2 = 0; i2 < src2->ne[2]; i2++) { |  | ||||||
|                         const int idx = i3*src2->ne[2] + i2; |  | ||||||
|                         ggml_vk_buffer_read(ctx, buf, offset + idx * src2->nb[2], ((char *)src2_clone->data + idx * src2_clone->nb[2]), src2->ne[1] * src2->nb[1]); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|  |  | ||||||
|                 src2_clone->nb[0] = src2->nb[0]; |  | ||||||
|                 src2_clone->nb[1] = src2->nb[1]; |  | ||||||
|                 for (int i = 2; i < GGML_MAX_DIMS; i++) { |  | ||||||
|                     src2_clone->nb[i] = src2_clone->nb[i - 1]*src2_clone->ne[i - 1]; |  | ||||||
|                 } |  | ||||||
|             } else { |  | ||||||
|                 if (offset + src2_size >= buf->size) { |  | ||||||
|                     src2_size = buf->size - offset; |  | ||||||
|                 } |  | ||||||
|                 ggml_vk_buffer_read(ctx, buf, offset, src2_clone->data, src2_size); |  | ||||||
|                 memcpy(src2_clone->nb, src2->nb, sizeof(size_t) * GGML_MAX_DIMS); |  | ||||||
|             } |  | ||||||
|         } else { |  | ||||||
|             GGML_ASSERT(false); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { |  | ||||||
|             ggml_vk_print_tensor(ctx, src2, "src2"); |  | ||||||
|             std::cerr << "TENSOR CHECK: " << ggml_op_name(src2_clone->op) << " (check " << check_counter << ")" << std::endl; |  | ||||||
|             std::cerr << "src2_clone=" << tensor << " src2_clone->backend: " << src2_clone->backend << " src2_clone->type: " << ggml_type_name(src2_clone->type) << " ne0=" << src2_clone->ne[0] << " nb0=" << src2_clone->nb[0] << " ne1=" << src2_clone->ne[1] << " nb1=" << src2_clone->nb[1] << " ne2=" << src2_clone->ne[2] << " nb2=" << src2_clone->nb[2] << " ne3=" << src2_clone->ne[3] << " nb3=" << src2_clone->nb[3] << std::endl; |  | ||||||
|             if (src2->src[0] != nullptr) { |  | ||||||
|                 std::cerr << "src2->src[0]=" << src2->src[0] << " op=" << ggml_op_name(src2->src[0]->op) << " type=" << ggml_type_name(src2->src[0]->type) << " backend=" << src2->src[0]->backend << " ne0=" << src2->src[0]->ne[0] << " nb0=" << src2->src[0]->nb[0] << " ne1=" << src2->src[0]->ne[1] << " nb1=" << src2->src[0]->nb[1] << " ne2=" << src2->src[0]->ne[2] << " nb2=" << src2->src[0]->nb[2] << " ne3=" << src2->src[0]->ne[3] << " nb3=" << src2->src[0]->nb[3] << std::endl; |  | ||||||
|             } |  | ||||||
|             if (src2->src[1] != nullptr) { |  | ||||||
|                 std::cerr << "src2->src[1]=" << src2->src[1] << " op=" << ggml_op_name(src2->src[1]->op) << " type=" << ggml_type_name(src2->src[1]->type) << " backend=" << src2->src[1]->backend << " ne0=" << src2->src[1]->ne[0] << " nb0=" << src2->src[1]->nb[0] << " ne1=" << src2->src[1]->ne[1] << " nb1=" << src2->src[1]->nb[1] << " ne2=" << src2->src[1]->ne[2] << " nb2=" << src2->src[1]->nb[2] << " ne3=" << src2->src[1]->ne[3] << " nb3=" << src2->src[1]->nb[3] << std::endl; |  | ||||||
|             } |  | ||||||
|             std::cerr << std::endl << "Result:" << std::endl; |  | ||||||
|             ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 0, 0); |  | ||||||
|             std::cerr << std::endl; |  | ||||||
|             std::cerr << std::endl << "Result:" << std::endl; |  | ||||||
|             ggml_vk_print_tensor_area(src2_clone, src2_clone->data, 5, 5, 1, 0); |  | ||||||
|             std::cerr << std::endl; |  | ||||||
|             std::vector<const ggml_tensor *> done; |  | ||||||
|             ggml_vk_print_graph_origin(src2_clone, done); |  | ||||||
|         } |  | ||||||
|  |  | ||||||
|         ggml_vk_check_tensor(std::string(ggml_op_name(tensor->op)) + "->src2", src2_clone); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     if (tensor->op == GGML_OP_MUL_MAT) { |     if (tensor->op == GGML_OP_MUL_MAT) { | ||||||
|         tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone); |         tensor_clone = ggml_mul_mat(ggml_ctx, src0_clone, src1_clone); | ||||||
| @@ -6877,7 +6813,7 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ | |||||||
|         tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); |         tensor_clone = ggml_rms_norm(ggml_ctx, src0_clone, *(float *)tensor->op_params); | ||||||
|     } else if (tensor->op == GGML_OP_SOFT_MAX) { |     } else if (tensor->op == GGML_OP_SOFT_MAX) { | ||||||
|         if (src1 != nullptr) { |         if (src1 != nullptr) { | ||||||
|             tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, src2_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); |             tensor_clone = ggml_soft_max_ext(ggml_ctx, src0_clone, src1_clone, ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); | ||||||
|         } else { |         } else { | ||||||
|             tensor_clone = ggml_soft_max(ggml_ctx, src0_clone); |             tensor_clone = ggml_soft_max(ggml_ctx, src0_clone); | ||||||
|         } |         } | ||||||
| @@ -6964,9 +6900,6 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_compute_ | |||||||
|     if (src1 != nullptr) { |     if (src1 != nullptr) { | ||||||
|         free(src1_buffer); |         free(src1_buffer); | ||||||
|     } |     } | ||||||
|     if (src2 != nullptr) { |  | ||||||
|         free(src2_buffer); |  | ||||||
|     } |  | ||||||
|  |  | ||||||
|     ggml_free(ggml_ctx); |     ggml_free(ggml_ctx); | ||||||
| } | } | ||||||
| @@ -7026,8 +6959,11 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_compute_ | |||||||
|                         } else if (tensor->type == GGML_TYPE_F16) { |                         } else if (tensor->type == GGML_TYPE_F16) { | ||||||
|                             correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); |                             correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); | ||||||
|                             result  = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); |                             result  = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); | ||||||
|  |                         } else if (tensor->type == GGML_TYPE_I32) { | ||||||
|  |                             correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); | ||||||
|  |                             result  = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); | ||||||
|                         } else { |                         } else { | ||||||
|                             std::cerr << "comp_size=" << comp_size << " but required is " << (i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]) << std::endl; |                             std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl; | ||||||
|                         } |                         } | ||||||
|                     } else { |                     } else { | ||||||
|                         std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl; |                         std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl; | ||||||
|   | |||||||
| @@ -2432,7 +2432,6 @@ layout (push_constant) uniform parameter | |||||||
| { | { | ||||||
|     uint KX; |     uint KX; | ||||||
|     uint KY; |     uint KY; | ||||||
|     uint KZ; |  | ||||||
|     float scale; |     float scale; | ||||||
|     float max_bias; |     float max_bias; | ||||||
|     float m0; |     float m0; | ||||||
| @@ -2449,8 +2448,7 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; | |||||||
|  |  | ||||||
| layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; | layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; | ||||||
| layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; | layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; | ||||||
| layout (binding = 2) readonly buffer Z {C_TYPE data_c[];}; | layout (binding = 2) buffer D {D_TYPE data_d[];}; | ||||||
| layout (binding = 3) buffer D {D_TYPE data_d[];}; |  | ||||||
|  |  | ||||||
| shared FLOAT_TYPE vals[BLOCK_SIZE]; | shared FLOAT_TYPE vals[BLOCK_SIZE]; | ||||||
|  |  | ||||||
| @@ -2459,7 +2457,7 @@ void main() { | |||||||
|     const uint rowx = gl_WorkGroupID.x; |     const uint rowx = gl_WorkGroupID.x; | ||||||
|     const uint rowy = rowx % p.KY; |     const uint rowy = rowx % p.KY; | ||||||
|  |  | ||||||
|     float slope = 0.0f; |     float slope = 1.0f; | ||||||
|  |  | ||||||
|     // ALiBi |     // ALiBi | ||||||
|     if (p.max_bias > 0.0f) { |     if (p.max_bias > 0.0f) { | ||||||
| @@ -2472,12 +2470,19 @@ void main() { | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     // Find max |     // Find max | ||||||
|     vals[tid] = uintBitsToFloat(0xFF800000); |     FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000); | ||||||
|  |  | ||||||
|     [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { |     [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) { | ||||||
|         vals[tid] = max(vals[tid], FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) + (p.KZ > 0 ? slope * FLOAT_TYPE(data_c[col]) : 0.0f)); |         const uint col = col0 + tid; | ||||||
|  |  | ||||||
|  |         if (col >= p.KX) { | ||||||
|  |             break; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f))); | ||||||
|  |     } | ||||||
|  |     vals[tid] = max_val; | ||||||
|  |  | ||||||
|     barrier(); |     barrier(); | ||||||
|     [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { |     [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { | ||||||
|         if (tid < s) { |         if (tid < s) { | ||||||
| @@ -2486,15 +2491,21 @@ void main() { | |||||||
|         barrier(); |         barrier(); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const FLOAT_TYPE max_val = vals[0]; |     max_val = vals[0]; | ||||||
|     barrier(); |     barrier(); | ||||||
|  |  | ||||||
|     // Sum up values |     // Sum up values | ||||||
|     vals[tid] = FLOAT_TYPE(0.0f); |     vals[tid] = FLOAT_TYPE(0.0f); | ||||||
|  |  | ||||||
|     [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { |     [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) { | ||||||
|  |         const uint col = col0 + tid; | ||||||
|  |  | ||||||
|  |         if (col >= p.KX) { | ||||||
|  |             break; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         const uint i = rowx * p.KX + col; |         const uint i = rowx * p.KX + col; | ||||||
|         const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); |         const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val); | ||||||
|         vals[tid] += val; |         vals[tid] += val; | ||||||
|         data_d[i] = D_TYPE(val); |         data_d[i] = D_TYPE(val); | ||||||
|     } |     } | ||||||
| @@ -2509,7 +2520,13 @@ void main() { | |||||||
|  |  | ||||||
|     const D_TYPE divisor = D_TYPE(vals[0]); |     const D_TYPE divisor = D_TYPE(vals[0]); | ||||||
|  |  | ||||||
|     [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { |     [[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) { | ||||||
|  |         const uint col = col0 + tid; | ||||||
|  |  | ||||||
|  |         if (col >= p.KX) { | ||||||
|  |             break; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         data_d[rowx*p.KX + col] /= divisor; |         data_d[rowx*p.KX + col] /= divisor; | ||||||
|     } |     } | ||||||
| } | } | ||||||
| @@ -2672,20 +2689,26 @@ argsort_src = """ | |||||||
|  |  | ||||||
| #extension GL_EXT_shader_16bit_storage : require | #extension GL_EXT_shader_16bit_storage : require | ||||||
|  |  | ||||||
| layout(local_size_x = 1024, local_size_y = 1, local_size_z = 1) in; | #define BLOCK_SIZE 1024 | ||||||
|  | #define ASC 0 | ||||||
|  |  | ||||||
|  | layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; | ||||||
|  |  | ||||||
| layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; | layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; | ||||||
| layout (binding = 1)          buffer D {int data_d[];}; | layout (binding = 1)          buffer D {int data_d[];}; | ||||||
|  |  | ||||||
| layout (push_constant) uniform parameter { | layout (push_constant) uniform parameter { | ||||||
|     uint ncols; |     uint ncols; | ||||||
|     bool ascending; |     uint ncols_pad; | ||||||
|  |     uint order; | ||||||
| } p; | } p; | ||||||
|  |  | ||||||
|  | shared int dst_row[BLOCK_SIZE]; | ||||||
|  |  | ||||||
| void swap(uint idx0, uint idx1) { | void swap(uint idx0, uint idx1) { | ||||||
|     int tmp = data_d[idx0]; |     int tmp = dst_row[idx0]; | ||||||
|     data_d[idx0] = data_d[idx1]; |     dst_row[idx0] = dst_row[idx1]; | ||||||
|     data_d[idx1] = tmp; |     dst_row[idx1] = tmp; | ||||||
| } | } | ||||||
|  |  | ||||||
| void main() { | void main() { | ||||||
| @@ -2693,36 +2716,45 @@ void main() { | |||||||
|     const int col = int(gl_LocalInvocationID.x); |     const int col = int(gl_LocalInvocationID.x); | ||||||
|     const uint row = gl_WorkGroupID.y; |     const uint row = gl_WorkGroupID.y; | ||||||
|  |  | ||||||
|     if (col >= p.ncols) { |     if (col >= p.ncols_pad) { | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     const uint a_idx = row * p.ncols; |     const uint row_offset = row * p.ncols; | ||||||
|     const uint d_idx = row * p.ncols; |  | ||||||
|  |  | ||||||
|     // initialize indices |     // initialize indices | ||||||
|     if (col < p.ncols) { |     dst_row[col] = col; | ||||||
|         data_d[col] = col; |  | ||||||
|     } |  | ||||||
|     barrier(); |     barrier(); | ||||||
|  |  | ||||||
|     for (uint k = 2; k <= p.ncols; k *= 2) { |     for (uint k = 2; k <= p.ncols_pad; k *= 2) { | ||||||
|         for (uint j = k / 2; j > 0; j /= 2) { |         for (uint j = k / 2; j > 0; j /= 2) { | ||||||
|             const uint ixj = col ^ j; |             const uint ixj = col ^ j; | ||||||
|             if (ixj > col) { |             if (ixj > col) { | ||||||
|                 if ((col & k) == 0) { |                 if ((col & k) == 0) { | ||||||
|                     if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]]) { |                     if (dst_row[col] >= p.ncols || | ||||||
|                         swap(d_idx + col, d_idx + ixj); |                         (dst_row[ixj] < p.ncols && (p.order == ASC ? | ||||||
|  |                             data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]] : | ||||||
|  |                             data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]])) | ||||||
|  |                     ) { | ||||||
|  |                         swap(col, ixj); | ||||||
|                     } |                     } | ||||||
|                 } else { |                 } else { | ||||||
|                     if (p.ascending ? data_a[a_idx + data_d[d_idx + col]] < data_a[a_idx + data_d[d_idx + ixj]] : data_a[a_idx + data_d[d_idx + col]] > data_a[a_idx + data_d[d_idx + ixj]]) { |                     if (dst_row[ixj] >= p.ncols || | ||||||
|                         swap(d_idx + col, d_idx + ixj); |                         (dst_row[col] < p.ncols && (p.order == ASC ? | ||||||
|  |                             data_a[row_offset + dst_row[col]] < data_a[row_offset + dst_row[ixj]] : | ||||||
|  |                             data_a[row_offset + dst_row[col]] > data_a[row_offset + dst_row[ixj]])) | ||||||
|  |                     ) { | ||||||
|  |                         swap(col, ixj); | ||||||
|                     } |                     } | ||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
|             barrier(); |             barrier(); | ||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|  |     if (col < p.ncols) { | ||||||
|  |         data_d[row_offset + col] = dst_row[col]; | ||||||
|  |     } | ||||||
| } | } | ||||||
| """ | """ | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 0cc4m
					0cc4m