mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			89 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			89 lines
		
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
#!/usr/bin/env python3
 | 
						|
 | 
						|
import numpy as np
 | 
						|
import sys
 | 
						|
import os
 | 
						|
from pathlib import Path
 | 
						|
 | 
						|
def quick_logits_check(pytorch_file, llamacpp_file):
 | 
						|
    """Lightweight sanity check before NMSE"""
 | 
						|
 | 
						|
    try:
 | 
						|
        pytorch_logits = np.fromfile(pytorch_file, dtype=np.float32)
 | 
						|
        llamacpp_logits = np.fromfile(llamacpp_file, dtype=np.float32)
 | 
						|
    except Exception as e:
 | 
						|
        print(f"❌ NOK: Failed to load files - {e}")
 | 
						|
        return False
 | 
						|
 | 
						|
    # Check shapes match
 | 
						|
    if pytorch_logits.shape != llamacpp_logits.shape:
 | 
						|
        print(f"❌ NOK: Shape mismatch - PyTorch: {pytorch_logits.shape}, llama.cpp: {llamacpp_logits.shape}")
 | 
						|
        return False
 | 
						|
 | 
						|
    # Calculate key metrics
 | 
						|
    diff = pytorch_logits - llamacpp_logits
 | 
						|
    abs_diff = np.abs(diff)
 | 
						|
    max_diff = np.max(abs_diff)
 | 
						|
 | 
						|
    # Get top 10 predictions from both models
 | 
						|
    pytorch_top10 = np.argsort(pytorch_logits)[-10:][::-1]
 | 
						|
    llamacpp_top10 = np.argsort(llamacpp_logits)[-10:][::-1]
 | 
						|
    print(f"Top 10 PyTorch logits: {pytorch_logits[pytorch_top10]}")
 | 
						|
    print(f"Top 10 llama.cpp logits: {llamacpp_logits[llamacpp_top10]}")
 | 
						|
    print(f"Max absolute difference: {max_diff:.4f}")
 | 
						|
 | 
						|
    if max_diff > 1.0:
 | 
						|
        print(f"❌ NOK: Large differences detected - max diff: {max_diff:.4f}")
 | 
						|
        return False
 | 
						|
 | 
						|
    return True
 | 
						|
 | 
						|
def main():
 | 
						|
    model_path = os.getenv('MODEL_PATH')
 | 
						|
    if not model_path:
 | 
						|
        print("Error: MODEL_PATH environment variable not set")
 | 
						|
        sys.exit(1)
 | 
						|
 | 
						|
    if not os.path.exists(model_path):
 | 
						|
        print(f"Error: Model file not found: {model_path}")
 | 
						|
        sys.exit(1)
 | 
						|
 | 
						|
    model_name = os.path.basename(model_path)
 | 
						|
    data_dir = Path("data")
 | 
						|
 | 
						|
    pytorch_file = data_dir / f"pytorch-{model_name}.bin"
 | 
						|
    llamacpp_file = data_dir / f"llamacpp-{model_name}.bin"
 | 
						|
 | 
						|
    if not pytorch_file.exists():
 | 
						|
        print(f"Error: PyTorch logits file not found: {pytorch_file}")
 | 
						|
        print("Please run scripts/run-org-model.sh first to generate this file.")
 | 
						|
        sys.exit(1)
 | 
						|
 | 
						|
    if not llamacpp_file.exists():
 | 
						|
        print(f"Error: llama.cpp logits file not found: {llamacpp_file}")
 | 
						|
        print("Please run scripts/run-converted-model.sh first to generate this file.")
 | 
						|
        sys.exit(1)
 | 
						|
 | 
						|
    print("Checked all required files were found. Proceeding...\n")
 | 
						|
 | 
						|
 | 
						|
    print("🔍 GGML Model Validation for model ", model_name)
 | 
						|
    print("=" * 40)
 | 
						|
    print(f"PyTorch logits  : {pytorch_file}")
 | 
						|
    print(f"llama.cpp logits: {llamacpp_file}")
 | 
						|
    print()
 | 
						|
 | 
						|
    success = quick_logits_check(pytorch_file, llamacpp_file)
 | 
						|
 | 
						|
    # Exit with appropriate code
 | 
						|
    if success:
 | 
						|
        print("✅ OK: Lightweight model check successful!")
 | 
						|
        print("       Ok to proceed with NMSE check...")
 | 
						|
        sys.exit(0)
 | 
						|
    else:
 | 
						|
        print(f"❌ NOK: Top 10 predictions don't match - generation will differ")
 | 
						|
        sys.exit(1)
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    main()
 |