import io import gc import os import json import struct import torch import pandas as pd import seaborn as sns import matplotlib.pyplot as plt import plotly.graph_objects as go import gradio as gr import PIL.Image from transformers import AutoModelForCausalLM, AutoConfig from huggingface_hub import hf_hub_download, hf_hub_url, snapshot_download from huggingface_hub.utils import build_hf_headers from safetensors import safe_open import requests # Set style for matplotlib sns.set_theme(style="whitegrid") # Cache for metadata only _metadata_cache = {} def calculate_weight_diff(base_weight, chat_weight): """Calculates the mean absolute difference between two tensors.""" b_w = base_weight.float() c_w = chat_weight.float() result = torch.abs(b_w - c_w).mean().item() del b_w, c_w return result def get_safetensor_index(repo_id, token=None): """Download and parse the safetensors index.""" cache_key = f"{repo_id}_index" if cache_key in _metadata_cache: return _metadata_cache[cache_key] try: index_path = hf_hub_download(repo_id, "model.safetensors.index.json", token=token) with open(index_path, 'r') as f: index_data = json.load(f) weight_map = index_data.get("weight_map", {}) _metadata_cache[cache_key] = weight_map return weight_map except Exception: _metadata_cache[cache_key] = None return None # ============================================================================= # STREAMING MODE (Ultra Low Memory - No disk usage) # ============================================================================= def get_safetensor_header(repo_id, filename, token=None): """Fetch only the header of a safetensor file using HTTP range request.""" cache_key = f"{repo_id}_{filename}_header" if cache_key in _metadata_cache: return _metadata_cache[cache_key] url = hf_hub_url(repo_id, filename) headers = build_hf_headers(token=token) # First, get the header size (first 8 bytes) headers["Range"] = "bytes=0-7" response = requests.get(url, headers=headers) response.raise_for_status() header_size = struct.unpack(' 6: nrows = 2 ncols = (num_components + 1) // 2 else: nrows = 1 ncols = num_components fig, axs = plt.subplots(nrows, ncols, figsize=(width, height * (1.2 if nrows > 1 else 1))) axs = axs.flatten() if num_components > 1 else [axs] fig.suptitle(f"Weight Differences: {base_model_name} vs {chat_model_name}", fontsize=16, y=0.98) tick_font_size = max(6, min(10, 300 / num_layers)) for i, component in enumerate(components): data = [[row[component]] for row in layer_diffs] sns.heatmap(data, annot=True, fmt=".6f", cmap="viridis", ax=axs[i], cbar=False, annot_kws={'size': tick_font_size * 0.8}) axs[i].set_title(component, fontsize=12, fontweight='bold') axs[i].set_yticks(range(num_layers)) axs[i].set_yticklabels(range(num_layers), fontsize=tick_font_size) axs[i].set_xticks([]) axs[i].invert_yaxis() for j in range(i + 1, len(axs)): fig.delaxes(axs[j]) plt.tight_layout(rect=[0, 0, 1, 0.96]) buf = io.BytesIO() fig.savefig(buf, format='png', dpi=150, bbox_inches='tight') buf.seek(0) plt.close(fig) return PIL.Image.open(buf) def generate_3d_plot(layer_diffs): """Generates an interactive 3D Surface plot as a Plotly Figure.""" if not layer_diffs: return None df = pd.DataFrame(layer_diffs) x_labels = df.columns.tolist() y_labels = df.index.tolist() z_data = df.values fig = go.Figure(data=[go.Surface(z=z_data, x=x_labels, y=y_labels, colorscale='Viridis')]) fig.update_layout( title='3D Landscape of Weight Differences', scene=dict( xaxis_title='Model Components', yaxis_title='Layer Index', zaxis_title='Mean Weight Diff', xaxis=dict(tickangle=45), ), autosize=True, height=700, margin=dict(l=65, r=50, b=65, t=90) ) return fig # ============================================================================= # MAIN PROCESSING # ============================================================================= def process_models(base_name, chat_name, hf_token, memory_mode, progress=gr.Progress()): if not base_name or not chat_name: raise gr.Error("Please provide both model names.") token = hf_token if hf_token else None try: if memory_mode == "streaming": # Streaming mode - ultra low memory, no disk progress(0, desc="Starting streaming mode (ultra low memory)...") diffs = calculate_layer_diffs_streaming( base_name, chat_name, token=token, progress=progress ) elif memory_mode == "disk_cache": # Disk cache mode - downloads to disk, loads tensors one at a time progress(0, desc="Starting disk cache mode...") diffs = calculate_layer_diffs_disk_cache( base_name, chat_name, token=token, progress=progress ) else: # Standard mode - full models in memory progress(0, desc=f"Loading {base_name}...") print(f"Loading {base_name}...") base_model = AutoModelForCausalLM.from_pretrained( base_name, torch_dtype=torch.bfloat16, token=token, trust_remote_code=True, low_cpu_mem_usage=True ) progress(0.3, desc=f"Loading {chat_name}...") print(f"Loading {chat_name}...") chat_model = AutoModelForCausalLM.from_pretrained( chat_name, torch_dtype=torch.bfloat16, token=token, trust_remote_code=True, low_cpu_mem_usage=True ) progress(0.5, desc="Calculating differences...") diffs = calculate_layer_diffs_standard(base_model, chat_model, progress=None) del base_model del chat_model gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() progress(0.9, desc="Generating visualizations...") img_2d = visualize_2d_heatmap(diffs, base_name, chat_name) plot_3d = generate_3d_plot(diffs) progress(1.0, desc="Complete!") return img_2d, plot_3d except Exception as e: import traceback traceback.print_exc() raise gr.Error(f"Error processing models: {str(e)}") # ============================================================================= # GRADIO UI # ============================================================================= with gr.Blocks(title="Model Diff Visualizer") as demo: gr.Markdown("# 🧠 LLM Weight Difference Visualizer") gr.Markdown("Compare the weights of a Base model vs. its Instruct/Chat tuned version layer by layer.") with gr.Row(): with gr.Column(scale=1): base_input = gr.Textbox( label="Base Model Name", placeholder="e.g., meta-llama/Llama-3.3-70B-Instruct" ) chat_input = gr.Textbox( label="Chat/Tuned Model Name", placeholder="e.g., CrucibleLab/L3.3-70B-Loki-V2.0" ) token_input = gr.Textbox( label="Hugging Face Token (Optional)", type="password", placeholder="hf_..." ) memory_mode = gr.Radio( label="Memory Mode", choices=[ ("🚀 Standard (Fast, High RAM)", "standard"), ("💾 Disk Cache (Medium Speed, Low RAM, Uses Disk)", "disk_cache"), ("🐢 Streaming (Slow, Ultra Low RAM, No Disk)", "streaming"), ], value="standard", info="Choose based on your available RAM and disk space" ) with gr.Accordion("Memory Mode Details", open=False): gr.Markdown(""" ### 🚀 Standard Mode - **RAM Usage:** ~2x model size (e.g., ~280GB for 70B models) - **Disk Usage:** HuggingFace cache only - **Speed:** Fastest - **Best for:** Machines with lots of RAM ### 💾 Disk Cache Mode - **RAM Usage:** ~2-4GB (only one tensor at a time) - **Disk Usage:** ~2x model size (downloads full safetensors) - **Speed:** Medium (disk I/O bound) - **Best for:** Machines with limited RAM but plenty of disk space ### 🐢 Streaming Mode - **RAM Usage:** ~1-2GB (streams bytes directly) - **Disk Usage:** Minimal (only metadata cached) - **Speed:** Slowest (many HTTP requests) - **Best for:** Very constrained environments, or when disk space is also limited """) submit_btn = gr.Button("🚀 Analyze Differences", variant="primary") with gr.Row(): with gr.Column(): gr.Markdown("### 2D Layer-wise Heatmap") output_2d = gr.Image(label="2D Visualization", type="pil") with gr.Row(): with gr.Column(): gr.Markdown("### 3D Interactive Landscape") output_3d = gr.Plot(label="3D Visualization") submit_btn.click( fn=process_models, inputs=[base_input, chat_input, token_input, memory_mode], outputs=[output_2d, output_3d] ) if __name__ == "__main__": demo.launch(share=False, server_port=7860)