File size: 1,121 Bytes
7771996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
from typing import Dict
from utils.distributed import DistributedMetric
from tqdm import tqdm
from torchpack import distributed as dist
from utils import accuracy
import copy
import torch.nn.functional as F
import numpy as np
def eval_local_lip(model, x, xp, top_norm=1, btm_norm=float('inf'), reduction='mean'):
    model.eval()
    down = torch.flatten(x - xp, start_dim=1)
    with torch.no_grad():
        if top_norm == "kl":
            criterion_kl = nn.KLDivLoss(reduction='none')
            top = criterion_kl(F.log_softmax(model(xp), dim=1),
                               F.softmax(model(x), dim=1))
            ret = torch.sum(top, dim=1) / torch.norm(down + 1e-6, dim=1, p=btm_norm)
        else:
            top = torch.flatten(model(x), start_dim=1) - torch.flatten(model(xp), start_dim=1)
            ret = torch.norm(top, dim=1, p=top_norm) / torch.norm(down + 1e-6, dim=1, p=btm_norm)

    if reduction == 'mean':
        return torch.mean(ret)
    elif reduction == 'sum':
        return torch.sum(ret)
    else:
        raise ValueError("Not supported reduction")