mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	vulkan: add RTE variants for glu/add/sub/mul/div (#14653)
This commit is contained in:
		| @@ -2835,10 +2835,11 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
|         return s; | ||||
|     }; | ||||
|  | ||||
|     bool rte = device->float_controls_rte_fp16; | ||||
| #define CREATE_BINARY(name, namemod, spec) \ | ||||
|     for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ | ||||
|                                 #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \ | ||||
|                                 #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ | ||||
|                                 "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); | ||||
|  | ||||
|     CREATE_BINARY(add, , {0}) | ||||
| @@ -2890,8 +2891,13 @@ static void ggml_vk_load_shaders(vk_device& device) { | ||||
| #undef CREATE_UNARY | ||||
|  | ||||
| #define CREATE_GLU(name)  \ | ||||
|     if (device->float_controls_rte_fp16) {  \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \ | ||||
|     } else {    \ | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \ | ||||
|     ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); | ||||
|         ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \ | ||||
|     } | ||||
|  | ||||
|     CREATE_GLU(geglu) | ||||
|     CREATE_GLU(reglu) | ||||
|   | ||||
| @@ -1,10 +1,6 @@ | ||||
| #version 450 | ||||
|  | ||||
| #if RTE16 | ||||
| #extension GL_EXT_spirv_intrinsics : enable | ||||
| spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits | ||||
| #endif // RTE16 | ||||
|  | ||||
| #include "rte.comp" | ||||
| #include "types.comp" | ||||
|  | ||||
| #if defined(SET_ROWS) && QUANT_K == 1 | ||||
|   | ||||
| @@ -1,6 +1,8 @@ | ||||
| #extension GL_EXT_shader_16bit_storage : require | ||||
| #extension GL_EXT_control_flow_attributes : require | ||||
|  | ||||
| #include "rte.comp" | ||||
|  | ||||
| layout (push_constant) uniform parameter | ||||
| { | ||||
|     uint ne; | ||||
|   | ||||
| @@ -1,5 +1,7 @@ | ||||
| #extension GL_EXT_shader_16bit_storage : require | ||||
|  | ||||
| #include "rte.comp" | ||||
|  | ||||
| layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; | ||||
|  | ||||
| layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; | ||||
|   | ||||
| @@ -1,12 +1,9 @@ | ||||
| #version 450 | ||||
|  | ||||
| #extension GL_EXT_shader_16bit_storage : require | ||||
| #extension GL_EXT_spirv_intrinsics: enable | ||||
| #extension GL_EXT_control_flow_attributes : require | ||||
|  | ||||
| #if RTE16 | ||||
| spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits | ||||
| #endif | ||||
| #include "rte.comp" | ||||
|  | ||||
| layout (push_constant) uniform parameter | ||||
| { | ||||
|   | ||||
| @@ -1,11 +1,8 @@ | ||||
| #include "types.comp" | ||||
|  | ||||
| #extension GL_EXT_shader_16bit_storage : require | ||||
| #extension GL_EXT_spirv_intrinsics: enable | ||||
|  | ||||
| #if RTE16 | ||||
| spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits | ||||
| #endif | ||||
| #include "rte.comp" | ||||
|  | ||||
| layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; | ||||
|  | ||||
|   | ||||
							
								
								
									
										5
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/rte.comp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								ggml/src/ggml-vulkan/vulkan-shaders/rte.comp
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
|  | ||||
| #if RTE16 | ||||
| #extension GL_EXT_spirv_intrinsics : enable | ||||
| spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits | ||||
| #endif // RTE16 | ||||
| @@ -537,8 +537,10 @@ void process_shaders() { | ||||
|     for (auto src0_f16 : {false, true}) { | ||||
|     for (auto src1_f16 : {false, true}) { | ||||
|     for (auto dst_f16  : {false, true}) { | ||||
|         auto name = op + get_suffix(src0_f16, src1_f16, dst_f16); | ||||
|         string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}}); | ||||
|     for (auto rte      : {false, true}) { | ||||
|         auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : ""); | ||||
|         string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); | ||||
|     } | ||||
|     } | ||||
|     } | ||||
|     } | ||||
| @@ -592,16 +594,19 @@ void process_shaders() { | ||||
|     string_to_spv("sigmoid_f16",    "sigmoid.comp",     {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}}); | ||||
|     string_to_spv("sigmoid_f32",    "sigmoid.comp",     {{"A_TYPE", "float"},       {"D_TYPE", "float"}}); | ||||
|  | ||||
|     string_to_spv("geglu_f16",      "geglu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}}); | ||||
|     string_to_spv("geglu_f32",      "geglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}}); | ||||
|     string_to_spv("reglu_f16",      "reglu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}}); | ||||
|     string_to_spv("reglu_f32",      "reglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}}); | ||||
|     string_to_spv("swiglu_f16",     "swiglu.comp",      {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}}); | ||||
|     string_to_spv("swiglu_f32",     "swiglu.comp",      {{"A_TYPE", "float"},       {"D_TYPE", "float"}}); | ||||
|     string_to_spv("geglu_erf_f16",  "geglu_erf.comp",   {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}}); | ||||
|     string_to_spv("geglu_erf_f32",  "geglu_erf.comp",   {{"A_TYPE", "float"},       {"D_TYPE", "float"}}); | ||||
|     string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}}); | ||||
|     string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"},       {"D_TYPE", "float"}}); | ||||
|     for (auto rte : {false, true}) { | ||||
|         std::string suffix = rte ? "_rte" : ""; | ||||
|         string_to_spv("geglu_f16" + suffix,      "geglu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}}); | ||||
|         string_to_spv("geglu_f32" + suffix,      "geglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}}); | ||||
|         string_to_spv("reglu_f16" + suffix,      "reglu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}}); | ||||
|         string_to_spv("reglu_f32" + suffix,      "reglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}}); | ||||
|         string_to_spv("swiglu_f16" + suffix,     "swiglu.comp",      {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}}); | ||||
|         string_to_spv("swiglu_f32" + suffix,     "swiglu.comp",      {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}}); | ||||
|         string_to_spv("geglu_erf_f16" + suffix,  "geglu_erf.comp",   {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}}); | ||||
|         string_to_spv("geglu_erf_f32" + suffix,  "geglu_erf.comp",   {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}}); | ||||
|         string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}}); | ||||
|         string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}}); | ||||
|     } | ||||
|  | ||||
|     string_to_spv("leaky_relu_f32", "leaky_relu.comp",  {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); | ||||
|     string_to_spv("silu_back_f32",  "silu_back.comp",   {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); | ||||
| @@ -709,11 +714,59 @@ void write_output_files() { | ||||
|             std::remove(path.c_str()); | ||||
|         } | ||||
|     } | ||||
|  | ||||
|     std::string suffixes[2] = {"_f32", "_f16"}; | ||||
|     for (const char *op : {"add", "sub", "mul", "div"}) { | ||||
|         fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op); | ||||
|         fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op); | ||||
|         fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op); | ||||
|         fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op); | ||||
|         fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op); | ||||
|         fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op); | ||||
|         std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = "; | ||||
|         std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = "; | ||||
|         for (uint32_t t0 = 0; t0 < 2; ++t0) { | ||||
|             if (t0 == 0) { | ||||
|                 data += "{"; | ||||
|                 len += "{"; | ||||
|             } | ||||
|             for (uint32_t t1 = 0; t1 < 2; ++t1) { | ||||
|                 if (t1 == 0) { | ||||
|                     data += "{"; | ||||
|                     len += "{"; | ||||
|                 } | ||||
|                 for (uint32_t t2 = 0; t2 < 2; ++t2) { | ||||
|                     if (t2 == 0) { | ||||
|                         data += "{"; | ||||
|                         len += "{"; | ||||
|                     } | ||||
|                     for (uint32_t rte = 0; rte < 2; ++rte) { | ||||
|                         if (rte == 0) { | ||||
|                             data += "{"; | ||||
|                             len += "{"; | ||||
|                         } | ||||
|                         data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); | ||||
|                         len  += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : ""); | ||||
|                         data += "_data,"; | ||||
|                         len  += "_len,"; | ||||
|                         if (rte == 1) { | ||||
|                             data += "}, "; | ||||
|                             len += "}, "; | ||||
|                         } | ||||
|                     } | ||||
|                     if (t2 == 1) { | ||||
|                         data += "}, "; | ||||
|                         len += "}, "; | ||||
|                     } | ||||
|                 } | ||||
|                 if (t1 == 1) { | ||||
|                     data += "}, "; | ||||
|                     len += "}, "; | ||||
|                 } | ||||
|             } | ||||
|             if (t0 == 1) { | ||||
|                 data += "};\n"; | ||||
|                 len += "};\n"; | ||||
|             } | ||||
|         } | ||||
|         fprintf(src, data.c_str()); | ||||
|         fprintf(src, len.c_str()); | ||||
|     } | ||||
|     fclose(hdr); | ||||
|     fclose(src); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Jeff Bolz
					Jeff Bolz