Files
llama.cpp/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py
Reese Levine 45363632cb ggml WebGPU: add support for quantization types (#15440)
* Begin work on set_rows

* Work on set rows

* Add error buffers for reporting unsupported SET_ROWS indices

* Remove extra comments

* Work on templating for different types in shaders

* Work on shader type generation

* Working q4_0 mul_mat and some templating for different types

* Add q4_0_f16 matmul and fix device init

* Add matmul support for basic quantization types

* Add q2_k and q3_k quantization

* Add rest of k-quants

* Get firt i-quant working

* Closer to supporting all i-quants

* Support rest of i-quants

* Cleanup code

* Fix python formatting

* debug

* Bugfix for memset

* Add padding to end of buffers on creation

* Simplify bit-shifting

* Update usage of StringView
2025-08-22 11:28:03 -07:00

86 lines
2.9 KiB
Python
Executable File

import os
import re
import ast
import argparse
def extract_block(text, name):
pattern = rf'#define\({name}\)\s*(.*?)#end\({name}\)'
match = re.search(pattern, text, re.DOTALL)
if not match:
raise ValueError(f"Missing block: {name}")
return match.group(1).strip()
def parse_decls(decls_text):
decls = {}
for name, code in re.findall(r'#decl\((.*?)\)\s*(.*?)#enddecl\(\1\)', decls_text, re.DOTALL):
decls[name.strip()] = code.strip()
return decls
def replace_placeholders(shader_text, replacements):
for key, val in replacements.items():
# Match {{KEY}} literally, where KEY is escaped
pattern = r'{{\s*' + re.escape(key) + r'\s*}}'
shader_text = re.sub(pattern, str(val), shader_text)
return shader_text
def write_shader(shader_name, shader_code, output_dir, outfile):
if output_dir:
wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl")
with open(wgsl_filename, "w", encoding="utf-8") as f_out:
f_out.write(shader_code)
outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n')
def generate_variants(shader_path, output_dir, outfile):
shader_base_name = shader_path.split("/")[-1].split(".")[0]
with open(shader_path, "r", encoding="utf-8") as f:
text = f.read()
try:
variants = ast.literal_eval(extract_block(text, "VARIANTS"))
except ValueError:
write_shader(shader_base_name, text, output_dir, outfile)
else:
decls_map = parse_decls(extract_block(text, "DECLS"))
shader_template = extract_block(text, "SHADER")
for variant in variants:
decls = variant["DECLS"]
decls_code = ""
for key in decls:
if key not in decls_map:
raise ValueError(f"DECLS key '{key}' not found.")
decls_code += decls_map[key] + "\n\n"
shader_variant = replace_placeholders(shader_template, variant["REPLS"])
final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant)
output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]])
write_shader(output_name, final_shader, output_dir, outfile)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", required=True)
parser.add_argument("--output_file", required=True)
parser.add_argument("--output_dir")
args = parser.parse_args()
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
with open(args.output_file, "w", encoding="utf-8") as out:
out.write("// Auto-generated shader embedding\n\n")
for fname in sorted(os.listdir(args.input_dir)):
if fname.endswith(".wgsl"):
generate_variants(os.path.join(args.input_dir, fname), args.output_dir, out)
if __name__ == "__main__":
main()