mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	vulkan: Additional type support for unary, binary, and copy (#13266)
Support f16->f32 copy. Support f16->f16 and f32->f32 unary ops. Support all combinations of f16/f32 for src0/src1/dst for add/sub/mul/div.
This commit is contained in:
		@@ -17,5 +17,5 @@ void main() {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    data_d[i] = max(float(data_a[i]), 0);
 | 
			
		||||
    data_d[i] = D_TYPE(max(float(data_a[i]), 0));
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -16,5 +16,5 @@ void main() {
 | 
			
		||||
    if (i >= p.KX) {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
    data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i])));
 | 
			
		||||
    data_d[i] = D_TYPE(1. / (1 + exp(-1. * float(data_a[i]))));
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -16,5 +16,5 @@ void main() {
 | 
			
		||||
    if (i >= p.KX) {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
    data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.));
 | 
			
		||||
    data_d[i] = D_TYPE(1. - 2. / (exp(2.*float(data_a[i])) + 1.));
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -485,10 +485,12 @@ void process_shaders() {
 | 
			
		||||
    string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
 | 
			
		||||
    string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
 | 
			
		||||
    string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
 | 
			
		||||
    string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
 | 
			
		||||
    string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
 | 
			
		||||
    string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
 | 
			
		||||
    string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
 | 
			
		||||
    string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
 | 
			
		||||
 | 
			
		||||
    for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
 | 
			
		||||
@@ -497,8 +499,26 @@ void process_shaders() {
 | 
			
		||||
        string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
 | 
			
		||||
    auto get_type_str = [](bool f16) {
 | 
			
		||||
        return f16 ? "float16_t" : "float";
 | 
			
		||||
    };
 | 
			
		||||
    auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
 | 
			
		||||
        std::string s;
 | 
			
		||||
        s += std::string(src0_f16 ? "_f16" : "_f32");
 | 
			
		||||
        s += std::string(src1_f16 ? "_f16" : "_f32");
 | 
			
		||||
        s += std::string(dst_f16 ? "_f16" : "_f32");
 | 
			
		||||
        return s;
 | 
			
		||||
    };
 | 
			
		||||
    for (std::string op : {"add", "sub", "mul", "div"}) {
 | 
			
		||||
    for (auto src0_f16 : {false, true}) {
 | 
			
		||||
    for (auto src1_f16 : {false, true}) {
 | 
			
		||||
    for (auto dst_f16  : {false, true}) {
 | 
			
		||||
        auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
 | 
			
		||||
        string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
 | 
			
		||||
    }
 | 
			
		||||
    }
 | 
			
		||||
    }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 | 
			
		||||
 | 
			
		||||
@@ -533,14 +553,21 @@ void process_shaders() {
 | 
			
		||||
 | 
			
		||||
    string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
 | 
			
		||||
    string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("gelu_f16",       "gelu.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 | 
			
		||||
    string_to_spv("gelu_f32",       "gelu.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("gelu_quick_f16", "gelu_quick.comp",  {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 | 
			
		||||
    string_to_spv("gelu_quick_f32", "gelu_quick.comp",  {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("silu_f16",       "silu.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 | 
			
		||||
    string_to_spv("silu_f32",       "silu.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("relu_f16",       "relu.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 | 
			
		||||
    string_to_spv("relu_f32",       "relu.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("tanh_f16",       "tanh.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 | 
			
		||||
    string_to_spv("tanh_f32",       "tanh.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("sigmoid_f16",    "sigmoid.comp",     {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
 | 
			
		||||
    string_to_spv("sigmoid_f32",    "sigmoid.comp",     {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 | 
			
		||||
 | 
			
		||||
    string_to_spv("leaky_relu_f32", "leaky_relu.comp",  {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
    string_to_spv("silu_back_f32",  "silu_back.comp",   {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
 | 
			
		||||
    string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 | 
			
		||||
 | 
			
		||||
@@ -641,7 +668,12 @@ void write_output_files() {
 | 
			
		||||
            std::remove(path.c_str());
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (const char *op : {"add", "sub", "mul", "div"}) {
 | 
			
		||||
        fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
 | 
			
		||||
        fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
 | 
			
		||||
        fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
 | 
			
		||||
        fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
 | 
			
		||||
    }
 | 
			
		||||
    fclose(hdr);
 | 
			
		||||
    fclose(src);
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user