mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-10-29 08:41:22 +00:00
121 lines
3.9 KiB
Python
121 lines
3.9 KiB
Python
#!/usr/bin/env python3
|
|
# generated by Claude
|
|
"""
|
|
Script to inspect SafeTensors model files and print tensor information.
|
|
"""
|
|
|
|
import json
|
|
from safetensors import safe_open
|
|
import os
|
|
from pathlib import Path
|
|
|
|
def inspect_safetensors_model(model_dir="."):
|
|
"""Inspect all SafeTensors files in the model directory."""
|
|
|
|
# First, let's read the index file to see the file structure
|
|
index_file = Path(model_dir) / "model.safetensors.index.json"
|
|
|
|
if index_file.exists():
|
|
with open(index_file, 'r') as f:
|
|
index_data = json.load(f)
|
|
|
|
print("=== Model Structure ===")
|
|
print(f"Total parameters: {index_data.get('metadata', {}).get('total_size', 'Unknown')}")
|
|
print()
|
|
|
|
# Get all safetensor files
|
|
safetensor_files = set(index_data.get('weight_map', {}).values())
|
|
else:
|
|
# If no index file, look for safetensor files directly
|
|
safetensor_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')]
|
|
|
|
# Sort files for consistent output
|
|
safetensor_files = sorted(safetensor_files)
|
|
|
|
print("=== Tensor Information ===")
|
|
print(f"{'Tensor Name':<50} {'Shape':<25} {'Data Type':<15} {'File'}")
|
|
print("-" * 110)
|
|
|
|
total_tensors = 0
|
|
|
|
for filename in safetensor_files:
|
|
filepath = Path(model_dir) / filename
|
|
if not filepath.exists():
|
|
continue
|
|
|
|
print(f"\n--- {filename} ---")
|
|
|
|
# Open and inspect the safetensor file
|
|
with safe_open(filepath, framework="pt") as f: # Use PyTorch framework for better dtype support
|
|
tensor_names = f.keys()
|
|
|
|
for tensor_name in sorted(tensor_names):
|
|
# Get tensor metadata without loading the full tensor
|
|
tensor_slice = f.get_slice(tensor_name)
|
|
shape = tensor_slice.get_shape()
|
|
dtype = tensor_slice.get_dtype()
|
|
|
|
shape_str = str(tuple(shape))
|
|
dtype_str = str(dtype)
|
|
|
|
print(f"{tensor_name:<50} {shape_str:<25} {dtype_str:<15} {filename}")
|
|
total_tensors += 1
|
|
|
|
print(f"\nTotal tensors found: {total_tensors}")
|
|
|
|
def main():
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(description="Inspect SafeTensors model files")
|
|
parser.add_argument("--model-dir", "-d", default=".",
|
|
help="Directory containing the model files (default: current directory)")
|
|
parser.add_argument("--summary", "-s", action="store_true",
|
|
help="Show only summary statistics")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.summary:
|
|
print_summary_only(args.model_dir)
|
|
else:
|
|
inspect_safetensors_model(args.model_dir)
|
|
|
|
def print_summary_only(model_dir="."):
|
|
"""Print only summary statistics."""
|
|
safetensor_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')]
|
|
|
|
total_tensors = 0
|
|
dtype_counts = {}
|
|
total_params = 0
|
|
|
|
for filename in sorted(safetensor_files):
|
|
filepath = Path(model_dir) / filename
|
|
if not filepath.exists():
|
|
continue
|
|
|
|
with safe_open(filepath, framework="pt") as f: # Use PyTorch framework
|
|
for tensor_name in f.keys():
|
|
tensor_slice = f.get_slice(tensor_name)
|
|
shape = tensor_slice.get_shape()
|
|
dtype = tensor_slice.get_dtype()
|
|
|
|
total_tensors += 1
|
|
|
|
dtype_str = str(dtype)
|
|
dtype_counts[dtype_str] = dtype_counts.get(dtype_str, 0) + 1
|
|
|
|
# Calculate parameter count
|
|
param_count = 1
|
|
for dim in shape:
|
|
param_count *= dim
|
|
total_params += param_count
|
|
|
|
print("=== Model Summary ===")
|
|
print(f"Total tensors: {total_tensors}")
|
|
print(f"Total parameters: {total_params:,}")
|
|
print(f"Data type distribution:")
|
|
for dtype, count in sorted(dtype_counts.items()):
|
|
print(f" {dtype}: {count} tensors")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|