|
|
import torch |
|
|
from torch import Tensor |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision.transforms import ToTensor, Normalize, Compose |
|
|
import os |
|
|
from glob import glob |
|
|
from tqdm import tqdm |
|
|
|
|
|
from turbojpeg import TurboJPEG, TJPF_RGB |
|
|
jpeg_decoder = TurboJPEG() |
|
|
|
|
|
import numpy as np |
|
|
from typing import Optional, Callable, Union, Tuple |
|
|
|
|
|
from .utils import get_id, generate_density_map |
|
|
|
|
|
curr_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
available_datasets = [ |
|
|
"shanghaitech_a", "sha", |
|
|
"shanghaitech_b", "shb", |
|
|
"shanghaitech", "sh", |
|
|
"ucf_qnrf", "qnrf", "ucf-qnrf", |
|
|
"nwpu", "nwpu_crowd", "nwpu-crowd", |
|
|
] |
|
|
|
|
|
mean = (0.48145466, 0.4578275, 0.40821073) |
|
|
std = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
|
|
|
|
|
def standardize_dataset_name(dataset: str) -> str: |
|
|
assert dataset.lower() in available_datasets, f"Dataset {dataset} is not available." |
|
|
if dataset.lower() in ["shanghaitech_a", "sha"]: |
|
|
return "sha" |
|
|
elif dataset.lower() in ["shanghaitech_b", "shb"]: |
|
|
return "shb" |
|
|
elif dataset.lower() in ["shanghaitech", "sh"]: |
|
|
return "sh" |
|
|
elif dataset.lower() in ["ucf_qnrf", "qnrf", "ucf-qnrf"]: |
|
|
return "qnrf" |
|
|
else: |
|
|
assert dataset.lower() in ["nwpu", "nwpu_crowd", "nwpu-crowd"], f"Dataset {dataset} is not available." |
|
|
return "nwpu" |
|
|
|
|
|
|
|
|
class Crowd(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
dataset: str, |
|
|
split: str, |
|
|
transforms: Optional[Callable] = None, |
|
|
sigma: Optional[float] = None, |
|
|
return_filename: bool = False, |
|
|
num_crops: int = 1, |
|
|
) -> None: |
|
|
""" |
|
|
Dataset for crowd counting. |
|
|
""" |
|
|
assert dataset.lower() in available_datasets, f"Dataset {dataset} is not available." |
|
|
assert dataset.lower() not in ["shanghaitech", "sh"], "For the combined ShanghaiTech dataset, use ShanghaiTech class." |
|
|
assert split in ["train", "val", "test"], f"Split {split} is not available." |
|
|
assert num_crops > 0, f"num_crops should be positive, got {num_crops}." |
|
|
|
|
|
self.dataset = standardize_dataset_name(dataset) |
|
|
self.split = split |
|
|
|
|
|
self.__find_root__() |
|
|
self.__make_dataset__() |
|
|
self.__check_sanity__() |
|
|
|
|
|
self.to_tensor = ToTensor() |
|
|
self.normalize = Normalize(mean=mean, std=std) |
|
|
self.transforms = transforms |
|
|
|
|
|
self.sigma = sigma |
|
|
self.return_filename = return_filename |
|
|
self.num_crops = num_crops |
|
|
|
|
|
def __find_root__(self) -> None: |
|
|
self.root = os.path.join(curr_dir, "..", "data", self.dataset) |
|
|
|
|
|
def __make_dataset__(self) -> None: |
|
|
image_names = glob(os.path.join(self.root, self.split, "images", "*.jpg")) |
|
|
|
|
|
label_names = glob(os.path.join(self.root, self.split, "labels", "*.npy")) |
|
|
image_names = [os.path.basename(image_name) for image_name in image_names] |
|
|
label_names = [os.path.basename(label_name) for label_name in label_names] |
|
|
image_names.sort(key=get_id) |
|
|
label_names.sort(key=get_id) |
|
|
image_ids = tuple([get_id(image_name) for image_name in image_names]) |
|
|
label_ids = tuple([get_id(label_name) for label_name in label_names]) |
|
|
assert image_ids == label_ids, "image_ids and label_ids do not match." |
|
|
self.image_names = tuple(image_names) |
|
|
self.label_names = tuple(label_names) |
|
|
|
|
|
def __check_sanity__(self) -> None: |
|
|
if self.dataset == "sha": |
|
|
if self.split == "train": |
|
|
assert len(self.image_names) == len(self.label_names) == 300, f"ShanghaiTech_A train split should have 300 images, but found {len(self.image_names)}." |
|
|
else: |
|
|
assert self.split == "val", f"Split {self.split} is not available for dataset {self.dataset}." |
|
|
assert len(self.image_names) == len(self.label_names) == 182, f"ShanghaiTech_A val split should have 182 images, but found {len(self.image_names)}." |
|
|
elif self.dataset == "shb": |
|
|
if self.split == "train": |
|
|
assert len(self.image_names) == len(self.label_names) == 399, f"ShanghaiTech_B train split should have 399 images, but found {len(self.image_names)}." |
|
|
else: |
|
|
assert self.split == "val", f"Split {self.split} is not available for dataset {self.dataset}." |
|
|
assert len(self.image_names) == len(self.label_names) == 316, f"ShanghaiTech_B val split should have 316 images, but found {len(self.image_names)}." |
|
|
elif self.dataset == "nwpu": |
|
|
if self.split == "train": |
|
|
assert len(self.image_names) == len(self.label_names) == 3109, f"NWPU train split should have 3109 images, but found {len(self.image_names)}." |
|
|
else: |
|
|
assert self.split == "val", f"Split {self.split} is not available for dataset {self.dataset}." |
|
|
assert len(self.image_names) == len(self.label_names) == 500, f"NWPU val split should have 500 images, but found {len(self.image_names)}." |
|
|
elif self.dataset == "qnrf": |
|
|
if self.split == "train": |
|
|
assert len(self.image_names) == len(self.label_names) == 1201, f"UCF_QNRF train split should have 1201 images, but found {len(self.image_names)}." |
|
|
else: |
|
|
assert self.split == "val", f"Split {self.split} is not available for dataset {self.dataset}." |
|
|
assert len(self.image_names) == len(self.label_names) == 334, f"UCF_QNRF val split should have 334 images, but found {len(self.image_names)}." |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.image_names) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, str]]: |
|
|
image_name = self.image_names[idx] |
|
|
label_name = self.label_names[idx] |
|
|
|
|
|
image_path = os.path.join(self.root, self.split, "images", image_name) |
|
|
label_path = os.path.join(self.root, self.split, "labels", label_name) |
|
|
|
|
|
with open(image_path, "rb") as f: |
|
|
|
|
|
image = jpeg_decoder.decode(f.read(), pixel_format=TJPF_RGB) |
|
|
image = self.to_tensor(image) |
|
|
|
|
|
with open(label_path, "rb") as f: |
|
|
label = np.load(f) |
|
|
label = torch.from_numpy(label).float() |
|
|
|
|
|
if self.transforms is not None: |
|
|
images_labels = [self.transforms(image.clone(), label.clone()) for _ in range(self.num_crops)] |
|
|
images, labels = zip(*images_labels) |
|
|
else: |
|
|
images = [image.clone() for _ in range(self.num_crops)] |
|
|
labels = [label.clone() for _ in range(self.num_crops)] |
|
|
|
|
|
images = [self.normalize(img) for img in images] |
|
|
density_maps = torch.stack([generate_density_map(label, image.shape[-2], image.shape[-1], sigma=self.sigma) for image, label in zip(images, labels)], 0) |
|
|
image_names = [image_name] * len(images) |
|
|
images = torch.stack(images, 0) |
|
|
|
|
|
if self.return_filename: |
|
|
return images, labels, density_maps, image_names |
|
|
else: |
|
|
return images, labels, density_maps |
|
|
|
|
|
|
|
|
class InMemoryCrowd(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
dataset: str, |
|
|
split: str, |
|
|
transforms: Optional[Callable] = None, |
|
|
sigma: Optional[float] = None, |
|
|
return_filename: bool = False, |
|
|
num_crops: int = 1, |
|
|
) -> None: |
|
|
""" |
|
|
Dataset for crowd counting, with images and labels loaded into memory. |
|
|
""" |
|
|
crowd = Crowd( |
|
|
dataset=dataset, |
|
|
split=split, |
|
|
transforms=None, |
|
|
sigma=sigma, |
|
|
return_filename=True, |
|
|
num_crops=1, |
|
|
) |
|
|
print(f"Loading {len(crowd)} samples from {dataset} {split} split into memory...") |
|
|
self.images, self.labels, self.image_names = [], [], [] |
|
|
self.unnormalize = Compose([ |
|
|
Normalize(mean=(0., 0., 0.), std=(1./std[0], 1./std[1], 1./std[2]), inplace=True), |
|
|
Normalize(mean=(-mean[0], -mean[1], -mean[2]), std=(1., 1., 1.), inplace=True) |
|
|
]) |
|
|
|
|
|
for i in tqdm(range(len(crowd)), desc="Loading images and labels into memory"): |
|
|
image, label, _, image_name = crowd[i] |
|
|
self.images.append(self.unnormalize(image[0])) |
|
|
self.labels.append(label[0]) |
|
|
self.image_names.append(image_name[0]) |
|
|
|
|
|
assert len(self.images) == len(self.labels) == len(self.image_names), "Mismatch in number of images, labels, and image names." |
|
|
|
|
|
self.transforms = transforms |
|
|
self.sigma = sigma |
|
|
self.num_crops = num_crops |
|
|
self.return_filename = return_filename |
|
|
self.normalize = Normalize(mean=mean, std=std, inplace=False) |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.images) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, str]]: |
|
|
image, label, image_name = self.images[idx].clone(), self.labels[idx].clone(), self.image_names[idx] |
|
|
|
|
|
if self.transforms is not None: |
|
|
images_labels = [self.transforms(image.clone(), label.clone()) for _ in range(self.num_crops)] |
|
|
images, labels = zip(*images_labels) |
|
|
else: |
|
|
images = [image.clone() for _ in range(self.num_crops)] |
|
|
labels = [label.clone() for _ in range(self.num_crops)] |
|
|
|
|
|
images = [self.normalize(img) for img in images] |
|
|
density_maps = torch.stack([generate_density_map(label, image.shape[-2], image.shape[-1], sigma=self.sigma) for image, label in zip(images, labels)], 0) |
|
|
image_names = [image_name] * len(images) |
|
|
images = torch.stack(images, 0) |
|
|
|
|
|
if self.return_filename: |
|
|
return images, labels, density_maps, image_names |
|
|
else: |
|
|
return images, labels, density_maps |
|
|
|
|
|
|
|
|
class NWPUTest(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
transforms: Optional[Callable] = None, |
|
|
return_filename: bool = False, |
|
|
) -> None: |
|
|
""" |
|
|
The test set of NWPU-Crowd dataset. The test set is not labeled, so only images are returned. |
|
|
""" |
|
|
self.root = os.path.join(curr_dir, "..", "data", "nwpu") |
|
|
image_names = glob(os.path.join(self.root, "test", "images", "*.jpg")) |
|
|
|
|
|
image_names = [os.path.basename(image_name) for image_name in image_names] |
|
|
assert len(image_names) == 1500, f"NWPU test split should have 1500 images, but found {len(image_names)}." |
|
|
image_names.sort(key=get_id) |
|
|
self.image_names = tuple(image_names) |
|
|
|
|
|
self.to_tensor = ToTensor() |
|
|
self.normalize = Normalize(mean=mean, std=std) |
|
|
self.transforms = transforms |
|
|
self.return_filename = return_filename |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.image_names) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Union[Tensor, Tuple[Tensor, str]]: |
|
|
image_name = self.image_names[idx] |
|
|
image_path = os.path.join(self.root, "test", "images", image_name) |
|
|
|
|
|
with open(image_path, "rb") as f: |
|
|
|
|
|
image = jpeg_decoder.decode(f.read(), pixel_format=TJPF_RGB) |
|
|
image = self.to_tensor(image) |
|
|
|
|
|
label = torch.tensor([], dtype=torch.float) |
|
|
image, _ = self.transforms(image, label) if self.transforms is not None else (image, label) |
|
|
image = self.normalize(image) |
|
|
|
|
|
if self.return_filename: |
|
|
return image, image_name |
|
|
else: |
|
|
return image |
|
|
|
|
|
|
|
|
class ShanghaiTech(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
split: str, |
|
|
transforms: Optional[Callable] = None, |
|
|
sigma: Optional[float] = None, |
|
|
return_filename: bool = False, |
|
|
num_crops: int = 1, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.sha = Crowd( |
|
|
dataset="sha", |
|
|
split=split, |
|
|
transforms=transforms, |
|
|
sigma=sigma, |
|
|
return_filename=return_filename, |
|
|
num_crops=num_crops, |
|
|
) |
|
|
self.shb = Crowd( |
|
|
dataset="shb", |
|
|
split=split, |
|
|
transforms=transforms, |
|
|
sigma=sigma, |
|
|
return_filename=return_filename, |
|
|
num_crops=num_crops, |
|
|
) |
|
|
self.dataset = "sh" |
|
|
self.split = split |
|
|
self.transforms = transforms |
|
|
self.sigma = sigma |
|
|
self.return_filename = return_filename |
|
|
self.num_crops = num_crops |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.sha) + len(self.shb) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, str]]: |
|
|
if idx < len(self.sha): |
|
|
return self.sha[idx] |
|
|
else: |
|
|
return self.shb[idx - len(self.sha)] |
|
|
|