import torch import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr import matplotlib.cm as cm import matplotlib.colors as mcolors MAX_VAL = 5.0 MODEL_NAME = "ibm-granite/granite-3.3-2b-instruct" tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto") model.eval() if not tokenizer.chat_template: raise ValueError("Tokenizer does not have a chat template. This only works with chat models.") norm = mcolors.TwoSlopeNorm(vmin=-MAX_VAL, vcenter=0, vmax=MAX_VAL) colormap = cm.get_cmap('coolwarm') def delta_to_color(delta): rgba = colormap(norm(delta)) rgb = tuple(int(255 * c) for c in rgba[:3]) return f"rgb{rgb}" def get_response_logprobs(user_prompt, assistant_reply): chat = [{"role": "user", "content": user_prompt}, {"role": "assistant", "content": assistant_reply}] input_ids = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=False) with torch.inference_mode(): logits = model(input_ids).logits log_probs = F.log_softmax(logits, dim=-1) base_ids = tokenizer.apply_chat_template(chat[:-1], return_tensors="pt", add_generation_prompt=True)[0] assistant_ids = input_ids[0][len(base_ids):] log_probs_list = [] for i, tok_id in enumerate(assistant_ids): prev_pos = len(base_ids) + i - 1 tok_logprob = log_probs[0, prev_pos, tok_id].item() log_probs_list.append((tok_id.item(), tok_logprob)) return log_probs_list, assistant_ids def compare_prompts(prompt1, prompt2, text): lps1, ids1 = get_response_logprobs(prompt1, text) lps2, ids2 = get_response_logprobs(prompt2, text) assert len(lps1) == len(lps2), "Token length mismatch." html_output = '