File size: 4,585 Bytes
0ecb9aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
from torch import nn
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from argparse import ArgumentParser
import os

current_dir = os.path.abspath(os.path.dirname(__file__))

from datasets import standardize_dataset_name
from models import get_model
from utils import get_config, get_dataloader, setup, cleanup
from evaluate import evaluate


parser = ArgumentParser(description="Test a trained model on a dataset.")
# Parameters for model
parser.add_argument("--weight_path", type=str, required=True, help="The name of the weight to use.")
parser.add_argument("--output_filename", type=str, default=None, help="The name of the result file.")

# Parameters for evaluation
parser.add_argument("--dataset", type=str, required=True, help="The dataset to evaluate on.")
parser.add_argument("--split", type=str, default="val", choices=["val", "test"], help="The split to evaluate on.")
parser.add_argument("--input_size", type=int, default=224, help="The size of the input image.")
parser.add_argument("--sliding_window", action="store_true", help="Use sliding window strategy for evaluation.")
parser.add_argument("--max_input_size", type=int, default=4096, help="The maximum size of the input image in evaluation. Images larger than this will be processed using sliding window by force to avoid OOM.")
parser.add_argument("--max_num_windows", type=int, default=8, help="The maximum number of windows to be simultaneously processed.")
parser.add_argument("--resize_to_multiple", action="store_true", help="Resize the image to the nearest multiple of the input size.")
parser.add_argument("--stride", type=int, default=None, help="The stride for sliding window strategy.")
parser.add_argument("--amp", action="store_true", help="Use automatic mixed precision for evaluation.")
parser.add_argument("--device", type=str, default="cuda", help="The device to use for evaluation.")
parser.add_argument("--num_workers", type=int, default=8, help="The number of workers for the data loader.")
parser.add_argument("--local_rank", type=int, default=-1, help="The local rank for distributed training.")


def run(local_rank: int, nprocs: int, args: ArgumentParser):
    print(f"Rank {local_rank} process among {nprocs} processes.")
    setup(local_rank, nprocs)
    print(f"Initialized successfully. Training with {nprocs} GPUs.")
    device = f"cuda:{local_rank}" if local_rank != -1 else "cuda:0"
    print(f"Using device: {device}.")

    ddp = nprocs > 1
    _ = get_config(vars(args).copy(), mute=False)

    model = get_model(model_info_path=args.weight_path).to(device)
    model = DDP(nn.SyncBatchNorm.convert_sync_batchnorm(model), device_ids=[local_rank], output_device=local_rank) if ddp else model
    model = model.to(device)
    model.eval()

    args.output_filename = f"{model.model_name}_{args.weight_path.split('/')[-1].split('.')[0]}" if args.output_filename is None else args.output_filename

    dataloader = get_dataloader(args, split=args.split)
    scores = evaluate(
        model=model,
        data_loader=dataloader,
        sliding_window=args.sliding_window,
        max_input_size=args.max_input_size,
        window_size=args.input_size,
        stride=args.stride,
        max_num_windows=args.max_num_windows,
        amp=args.amp,
        local_rank=local_rank,
        nprocs=nprocs,
    )

    if local_rank == 0:
        for k, v in scores.items():
            print(f"{k}: {v}")

        result_dir = os.path.join(current_dir, "results", args.dataset, args.split)
        os.makedirs(result_dir, exist_ok=True)
        with open(os.path.join(result_dir, f"{args.output_filename}.txt"), "w") as f:
            for k, v in scores.items():
                f.write(f"{k}: {v}\n")
    
    cleanup(ddp)


if __name__ == "__main__":
    args = parser.parse_args()
    args.dataset = standardize_dataset_name(args.dataset)

    if args.dataset in ["sha", "shb", "qnrf", "nwpu"]:
        assert args.split == "val", f"Split {args.split} is not available for dataset {args.dataset}."

    # Sliding window prediction will be used if args.sliding_window is True, or when the image size is larger than args.max_input_size
    args.stride = args.stride or args.input_size
    assert os.path.exists(args.weight_path), f"Weight path {args.weight_path} does not exist."
    args.in_memory_dataset = False

    args.nprocs = torch.cuda.device_count()
    print(f"Using {args.nprocs} GPUs.")
    if args.nprocs > 1:
        mp.spawn(run, nprocs=args.nprocs, args=(args.nprocs, args))
    else:
        run(0, 1, args)