mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	model-conversion : add extra debugging support for model conversion (#15877)
* feat: Extra debugging support for model conversion - added BF16 support for llama-callback-eval and support for dumping intermediate steps in run-org-model.py
This commit is contained in:
		 Piotr Wilkin (ilintar)
					Piotr Wilkin (ilintar)
				
			
				
					committed by
					
						 GitHub
						GitHub
					
				
			
			
				
	
			
			
			 GitHub
						GitHub
					
				
			
						parent
						
							7057faf64b
						
					
				
				
					commit
					acc1b008cf
				
			| @@ -28,6 +28,15 @@ static std::string ggml_ne_string(const ggml_tensor * t) { | |||||||
|     return str; |     return str; | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) { | ||||||
|  |     union { | ||||||
|  |         float f; | ||||||
|  |         uint32_t i; | ||||||
|  |     } u; | ||||||
|  |     u.i = (uint32_t)h.bits << 16; | ||||||
|  |     return u.f; | ||||||
|  | } | ||||||
|  |  | ||||||
| static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { | static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { | ||||||
|     size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; |     size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; | ||||||
|     float v; |     float v; | ||||||
| @@ -43,6 +52,8 @@ static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * | |||||||
|         v = (float) *(int16_t *) &data[i]; |         v = (float) *(int16_t *) &data[i]; | ||||||
|     } else if (type == GGML_TYPE_I8) { |     } else if (type == GGML_TYPE_I8) { | ||||||
|         v = (float) *(int8_t *) &data[i]; |         v = (float) *(int8_t *) &data[i]; | ||||||
|  |     } else if (type == GGML_TYPE_BF16) { | ||||||
|  |         v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]); | ||||||
|     } else { |     } else { | ||||||
|         GGML_ABORT("fatal error"); |         GGML_ABORT("fatal error"); | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -1,5 +1,6 @@ | |||||||
| --extra-index-url https://download.pytorch.org/whl/cpu | --extra-index-url https://download.pytorch.org/whl/cpu | ||||||
| torch~=2.6.0 | torch | ||||||
| torchvision~=0.21.0 | torchvision | ||||||
| transformers~=4.55.0 | transformers | ||||||
| huggingface-hub~=0.34.0 | huggingface-hub | ||||||
|  | accelerate | ||||||
|   | |||||||
| @@ -9,15 +9,134 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig | |||||||
| import torch | import torch | ||||||
| import numpy as np | import numpy as np | ||||||
|  |  | ||||||
| unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') | ### If you want to dump RoPE activations, apply this monkey patch to the model | ||||||
|  | ### class from Transformers that you are running (replace apertus.modeling_apertus | ||||||
|  | ### with the proper package and class for your model | ||||||
|  | ### === START ROPE DEBUG === | ||||||
|  | # from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb | ||||||
|  |  | ||||||
| parser = argparse.ArgumentParser(description='Process model with specified path') | # orig_rope = apply_rotary_pos_emb | ||||||
| parser.add_argument('--model-path', '-m', help='Path to the model') | # torch.set_printoptions(threshold=float('inf')) | ||||||
|  | # torch.set_printoptions(precision=6, sci_mode=False) | ||||||
|  |  | ||||||
|  | # def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): | ||||||
|  | #     # log inputs | ||||||
|  | #     summarize(q, "RoPE.q_in") | ||||||
|  | #     summarize(k, "RoPE.k_in") | ||||||
|  |  | ||||||
|  | #     # call original | ||||||
|  | #     q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim) | ||||||
|  |  | ||||||
|  | #     # log outputs | ||||||
|  | #     summarize(q_out, "RoPE.q_out") | ||||||
|  | #     summarize(k_out, "RoPE.k_out") | ||||||
|  |  | ||||||
|  | #     return q_out, k_out | ||||||
|  |  | ||||||
|  | # # Patch it | ||||||
|  | # import transformers.models.apertus.modeling_apertus as apertus_mod  # noqa: E402 | ||||||
|  | # apertus_mod.apply_rotary_pos_emb = debug_rope | ||||||
|  | ### == END ROPE DEBUG === | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3): | ||||||
|  |     """ | ||||||
|  |     Print a tensor in llama.cpp debug style. | ||||||
|  |  | ||||||
|  |     Supports: | ||||||
|  |     - 2D tensors (seq, hidden) | ||||||
|  |     - 3D tensors (batch, seq, hidden) | ||||||
|  |     - 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head | ||||||
|  |  | ||||||
|  |     Shows first and last max_vals of each vector per sequence position. | ||||||
|  |     """ | ||||||
|  |     t = tensor.detach().to(torch.float32).cpu() | ||||||
|  |  | ||||||
|  |     # Determine dimensions | ||||||
|  |     if t.ndim == 3: | ||||||
|  |         _, s, _ = t.shape | ||||||
|  |     elif t.ndim == 2: | ||||||
|  |         _, s = 1, t.shape[0] | ||||||
|  |         t = t.unsqueeze(0) | ||||||
|  |     elif t.ndim == 4: | ||||||
|  |         _, s, _, _ = t.shape | ||||||
|  |     else: | ||||||
|  |         print(f"Skipping tensor due to unsupported dimensions: {t.ndim}") | ||||||
|  |         return | ||||||
|  |  | ||||||
|  |     ten_shape = t.shape | ||||||
|  |  | ||||||
|  |     print(f"ggml_debug: {name} = (f32)  ... = {{{ten_shape}}}") | ||||||
|  |     print("                                     [") | ||||||
|  |     print("                                      [") | ||||||
|  |  | ||||||
|  |     # Determine indices for first and last sequences | ||||||
|  |     first_indices = list(range(min(s, max_seq))) | ||||||
|  |     last_indices = list(range(max(0, s - max_seq), s)) | ||||||
|  |  | ||||||
|  |     # Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq | ||||||
|  |     has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s) | ||||||
|  |  | ||||||
|  |     # Combine indices | ||||||
|  |     if has_overlap: | ||||||
|  |         # If there's overlap, just use the combined unique indices | ||||||
|  |         indices = sorted(list(set(first_indices + last_indices))) | ||||||
|  |         separator_index = None | ||||||
|  |     else: | ||||||
|  |         # If no overlap, we'll add a separator between first and last sequences | ||||||
|  |         indices = first_indices + last_indices | ||||||
|  |         separator_index = len(first_indices) | ||||||
|  |  | ||||||
|  |     for i, si in enumerate(indices): | ||||||
|  |         # Add separator if needed | ||||||
|  |         if separator_index is not None and i == separator_index: | ||||||
|  |             print("                                       ...") | ||||||
|  |  | ||||||
|  |         # Extract appropriate slice | ||||||
|  |         vec = t[0, si] | ||||||
|  |         if vec.ndim == 2:  # 4D case: flatten heads × dim_per_head | ||||||
|  |             flat = vec.flatten().tolist() | ||||||
|  |         else:  # 2D or 3D case | ||||||
|  |             flat = vec.tolist() | ||||||
|  |  | ||||||
|  |         # First and last slices | ||||||
|  |         first = flat[:max_vals] | ||||||
|  |         last = flat[-max_vals:] if len(flat) >= max_vals else flat | ||||||
|  |         first_str = ", ".join(f"{v:12.4f}" for v in first) | ||||||
|  |         last_str = ", ".join(f"{v:12.4f}" for v in last) | ||||||
|  |  | ||||||
|  |         print(f"                                       [{first_str}, ..., {last_str}]") | ||||||
|  |  | ||||||
|  |     print("                                      ],") | ||||||
|  |     print("                                     ]") | ||||||
|  |     print(f"                                     sum = {t.sum().item():.6f}\n") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def debug_hook(name): | ||||||
|  |     def fn(_m, input, output): | ||||||
|  |         if isinstance(input, torch.Tensor): | ||||||
|  |             summarize(input, name + "_in") | ||||||
|  |         elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor): | ||||||
|  |             summarize(input[0], name + "_in") | ||||||
|  |         if isinstance(output, torch.Tensor): | ||||||
|  |             summarize(output, name + "_out") | ||||||
|  |         elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor): | ||||||
|  |             summarize(output[0], name + "_out") | ||||||
|  |  | ||||||
|  |     return fn | ||||||
|  |  | ||||||
|  |  | ||||||
|  | unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME") | ||||||
|  |  | ||||||
|  | parser = argparse.ArgumentParser(description="Process model with specified path") | ||||||
|  | parser.add_argument("--model-path", "-m", help="Path to the model") | ||||||
| args = parser.parse_args() | args = parser.parse_args() | ||||||
|  |  | ||||||
| model_path = os.environ.get('MODEL_PATH', args.model_path) | model_path = os.environ.get("MODEL_PATH", args.model_path) | ||||||
| if model_path is None: | if model_path is None: | ||||||
|     parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable") |     parser.error( | ||||||
|  |         "Model path must be specified either via --model-path argument or MODEL_PATH environment variable" | ||||||
|  |     ) | ||||||
|  |  | ||||||
| config = AutoConfig.from_pretrained(model_path) | config = AutoConfig.from_pretrained(model_path) | ||||||
|  |  | ||||||
| @@ -34,18 +153,30 @@ config = AutoConfig.from_pretrained(model_path) | |||||||
|  |  | ||||||
| if unreleased_model_name: | if unreleased_model_name: | ||||||
|     model_name_lower = unreleased_model_name.lower() |     model_name_lower = unreleased_model_name.lower() | ||||||
|     unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" |     unreleased_module_path = ( | ||||||
|  |         f"transformers.models.{model_name_lower}.modular_{model_name_lower}" | ||||||
|  |     ) | ||||||
|     class_name = f"{unreleased_model_name}ForCausalLM" |     class_name = f"{unreleased_model_name}ForCausalLM" | ||||||
|     print(f"Importing unreleased model module: {unreleased_module_path}") |     print(f"Importing unreleased model module: {unreleased_module_path}") | ||||||
|  |  | ||||||
|     try: |     try: | ||||||
|         model_class = getattr(importlib.import_module(unreleased_module_path), class_name) |         model_class = getattr( | ||||||
|         model = model_class.from_pretrained(model_path)  # Note: from_pretrained, not fromPretrained |             importlib.import_module(unreleased_module_path), class_name | ||||||
|  |         ) | ||||||
|  |         model = model_class.from_pretrained( | ||||||
|  |             model_path | ||||||
|  |         )  # Note: from_pretrained, not fromPretrained | ||||||
|     except (ImportError, AttributeError) as e: |     except (ImportError, AttributeError) as e: | ||||||
|         print(f"Failed to import or load model: {e}") |         print(f"Failed to import or load model: {e}") | ||||||
|         exit(1) |         exit(1) | ||||||
| else: | else: | ||||||
|     model = AutoModelForCausalLM.from_pretrained(model_path) |     model = AutoModelForCausalLM.from_pretrained( | ||||||
|  |         model_path, device_map="auto", offload_folder="offload" | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  | for name, module in model.named_modules(): | ||||||
|  |     if len(list(module.children())) == 0:  # only leaf modules | ||||||
|  |         module.register_forward_hook(debug_hook(name)) | ||||||
|  |  | ||||||
| model_name = os.path.basename(model_path) | model_name = os.path.basename(model_path) | ||||||
| # Printing the Model class to allow for easier debugging. This can be useful | # Printing the Model class to allow for easier debugging. This can be useful | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user