import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr import matplotlib.cm as cm import matplotlib.colors as mcolors MAX_VAL = 5.0 MODEL_NAME = "ibm-granite/granite-3.3-2b-instruct" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto") model.eval() if not tokenizer.chat_template: raise ValueError("Tokenizer does not have a chat template. This only works with chat models.") norm = mcolors.TwoSlopeNorm(vmin=-MAX_VAL, vcenter=0, vmax=MAX_VAL) colormap = cm.get_cmap('coolwarm') def delta_to_color(delta): rgba = colormap(norm(delta)) rgb = tuple(int(255 * c) for c in rgba[:3]) return f"rgb{rgb}" def get_response_logprobs(user_prompt, assistant_reply): chat = [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_reply}] input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=False) with torch.inference_mode(): logits = model(input_ids).logits log_probs = F.log_softmax(logits, dim=-1) base_ids = tokenizer.apply_chat_template(chat[:-1], return_tensors="pt", add_generation_prompt=True)[0] assistant_ids = input_ids[0][len(base_ids):] log_probs_list = [] for i, tok_id in enumerate(assistant_ids): prev_pos = len(base_ids) + i - 1 tok_logprob = log_probs[0, prev_pos, tok_id].item() log_probs_list.append((tok_id.item(), tok_logprob)) return log_probs_list, assistant_ids def compare_prompts(prompt1, prompt2, text): lps1, ids1 = get_response_logprobs(prompt1, text) lps2, ids2 = get_response_logprobs(prompt2, text) assert len(lps1) == len(lps2), "Token length mismatch." html_output = '
' for (tok_id1, lp1), (_, lp2) in zip(lps1, lps2): if tok_id1 in tokenizer.all_special_ids: continue delta = lp1 - lp2 token_str = tokenizer.decode([tok_id1], clean_up_tokenization_spaces=False) if not token_str.strip(): html_output += token_str continue color = delta_to_color(delta) html_output += ( f'{token_str}' ) html_output += '
' return html_output with gr.Blocks() as demo: gr.Markdown("## Chat Prompt Sensitivity") with gr.Row(): with gr.Column(): gr.HTML('
Prompt 1
') prompt1 = gr.Textbox(lines=2, placeholder='Explain flowers as an engineer') with gr.Column(): gr.HTML('
Prompt 2
') prompt2 = gr.Textbox(lines=2, placeholder='Explain flowers as a poet') gr.HTML('
Text (Assistant reply)
') text = gr.Textbox(lines=3, placeholder='Flowers are biologically optimized structures engineered by evolution for efficient reproduction through pollinator attraction and seed development.') status = gr.HTML("") output = gr.HTML(label="Text with highlighted log probability differences", show_label=True) btn = gr.Button("Compare") def wrapper(p1, p2, t): # First show loading message yield "⏳ Computing...", "" # Then compute the actual output result = compare_prompts(p1, p2, t) yield "", result btn.click( fn=wrapper, inputs=[prompt1, prompt2, text], outputs=[status, output], queue=True ) if __name__ == '__main__': demo.launch()