|
|
|
|
|
import torch |
|
|
from torch.amp import autocast |
|
|
from torch import Tensor |
|
|
from typing import Union, Tuple, Dict |
|
|
|
|
|
M_EPS = 1e-16 |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
@autocast(device_type="cuda", enabled=True, dtype=torch.float32) |
|
|
def sinkhorn( |
|
|
a: Tensor, |
|
|
b: Tensor, |
|
|
C: Tensor, |
|
|
reg: float = 1e-1, |
|
|
maxIter: int = 1000, |
|
|
stopThr: float = 1e-9, |
|
|
verbose: bool = False, |
|
|
log: bool = True, |
|
|
eval_freq: int = 10, |
|
|
print_freq: int = 200, |
|
|
) -> Union[Tensor, Tuple[Tensor, Dict[str, Tensor]]]: |
|
|
device = a.device |
|
|
na, nb = C.shape |
|
|
assert na == a.shape[0] and nb == b.shape[0], f"Shapes of a ({a.shape}) or b ({b.shape}) do not match that of C ({C.shape})" |
|
|
assert reg > 0, f"reg should be greater than 0. Found reg = {reg}" |
|
|
assert a.min() >= 0. and b.min() >= 0., f"Elements in a and b should be nonnegative. Found a.min() = {a.min()}, b.min() = {b.min()}" |
|
|
|
|
|
if log: |
|
|
log = {"err": []} |
|
|
|
|
|
u = torch.ones(na, dtype=a.dtype, device=device) / na |
|
|
v = torch.ones(nb, dtype=b.dtype, device=device) / nb |
|
|
K = torch.exp(-C / reg) |
|
|
|
|
|
it, err = 1, 1 |
|
|
while (err > stopThr and it <= maxIter): |
|
|
u_pre, v_pre = u.clone(), v.clone() |
|
|
KTu = torch.matmul(K.T, u) |
|
|
v = b / (KTu + M_EPS) |
|
|
Kv = torch.matmul(K, v) |
|
|
u = a / (Kv + M_EPS) |
|
|
|
|
|
if torch.any(torch.isnan(u)) or torch.any(torch.isnan(v)) or torch.any(torch.isinf(u)) or torch.any(torch.isinf(v)): |
|
|
print("Warning: numerical errors at iteration", it) |
|
|
u, v = u_pre, v_pre |
|
|
break |
|
|
|
|
|
if log and it % eval_freq == 0: |
|
|
b_hat = torch.matmul(u, K) * v |
|
|
err = (b - b_hat).pow(2).sum().item() |
|
|
log["err"].append(err) |
|
|
|
|
|
if verbose and it % print_freq == 0: |
|
|
print(f"Iteration {it}, constraint error {err}") |
|
|
|
|
|
it += 1 |
|
|
|
|
|
if log: |
|
|
log["u"] = u |
|
|
log["v"] = v |
|
|
log["alpha"] = reg * torch.log(u + M_EPS) |
|
|
log["beta"] = reg * torch.log(v + M_EPS) |
|
|
|
|
|
P = u.view(-1, 1) * K * v.view(1, -1) |
|
|
if log: |
|
|
return P, log |
|
|
else: |
|
|
return P |
|
|
|