|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from argparse import ArgumentParser |
|
|
import os |
|
|
from tqdm import tqdm |
|
|
|
|
|
current_dir = os.path.abspath(os.path.dirname(__file__)) |
|
|
|
|
|
from datasets import NWPUTest, Resize2Multiple |
|
|
from models import get_model |
|
|
from utils import get_config, sliding_window_predict |
|
|
|
|
|
parser = ArgumentParser(description="Generate the test result of a trained model on the NWPU-Crowd test set.") |
|
|
|
|
|
parser.add_argument("--weight_path", type=str, required=True, help="The directory to the checkpoints. This should also include the model_info.pkl file.") |
|
|
parser.add_argument("--output_filename", type=str, default="test_results", help="The name of the output file.") |
|
|
|
|
|
|
|
|
parser.add_argument("--input_size", type=int, default=224, help="The size of the input image.") |
|
|
parser.add_argument("--sliding_window", action="store_true", help="Use sliding window strategy for evaluation.") |
|
|
parser.add_argument("--max_input_size", type=int, default=4096, help="The maximum size of the input image in evaluation. Images larger than this will be processed using sliding window by force to avoid OOM.") |
|
|
parser.add_argument("--max_num_windows", type=int, default=8, help="The maximum number of windows to be simultaneously processed.") |
|
|
parser.add_argument("--resize_to_multiple", action="store_true", help="Resize the image to the nearest multiple of the input size.") |
|
|
parser.add_argument("--stride", type=int, default=None, help="The stride for sliding window strategy.") |
|
|
parser.add_argument("--amp", action="store_true", help="Use automatic mixed precision for evaluation.") |
|
|
parser.add_argument("--device", type=str, default="cuda", help="The device to use for evaluation.") |
|
|
|
|
|
|
|
|
def main(args: ArgumentParser): |
|
|
print("Testing a trained model on the NWPU-Crowd test set.") |
|
|
device = torch.device(args.device) |
|
|
_ = get_config(vars(args).copy(), mute=False) |
|
|
|
|
|
model = get_model(model_info_path=args.weight_path).to(device) |
|
|
model.eval() |
|
|
|
|
|
sliding_window = args.sliding_window |
|
|
if args.resize_to_multiple: |
|
|
transforms = Resize2Multiple(base=args.input_size) |
|
|
else: |
|
|
transforms = None |
|
|
|
|
|
dataset = NWPUTest(transforms=transforms, return_filename=True) |
|
|
|
|
|
image_ids = [] |
|
|
preds = [] |
|
|
input_size = args.input_size |
|
|
|
|
|
for idx in tqdm(range(len(dataset)), desc="Testing on NWPU"): |
|
|
image, image_path = dataset[idx] |
|
|
image = image.unsqueeze(0) |
|
|
image = image.to(device) |
|
|
image_height, image_width = image.shape[-2:] |
|
|
|
|
|
|
|
|
aspect_ratio = image_width / image_height |
|
|
if image_height < input_size: |
|
|
new_height = input_size |
|
|
new_width = int(new_height * aspect_ratio) |
|
|
image = F.interpolate(image, size=(new_height, new_width), mode="bicubic", align_corners=False) |
|
|
image_height, image_width = new_height, new_width |
|
|
if image_width < input_size: |
|
|
new_width = input_size |
|
|
new_height = int(new_width / aspect_ratio) |
|
|
image = F.interpolate(image, size=(new_height, new_width), mode="bicubic", align_corners=False) |
|
|
image_height, image_width = new_height, new_width |
|
|
|
|
|
with torch.set_grad_enabled(False), torch.autocast(device_type="cuda", enabled=args.amp): |
|
|
if sliding_window or (args.max_input_size is not None and (image_height * image_width) > args.max_input_size ** 2): |
|
|
pred_den_map = sliding_window_predict(model, image, input_size, args.stride, args.max_num_windows) |
|
|
else: |
|
|
pred_den_map = model(image) |
|
|
|
|
|
pred_count = pred_den_map.sum(dim=(1, 2, 3)).item() |
|
|
|
|
|
image_ids.append(os.path.basename(image_path).split(".")[0]) |
|
|
preds.append(pred_count) |
|
|
|
|
|
result_dir = os.path.join(current_dir, "nwpu_test_results") |
|
|
os.makedirs(result_dir, exist_ok=True) |
|
|
with open(os.path.join(result_dir, f"{args.output_filename}.txt"), "w") as f: |
|
|
for idx, (image_id, pred) in enumerate(zip(image_ids, preds)): |
|
|
if idx != len(image_ids) - 1: |
|
|
f.write(f"{image_id} {pred}\n") |
|
|
else: |
|
|
f.write(f"{image_id} {pred}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parser.parse_args() |
|
|
|
|
|
args.stride = args.stride or args.input_size |
|
|
assert os.path.exists(args.weight_path), f"Checkpoint directory {args.weight_path} does not exist." |
|
|
main(args) |
|
|
|