Seems there is a port on Kaggle…?
What “tokens are wrong but change with the image” usually means (in cross-framework ports)
When weights match layer-by-layer but generation is consistently “off,” the cause is almost always one of these:
-
Input/preprocessing mismatch (most common in VLM ports)
- Channel order / data format mismatch: HF PaliGemma expects
pixel_values shaped like (batch, channels, H, W) in its PyTorch path. (Hugging Face)
If your TF model is NHWC but you feed NCHW (or vice versa), you’ll get image-dependent outputs that are consistently wrong.
- Normalization mismatch: HF SigLIP image processor defaults to
rescale_factor=1/255, image_mean=[0.5]*3, image_std=[0.5]*3, and RGB conversion. (Hugging Face)
If you normalize like CLIP (mean/std ~ ImageNet) or omit rescaling, logits can shift enough to flip argmax early.
-
Masking / cache / position handling mismatch (very common)
- PaliGemma uses full block attention over image tokens + input text tokens. (Hugging Face)
- HF’s implementation also has special handling described as a bidirectional mask on prompt tokens (and causal behavior for generated tokens). (GitHub)
- A real HF bug showed
use_cache=True broke generation due to an attention-mask computation issue; outputs went wrong even though the model “worked.” (Hugging Face)
If your TF port has a KV cache, an off-by-one in cache_position / position_ids / mask broadcasting is enough to derail tokens.
-
Attention shape logic mismatch: Gemma uses “different Q vs K/V sizes”
- In the reference architecture,
q_proj has output 2048 while k_proj/v_proj are much smaller (e.g., 256). (Google Developers Blog)
That implies grouped-query / multi-query attention behavior. A TF port that reshapes K/V as if they had the same head count as Q (or repeats the wrong axis) will produce systematically wrong logits while weights still “match.”
-
Activation function mismatch (quiet but impactful)
- Both SigLIP and Gemma commonly use GELU tanh approximation (
gelu_pytorch_tanh / PytorchGELUTanh). (Hugging Face)
Using “exact” GELU vs tanh-approx can change logits enough to flip tokens, especially early in decoding.
-
Precision / backend differences (Mac Metal can amplify)
- There are reports of inconsistent results on Mac M4 vs NVIDIA in Keras. (GitHub)
This usually shouldn’t totally scramble tokens on its own, but it can make debugging harder if you’re already near decision boundaries.
The fastest way to debug your port (what I would do with your constraints)
The winning strategy is binary search on the forward pass, before you touch sampling/decoding.
Step 0 — Make it deterministic and remove “generation complexity”
Do this first:
-
Force greedy decoding (argmax), no temperature/top-p.
-
Run in float32 everywhere for debugging.
-
Temporarily disable KV cache (use_cache=False equivalent): recompute full forward each step.
- If outputs become correct (or much closer), the bug is in cache/mask/positions (matches the HF
use_cache failure mode). (Hugging Face)
Step 1 — Lock inputs so preprocessing can’t be the culprit
In PyTorch/HF:
-
Use the official processor and save the exact tensors you feed the model:
input_ids, attention_mask
pixel_values (as produced by the processor)
-
Then in TF, load those saved arrays and run your TF model on them.
Why: this eliminates every difference in resizing/normalization/tokenization in one move. HF’s SigLIP processor defaults are easy to miss. (Hugging Face)
If TF outputs are still wrong using HF-produced pixel_values and input_ids, preprocessing is not the problem.
Step 2 — Compare intermediate activations (layerwise “tripwires”)
You want the first layer where TF diverges from PyTorch.
Do it in this order (cheap → expensive):
-
Vision tower output (end of SigLIP)
-
Multimodal projector output
-
Text token embeddings (embedding table lookup)
-
One decoder layer at a time:
- input RMSNorm output
- Q/K/V tensors (after projection, after reshape)
- attention scores (pre-softmax)
- attention probs (post-softmax)
- attention output projection
- MLP pre-activation, post-activation, output
-
Final logits
This is exactly the “single forward pass validation / binary search” workflow TF recommends for migrations: narrow scope by checking equivalence at intermediate steps. (TensorFlow)
Practical note: for each checkpoint, compute:
- max absolute diff
- mean absolute diff
- cosine similarity (for large vectors)
and log shapes/dtypes.
Step 3 — If divergence starts inside attention, check these specific traps
Given Gemma’s Q vs KV projection sizes (Google Developers Blog), I would audit:
-
Head math
q: (B, T, n_heads, head_dim)
k/v: (B, T, n_kv_heads, head_dim)
- Then repeat/broadcast k/v across query heads (grouped-query logic).
-
Transpose conventions (TF often uses (B, heads, T, head_dim) vs (B, T, heads, head_dim))
One wrong transpose produces “valid-looking” tensors and totally wrong logits.
-
Mask application point
Mask must be added to attention scores before softmax with a large negative value.
-
RoPE / positions
In HF forward signature you’ll see cache_position and position_ids concerns. (Hugging Face)
With caching, position handling is the #1 off-by-one source.
Step 4 — Only after forward-pass matches, debug generation
Once a single forward pass matches closely, then:
If cache breaks it, compare against the known HF failure mode: use_cache=True causing wrong outputs due to attention-mask computation. (Hugging Face)
“Similar cases” and issues worth reading (directly relevant)
Cache/mask issues in PaliGemma generation
- HF issue/discussion:
use_cache=True breaks PaliGemma generation (attention mask miscomputed; outputs wrong). (Hugging Face)
This is extremely aligned with your symptoms if you have a KV cache.
Prompt formatting pitfalls (especially newline + ordering)
- Keras PaliGemma model card examples include a trailing newline in prompts (e.g.,
"caption en\n"). (Hugging Face)
- Google’s prompt-format guide highlights the ordering: image first, then prompt text.
- HF forum discussion about newline tokenization behavior (easy to misunderstand when copying prompts between implementations). (Hugging Face Forums)
Weight conversion / shape pitfalls (good for “what can go wrong”)
- HF issue about converting PaliGemma NPZ → HF hit a reshape mismatch in attention projection handling. (GitHub)
Not your exact path, but it’s another data point that projection shapes/head reshaping are a common failure point.
Mac-specific “equivalence noise”
- Keras issue reporting inconsistent results on Mac M4 vs NVIDIA. (GitHub)
For debugging, run CPU/float32 first to avoid backend-specific numeric quirks.
“Has anyone ported it to TF?”
Yes—just not necessarily as a HF-style TF transformers model:
- Google’s official “inference with Keras” guide uses
PaliGemmaCausalLM from Keras Hub. (Google AI for Developers)
- Kaggle hosts a Keras implementation of PaliGemma 2 that runs on JAX, TensorFlow, and PyTorch (Keras 3 multi-backend). (kaggle.com)
- HF also hosts Keras-formatted checkpoints (your earlier link shows a “*-keras” variant). (Hugging Face)
If your goal is “TF inference on laptop,” using the Keras Hub model as a reference oracle is valuable even if you keep your own port.
Tooling: how to debug “complex TF models” in practice
1) TensorBoard Debugger V2 (for shapes/NaNs/Inf and execution traces)
TensorFlow provides a debugger workflow via tf.debugging.experimental.enable_dump_debug_info(...) to inspect tensor health, shapes, and execution history. (TensorFlow)
This won’t directly tell you “your transpose is wrong,” but it helps catch silent numeric pathologies and gives you visibility into what ran.
2) TF migration debugging playbook (applies directly to PyTorch→TF ports)
TF’s migration debugging guide explicitly recommends:
- single forward-pass equivalence
- disabling randomness
- binary search over the model to localize divergence (TensorFlow)
That workflow maps almost perfectly to debugging a transformer port.
3) NNsight / NDIF (useful, but not the first hammer here)
NNsight/NDIF is great for activation patching and interpretability workflows. (arXiv)
For your problem (“my port is numerically wrong”), you’ll usually get faster answers from layerwise numeric equivalence tests than from interpretability tooling.
The single most likely bug class in your specific case
Given:
- weights match,
- outputs are image-dependent but wrong,
- you likely implemented KV cache + attention yourself (your repo mentions KVCache and custom Gemma bits),
…the highest-probability root cause is:
Grouped-query/multi-query attention + cache_position/mask handling (shape/broadcast/off-by-one).
The architectural clue is that Q and KV projection sizes differ (e.g., q_proj: 2048, k_proj/v_proj: 256). (Google Developers Blog)
If KV heads are mishandled, every subsequent token distribution will be “reasonable-looking” but wrong.
Second most likely:
Third:
If you want a very concrete next move
-
Disable cache and do 1-step greedy decoding.
-
Feed TF the exact pixel_values and input_ids saved from HF.
-
Compare:
- projector output
- logits at step 0
That 3-step test will tell you, with high confidence, whether the bug is:
- preprocessing/layout,
- core forward pass (attention/MLP),
- or cache/mask/positions.