mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	kompute : improve backend to pass test_backend_ops (#10542)
* kompute: op_unary: reject unsupported parameters Signed-off-by: Sergio Lopez <slp@redhat.com> * kompute: softmax: implement ALiBi support Signed-off-by: Sergio Lopez <slp@redhat.com> * kompute: rope: implement neox and phi3 support Signed-off-by: Sergio Lopez <slp@redhat.com> * kompute: op_mul_mat_q4_k permutted support Signed-off-by: Sergio Lopez <slp@redhat.com> * kompute: op_mul_mat_[q4_0|q4_1|q8_0] permutted support Signed-off-by: Sergio Lopez <slp@redhat.com> * kompute: op_mul_mat_f16 permutted support Signed-off-by: Sergio Lopez <slp@redhat.com> * kompute: op_mul_mat_q6_k permutted support Signed-off-by: Sergio Lopez <slp@redhat.com> --------- Signed-off-by: Sergio Lopez <slp@redhat.com>
This commit is contained in:
		@@ -105,8 +105,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
 | 
				
			|||||||
        kompute-shaders/op_getrows_q4_0.comp
 | 
					        kompute-shaders/op_getrows_q4_0.comp
 | 
				
			||||||
        kompute-shaders/op_getrows_q4_1.comp
 | 
					        kompute-shaders/op_getrows_q4_1.comp
 | 
				
			||||||
        kompute-shaders/op_getrows_q6_k.comp
 | 
					        kompute-shaders/op_getrows_q6_k.comp
 | 
				
			||||||
        kompute-shaders/op_rope_f16.comp
 | 
					        kompute-shaders/op_rope_norm_f16.comp
 | 
				
			||||||
        kompute-shaders/op_rope_f32.comp
 | 
					        kompute-shaders/op_rope_norm_f32.comp
 | 
				
			||||||
 | 
					        kompute-shaders/op_rope_neox_f16.comp
 | 
				
			||||||
 | 
					        kompute-shaders/op_rope_neox_f32.comp
 | 
				
			||||||
        kompute-shaders/op_cpy_f16_f16.comp
 | 
					        kompute-shaders/op_cpy_f16_f16.comp
 | 
				
			||||||
        kompute-shaders/op_cpy_f16_f32.comp
 | 
					        kompute-shaders/op_cpy_f16_f32.comp
 | 
				
			||||||
        kompute-shaders/op_cpy_f32_f16.comp
 | 
					        kompute-shaders/op_cpy_f32_f16.comp
 | 
				
			||||||
@@ -139,8 +141,10 @@ if (EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/kompute/CMakeLists.txt")
 | 
				
			|||||||
        shaderop_getrows_q4_0.h
 | 
					        shaderop_getrows_q4_0.h
 | 
				
			||||||
        shaderop_getrows_q4_1.h
 | 
					        shaderop_getrows_q4_1.h
 | 
				
			||||||
        shaderop_getrows_q6_k.h
 | 
					        shaderop_getrows_q6_k.h
 | 
				
			||||||
        shaderop_rope_f16.h
 | 
					        shaderop_rope_norm_f16.h
 | 
				
			||||||
        shaderop_rope_f32.h
 | 
					        shaderop_rope_norm_f32.h
 | 
				
			||||||
 | 
					        shaderop_rope_neox_f16.h
 | 
				
			||||||
 | 
					        shaderop_rope_neox_f32.h
 | 
				
			||||||
        shaderop_cpy_f16_f16.h
 | 
					        shaderop_cpy_f16_f16.h
 | 
				
			||||||
        shaderop_cpy_f16_f32.h
 | 
					        shaderop_cpy_f16_f32.h
 | 
				
			||||||
        shaderop_cpy_f32_f16.h
 | 
					        shaderop_cpy_f32_f16.h
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -28,8 +28,10 @@
 | 
				
			|||||||
#include "shaderop_getrows_q4_0.h"
 | 
					#include "shaderop_getrows_q4_0.h"
 | 
				
			||||||
#include "shaderop_getrows_q4_1.h"
 | 
					#include "shaderop_getrows_q4_1.h"
 | 
				
			||||||
#include "shaderop_getrows_q6_k.h"
 | 
					#include "shaderop_getrows_q6_k.h"
 | 
				
			||||||
#include "shaderop_rope_f16.h"
 | 
					#include "shaderop_rope_norm_f16.h"
 | 
				
			||||||
#include "shaderop_rope_f32.h"
 | 
					#include "shaderop_rope_norm_f32.h"
 | 
				
			||||||
 | 
					#include "shaderop_rope_neox_f16.h"
 | 
				
			||||||
 | 
					#include "shaderop_rope_neox_f32.h"
 | 
				
			||||||
#include "shaderop_cpy_f16_f16.h"
 | 
					#include "shaderop_cpy_f16_f16.h"
 | 
				
			||||||
#include "shaderop_cpy_f16_f32.h"
 | 
					#include "shaderop_cpy_f16_f32.h"
 | 
				
			||||||
#include "shaderop_cpy_f32_f16.h"
 | 
					#include "shaderop_cpy_f32_f16.h"
 | 
				
			||||||
@@ -345,7 +347,7 @@ void ggml_vk_allocate_descriptor_pool(struct ggml_kompute_context * ctx, size_t
 | 
				
			|||||||
    std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
 | 
					    std::vector<vk::DescriptorPoolSize> descriptorPoolSizes = {
 | 
				
			||||||
        vk::DescriptorPoolSize(
 | 
					        vk::DescriptorPoolSize(
 | 
				
			||||||
          vk::DescriptorType::eStorageBuffer,
 | 
					          vk::DescriptorType::eStorageBuffer,
 | 
				
			||||||
          3 * size // Descriptor count is number of possible tensors to pass into an algorithm
 | 
					          4 * size // Descriptor count is number of possible tensors to pass into an algorithm
 | 
				
			||||||
          )
 | 
					          )
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -788,7 +790,8 @@ static void ggml_vk_soft_max(
 | 
				
			|||||||
    const std::shared_ptr<kp::Tensor>& out,
 | 
					    const std::shared_ptr<kp::Tensor>& out,
 | 
				
			||||||
    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 | 
					    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 | 
				
			||||||
    int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
 | 
					    int32_t ne00, int32_t ne01, int32_t ne02, uint32_t ne03,
 | 
				
			||||||
    float scale
 | 
					    float scale, float max_bias, float m0, float m1,
 | 
				
			||||||
 | 
					    uint32_t n_head_log2
 | 
				
			||||||
) {
 | 
					) {
 | 
				
			||||||
    const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
 | 
					    const static auto spirv = getSpirvShader(kp::shader_data::op_softmax_comp_spv,
 | 
				
			||||||
        kp::shader_data::op_softmax_comp_spv_len);
 | 
					        kp::shader_data::op_softmax_comp_spv_len);
 | 
				
			||||||
@@ -796,12 +799,14 @@ static void ggml_vk_soft_max(
 | 
				
			|||||||
    struct PushConstants {
 | 
					    struct PushConstants {
 | 
				
			||||||
        uint32_t inAOff, inBOff, outOff;
 | 
					        uint32_t inAOff, inBOff, outOff;
 | 
				
			||||||
        int32_t ne00, ne01, ne02;
 | 
					        int32_t ne00, ne01, ne02;
 | 
				
			||||||
        float scale;
 | 
					        float scale, max_bias, m0, m1;
 | 
				
			||||||
 | 
					        uint32_t n_head_log2;
 | 
				
			||||||
        int32_t mask;
 | 
					        int32_t mask;
 | 
				
			||||||
    } pushConsts {
 | 
					    } pushConsts {
 | 
				
			||||||
        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
 | 
					        safe_divide(inAOff, 4), safe_divide(inBOff, 4), safe_divide(outOff, 4),
 | 
				
			||||||
        ne00, ne01, ne02,
 | 
					        ne00, ne01, ne02,
 | 
				
			||||||
        scale,
 | 
					        scale, max_bias, m0, m1,
 | 
				
			||||||
 | 
					        n_head_log2,
 | 
				
			||||||
        bool(inB)
 | 
					        bool(inB)
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -911,9 +916,9 @@ static void ggml_vk_mul_mat_f16(
 | 
				
			|||||||
    const std::shared_ptr<kp::Tensor>& out,
 | 
					    const std::shared_ptr<kp::Tensor>& out,
 | 
				
			||||||
    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 | 
					    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 | 
				
			||||||
    int32_t ne00, int32_t ne01, int32_t ne02,
 | 
					    int32_t ne00, int32_t ne01, int32_t ne02,
 | 
				
			||||||
    uint32_t nb00, uint32_t nb01, uint32_t nb02,
 | 
					    uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
 | 
				
			||||||
    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
 | 
					    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
 | 
				
			||||||
    uint32_t nb10, uint32_t nb11, uint32_t nb12,
 | 
					    uint32_t nb10, uint32_t nb11, uint32_t nb12, uint32_t nb13,
 | 
				
			||||||
    int32_t ne0, int32_t ne1,
 | 
					    int32_t ne0, int32_t ne1,
 | 
				
			||||||
    uint32_t r2, uint32_t r3
 | 
					    uint32_t r2, uint32_t r3
 | 
				
			||||||
) {
 | 
					) {
 | 
				
			||||||
@@ -923,17 +928,17 @@ static void ggml_vk_mul_mat_f16(
 | 
				
			|||||||
    struct PushConstants {
 | 
					    struct PushConstants {
 | 
				
			||||||
        uint32_t inAOff, inBOff, outOff;
 | 
					        uint32_t inAOff, inBOff, outOff;
 | 
				
			||||||
        int32_t ne00, ne01, ne02;
 | 
					        int32_t ne00, ne01, ne02;
 | 
				
			||||||
        uint32_t nb00, nb01, nb02;
 | 
					        uint32_t nb00, nb01, nb02, nb03;
 | 
				
			||||||
        int32_t ne10, ne11, ne12;
 | 
					        int32_t ne10, ne11, ne12;
 | 
				
			||||||
        uint32_t nb10, nb11, nb12;
 | 
					        uint32_t nb10, nb11, nb12, nb13;
 | 
				
			||||||
        int32_t ne0, ne1;
 | 
					        int32_t ne0, ne1;
 | 
				
			||||||
        uint32_t r2, r3;
 | 
					        uint32_t r2, r3;
 | 
				
			||||||
    } pushConsts {
 | 
					    } pushConsts {
 | 
				
			||||||
        safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
 | 
					        safe_divide(inAOff, 2), safe_divide(inBOff, 4), safe_divide(outOff, 4),
 | 
				
			||||||
        ne00, ne01, ne02,
 | 
					        ne00, ne01, ne02,
 | 
				
			||||||
        nb00, nb01, nb02,
 | 
					        nb00, nb01, nb02, nb03,
 | 
				
			||||||
        ne10, ne11, ne12,
 | 
					        ne10, ne11, ne12,
 | 
				
			||||||
        nb10, nb11, nb12,
 | 
					        nb10, nb11, nb12, nb13,
 | 
				
			||||||
        ne0, ne1,
 | 
					        ne0, ne1,
 | 
				
			||||||
        r2, r3
 | 
					        r2, r3
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
@@ -1013,6 +1018,8 @@ static void ggml_vk_mul_mat_impl(
 | 
				
			|||||||
    int32_t ne00, int32_t ne01, int32_t ne02,
 | 
					    int32_t ne00, int32_t ne01, int32_t ne02,
 | 
				
			||||||
    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
 | 
					    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
 | 
				
			||||||
    int32_t ne0, int32_t ne1,
 | 
					    int32_t ne0, int32_t ne1,
 | 
				
			||||||
 | 
					    uint32_t nb01, uint32_t nb02, uint32_t nb03,
 | 
				
			||||||
 | 
					    uint32_t nb11, uint32_t nb12, uint32_t nb13,
 | 
				
			||||||
    uint32_t r2, uint32_t r3
 | 
					    uint32_t r2, uint32_t r3
 | 
				
			||||||
) {
 | 
					) {
 | 
				
			||||||
    struct PushConstants {
 | 
					    struct PushConstants {
 | 
				
			||||||
@@ -1020,19 +1027,23 @@ static void ggml_vk_mul_mat_impl(
 | 
				
			|||||||
        int32_t ne00, ne01, ne02;
 | 
					        int32_t ne00, ne01, ne02;
 | 
				
			||||||
        int32_t ne10, ne12;
 | 
					        int32_t ne10, ne12;
 | 
				
			||||||
        int32_t ne0, ne1;
 | 
					        int32_t ne0, ne1;
 | 
				
			||||||
 | 
					        uint32_t nb01, nb02, nb03;
 | 
				
			||||||
 | 
					        uint32_t nb11, nb12, nb13;
 | 
				
			||||||
        uint32_t r2, r3;
 | 
					        uint32_t r2, r3;
 | 
				
			||||||
    } pushConsts {
 | 
					    } pushConsts {
 | 
				
			||||||
        safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
 | 
					        safe_divide(inAOff, block_size), safe_divide(inBOff, 4), safe_divide(outOff, 4),
 | 
				
			||||||
        ne00, ne01, ne02,
 | 
					        ne00, ne01, ne02,
 | 
				
			||||||
        ne10, ne12,
 | 
					        ne10, ne12,
 | 
				
			||||||
        ne0, ne1,
 | 
					        ne0, ne1,
 | 
				
			||||||
 | 
					        nb01, nb02, nb03,
 | 
				
			||||||
 | 
					        nb11, nb12, nb13,
 | 
				
			||||||
        r2, r3
 | 
					        r2, r3
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto name = std::string(__func__) + "_" + suffix;
 | 
					    auto name = std::string(__func__) + "_" + suffix;
 | 
				
			||||||
    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 | 
					    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 | 
				
			||||||
    if (!komputeManager()->hasAlgorithm(name)) {
 | 
					    if (!komputeManager()->hasAlgorithm(name)) {
 | 
				
			||||||
        const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
 | 
					        const uint32_t local_x = (ggml_vk_current_device().subgroupSize * 2) / 8;
 | 
				
			||||||
        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
 | 
					        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(name, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 7)/8), unsigned(ne11), unsigned(ne12*ne13)}, {local_x}, {pushConsts});
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
        s_algo = komputeManager()->getAlgorithm(name);
 | 
					        s_algo = komputeManager()->getAlgorithm(name);
 | 
				
			||||||
@@ -1074,19 +1085,26 @@ static void ggml_vk_mul_mat_q4_k(
 | 
				
			|||||||
    const std::shared_ptr<kp::Tensor>& inB,
 | 
					    const std::shared_ptr<kp::Tensor>& inB,
 | 
				
			||||||
    const std::shared_ptr<kp::Tensor>& out,
 | 
					    const std::shared_ptr<kp::Tensor>& out,
 | 
				
			||||||
    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 | 
					    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 | 
				
			||||||
    int32_t ne00, int32_t ne01, int32_t ne02, int32_t ne10,
 | 
					    int32_t ne00, int32_t ne01, int32_t ne02,
 | 
				
			||||||
    int32_t ne11, int32_t ne12, int32_t ne13, int32_t ne0,
 | 
					    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
 | 
				
			||||||
    int32_t ne1, int32_t r2, int32_t r3
 | 
					    int32_t ne0, int32_t ne1,
 | 
				
			||||||
 | 
					    uint32_t nb01, uint32_t nb02, uint32_t nb03,
 | 
				
			||||||
 | 
					    uint32_t nb11, uint32_t nb12, uint32_t nb13,
 | 
				
			||||||
 | 
					    uint32_t r2, uint32_t r3
 | 
				
			||||||
) {
 | 
					) {
 | 
				
			||||||
    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
 | 
					    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q4_k_comp_spv,
 | 
				
			||||||
        kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
 | 
					        kp::shader_data::op_mul_mat_q4_k_comp_spv_len);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    struct PushConstants {
 | 
					    struct PushConstants {
 | 
				
			||||||
        uint32_t inAOff, inBOff, outOff;
 | 
					        uint32_t inAOff, inBOff, outOff;
 | 
				
			||||||
        int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3;
 | 
					        int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
 | 
				
			||||||
 | 
					        uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
 | 
				
			||||||
 | 
					        uint32_t r2, r3;
 | 
				
			||||||
    } pushConsts {
 | 
					    } pushConsts {
 | 
				
			||||||
        0, 0, 0,
 | 
					        inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
 | 
				
			||||||
        ne00, ne10, ne0, ne1, ne01, ne02, ne12, r2, r3
 | 
					        ne00, ne10, ne0, ne1, ne01, ne02, ne12,
 | 
				
			||||||
 | 
					        nb01, nb02, nb03, nb11, nb12, nb13,
 | 
				
			||||||
 | 
					        r2, r3
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 | 
					    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 | 
				
			||||||
@@ -1108,28 +1126,37 @@ static void ggml_vk_mul_mat_q6_k(
 | 
				
			|||||||
    const std::shared_ptr<kp::Tensor>& inB,
 | 
					    const std::shared_ptr<kp::Tensor>& inB,
 | 
				
			||||||
    const std::shared_ptr<kp::Tensor>& out,
 | 
					    const std::shared_ptr<kp::Tensor>& out,
 | 
				
			||||||
    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 | 
					    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 | 
				
			||||||
    int32_t ne00, int32_t ne10, int32_t ne0, int32_t ne1,
 | 
					    int32_t ne00, int32_t ne01, int32_t ne02,
 | 
				
			||||||
    int32_t ne01, int32_t ne11, int32_t ne12, int32_t ne02
 | 
					    int32_t ne10, int32_t ne11, int32_t ne12, int32_t ne13,
 | 
				
			||||||
 | 
					    int32_t ne0, int32_t ne1,
 | 
				
			||||||
 | 
					    uint32_t nb01, uint32_t nb02, uint32_t nb03,
 | 
				
			||||||
 | 
					    uint32_t nb11, uint32_t nb12, uint32_t nb13,
 | 
				
			||||||
 | 
					    uint32_t r2, uint32_t r3
 | 
				
			||||||
) {
 | 
					) {
 | 
				
			||||||
    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
 | 
					    const static auto spirv = getSpirvShader(kp::shader_data::op_mul_mat_q6_k_comp_spv,
 | 
				
			||||||
        kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
 | 
					        kp::shader_data::op_mul_mat_q6_k_comp_spv_len);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    struct PushConstants {
 | 
					    struct PushConstants {
 | 
				
			||||||
        uint32_t inAOff, inBOff, outOff;
 | 
					        uint32_t inAOff, inBOff, outOff;
 | 
				
			||||||
        int32_t ne00, ne10, ne0, ne1, ne01, gqa;
 | 
					        int32_t ne00, ne10, ne0, ne1, ne01, ne02, ne12;
 | 
				
			||||||
 | 
					        uint32_t nb01, nb02, nb03, nb11, nb12, nb13;
 | 
				
			||||||
 | 
					        uint32_t r2, r3;
 | 
				
			||||||
    } pushConsts {
 | 
					    } pushConsts {
 | 
				
			||||||
        inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
 | 
					        inAOff, safe_divide(inBOff, 4), safe_divide(outOff, 4),
 | 
				
			||||||
        ne00, ne10, ne0, ne1, ne01, ne12/ne02
 | 
					        ne00, ne10, ne0, ne1, ne01, ne02, ne12,
 | 
				
			||||||
 | 
					        nb01, nb02, nb03, nb11, nb12, nb13,
 | 
				
			||||||
 | 
					        r2, r3
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 | 
					    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 | 
				
			||||||
    if (!komputeManager()->hasAlgorithm(__func__)) {
 | 
					    if (!komputeManager()->hasAlgorithm(__func__)) {
 | 
				
			||||||
        const uint32_t local_x = ggml_vk_current_device().subgroupSize * 2;
 | 
					        const uint32_t local_x = 2;
 | 
				
			||||||
        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)}, {local_x}, {pushConsts});
 | 
					        const uint32_t local_y = ggml_vk_current_device().subgroupSize;
 | 
				
			||||||
 | 
					        s_algo = komputeManager()->algorithm<uint32_t, PushConstants>(__func__, s_kompute_context->pool.get(), {inA, inB, out}, spirv, {unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)}, {local_x, local_y}, {pushConsts});
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
        s_algo = komputeManager()->getAlgorithm(__func__);
 | 
					        s_algo = komputeManager()->getAlgorithm(__func__);
 | 
				
			||||||
        s_algo->setTensors({inA, inB, out});
 | 
					        s_algo->setTensors({inA, inB, out});
 | 
				
			||||||
        s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)});
 | 
					        s_algo->setWorkgroup({unsigned((ne01 + 1)/2), unsigned(ne11), unsigned(ne12)*unsigned(ne13)});
 | 
				
			||||||
        s_algo->setPushConstants<PushConstants>({pushConsts});
 | 
					        s_algo->setPushConstants<PushConstants>({pushConsts});
 | 
				
			||||||
        s_algo->updateDescriptors(s_kompute_context->pool.get());
 | 
					        s_algo->updateDescriptors(s_kompute_context->pool.get());
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
@@ -1217,10 +1244,11 @@ static void ggml_vk_rope(
 | 
				
			|||||||
    kp::Sequence& seq,
 | 
					    kp::Sequence& seq,
 | 
				
			||||||
    const std::shared_ptr<kp::Tensor>& inA,
 | 
					    const std::shared_ptr<kp::Tensor>& inA,
 | 
				
			||||||
    const std::shared_ptr<kp::Tensor>& inB,
 | 
					    const std::shared_ptr<kp::Tensor>& inB,
 | 
				
			||||||
 | 
					    const std::shared_ptr<kp::Tensor>& inC,
 | 
				
			||||||
    const std::shared_ptr<kp::Tensor>& out,
 | 
					    const std::shared_ptr<kp::Tensor>& out,
 | 
				
			||||||
    uint32_t inAOff, uint32_t inBOff, uint32_t outOff,
 | 
					    uint32_t inAOff, uint32_t inBOff, uint32_t inCOff, uint32_t outOff,
 | 
				
			||||||
    ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
 | 
					    ggml_type src0t, int32_t n_dims, int32_t mode, int32_t n_ctx_orig,
 | 
				
			||||||
    float freq_base, float freq_scale, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
 | 
					    float freq_base, float freq_scale, bool has_freq_factors, float ext_factor, float attn_factor, float beta_fast, float beta_slow,
 | 
				
			||||||
    int32_t ne01, int32_t ne02, int32_t ne03,
 | 
					    int32_t ne01, int32_t ne02, int32_t ne03,
 | 
				
			||||||
    uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
 | 
					    uint32_t nb00, uint32_t nb01, uint32_t nb02, uint32_t nb03,
 | 
				
			||||||
    int32_t ne0,
 | 
					    int32_t ne0,
 | 
				
			||||||
@@ -1228,11 +1256,17 @@ static void ggml_vk_rope(
 | 
				
			|||||||
) {
 | 
					) {
 | 
				
			||||||
    GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
 | 
					    GGML_ASSERT(src0t == GGML_TYPE_F16 || src0t == GGML_TYPE_F32);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    static const auto spirv_f16 = getSpirvShader(
 | 
					    static const auto spirv_norm_f16 = getSpirvShader(
 | 
				
			||||||
        kp::shader_data::op_rope_f16_comp_spv, kp::shader_data::op_rope_f16_comp_spv_len
 | 
					        kp::shader_data::op_rope_norm_f16_comp_spv, kp::shader_data::op_rope_norm_f16_comp_spv_len
 | 
				
			||||||
    );
 | 
					    );
 | 
				
			||||||
    static const auto spirv_f32 = getSpirvShader(
 | 
					    static const auto spirv_norm_f32 = getSpirvShader(
 | 
				
			||||||
        kp::shader_data::op_rope_f32_comp_spv, kp::shader_data::op_rope_f32_comp_spv_len
 | 
					        kp::shader_data::op_rope_norm_f32_comp_spv, kp::shader_data::op_rope_norm_f32_comp_spv_len
 | 
				
			||||||
 | 
					    );
 | 
				
			||||||
 | 
					    static const auto spirv_neox_f16 = getSpirvShader(
 | 
				
			||||||
 | 
					        kp::shader_data::op_rope_neox_f16_comp_spv, kp::shader_data::op_rope_neox_f16_comp_spv_len
 | 
				
			||||||
 | 
					    );
 | 
				
			||||||
 | 
					    static const auto spirv_neox_f32 = getSpirvShader(
 | 
				
			||||||
 | 
					        kp::shader_data::op_rope_neox_f32_comp_spv, kp::shader_data::op_rope_neox_f32_comp_spv_len
 | 
				
			||||||
    );
 | 
					    );
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
 | 
					    int type_size = src0t == GGML_TYPE_F16 ? 2 : 4;
 | 
				
			||||||
@@ -1247,32 +1281,40 @@ static void ggml_vk_rope(
 | 
				
			|||||||
    GGML_ASSERT(nb0  % type_size == 0);
 | 
					    GGML_ASSERT(nb0  % type_size == 0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    struct PushConstants {
 | 
					    struct PushConstants {
 | 
				
			||||||
        uint32_t inAOff, inBOff, outOff;
 | 
					        uint32_t inAOff, inBOff, inCOff, outOff;
 | 
				
			||||||
        int32_t n_dims, mode, n_ctx_orig;
 | 
					        int32_t n_dims, mode, n_ctx_orig;
 | 
				
			||||||
        float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 | 
					        float freq_base, freq_scale;
 | 
				
			||||||
 | 
					        bool has_freq_factors;
 | 
				
			||||||
 | 
					        float ext_factor, attn_factor, beta_fast, beta_slow;
 | 
				
			||||||
        uint32_t nb00, nb01, nb02, nb03;
 | 
					        uint32_t nb00, nb01, nb02, nb03;
 | 
				
			||||||
        int32_t ne0;
 | 
					        int32_t ne0;
 | 
				
			||||||
        uint32_t nb0, nb1, nb2, nb3;
 | 
					        uint32_t nb0, nb1, nb2, nb3;
 | 
				
			||||||
    } pushConsts {
 | 
					    } pushConsts {
 | 
				
			||||||
        safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(outOff, type_size),
 | 
					        safe_divide(inAOff, type_size), safe_divide(inBOff, 4), safe_divide(inCOff, type_size), safe_divide(outOff, type_size),
 | 
				
			||||||
        n_dims, mode, n_ctx_orig,
 | 
					        n_dims, mode, n_ctx_orig,
 | 
				
			||||||
        freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
 | 
					        freq_base, freq_scale,
 | 
				
			||||||
 | 
					        has_freq_factors,
 | 
				
			||||||
 | 
					        ext_factor, attn_factor, beta_fast, beta_slow,
 | 
				
			||||||
        nb00, nb01, nb02, nb03,
 | 
					        nb00, nb01, nb02, nb03,
 | 
				
			||||||
        ne0,
 | 
					        ne0,
 | 
				
			||||||
        nb0, nb1, nb2, nb3
 | 
					        nb0, nb1, nb2, nb3
 | 
				
			||||||
    };
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    auto name = std::string(__func__) + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
 | 
					    auto & inC_ = inC ? inC : inA;
 | 
				
			||||||
 | 
					    const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
 | 
				
			||||||
 | 
					    const bool is_f16 = src0t == GGML_TYPE_F16;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    auto name = std::string(__func__) + (is_neox ? "_neox" : "_norm") + (src0t == GGML_TYPE_F16 ? "_f16" : "_f32");
 | 
				
			||||||
    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 | 
					    std::shared_ptr<kp::Algorithm> s_algo = nullptr;
 | 
				
			||||||
    if (!komputeManager()->hasAlgorithm(name)) {
 | 
					    if (!komputeManager()->hasAlgorithm(name)) {
 | 
				
			||||||
 | 
					        auto & spirv = is_neox ? is_f16 ? spirv_neox_f16 : spirv_neox_f32 : is_f16 ? spirv_norm_f16 : spirv_norm_f32;
 | 
				
			||||||
        s_algo = komputeManager()->algorithm<float, PushConstants>(
 | 
					        s_algo = komputeManager()->algorithm<float, PushConstants>(
 | 
				
			||||||
            name, s_kompute_context->pool.get(), {inA, inB, out},
 | 
					            name, s_kompute_context->pool.get(), {inA, inB, inC_, out}, spirv,
 | 
				
			||||||
            src0t == GGML_TYPE_F16 ? spirv_f16 : spirv_f32,
 | 
					 | 
				
			||||||
            {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
 | 
					            {unsigned(ne01), unsigned(ne02), unsigned(ne03)}, {}, {pushConsts}
 | 
				
			||||||
        );
 | 
					        );
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
        s_algo = komputeManager()->getAlgorithm(name);
 | 
					        s_algo = komputeManager()->getAlgorithm(name);
 | 
				
			||||||
        s_algo->setTensors({inA, inB, out});
 | 
					        s_algo->setTensors({inA, inB, inC_, out});
 | 
				
			||||||
        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
 | 
					        s_algo->setWorkgroup({unsigned(ne01), unsigned(ne02), unsigned(ne03)});
 | 
				
			||||||
        s_algo->setPushConstants<PushConstants>({pushConsts});
 | 
					        s_algo->setPushConstants<PushConstants>({pushConsts});
 | 
				
			||||||
        s_algo->updateDescriptors(s_kompute_context->pool.get());
 | 
					        s_algo->updateDescriptors(s_kompute_context->pool.get());
 | 
				
			||||||
@@ -1351,11 +1393,15 @@ static void ggml_vk_cpy_f16_f32(Args&&... args) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
 | 
					static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
 | 
				
			||||||
 | 
					    int64_t n = ggml_nelements(op);
 | 
				
			||||||
    switch (op->op) {
 | 
					    switch (op->op) {
 | 
				
			||||||
        case GGML_OP_UNARY:
 | 
					        case GGML_OP_UNARY:
 | 
				
			||||||
 | 
					            if (n % 4 != 0) return false;
 | 
				
			||||||
            switch (ggml_get_unary_op(op)) {
 | 
					            switch (ggml_get_unary_op(op)) {
 | 
				
			||||||
                case GGML_UNARY_OP_RELU:
 | 
					 | 
				
			||||||
                case GGML_UNARY_OP_GELU:
 | 
					                case GGML_UNARY_OP_GELU:
 | 
				
			||||||
 | 
					                    if (n % 8 != 0) return false;
 | 
				
			||||||
 | 
					                    // fall through
 | 
				
			||||||
 | 
					                case GGML_UNARY_OP_RELU:
 | 
				
			||||||
                case GGML_UNARY_OP_SILU:
 | 
					                case GGML_UNARY_OP_SILU:
 | 
				
			||||||
                    return ggml_is_contiguous(op->src[0]);
 | 
					                    return ggml_is_contiguous(op->src[0]);
 | 
				
			||||||
                default:
 | 
					                default:
 | 
				
			||||||
@@ -1413,8 +1459,8 @@ static bool ggml_backend_kompute_device_supports_op(ggml_backend_dev_t dev, cons
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
            switch (op->src[0]->type) {
 | 
					            switch (op->src[0]->type) {
 | 
				
			||||||
                case GGML_TYPE_F32:
 | 
					                case GGML_TYPE_F32:
 | 
				
			||||||
                case GGML_TYPE_Q6_K:
 | 
					 | 
				
			||||||
                    return op->ne[3] == 1;
 | 
					                    return op->ne[3] == 1;
 | 
				
			||||||
 | 
					                case GGML_TYPE_Q6_K:
 | 
				
			||||||
                case GGML_TYPE_F16:
 | 
					                case GGML_TYPE_F16:
 | 
				
			||||||
                case GGML_TYPE_Q8_0:
 | 
					                case GGML_TYPE_Q8_0:
 | 
				
			||||||
                case GGML_TYPE_Q4_0:
 | 
					                case GGML_TYPE_Q4_0:
 | 
				
			||||||
@@ -1515,9 +1561,11 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
 | 
				
			|||||||
            const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
 | 
					            const static std::shared_ptr<kp::Tensor> nullTensor = nullptr;
 | 
				
			||||||
            uint32_t off_src0 = 0;
 | 
					            uint32_t off_src0 = 0;
 | 
				
			||||||
            uint32_t off_src1 = 0;
 | 
					            uint32_t off_src1 = 0;
 | 
				
			||||||
 | 
					            uint32_t off_src2 = 0;
 | 
				
			||||||
            uint32_t off_dst  = 0;
 | 
					            uint32_t off_dst  = 0;
 | 
				
			||||||
            const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
 | 
					            const std::shared_ptr<kp::Tensor>& id_src0 = src0 ? ggml_vk_get_tensor(src0, &off_src0) : nullTensor;
 | 
				
			||||||
            const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
 | 
					            const std::shared_ptr<kp::Tensor>& id_src1 = src1 ? ggml_vk_get_tensor(src1, &off_src1) : nullTensor;
 | 
				
			||||||
 | 
					            const std::shared_ptr<kp::Tensor>& id_src2 = src2 ? ggml_vk_get_tensor(src2, &off_src2) : nullTensor;
 | 
				
			||||||
            const std::shared_ptr<kp::Tensor>& id_dst  = dst  ? ggml_vk_get_tensor(dst,  &off_dst)  : nullTensor;
 | 
					            const std::shared_ptr<kp::Tensor>& id_dst  = dst  ? ggml_vk_get_tensor(dst,  &off_dst)  : nullTensor;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            switch (dst->op) {
 | 
					            switch (dst->op) {
 | 
				
			||||||
@@ -1593,11 +1641,16 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
 | 
				
			|||||||
#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
 | 
					#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
 | 
				
			||||||
                        GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
 | 
					                        GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#pragma message("TODO: add ALiBi support")
 | 
					                        const int64_t nrows_x = ggml_nrows(src0);
 | 
				
			||||||
#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/7192")
 | 
					                        const int64_t nrows_y = src0->ne[1];
 | 
				
			||||||
                        GGML_ASSERT(max_bias == 0.0f);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
 | 
					                        const uint32_t n_head      = nrows_x/nrows_y;
 | 
				
			||||||
 | 
					                        const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
 | 
				
			||||||
 | 
					                        const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale, max_bias, m0, m1, n_head_log2);
 | 
				
			||||||
                    } break;
 | 
					                    } break;
 | 
				
			||||||
                case GGML_OP_DIAG_MASK_INF:
 | 
					                case GGML_OP_DIAG_MASK_INF:
 | 
				
			||||||
                    {
 | 
					                    {
 | 
				
			||||||
@@ -1649,38 +1702,44 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
 | 
				
			|||||||
                            case GGML_TYPE_F16:
 | 
					                            case GGML_TYPE_F16:
 | 
				
			||||||
                                ggml_vk_mul_mat_f16(
 | 
					                                ggml_vk_mul_mat_f16(
 | 
				
			||||||
                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
 | 
					                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
 | 
				
			||||||
                                    ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, ne13, nb10, nb11, nb12,
 | 
					                                    ne00, ne01, ne02, nb00, nb01, nb02, nb03,
 | 
				
			||||||
 | 
					                                    ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
 | 
				
			||||||
                                    ne0, ne1, r2, r3
 | 
					                                    ne0, ne1, r2, r3
 | 
				
			||||||
                                );
 | 
					                                );
 | 
				
			||||||
                                break;
 | 
					                                break;
 | 
				
			||||||
                            case GGML_TYPE_Q8_0:
 | 
					                            case GGML_TYPE_Q8_0:
 | 
				
			||||||
                                ggml_vk_mul_mat_q8_0(
 | 
					                                ggml_vk_mul_mat_q8_0(
 | 
				
			||||||
                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
 | 
					                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
 | 
				
			||||||
                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
 | 
					                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
 | 
				
			||||||
 | 
					                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
 | 
				
			||||||
                                );
 | 
					                                );
 | 
				
			||||||
                                break;
 | 
					                                break;
 | 
				
			||||||
                            case GGML_TYPE_Q4_0:
 | 
					                            case GGML_TYPE_Q4_0:
 | 
				
			||||||
                                ggml_vk_mul_mat_q4_0(
 | 
					                                ggml_vk_mul_mat_q4_0(
 | 
				
			||||||
                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
 | 
					                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
 | 
				
			||||||
                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
 | 
					                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
 | 
				
			||||||
 | 
					                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
 | 
				
			||||||
                                );
 | 
					                                );
 | 
				
			||||||
                                break;
 | 
					                                break;
 | 
				
			||||||
                            case GGML_TYPE_Q4_1:
 | 
					                            case GGML_TYPE_Q4_1:
 | 
				
			||||||
                                ggml_vk_mul_mat_q4_1(
 | 
					                                ggml_vk_mul_mat_q4_1(
 | 
				
			||||||
                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
 | 
					                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
 | 
				
			||||||
                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, r2, r3
 | 
					                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
 | 
				
			||||||
 | 
					                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
 | 
				
			||||||
                                );
 | 
					                                );
 | 
				
			||||||
                                break;
 | 
					                                break;
 | 
				
			||||||
                            case GGML_TYPE_Q4_K:
 | 
					                            case GGML_TYPE_Q4_K:
 | 
				
			||||||
                                ggml_vk_mul_mat_q4_k(
 | 
					                                ggml_vk_mul_mat_q4_k(
 | 
				
			||||||
                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
 | 
					                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
 | 
				
			||||||
                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1, ne12/ne02, ne13/ne03
 | 
					                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
 | 
				
			||||||
 | 
					                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
 | 
				
			||||||
                                );
 | 
					                                );
 | 
				
			||||||
                                break;
 | 
					                                break;
 | 
				
			||||||
                            case GGML_TYPE_Q6_K:
 | 
					                            case GGML_TYPE_Q6_K:
 | 
				
			||||||
                                ggml_vk_mul_mat_q6_k(
 | 
					                                ggml_vk_mul_mat_q6_k(
 | 
				
			||||||
                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
 | 
					                                    seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst,
 | 
				
			||||||
                                    ne00, ne10, ne0, ne1, ne01, ne11, ne12, ne02
 | 
					                                    ne00, ne01, ne02, ne10, ne11, ne12, ne13, ne0, ne1,
 | 
				
			||||||
 | 
					                                    nb01, nb02, nb03, nb11, nb12, nb13, r2, r3
 | 
				
			||||||
                                );
 | 
					                                );
 | 
				
			||||||
                                break;
 | 
					                                break;
 | 
				
			||||||
                            default: {
 | 
					                            default: {
 | 
				
			||||||
@@ -1709,13 +1768,6 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
 | 
				
			|||||||
                    } break;
 | 
					                    } break;
 | 
				
			||||||
                case GGML_OP_ROPE:
 | 
					                case GGML_OP_ROPE:
 | 
				
			||||||
                    {
 | 
					                    {
 | 
				
			||||||
#pragma message("TODO: implement phi3 frequency factors support")
 | 
					 | 
				
			||||||
#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7225")
 | 
					 | 
				
			||||||
                        GGML_ASSERT(dst->src[2] == nullptr && "phi3 frequency factors not implemented yet");
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#pragma message("TODO: update rope NORM mode to match NEOX mode")
 | 
					 | 
				
			||||||
#pragma message("      https://github.com/ggerganov/llama.cpp/pull/7634")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                        GGML_ASSERT(ne10 == ne02);
 | 
					                        GGML_ASSERT(ne10 == ne02);
 | 
				
			||||||
                        GGML_ASSERT(src0t == dstt);
 | 
					                        GGML_ASSERT(src0t == dstt);
 | 
				
			||||||
                        // const int n_past = ((int32_t *) dst->op_params)[0];
 | 
					                        // const int n_past = ((int32_t *) dst->op_params)[0];
 | 
				
			||||||
@@ -1724,6 +1776,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
 | 
				
			|||||||
                        // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
 | 
					                        // skip 3, n_ctx used in GLM RoPE, unimplemented in Vulkan
 | 
				
			||||||
                        const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 | 
					                        const int n_ctx_orig = ((int32_t *) dst->op_params)[4];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                        const bool has_freq_factors = dst->src[2] != nullptr;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 | 
					                        float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
 | 
				
			||||||
                        memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
 | 
					                        memcpy(&freq_base,   (int32_t *) dst->op_params +  5, sizeof(float));
 | 
				
			||||||
                        memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
 | 
					                        memcpy(&freq_scale,  (int32_t *) dst->op_params +  6, sizeof(float));
 | 
				
			||||||
@@ -1732,8 +1786,8 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
 | 
				
			|||||||
                        memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
 | 
					                        memcpy(&beta_fast,   (int32_t *) dst->op_params +  9, sizeof(float));
 | 
				
			||||||
                        memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 | 
					                        memcpy(&beta_slow,   (int32_t *) dst->op_params + 10, sizeof(float));
 | 
				
			||||||
                        ggml_vk_rope(
 | 
					                        ggml_vk_rope(
 | 
				
			||||||
                            seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, src0t, n_dims, mode, n_ctx_orig,
 | 
					                            seq, id_src0, id_src1, id_src2, id_dst, off_src0, off_src1, off_src2, off_dst, src0t, n_dims, mode, n_ctx_orig,
 | 
				
			||||||
                            freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow,
 | 
					                            freq_base, freq_scale, has_freq_factors, ext_factor, attn_factor, beta_fast, beta_slow,
 | 
				
			||||||
                            ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
 | 
					                            ne01, ne02, ne03, nb00, nb01, nb02, nb03, ne0, nb0, nb1, nb2, nb3
 | 
				
			||||||
                        );
 | 
					                        );
 | 
				
			||||||
                    } break;
 | 
					                    } break;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -3,6 +3,7 @@
 | 
				
			|||||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
 | 
					#extension GL_EXT_shader_explicit_arithmetic_types_float16: require
 | 
				
			||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
 | 
					#extension GL_EXT_shader_explicit_arithmetic_types_int8: require
 | 
				
			||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
 | 
					#extension GL_EXT_shader_explicit_arithmetic_types_int16: require
 | 
				
			||||||
 | 
					#extension GL_EXT_shader_explicit_arithmetic_types_int64: require
 | 
				
			||||||
#extension GL_EXT_control_flow_attributes: enable
 | 
					#extension GL_EXT_control_flow_attributes: enable
 | 
				
			||||||
#extension GL_KHR_shader_subgroup_arithmetic : require
 | 
					#extension GL_KHR_shader_subgroup_arithmetic : require
 | 
				
			||||||
#extension GL_EXT_debug_printf : enable
 | 
					#extension GL_EXT_debug_printf : enable
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -20,12 +20,14 @@ layout (push_constant) uniform parameter {
 | 
				
			|||||||
    uint nb00;
 | 
					    uint nb00;
 | 
				
			||||||
    uint nb01;
 | 
					    uint nb01;
 | 
				
			||||||
    uint nb02;
 | 
					    uint nb02;
 | 
				
			||||||
 | 
					    uint nb03;
 | 
				
			||||||
    int ne10;
 | 
					    int ne10;
 | 
				
			||||||
    int ne11;
 | 
					    int ne11;
 | 
				
			||||||
    int ne12;
 | 
					    int ne12;
 | 
				
			||||||
    uint nb10;
 | 
					    uint nb10;
 | 
				
			||||||
    uint nb11;
 | 
					    uint nb11;
 | 
				
			||||||
    uint nb12;
 | 
					    uint nb12;
 | 
				
			||||||
 | 
					    uint nb13;
 | 
				
			||||||
    int ne0;
 | 
					    int ne0;
 | 
				
			||||||
    int ne1;
 | 
					    int ne1;
 | 
				
			||||||
    uint r2;
 | 
					    uint r2;
 | 
				
			||||||
@@ -42,7 +44,7 @@ void main() {
 | 
				
			|||||||
    const uint i12 = im%pcs.ne12;
 | 
					    const uint i12 = im%pcs.ne12;
 | 
				
			||||||
    const uint i13 = im/pcs.ne12;
 | 
					    const uint i13 = im/pcs.ne12;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb02*pcs.ne02;
 | 
					    const uint offset0 = r0*pcs.nb01 + (i12/pcs.r2)*pcs.nb02 + (i13/pcs.r3)*pcs.nb03;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
 | 
					    const uint x = offset0 / 2 + pcs.inAOff; // Based from inA
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -52,7 +54,7 @@ void main() {
 | 
				
			|||||||
            break;
 | 
					            break;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        const uint y = (r1*pcs.nb11 + im*pcs.nb12) / 4 + pcs.inBOff; // Based from inB
 | 
					        const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        float sumf = 0;
 | 
					        float sumf = 0;
 | 
				
			||||||
        for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
 | 
					        for (uint i = gl_SubgroupInvocationID.x; i < pcs.ne00; i += gl_SubgroupSize) {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -24,8 +24,14 @@ layout (push_constant) uniform parameter {
 | 
				
			|||||||
    int ne01;
 | 
					    int ne01;
 | 
				
			||||||
    int ne02;
 | 
					    int ne02;
 | 
				
			||||||
    int ne12;
 | 
					    int ne12;
 | 
				
			||||||
    int r2;
 | 
					    uint nb01;
 | 
				
			||||||
    int r3;
 | 
					    uint nb02;
 | 
				
			||||||
 | 
					    uint nb03;
 | 
				
			||||||
 | 
					    uint nb11;
 | 
				
			||||||
 | 
					    uint nb12;
 | 
				
			||||||
 | 
					    uint nb13;
 | 
				
			||||||
 | 
					    uint r2;
 | 
				
			||||||
 | 
					    uint r3;
 | 
				
			||||||
} pcs;
 | 
					} pcs;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void main() {
 | 
					void main() {
 | 
				
			||||||
@@ -50,10 +56,11 @@ void main() {
 | 
				
			|||||||
    const uint i12 = im%pcs.ne12;
 | 
					    const uint i12 = im%pcs.ne12;
 | 
				
			||||||
    const uint i13 = im/pcs.ne12;
 | 
					    const uint i13 = im/pcs.ne12;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const uint offset0 = (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
 | 
					    const uint offset0 = first_row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
 | 
				
			||||||
 | 
					    const uint offset1 =        r1*pcs.nb11 + (i12       )*pcs.nb12 + (i13       )*pcs.nb13;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const uint xblk = ib_row + offset0 + pcs.inAOff;
 | 
					    const uint xblk = offset0 + pcs.inAOff;
 | 
				
			||||||
    const uint y = r1*pcs.ne10 + im*pcs.ne00*pcs.ne1 + pcs.inBOff;
 | 
					    const uint y = (offset1 / 4) + pcs.inBOff;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    float yl[16];
 | 
					    float yl[16];
 | 
				
			||||||
    float yh[16];
 | 
					    float yh[16];
 | 
				
			||||||
@@ -74,7 +81,7 @@ void main() {
 | 
				
			|||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for (int row = 0; row < N_DST; row++) {
 | 
					        for (int row = 0; row < N_DST; row++) {
 | 
				
			||||||
            uint row_idx = row * nb;
 | 
					            uint row_idx = row * (pcs.nb01 / SIZE_OF_BLOCK);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
 | 
					            uint16_t sc_0 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 0);
 | 
				
			||||||
            uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
 | 
					            uint16_t sc_1 = u8BufToU16(inA[blk_idx + row_idx].scales, iq * 2 + 2);
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -21,7 +21,16 @@ layout (push_constant) uniform parameter {
 | 
				
			|||||||
    int ne0;
 | 
					    int ne0;
 | 
				
			||||||
    int ne1;
 | 
					    int ne1;
 | 
				
			||||||
    int ne01;
 | 
					    int ne01;
 | 
				
			||||||
    int gqa;
 | 
					    int ne02;
 | 
				
			||||||
 | 
					    int ne12;
 | 
				
			||||||
 | 
					    uint nb01;
 | 
				
			||||||
 | 
					    uint nb02;
 | 
				
			||||||
 | 
					    uint nb03;
 | 
				
			||||||
 | 
					    uint nb11;
 | 
				
			||||||
 | 
					    uint nb12;
 | 
				
			||||||
 | 
					    uint nb13;
 | 
				
			||||||
 | 
					    uint r2;
 | 
				
			||||||
 | 
					    uint r3;
 | 
				
			||||||
} pcs;
 | 
					} pcs;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
void main() {
 | 
					void main() {
 | 
				
			||||||
@@ -34,12 +43,15 @@ void main() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    const uint r0 = gl_WorkGroupID.x;
 | 
					    const uint r0 = gl_WorkGroupID.x;
 | 
				
			||||||
    const uint r1 = gl_WorkGroupID.y;
 | 
					    const uint r1 = gl_WorkGroupID.y;
 | 
				
			||||||
    const uint r2 = gl_WorkGroupID.z;
 | 
					    const uint im = gl_WorkGroupID.z;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
 | 
					    const uint row = (r0 * gl_NumSubgroups + gl_SubgroupID);
 | 
				
			||||||
    const uint offset0 = r2/pcs.gqa*(nb*pcs.ne0);
 | 
					
 | 
				
			||||||
    const uint x = row * nb + offset0; // Based from inA without base offset
 | 
					    const uint i12 = im%pcs.ne12;
 | 
				
			||||||
    const uint yy = r1*pcs.ne10 + r2*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
 | 
					    const uint i13 = im/pcs.ne12;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const uint x = row*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
 | 
				
			||||||
 | 
					    const uint yy = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    float sumf = 0;
 | 
					    float sumf = 0;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -89,6 +101,6 @@ void main() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    const float tot = subgroupAdd(sumf);
 | 
					    const float tot = subgroupAdd(sumf);
 | 
				
			||||||
    if (subgroupElect()) {
 | 
					    if (subgroupElect()) {
 | 
				
			||||||
        out_[r1*pcs.ne0 + r2*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
 | 
					        out_[r1*pcs.ne0 + im*pcs.ne0*pcs.ne1 + row + pcs.outOff] = tot;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -14,10 +14,15 @@ void main() {
 | 
				
			|||||||
    const uint i12 = im%pcs.ne12;
 | 
					    const uint i12 = im%pcs.ne12;
 | 
				
			||||||
    const uint i13 = im/pcs.ne12;
 | 
					    const uint i13 = im/pcs.ne12;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const uint offset0 = first_row * nb + (i12/pcs.r2)*(nb*pcs.ne01) + (i13/pcs.r3)*(nb*pcs.ne01*pcs.ne02);
 | 
					    // pointers to src0 rows
 | 
				
			||||||
 | 
					    uint ax[N_ROWS];
 | 
				
			||||||
 | 
					    for (int row = 0; row < N_ROWS; ++row) {
 | 
				
			||||||
 | 
					        const uint offset0 = (first_row + row)*(pcs.nb01/SIZE_OF_BLOCK) + (i12/pcs.r2)*(pcs.nb02/SIZE_OF_BLOCK) + (i13/pcs.r3)*(pcs.nb03/SIZE_OF_BLOCK);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    const uint x = offset0; // Based from inA without base offset
 | 
					        ax[row] = offset0 + pcs.inAOff;
 | 
				
			||||||
    const uint y = r1*uint(pcs.ne10)+im*pcs.ne00*pcs.ne1+pcs.inBOff; // Based from inB
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const uint y = (r1*pcs.nb11 + i12*pcs.nb12 + i13*pcs.nb13) / 4 + pcs.inBOff;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
 | 
					    float sumf[N_ROWS] = {0.0f, 0.0f, 0.0f, 0.0f};
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -32,8 +37,7 @@ void main() {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    for (uint ib = ix; ib < nb; ib += 16) {
 | 
					    for (uint ib = ix; ib < nb; ib += 16) {
 | 
				
			||||||
        for (int row = 0; row < N_ROWS; row++) {
 | 
					        for (int row = 0; row < N_ROWS; row++) {
 | 
				
			||||||
            const uint block_index = x + ib + row * nb;
 | 
					            sumf[row] += block_q_n_dot_y(ax[row] + ib, yb, il);
 | 
				
			||||||
            sumf[row] += block_q_n_dot_y(block_index, yb, il);
 | 
					 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        yb += BLOCKS_IN_QUANT * 16;
 | 
					        yb += BLOCKS_IN_QUANT * 16;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,5 +1,5 @@
 | 
				
			|||||||
layout(local_size_x_id = 0) in;
 | 
					layout(local_size_x_id = 0) in;
 | 
				
			||||||
layout(local_size_y = 1) in;
 | 
					layout(local_size_y = 8) in;
 | 
				
			||||||
layout(local_size_z = 1) in;
 | 
					layout(local_size_z = 1) in;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
 | 
					layout (binding = 0) readonly buffer tensorInA { uint8_t inA[]; };
 | 
				
			||||||
@@ -17,6 +17,12 @@ layout (push_constant) uniform parameter {
 | 
				
			|||||||
    int  ne12;
 | 
					    int  ne12;
 | 
				
			||||||
    int  ne0;
 | 
					    int  ne0;
 | 
				
			||||||
    int  ne1;
 | 
					    int  ne1;
 | 
				
			||||||
 | 
					    uint nb01;
 | 
				
			||||||
 | 
					    uint nb02;
 | 
				
			||||||
 | 
					    uint nb03;
 | 
				
			||||||
 | 
					    uint nb11;
 | 
				
			||||||
 | 
					    uint nb12;
 | 
				
			||||||
 | 
					    uint nb13;
 | 
				
			||||||
    uint r2;
 | 
					    uint r2;
 | 
				
			||||||
    uint r3;
 | 
					    uint r3;
 | 
				
			||||||
} pcs;
 | 
					} pcs;
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -1,73 +0,0 @@
 | 
				
			|||||||
#version 450
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#include "rope_common.comp"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
layout(binding = 0) buffer restrict readonly  tensorInA { float16_t inA[]; };
 | 
					 | 
				
			||||||
layout(binding = 1) buffer restrict readonly  tensorInB { int       inB[]; };
 | 
					 | 
				
			||||||
layout(binding = 2) buffer restrict writeonly tensorOut { float16_t out_[]; };
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
void main() {
 | 
					 | 
				
			||||||
    const uint i3 = gl_WorkGroupID.z;
 | 
					 | 
				
			||||||
    const uint i2 = gl_WorkGroupID.y;
 | 
					 | 
				
			||||||
    const uint i1 = gl_WorkGroupID.x;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    float corr_dims[2];
 | 
					 | 
				
			||||||
    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const int p = inB[pcs.inBOff + i2];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    float theta = float(p);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (!is_neox) {
 | 
					 | 
				
			||||||
        for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
 | 
					 | 
				
			||||||
            float cos_theta, sin_theta;
 | 
					 | 
				
			||||||
            rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            theta *= theta_scale;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
 | 
					 | 
				
			||||||
            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            const float x0 = float(inA[src]);
 | 
					 | 
				
			||||||
            const float x1 = float(inA[src+1]);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            out_[dst_data]   = float16_t(x0*cos_theta - x1*sin_theta);
 | 
					 | 
				
			||||||
            out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
        const float inv_ndims = -1.f/pcs.n_dims;
 | 
					 | 
				
			||||||
        for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
 | 
					 | 
				
			||||||
            const uint cur_rot = ic;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            float cos_theta, sin_theta;
 | 
					 | 
				
			||||||
            rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            theta *= theta_scale;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            const uint i0 = ic/2;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
 | 
					 | 
				
			||||||
            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            const float x0 = float(inA[src]);
 | 
					 | 
				
			||||||
            const float x1 = float(inA[src+pcs.n_dims/2]);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            out_[dst_data]              = float16_t(x0*cos_theta - x1*sin_theta);
 | 
					 | 
				
			||||||
            out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
 | 
					 | 
				
			||||||
            const uint i0 = ic;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
 | 
					 | 
				
			||||||
            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            out_[dst_data + 0] = inA[src + 0];
 | 
					 | 
				
			||||||
            out_[dst_data + 1] = inA[src + 1];
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
@@ -1,73 +0,0 @@
 | 
				
			|||||||
#version 450
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
#include "rope_common.comp"
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
layout(binding = 0) buffer restrict readonly  tensorInA { float inA[]; };
 | 
					 | 
				
			||||||
layout(binding = 1) buffer restrict readonly  tensorInB { int   inB[]; };
 | 
					 | 
				
			||||||
layout(binding = 2) buffer restrict writeonly tensorOut { float out_[]; };
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
void main() {
 | 
					 | 
				
			||||||
    const uint i3 = gl_WorkGroupID.z;
 | 
					 | 
				
			||||||
    const uint i2 = gl_WorkGroupID.y;
 | 
					 | 
				
			||||||
    const uint i1 = gl_WorkGroupID.x;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const bool is_neox = (pcs.mode & GGML_ROPE_TYPE_NEOX) != 0;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    float corr_dims[2];
 | 
					 | 
				
			||||||
    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    const int p = inB[pcs.inBOff + i2];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    float theta = float(p);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if (!is_neox) {
 | 
					 | 
				
			||||||
        for (uint i0 = 0; i0 < pcs.ne0; i0 += 2) {
 | 
					 | 
				
			||||||
            float cos_theta, sin_theta;
 | 
					 | 
				
			||||||
            rope_yarn(theta, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            theta *= theta_scale;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
 | 
					 | 
				
			||||||
            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            const float x0 = inA[src];
 | 
					 | 
				
			||||||
            const float x1 = inA[src+1];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            out_[dst_data]   = x0*cos_theta - x1*sin_theta;
 | 
					 | 
				
			||||||
            out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    } else {
 | 
					 | 
				
			||||||
        const float inv_ndims = -1.f/pcs.n_dims;
 | 
					 | 
				
			||||||
        for (uint ic = 0; ic < pcs.n_dims; ic += 2) {
 | 
					 | 
				
			||||||
            const uint cur_rot = ic;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            float cos_theta, sin_theta;
 | 
					 | 
				
			||||||
            rope_yarn(theta, pcs.freq_scale, corr_dims, cur_rot, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            theta *= theta_scale;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            const uint i0 = ic/2;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
 | 
					 | 
				
			||||||
            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            const float x0 = inA[src];
 | 
					 | 
				
			||||||
            const float x1 = inA[src+pcs.n_dims/2];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            out_[dst_data] = x0*cos_theta - x1*sin_theta;
 | 
					 | 
				
			||||||
            out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for (uint ic = pcs.n_dims; ic < pcs.ne0; ic += 2) {
 | 
					 | 
				
			||||||
            const uint i0 = ic;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            const uint src = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
 | 
					 | 
				
			||||||
            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0) / 4) + pcs.outOff; // Based from out_
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            out_[dst_data + 0] = inA[src + 0];
 | 
					 | 
				
			||||||
            out_[dst_data + 1] = inA[src + 1];
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
    }
 | 
					 | 
				
			||||||
}
 | 
					 | 
				
			||||||
							
								
								
									
										52
									
								
								ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f16.comp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
				
			|||||||
 | 
					#version 450
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "rope_common.comp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					layout(binding = 0) buffer restrict readonly  tensorInA { float16_t inA[]; };
 | 
				
			||||||
 | 
					layout(binding = 1) buffer restrict readonly  tensorInB { int       inB[]; };
 | 
				
			||||||
 | 
					layout(binding = 2) buffer restrict readonly  tensorInC { float     inC[]; };
 | 
				
			||||||
 | 
					layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void main() {
 | 
				
			||||||
 | 
					    const uint i3 = gl_WorkGroupID.z;
 | 
				
			||||||
 | 
					    const uint i2 = gl_WorkGroupID.y;
 | 
				
			||||||
 | 
					    const uint i1 = gl_WorkGroupID.x;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float corr_dims[2];
 | 
				
			||||||
 | 
					    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float theta_base = float(inB[pcs.inBOff + i2]);
 | 
				
			||||||
 | 
					    float inv_ndims = -1.f/pcs.n_dims;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float cos_theta;
 | 
				
			||||||
 | 
					    float sin_theta;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
 | 
				
			||||||
 | 
					        if (i0 < pcs.n_dims) {
 | 
				
			||||||
 | 
					            uint ic = i0/2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 2) + pcs.inAOff; // Based from in
 | 
				
			||||||
 | 
					            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + ic*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            const float x0 = float(inA[src]);
 | 
				
			||||||
 | 
					            const float x1 = float(inA[src+pcs.n_dims/2]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            out_[dst_data]              = float16_t(x0*cos_theta - x1*sin_theta);
 | 
				
			||||||
 | 
					            out_[dst_data+pcs.n_dims/2] = float16_t(x0*sin_theta + x1*cos_theta);
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
 | 
				
			||||||
 | 
					            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            out_[dst_data]   = inA[src];
 | 
				
			||||||
 | 
					            out_[dst_data+1] = inA[src+1];
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										52
									
								
								ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								ggml/src/ggml-kompute/kompute-shaders/op_rope_neox_f32.comp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
				
			|||||||
 | 
					#version 450
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "rope_common.comp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					layout(binding = 0) buffer restrict readonly  tensorInA { float inA[]; };
 | 
				
			||||||
 | 
					layout(binding = 1) buffer restrict readonly  tensorInB { int       inB[]; };
 | 
				
			||||||
 | 
					layout(binding = 2) buffer restrict readonly  tensorInC { float inC[]; };
 | 
				
			||||||
 | 
					layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void main() {
 | 
				
			||||||
 | 
					    const uint i3 = gl_WorkGroupID.z;
 | 
				
			||||||
 | 
					    const uint i2 = gl_WorkGroupID.y;
 | 
				
			||||||
 | 
					    const uint i1 = gl_WorkGroupID.x;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float corr_dims[2];
 | 
				
			||||||
 | 
					    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float theta_base = float(inB[pcs.inBOff + i2]);
 | 
				
			||||||
 | 
					    float inv_ndims = -1.f/pcs.n_dims;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float cos_theta;
 | 
				
			||||||
 | 
					    float sin_theta;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
 | 
				
			||||||
 | 
					        if (i0 < pcs.n_dims) {
 | 
				
			||||||
 | 
					            uint ic = i0/2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + ic*pcs.nb00) / 4) + pcs.inAOff; // Based from in
 | 
				
			||||||
 | 
					            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + ic*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            const float x0 = inA[src];
 | 
				
			||||||
 | 
					            const float x1 = inA[src+pcs.n_dims/2];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            out_[dst_data]              = x0*cos_theta - x1*sin_theta;
 | 
				
			||||||
 | 
					            out_[dst_data+pcs.n_dims/2] = x0*sin_theta + x1*cos_theta;
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
 | 
				
			||||||
 | 
					            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            out_[dst_data]   = inA[src];
 | 
				
			||||||
 | 
					            out_[dst_data+1] = inA[src+1];
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										52
									
								
								ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f16.comp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
				
			|||||||
 | 
					#version 450
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "rope_common.comp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					layout(binding = 0) buffer restrict readonly  tensorInA { float16_t inA[]; };
 | 
				
			||||||
 | 
					layout(binding = 1) buffer restrict readonly  tensorInB { int       inB[]; };
 | 
				
			||||||
 | 
					layout(binding = 2) buffer restrict readonly  tensorInC { float     inC[]; };
 | 
				
			||||||
 | 
					layout(binding = 3) buffer restrict writeonly tensorOut { float16_t out_[]; };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void main() {
 | 
				
			||||||
 | 
					    const uint i3 = gl_WorkGroupID.z;
 | 
				
			||||||
 | 
					    const uint i2 = gl_WorkGroupID.y;
 | 
				
			||||||
 | 
					    const uint i1 = gl_WorkGroupID.x;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float corr_dims[2];
 | 
				
			||||||
 | 
					    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float theta_base = float(inB[pcs.inBOff + i2]);
 | 
				
			||||||
 | 
					    float inv_ndims = -1.f/pcs.n_dims;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float cos_theta;
 | 
				
			||||||
 | 
					    float sin_theta;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
 | 
				
			||||||
 | 
					        if (i0 < pcs.n_dims) {
 | 
				
			||||||
 | 
					            uint ic = i0/2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
 | 
				
			||||||
 | 
					            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            const float x0 = float(inA[src]);
 | 
				
			||||||
 | 
					            const float x1 = float(inA[src+1]);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            out_[dst_data]   = float16_t(x0*cos_theta - x1*sin_theta);
 | 
				
			||||||
 | 
					            out_[dst_data+1] = float16_t(x0*sin_theta + x1*cos_theta);
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 2) + pcs.inAOff; // Based from in
 | 
				
			||||||
 | 
					            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 2) + pcs.outOff; // Based from out_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            out_[dst_data]   = inA[src];
 | 
				
			||||||
 | 
					            out_[dst_data+1] = inA[src+1];
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
							
								
								
									
										52
									
								
								ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								ggml/src/ggml-kompute/kompute-shaders/op_rope_norm_f32.comp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
				
			|||||||
 | 
					#version 450
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					#include "rope_common.comp"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					layout(binding = 0) buffer restrict readonly  tensorInA { float inA[]; };
 | 
				
			||||||
 | 
					layout(binding = 1) buffer restrict readonly  tensorInB { int   inB[]; };
 | 
				
			||||||
 | 
					layout(binding = 2) buffer restrict readonly  tensorInC { float inC[]; };
 | 
				
			||||||
 | 
					layout(binding = 3) buffer restrict writeonly tensorOut { float out_[]; };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					void main() {
 | 
				
			||||||
 | 
					    const uint i3 = gl_WorkGroupID.z;
 | 
				
			||||||
 | 
					    const uint i2 = gl_WorkGroupID.y;
 | 
				
			||||||
 | 
					    const uint i1 = gl_WorkGroupID.x;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float corr_dims[2];
 | 
				
			||||||
 | 
					    rope_yarn_corr_dims(pcs.n_dims, pcs.n_ctx_orig, pcs.freq_base, pcs.beta_fast, pcs.beta_slow, corr_dims);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    const float theta_scale = pow(pcs.freq_base, -2.0/pcs.n_dims);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float theta_base = float(inB[pcs.inBOff + i2]);
 | 
				
			||||||
 | 
					    float inv_ndims = -1.f/pcs.n_dims;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float cos_theta;
 | 
				
			||||||
 | 
					    float sin_theta;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for (uint i0 = 2*gl_LocalInvocationIndex; i0 < pcs.ne0; i0 += 2*gl_WorkGroupSize.x) {
 | 
				
			||||||
 | 
					        if (i0 < pcs.n_dims) {
 | 
				
			||||||
 | 
					            uint ic = i0/2;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            float theta = theta_base * pow(pcs.freq_base, inv_ndims*i0);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            const float freq_factor = pcs.has_freq_factors ? inC[pcs.inCOff + ic] : 1.0f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            rope_yarn(theta/freq_factor, pcs.freq_scale, corr_dims, i0, pcs.ext_factor, pcs.attn_factor, cos_theta, sin_theta);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
 | 
				
			||||||
 | 
					            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            const float x0 = inA[src];
 | 
				
			||||||
 | 
					            const float x1 = inA[src+1];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            out_[dst_data]   = x0*cos_theta - x1*sin_theta;
 | 
				
			||||||
 | 
					            out_[dst_data+1] = x0*sin_theta + x1*cos_theta;
 | 
				
			||||||
 | 
					        } else {
 | 
				
			||||||
 | 
					            const uint src      = uint((i3*pcs.nb03 + i2*pcs.nb02 + i1*pcs.nb01 + i0*pcs.nb00) / 4) + pcs.inAOff; // Based from in
 | 
				
			||||||
 | 
					            const uint dst_data = uint((i3*pcs.nb3  + i2*pcs.nb2  + i1*pcs.nb1  + i0*pcs.nb0)  / 4) + pcs.outOff; // Based from out_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            out_[dst_data]   = inA[src];
 | 
				
			||||||
 | 
					            out_[dst_data+1] = inA[src+1];
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
@@ -18,6 +18,10 @@ layout(push_constant) uniform PushConstants {
 | 
				
			|||||||
    int ne01;
 | 
					    int ne01;
 | 
				
			||||||
    int ne02;
 | 
					    int ne02;
 | 
				
			||||||
    float scale;
 | 
					    float scale;
 | 
				
			||||||
 | 
					    float max_bias;
 | 
				
			||||||
 | 
					    float m0;
 | 
				
			||||||
 | 
					    float m1;
 | 
				
			||||||
 | 
					    uint n_head_log2;
 | 
				
			||||||
    int mask;
 | 
					    int mask;
 | 
				
			||||||
} pcs;
 | 
					} pcs;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -34,17 +38,29 @@ void main() {
 | 
				
			|||||||
    const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
 | 
					    const uint pmask = i01*pcs.ne00 + pcs.inBOff; // Based from inB
 | 
				
			||||||
    const uint pdst = extra_off + pcs.outOff; // Based from out_
 | 
					    const uint pdst = extra_off + pcs.outOff; // Based from out_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    float slope = 1.0f;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    // ALiBi
 | 
				
			||||||
 | 
					    if (pcs.max_bias > 0.0f) {
 | 
				
			||||||
 | 
					        int64_t h = i02;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        float base = h < pcs.n_head_log2 ? pcs.m0 : pcs.m1;
 | 
				
			||||||
 | 
					        int64_t exp = h < pcs.n_head_log2 ? h + 1 : 2*(h - pcs.n_head_log2) + 1;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        slope = pow(base, float(exp));
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // parallel max
 | 
					    // parallel max
 | 
				
			||||||
    float localMax = uintBitsToFloat(0xFF800000);
 | 
					    float localMax = uintBitsToFloat(0xFF800000);
 | 
				
			||||||
    for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
 | 
					    for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
 | 
				
			||||||
        localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f));
 | 
					        localMax = max(localMax, inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f));
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
    float max_ = subgroupMax(localMax);
 | 
					    float max_ = subgroupMax(localMax);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    // parallel sum
 | 
					    // parallel sum
 | 
				
			||||||
    float localSum = 0.0f;
 | 
					    float localSum = 0.0f;
 | 
				
			||||||
    for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
 | 
					    for (uint i00 = gl_SubgroupInvocationID.x; i00 < pcs.ne00; i00 += 32) {
 | 
				
			||||||
        const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? inB[pmask + i00] : 0.0f) - max_);
 | 
					        const float exp_psrc0 = exp(inA[psrc0 + i00]*pcs.scale + (pcs.mask!=0 ? slope*inB[pmask + i00] : 0.0f) - max_);
 | 
				
			||||||
        localSum += exp_psrc0;
 | 
					        localSum += exp_psrc0;
 | 
				
			||||||
        out_[pdst + i00] = exp_psrc0;
 | 
					        out_[pdst + i00] = exp_psrc0;
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -8,12 +8,14 @@ layout(local_size_x = 1) in;
 | 
				
			|||||||
layout (push_constant) uniform parameter {
 | 
					layout (push_constant) uniform parameter {
 | 
				
			||||||
    uint inAOff;
 | 
					    uint inAOff;
 | 
				
			||||||
    uint inBOff;
 | 
					    uint inBOff;
 | 
				
			||||||
 | 
					    uint inCOff;
 | 
				
			||||||
    uint outOff;
 | 
					    uint outOff;
 | 
				
			||||||
    int n_dims;
 | 
					    int n_dims;
 | 
				
			||||||
    int mode;
 | 
					    int mode;
 | 
				
			||||||
    int n_ctx_orig;
 | 
					    int n_ctx_orig;
 | 
				
			||||||
    float freq_base;
 | 
					    float freq_base;
 | 
				
			||||||
    float freq_scale;
 | 
					    float freq_scale;
 | 
				
			||||||
 | 
					    bool has_freq_factors;
 | 
				
			||||||
    float ext_factor;
 | 
					    float ext_factor;
 | 
				
			||||||
    float attn_factor;
 | 
					    float attn_factor;
 | 
				
			||||||
    float beta_fast;
 | 
					    float beta_fast;
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user