ZIP-B / losses /utils.py
Yiming-M's picture
2025-07-31 15:53 🐣
0ecb9aa verified
import torch
from torch import Tensor
from typing import List, Tuple
def _reshape_density(density: Tensor, block_size: int) -> Tensor:
assert len(density.shape) == 4, f"Expected 4D (B, 1, H, W) tensor, got {density.shape}"
assert density.shape[1] == 1, f"Expected 1 channel, got {density.shape[1]}"
assert density.shape[2] % block_size == 0, f"Expected height to be divisible by {block_size}, got {density.shape[2]}"
assert density.shape[3] % block_size == 0, f"Expected width to be divisible by {block_size}, got {density.shape[3]}"
return density.reshape(density.shape[0], 1, density.shape[2] // block_size, block_size, density.shape[3] // block_size, block_size).sum(dim=(-1, -3))
def _bin_count(density_map: Tensor, bins: List[Tuple[int, int]]) -> Tensor:
class_map = torch.zeros_like(density_map, dtype=torch.long)
for idx, (low, high) in enumerate(bins):
mask = (density_map >= low) & (density_map <= high)
class_map[mask] = idx
return class_map.squeeze(1) # remove channel dimension