mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Docs: script to auto-generate ggml operations docs (#14598)
* Docs: script to auto-generate ggml operations docs * Review: formatting changes + change github action * Use built-in types instead of typing * docs : add BLAS and Metal ops --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
		
							
								
								
									
										40
									
								
								.github/workflows/update-ops-docs.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										40
									
								
								.github/workflows/update-ops-docs.yml
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @@ -0,0 +1,40 @@ | ||||
| name: Update Operations Documentation | ||||
|  | ||||
| on: | ||||
|     push: | ||||
|         paths: | ||||
|             - 'docs/ops/**' | ||||
|             - 'scripts/create_ops_docs.py' | ||||
|     pull_request: | ||||
|         paths: | ||||
|             - 'docs/ops/**' | ||||
|             - 'scripts/create_ops_docs.py' | ||||
|  | ||||
| jobs: | ||||
|     update-ops-docs: | ||||
|         runs-on: ubuntu-latest | ||||
|  | ||||
|         steps: | ||||
|         - name: Checkout repository | ||||
|           uses: actions/checkout@v4 | ||||
|  | ||||
|         - name: Set up Python | ||||
|           uses: actions/setup-python@v5 | ||||
|           with: | ||||
|               python-version: '3.x' | ||||
|  | ||||
|         - name: Generate operations documentation to temporary file | ||||
|           run: | | ||||
|               mkdir -p /tmp/ops_check | ||||
|               ./scripts/create_ops_docs.py /tmp/ops_check/ops.md | ||||
|  | ||||
|         - name: Check if docs/ops.md matches generated version | ||||
|           run: | | ||||
|               if ! diff -q docs/ops.md /tmp/ops_check/ops.md; then | ||||
|                   echo "Operations documentation (docs/ops.md) is not up to date with the backend CSV files." | ||||
|                   echo "To fix: run ./scripts/create_ops_docs.py and commit the updated docs/ops.md along with your changes" | ||||
|                   echo "Differences found:" | ||||
|                   diff docs/ops.md /tmp/ops_check/ops.md || true | ||||
|                   exit 1 | ||||
|               fi | ||||
|               echo "Operations documentation is up to date." | ||||
							
								
								
									
										95
									
								
								docs/ops.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										95
									
								
								docs/ops.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,95 @@ | ||||
| # GGML Operations | ||||
|  | ||||
| List of GGML operations and backend support status. | ||||
|  | ||||
| Legend: | ||||
| - ✅ Fully supported by this backend | ||||
| - 🟡 Partially supported by this backend | ||||
| - ❌ Not supported by this backend | ||||
|  | ||||
| | Operation | BLAS | CPU | CUDA | Metal | | ||||
| |-----------|------|------|------|------| | ||||
| |                              ABS | ❌ | ✅ | 🟡 | ❌ | | ||||
| |                              ACC | ❌ | ✅ | ✅ | ✅ | | ||||
| |                              ADD | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                             ADD1 | ❌ | ✅ | ✅ | ❌ | | ||||
| |                           ARANGE | ❌ | ✅ | ✅ | ✅ | | ||||
| |                           ARGMAX | ❌ | ✅ | ✅ | ✅ | | ||||
| |                          ARGSORT | ❌ | ✅ | ✅ | ✅ | | ||||
| |                            CLAMP | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                           CONCAT | ❌ | ✅ | 🟡 | ✅ | | ||||
| |                             CONT | ❌ | ✅ | 🟡 | ✅ | | ||||
| |                       CONV_2D_DW | ❌ | ✅ | ✅ | ❌ | | ||||
| |                CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | | ||||
| |                CONV_TRANSPOSE_2D | ❌ | ✅ | ✅ | ❌ | | ||||
| |                              COS | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                      COUNT_EQUAL | ❌ | ✅ | ✅ | ❌ | | ||||
| |                              CPY | ❌ | 🟡 | 🟡 | 🟡 | | ||||
| |               CROSS_ENTROPY_LOSS | ❌ | ✅ | ✅ | ❌ | | ||||
| |          CROSS_ENTROPY_LOSS_BACK | ❌ | ✅ | ✅ | ❌ | | ||||
| |                    DIAG_MASK_INF | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                              DIV | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                              DUP | ❌ | ✅ | 🟡 | 🟡 | | ||||
| |                              ELU | ❌ | ✅ | ❌ | 🟡 | | ||||
| |                              EXP | ❌ | ✅ | 🟡 | ❌ | | ||||
| |                   FLASH_ATTN_EXT | ❌ | ✅ | 🟡 | 🟡 | | ||||
| |                GATED_LINEAR_ATTN | ❌ | ✅ | ✅ | ❌ | | ||||
| |                            GEGLU | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                        GEGLU_ERF | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                      GEGLU_QUICK | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                             GELU | ❌ | ✅ | 🟡 | 🟡 | | ||||
| |                         GELU_ERF | ❌ | ✅ | 🟡 | 🟡 | | ||||
| |                       GELU_QUICK | ❌ | ✅ | 🟡 | 🟡 | | ||||
| |                         GET_ROWS | ❌ | ✅ | 🟡 | ✅ | | ||||
| |                    GET_ROWS_BACK | ❌ | 🟡 | 🟡 | ❌ | | ||||
| |                       GROUP_NORM | ❌ | ✅ | ✅ | ✅ | | ||||
| |                      HARDSIGMOID | ❌ | ✅ | 🟡 | ❌ | | ||||
| |                        HARDSWISH | ❌ | ✅ | 🟡 | ❌ | | ||||
| |                           IM2COL | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                          L2_NORM | ❌ | ✅ | ✅ | ✅ | | ||||
| |                       LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | | ||||
| |                              LOG | ❌ | ✅ | ✅ | ❌ | | ||||
| |                             MEAN | ❌ | ✅ | ✅ | ✅ | | ||||
| |                              MUL | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                          MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | | ||||
| |                       MUL_MAT_ID | ❌ | ✅ | ✅ | ✅ | | ||||
| |                              NEG | ❌ | ✅ | 🟡 | 🟡 | | ||||
| |                             NORM | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                   OPT_STEP_ADAMW | ❌ | ✅ | ✅ | ❌ | | ||||
| |                         OUT_PROD | 🟡 | 🟡 | 🟡 | ❌ | | ||||
| |                              PAD | ❌ | ✅ | ✅ | ✅ | | ||||
| |                   PAD_REFLECT_1D | ❌ | ✅ | ❌ | ✅ | | ||||
| |                          POOL_2D | ❌ | ✅ | ✅ | ✅ | | ||||
| |                            REGLU | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                             RELU | ❌ | ✅ | 🟡 | 🟡 | | ||||
| |                           REPEAT | ❌ | ✅ | 🟡 | ✅ | | ||||
| |                      REPEAT_BACK | ❌ | ✅ | ✅ | ❌ | | ||||
| |                         RMS_NORM | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                    RMS_NORM_BACK | ❌ | ✅ | ✅ | ❌ | | ||||
| |                     RMS_NORM_MUL | ❌ | ✅ | ✅ | ✅ | | ||||
| |                             ROPE | ❌ | ✅ | ✅ | ✅ | | ||||
| |                        ROPE_BACK | ❌ | ✅ | ✅ | ❌ | | ||||
| |                        RWKV_WKV6 | ❌ | ✅ | ✅ | ✅ | | ||||
| |                        RWKV_WKV7 | ❌ | ✅ | ✅ | ✅ | | ||||
| |                            SCALE | ❌ | ✅ | ✅ | ✅ | | ||||
| |                              SET | ❌ | ✅ | ❌ | ✅ | | ||||
| |                         SET_ROWS | ❌ | 🟡 | ❌ | 🟡 | | ||||
| |                              SGN | ❌ | ✅ | 🟡 | ❌ | | ||||
| |                          SIGMOID | ❌ | ✅ | 🟡 | 🟡 | | ||||
| |                             SILU | ❌ | ✅ | 🟡 | 🟡 | | ||||
| |                        SILU_BACK | ❌ | ✅ | ✅ | ❌ | | ||||
| |                              SIN | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                         SOFT_MAX | ❌ | ✅ | ✅ | ✅ | | ||||
| |                    SOFT_MAX_BACK | ❌ | 🟡 | 🟡 | ❌ | | ||||
| |                              SQR | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                             SQRT | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                         SSM_CONV | ❌ | ✅ | ✅ | ✅ | | ||||
| |                         SSM_SCAN | ❌ | ✅ | ✅ | ✅ | | ||||
| |                             STEP | ❌ | ✅ | 🟡 | ❌ | | ||||
| |                              SUB | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                              SUM | ❌ | ✅ | ✅ | ❌ | | ||||
| |                         SUM_ROWS | ❌ | ✅ | ✅ | ✅ | | ||||
| |                           SWIGLU | ❌ | ✅ | ✅ | 🟡 | | ||||
| |                             TANH | ❌ | ✅ | 🟡 | 🟡 | | ||||
| |               TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | | ||||
| |                          UPSCALE | ❌ | ✅ | ✅ | 🟡 | | ||||
							
								
								
									
										6534
									
								
								docs/ops/BLAS.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6534
									
								
								docs/ops/BLAS.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										6534
									
								
								docs/ops/CPU.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6534
									
								
								docs/ops/CPU.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										6534
									
								
								docs/ops/CUDA.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6534
									
								
								docs/ops/CUDA.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										6534
									
								
								docs/ops/Metal.csv
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6534
									
								
								docs/ops/Metal.csv
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										196
									
								
								scripts/create_ops_docs.py
									
									
									
									
									
										Executable file
									
								
							
							
						
						
									
										196
									
								
								scripts/create_ops_docs.py
									
									
									
									
									
										Executable file
									
								
							| @@ -0,0 +1,196 @@ | ||||
| #!/usr/bin/env python3 | ||||
|  | ||||
| """ | ||||
| This script parses docs/ops/*.csv and creates the ops.md, which is a table documenting supported operations on various ggml backends. | ||||
| """ | ||||
| import csv | ||||
| import logging | ||||
| import sys | ||||
| from pathlib import Path | ||||
| from collections import defaultdict | ||||
|  | ||||
|  | ||||
| class DocsGenerator: | ||||
|     def __init__(self, ggml_root: str, output_filename: str = "ops.md"): | ||||
|         self.ggml_root = Path(ggml_root) | ||||
|         self.ops_dir = self.ggml_root / "docs" / "ops" | ||||
|         self.output_filename = output_filename | ||||
|         self.backend_support: dict[str, dict[str, list[bool]]] = defaultdict( | ||||
|             lambda: defaultdict(list) | ||||
|         ) | ||||
|         self.all_operations: set[str] = set() | ||||
|         self.all_backends: set[str] = set() | ||||
|         self.logger = logging.getLogger(__name__) | ||||
|  | ||||
|     def parse_support_files(self) -> None: | ||||
|         if not self.ops_dir.exists(): | ||||
|             self.logger.warning(f"ops directory not found: {self.ops_dir}") | ||||
|             return | ||||
|  | ||||
|         self.logger.info(f"Parsing support files from {self.ops_dir}...") | ||||
|  | ||||
|         for support_file in self.ops_dir.glob("*.csv"): | ||||
|             self.logger.info(f"  Reading: {support_file.name}") | ||||
|             self._parse_support_file(support_file) | ||||
|  | ||||
|     def _parse_support_file(self, file_path: Path) -> None: | ||||
|         try: | ||||
|             with open(file_path, "r", newline='') as f: | ||||
|                 reader = csv.DictReader(f) | ||||
|  | ||||
|                 for row in reader: | ||||
|                     # Skip rows that don't have support mode | ||||
|                     if row.get('test_mode') != 'support': | ||||
|                         continue | ||||
|  | ||||
|                     backend_name = row.get('backend_name', '').strip() | ||||
|                     operation = row.get('op_name', '').strip() | ||||
|                     supported_str = row.get('error_message', '').strip()  # "yes" or "no" | ||||
|                     backend_reg_name = row.get('backend_reg_name', '').strip() | ||||
|  | ||||
|                     # Skip invalid or error operations | ||||
|                     if not operation or not backend_name or operation in [ | ||||
|                         "CONTEXT_ERROR", | ||||
|                         "BUILD_ERROR", | ||||
|                     ]: | ||||
|                         continue | ||||
|  | ||||
|                     is_supported = supported_str.lower() == "yes" | ||||
|  | ||||
|                     # Use backend_reg_name for grouping, fallback to backend_name | ||||
|                     backend_key = backend_reg_name if backend_reg_name else backend_name | ||||
|  | ||||
|                     self.all_backends.add(backend_key) | ||||
|                     self.backend_support[backend_key][operation].append(is_supported) | ||||
|                     self.all_operations.add(operation) | ||||
|  | ||||
|         except Exception as e: | ||||
|             self.logger.error(f"    Error parsing {file_path}: {e}") | ||||
|  | ||||
|     def get_backend_support_status(self, backend: str, operation: str) -> str: | ||||
|         support_list = self.backend_support[backend].get(operation, []) | ||||
|  | ||||
|         if not support_list: | ||||
|             return "unsupported" | ||||
|  | ||||
|         all_supported = all(support_list) | ||||
|         any_supported = any(support_list) | ||||
|  | ||||
|         if all_supported: | ||||
|             return "supported" | ||||
|         elif any_supported: | ||||
|             return "partially supported" | ||||
|         else: | ||||
|             return "unsupported" | ||||
|  | ||||
|     def get_support_status(self, operation: str) -> str: | ||||
|         if operation not in self.all_operations: | ||||
|             return "unsupported" | ||||
|  | ||||
|         support_count = 0 | ||||
|         total_backends = len(self.all_backends) | ||||
|  | ||||
|         for backend in self.all_backends: | ||||
|             if self.backend_support[backend].get(operation, False): | ||||
|                 support_count += 1 | ||||
|  | ||||
|         if support_count == 0: | ||||
|             return "unsupported" | ||||
|         elif support_count == total_backends: | ||||
|             return "supported" | ||||
|         else: | ||||
|             return "partially supported" | ||||
|  | ||||
|     def get_support_symbol(self, status: str) -> str: | ||||
|         symbols = {"supported": "✅", "partially supported": "🟡", "unsupported": "❌"} | ||||
|         return symbols.get(status, "❓") | ||||
|  | ||||
|     def generate_markdown(self) -> str: | ||||
|         lines = [] | ||||
|  | ||||
|         lines.append("# GGML Operations") | ||||
|         lines.append("") | ||||
|         lines.append("List of GGML operations and backend support status.") | ||||
|         lines.append("") | ||||
|         lines.append("Legend:") | ||||
|         lines.append("- ✅ Fully supported by this backend") | ||||
|         lines.append("- 🟡 Partially supported by this backend") | ||||
|         lines.append("- ❌ Not supported by this backend") | ||||
|         lines.append("") | ||||
|  | ||||
|         backends = sorted(self.all_backends) | ||||
|         header = "| Operation |" | ||||
|         for backend in backends: | ||||
|             header += f" {backend} |" | ||||
|  | ||||
|         separator = "|-----------|" | ||||
|         for _ in backends: | ||||
|             separator += "------|" | ||||
|  | ||||
|         lines.append(header) | ||||
|         lines.append(separator) | ||||
|  | ||||
|         sorted_operations = sorted(self.all_operations) | ||||
|  | ||||
|         for operation in sorted_operations: | ||||
|             row = f"| {operation:>32} |" | ||||
|  | ||||
|             for backend in backends: | ||||
|                 status = self.get_backend_support_status(backend, operation) | ||||
|                 if status == "supported": | ||||
|                     symbol = "✅" | ||||
|                 elif status == "partially supported": | ||||
|                     symbol = "🟡" | ||||
|                 else: | ||||
|                     symbol = "❌" | ||||
|                 row += f" {symbol} |" | ||||
|  | ||||
|             lines.append(row) | ||||
|  | ||||
|         lines.append("") | ||||
|  | ||||
|         return "\n".join(lines) | ||||
|  | ||||
|     def run(self) -> None: | ||||
|         self.logger.info("Parsing GGML operation support files...") | ||||
|         self.parse_support_files() | ||||
|  | ||||
|         if not self.all_operations: | ||||
|             self.logger.error( | ||||
|                 "No operations found. Make sure to run test-backend-ops support --output csv > docs/ops/file.csv first." | ||||
|             ) | ||||
|             return | ||||
|  | ||||
|         self.logger.info( | ||||
|             f"Found {len(self.all_operations)} operations across {len(self.all_backends)} backends" | ||||
|         ) | ||||
|  | ||||
|         self.logger.info("Generating markdown...") | ||||
|         markdown_content = self.generate_markdown() | ||||
|  | ||||
|         docs_dir = self.ggml_root / "docs" | ||||
|         docs_dir.mkdir(exist_ok=True) | ||||
|  | ||||
|         ops_file = docs_dir / self.output_filename | ||||
|         with open(ops_file, "w") as f: | ||||
|             f.write(markdown_content) | ||||
|  | ||||
|         self.logger.info(f"Generated: {ops_file}") | ||||
|         self.logger.info(f"Operations: {len(self.all_operations)}") | ||||
|         self.logger.info(f"Backends: {len(self.all_backends)}") | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     logging.basicConfig(level=logging.INFO) | ||||
|  | ||||
|     if len(sys.argv) > 1: | ||||
|         output_filename = sys.argv[1] | ||||
|     else: | ||||
|         output_filename = "ops.md" | ||||
|  | ||||
|     generator = DocsGenerator(".", output_filename) | ||||
|     generator.run() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
| @@ -317,10 +317,11 @@ enum test_mode { | ||||
|     MODE_TEST, | ||||
|     MODE_PERF, | ||||
|     MODE_GRAD, | ||||
|     MODE_SUPPORT, | ||||
| }; | ||||
|  | ||||
| // Output format support similar to llama-bench | ||||
| enum output_formats { CONSOLE, SQL }; | ||||
| enum output_formats { CONSOLE, SQL, CSV }; | ||||
|  | ||||
| static const char * output_format_str(output_formats format) { | ||||
|     switch (format) { | ||||
| @@ -328,6 +329,8 @@ static const char * output_format_str(output_formats format) { | ||||
|             return "console"; | ||||
|         case SQL: | ||||
|             return "sql"; | ||||
|         case CSV: | ||||
|             return "csv"; | ||||
|         default: | ||||
|             GGML_ABORT("invalid output format"); | ||||
|     } | ||||
| @@ -338,6 +341,8 @@ static bool output_format_from_str(const std::string & s, output_formats & forma | ||||
|         format = CONSOLE; | ||||
|     } else if (s == "sql") { | ||||
|         format = SQL; | ||||
|     } else if (s == "csv") { | ||||
|         format = CSV; | ||||
|     } else { | ||||
|         return false; | ||||
|     } | ||||
| @@ -360,6 +365,8 @@ struct test_result { | ||||
|     double      bandwidth_gb_s; | ||||
|     size_t      memory_kb; | ||||
|     int         n_runs; | ||||
|     std::string device_description; | ||||
|     std::string backend_reg_name; | ||||
|  | ||||
|     test_result() { | ||||
|         // Initialize with default values | ||||
| @@ -384,7 +391,7 @@ struct test_result { | ||||
|     test_result(const std::string & backend_name, const std::string & op_name, const std::string & op_params, | ||||
|                 const std::string & test_mode, bool supported, bool passed, const std::string & error_message = "", | ||||
|                 double time_us = 0.0, double flops = 0.0, double bandwidth_gb_s = 0.0, size_t memory_kb = 0, | ||||
|                 int n_runs = 0) : | ||||
|                 int n_runs = 0, const std::string & device_description = "", const std::string & backend_reg_name = "") : | ||||
|         backend_name(backend_name), | ||||
|         op_name(op_name), | ||||
|         op_params(op_params), | ||||
| @@ -396,7 +403,9 @@ struct test_result { | ||||
|         flops(flops), | ||||
|         bandwidth_gb_s(bandwidth_gb_s), | ||||
|         memory_kb(memory_kb), | ||||
|         n_runs(n_runs) { | ||||
|         n_runs(n_runs), | ||||
|         device_description(device_description), | ||||
|         backend_reg_name(backend_reg_name) { | ||||
|         // Set test time | ||||
|         time_t t = time(NULL); | ||||
|         char   buf[32]; | ||||
| @@ -410,7 +419,8 @@ struct test_result { | ||||
|     static const std::vector<std::string> & get_fields() { | ||||
|         static const std::vector<std::string> fields = { | ||||
|             "test_time", "build_commit",  "backend_name", "op_name", "op_params",      "test_mode", "supported", | ||||
|             "passed",    "error_message", "time_us",      "flops",   "bandwidth_gb_s", "memory_kb", "n_runs" | ||||
|             "passed",    "error_message", "time_us",      "flops",   "bandwidth_gb_s", "memory_kb", "n_runs", | ||||
|             "device_description", "backend_reg_name" | ||||
|         }; | ||||
|         return fields; | ||||
|     } | ||||
| @@ -444,7 +454,9 @@ struct test_result { | ||||
|                  std::to_string(flops), | ||||
|                  std::to_string(bandwidth_gb_s), | ||||
|                  std::to_string(memory_kb), | ||||
|                  std::to_string(n_runs) }; | ||||
|                  std::to_string(n_runs), | ||||
|                  device_description, | ||||
|                  backend_reg_name }; | ||||
|     } | ||||
| }; | ||||
|  | ||||
| @@ -633,6 +645,8 @@ struct console_printer : public printer { | ||||
|             print_test_console(result); | ||||
|         } else if (result.test_mode == "perf") { | ||||
|             print_perf_console(result); | ||||
|         } else if (result.test_mode == "support") { | ||||
|             print_support_console(result); | ||||
|         } | ||||
|     } | ||||
|  | ||||
| @@ -799,6 +813,17 @@ struct console_printer : public printer { | ||||
|         } | ||||
|         printf("\n"); | ||||
|     } | ||||
|  | ||||
|     void print_support_console(const test_result & result) { | ||||
|         printf("  %s(%s): ", result.op_name.c_str(), result.op_params.c_str()); | ||||
|         fflush(stdout); | ||||
|  | ||||
|         if (result.supported) { | ||||
|             printf("\033[1;32mSUPPORTED\033[0m\n"); | ||||
|         } else { | ||||
|             printf("\033[1;31mNOT SUPPORTED\033[0m\n"); | ||||
|         } | ||||
|     } | ||||
| }; | ||||
|  | ||||
| struct sql_printer : public printer { | ||||
| @@ -841,12 +866,39 @@ struct sql_printer : public printer { | ||||
|     } | ||||
| }; | ||||
|  | ||||
| struct csv_printer : public printer { | ||||
|     void print_header() override { | ||||
|         std::vector<std::string> fields = test_result::get_fields(); | ||||
|         for (size_t i = 0; i < fields.size(); i++) { | ||||
|             printf("\"%s\"%s", fields[i].c_str(), i < fields.size() - 1 ? "," : ""); | ||||
|         } | ||||
|         printf("\n"); | ||||
|     } | ||||
|  | ||||
|     void print_test_result(const test_result & result) override { | ||||
|         std::vector<std::string> values = result.get_values(); | ||||
|         for (size_t i = 0; i < values.size(); i++) { | ||||
|             // Escape quotes and wrap in quotes for CSV | ||||
|             std::string escaped_value = values[i]; | ||||
|             size_t pos = 0; | ||||
|             while ((pos = escaped_value.find("\"", pos)) != std::string::npos) { | ||||
|                 escaped_value.replace(pos, 1, "\"\""); | ||||
|                 pos += 2; | ||||
|             } | ||||
|             printf("\"%s\"%s", escaped_value.c_str(), i < values.size() - 1 ? "," : ""); | ||||
|         } | ||||
|         printf("\n"); | ||||
|     } | ||||
| }; | ||||
|  | ||||
| static std::unique_ptr<printer> create_printer(output_formats format) { | ||||
|     switch (format) { | ||||
|         case CONSOLE: | ||||
|             return std::make_unique<console_printer>(); | ||||
|         case SQL: | ||||
|             return std::make_unique<sql_printer>(); | ||||
|         case CSV: | ||||
|             return std::make_unique<csv_printer>(); | ||||
|     } | ||||
|     GGML_ABORT("invalid output format"); | ||||
| } | ||||
| @@ -928,7 +980,7 @@ struct test_case { | ||||
|     std::vector<ggml_tensor *> sentinels; | ||||
|  | ||||
|     void add_sentinel(ggml_context * ctx) { | ||||
|         if (mode == MODE_PERF || mode == MODE_GRAD) { | ||||
|         if (mode == MODE_PERF || mode == MODE_GRAD || mode == MODE_SUPPORT) { | ||||
|             return; | ||||
|         } | ||||
|         ggml_tensor * sentinel = ::ggml_new_tensor_1d(ctx, GGML_TYPE_F32, sentinel_size); | ||||
| @@ -1153,15 +1205,12 @@ struct test_case { | ||||
|             return true; | ||||
|         } | ||||
|  | ||||
|         // check if backends support op | ||||
|         if (!ggml_backend_supports_op(backend, out)) { | ||||
|             // Create test result for unsupported performance test | ||||
|             test_result result(ggml_backend_name(backend), current_op_name, vars(), "perf", false, false, | ||||
|                                "not supported"); | ||||
|  | ||||
|             if (output_printer) { | ||||
|                 output_printer->print_test_result(result); | ||||
|             } | ||||
|             output_printer->print_test_result(result); | ||||
|  | ||||
|             return true; | ||||
|         } | ||||
| @@ -1266,6 +1315,38 @@ struct test_case { | ||||
|         return true; | ||||
|     } | ||||
|  | ||||
|     bool eval_support(ggml_backend_t backend, const char * op_name, printer * output_printer) { | ||||
|         mode = MODE_SUPPORT; | ||||
|  | ||||
|         static const size_t graph_nodes = 8192; | ||||
|  | ||||
|         ggml_init_params params = { | ||||
|             /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false), | ||||
|             /* .mem_base = */ NULL, | ||||
|             /* .no_alloc = */ true, | ||||
|         }; | ||||
|         ggml_context_ptr ctx(ggml_init(params)); // smart ptr | ||||
|         GGML_ASSERT(ctx); | ||||
|  | ||||
|         ggml_tensor * out             = build_graph(ctx.get()); | ||||
|         std::string   current_op_name = op_desc(out); | ||||
|         if (op_name != nullptr && current_op_name != op_name) { | ||||
|             return true; | ||||
|         } | ||||
|  | ||||
|         bool supported = ggml_backend_supports_op(backend, out); | ||||
|  | ||||
|         std::string device_desc = ggml_backend_dev_description(ggml_backend_get_device(backend)); | ||||
|         std::string backend_reg_name = ggml_backend_reg_name(ggml_backend_dev_backend_reg(ggml_backend_get_device(backend))); | ||||
|  | ||||
|         test_result result(ggml_backend_name(backend), current_op_name, vars(), "support", supported, supported, | ||||
|                            supported ? "yes" : "no", 0.0, 0.0, 0.0, 0, 0, device_desc, backend_reg_name); | ||||
|  | ||||
|         output_printer->print_test_result(result); | ||||
|  | ||||
|         return true; | ||||
|     } | ||||
|  | ||||
|     bool eval_grad(ggml_backend_t backend, const char * op_name, printer * output_printer) { | ||||
|         mode = MODE_GRAD; | ||||
|         const std::vector<float> expect = grad_expect(); | ||||
| @@ -5599,17 +5680,27 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op | ||||
|         return true; | ||||
|     } | ||||
|  | ||||
|     if (mode == MODE_SUPPORT) { | ||||
|         auto test_cases = make_test_cases_eval(); | ||||
|         filter_test_cases(test_cases, params_filter); | ||||
|         for (auto & test : test_cases) { | ||||
|             test->eval_support(backend, op_name, output_printer); | ||||
|         } | ||||
|         return true; | ||||
|     } | ||||
|  | ||||
|     GGML_ABORT("fatal error"); | ||||
| } | ||||
|  | ||||
| static void usage(char ** argv) { | ||||
|     printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>] [--output <console|sql>]\n", argv[0]); | ||||
|     printf("Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n", argv[0]); | ||||
|     printf("    valid modes:\n"); | ||||
|     printf("      - test (default, compare with CPU backend for correctness)\n"); | ||||
|     printf("      - grad (compare gradients from backpropagation with method of finite differences)\n"); | ||||
|     printf("      - perf (performance evaluation)\n"); | ||||
|     printf("      - support (probe backend operation support)\n"); | ||||
|     printf("    op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n"); | ||||
|     printf("    --output specifies output format (default: console)\n"); | ||||
|     printf("    --output specifies output format (default: console, options: console, sql, csv)\n"); | ||||
| } | ||||
|  | ||||
| int main(int argc, char ** argv) { | ||||
| @@ -5626,6 +5717,8 @@ int main(int argc, char ** argv) { | ||||
|             mode = MODE_PERF; | ||||
|         } else if (strcmp(argv[i], "grad") == 0) { | ||||
|             mode = MODE_GRAD; | ||||
|         } else if (strcmp(argv[i], "support") == 0) { | ||||
|             mode = MODE_SUPPORT; | ||||
|         } else if (strcmp(argv[i], "-o") == 0) { | ||||
|             if (i + 1 < argc) { | ||||
|                 op_name_filter = argv[++i]; | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Aman Gupta
					Aman Gupta