401 lines
15 KiB
Python
401 lines
15 KiB
Python
from pathlib import Path
|
|
from types import SimpleNamespace
|
|
from typing import Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
|
|
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
|
def normalize_keypoints(
|
|
kpts: torch.Tensor, size: Optional[torch.Tensor] = None
|
|
) -> torch.Tensor:
|
|
if size is None:
|
|
size = 1 + kpts.max(-2).values - kpts.min(-2).values
|
|
elif not isinstance(size, torch.Tensor):
|
|
size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
|
|
size = size.to(kpts)
|
|
shift = size / 2
|
|
scale = size.max(-1).values / 2
|
|
kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
|
|
return kpts
|
|
|
|
|
|
class LearnableFourierPositionalEncoding(nn.Module):
|
|
def __init__(self, M: int, head_dim: int, gamma: float = 1.0) -> None:
|
|
super().__init__()
|
|
self.head_dim = head_dim
|
|
self.gamma = gamma
|
|
self.Wr = nn.Linear(M, head_dim // 2, bias=False)
|
|
nn.init.normal_(self.Wr.weight.data, mean=0, std=self.gamma**-2)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""encode position vector"""
|
|
projected = self.Wr(x)
|
|
cosines, sines = torch.cos(projected), torch.sin(projected)
|
|
emb = torch.stack([cosines, sines], 0).unsqueeze(-1)
|
|
# emb.shape == (2, 1, N, 32, 1)
|
|
emb = torch.cat((emb, emb), dim=-1)
|
|
# emb.shape == (2, 1, N, 32, 2)
|
|
emb = emb.reshape(2, 1, 1, -1, self.head_dim)
|
|
return emb
|
|
|
|
|
|
class TokenConfidence(nn.Module):
|
|
def __init__(self, dim: int) -> None:
|
|
super(TokenConfidence, self).__init__()
|
|
self.token = nn.Sequential(nn.Linear(dim, 1), nn.Sigmoid())
|
|
|
|
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
|
"""get confidence tokens"""
|
|
return (
|
|
self.token(desc0.detach()).squeeze(-1),
|
|
self.token(desc1.detach()).squeeze(-1),
|
|
)
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, q, k, v) -> torch.Tensor:
|
|
return F.scaled_dot_product_attention(q, k, v)
|
|
|
|
|
|
class SelfBlock(nn.Module):
|
|
def __init__(self, embed_dim: int, num_heads: int, bias: bool = True) -> None:
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = embed_dim // num_heads
|
|
self.batch = 1
|
|
self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
|
|
self.inner_attn = Attention()
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
self.ffn = nn.Sequential(
|
|
nn.Linear(2 * embed_dim, 2 * embed_dim),
|
|
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
|
nn.GELU(),
|
|
nn.Linear(2 * embed_dim, embed_dim),
|
|
)
|
|
|
|
def forward(self, x: torch.Tensor, encoding: torch.Tensor) -> torch.Tensor:
|
|
qkv: torch.Tensor = self.Wqkv(x)
|
|
qkv = qkv.reshape(self.batch, -1, self.num_heads, self.head_dim, 3)
|
|
qkv = qkv.transpose(1, 2)
|
|
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
|
|
q = self.apply_cached_rotary_emb(encoding, q)
|
|
k = self.apply_cached_rotary_emb(encoding, k)
|
|
context = self.inner_attn(q, k, v)
|
|
# context.shape == (1, 4, N, 64)
|
|
context = context.transpose(1, 2)
|
|
context = context.reshape(self.batch, -1, self.embed_dim)
|
|
message = self.out_proj(context)
|
|
return x + self.ffn(torch.cat((x, message), -1))
|
|
|
|
def rotate_half(self, t: torch.Tensor) -> torch.Tensor:
|
|
t = t.reshape(self.batch, self.num_heads, -1, self.head_dim // 2, 2)
|
|
t = torch.stack((-t[..., 1], t[..., 0]), dim=-1)
|
|
t = t.reshape(self.batch, self.num_heads, -1, self.head_dim)
|
|
return t
|
|
|
|
def apply_cached_rotary_emb(
|
|
self, freqs: torch.Tensor, t: torch.Tensor
|
|
) -> torch.Tensor:
|
|
return (t * freqs[0]) + (self.rotate_half(t) * freqs[1])
|
|
|
|
|
|
class CrossBlock(nn.Module):
|
|
def __init__(self, embed_dim: int, num_heads: int, bias: bool = True) -> None:
|
|
super().__init__()
|
|
self.embed_dim = embed_dim
|
|
self.num_heads = num_heads
|
|
self.head_dim = embed_dim // num_heads
|
|
self.batch = 1
|
|
self.to_qk = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
self.to_v = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
self.inner_attn = Attention() # Q, K, V dot product
|
|
self.to_out = nn.Linear(embed_dim, embed_dim, bias=bias)
|
|
self.ffn = nn.Sequential(
|
|
nn.Linear(2 * embed_dim, 2 * embed_dim),
|
|
nn.LayerNorm(2 * embed_dim, elementwise_affine=True),
|
|
nn.GELU(),
|
|
nn.Linear(2 * embed_dim, embed_dim),
|
|
)
|
|
|
|
def forward(self, x0: torch.Tensor, x1: torch.Tensor) -> Tuple[torch.Tensor]:
|
|
qk0, qk1 = map(self.to_qk, (x0, x1))
|
|
v0, v1 = map(self.to_v, (x0, x1))
|
|
qk0, qk1, v0, v1 = map(
|
|
lambda t: t.reshape(
|
|
self.batch, -1, self.num_heads, self.head_dim
|
|
).transpose(1, 2),
|
|
(qk0, qk1, v0, v1),
|
|
)
|
|
|
|
m0 = self.inner_attn(qk0, qk1, v1)
|
|
m1 = self.inner_attn(qk1, qk0, v0)
|
|
|
|
m0, m1 = map(
|
|
lambda t: t.transpose(1, 2).reshape(self.batch, -1, self.embed_dim),
|
|
(m0, m1),
|
|
)
|
|
m0, m1 = map(self.to_out, (m0, m1))
|
|
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
|
|
x1 = x1 + self.ffn(torch.cat([x1, m1], -1))
|
|
return x0, x1
|
|
|
|
|
|
class TransformerLayer(nn.Module):
|
|
def __init__(self, embed_dim: int, num_heads: int):
|
|
super().__init__()
|
|
self.self_attn = SelfBlock(embed_dim, num_heads)
|
|
self.cross_attn = CrossBlock(embed_dim, num_heads)
|
|
|
|
def forward(
|
|
self,
|
|
desc0: torch.Tensor,
|
|
desc1: torch.Tensor,
|
|
encoding0: torch.Tensor,
|
|
encoding1: torch.Tensor,
|
|
) -> Tuple[torch.Tensor]:
|
|
desc0 = self.self_attn(desc0, encoding0)
|
|
desc1 = self.self_attn(desc1, encoding1)
|
|
return self.cross_attn(desc0, desc1)
|
|
|
|
|
|
def sigmoid_log_double_softmax(
|
|
sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
|
|
) -> torch.Tensor:
|
|
"""create the log assignment matrix from logits and similarity"""
|
|
certainties = F.logsigmoid(z0) + F.logsigmoid(z1).transpose(1, 2)
|
|
scores0 = F.log_softmax(sim, 2)
|
|
scores1 = F.log_softmax(sim, 1)
|
|
scores = scores0 + scores1 + certainties
|
|
return scores
|
|
|
|
|
|
class MatchAssignment(nn.Module):
|
|
def __init__(self, dim: int) -> None:
|
|
super(MatchAssignment, self).__init__()
|
|
self.dim = dim
|
|
self.scale = dim**0.25
|
|
self.final_proj = nn.Linear(dim, dim, bias=True)
|
|
self.matchability = nn.Linear(dim, 1, bias=True)
|
|
|
|
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor) -> torch.Tensor:
|
|
"""build assignment matrix from descriptors"""
|
|
mdesc0, mdesc1 = map(self.final_proj, (desc0, desc1))
|
|
mdesc0, mdesc1 = map(lambda t: t / self.scale, (mdesc0, mdesc1))
|
|
sim = mdesc0 @ mdesc1.transpose(1, 2)
|
|
z0 = self.matchability(desc0)
|
|
z1 = self.matchability(desc1)
|
|
scores = sigmoid_log_double_softmax(sim, z0, z1)
|
|
return scores
|
|
|
|
def get_matchability(self, desc: torch.Tensor):
|
|
return torch.sigmoid(self.matchability(desc)).squeeze(-1)
|
|
|
|
|
|
# def filter_matches(scores: torch.Tensor, th: float):
|
|
# """obtain matches from a log assignment matrix [BxMxN]"""
|
|
# max0 = torch.topk(scores, k=1, dim=2, sorted=False) # scores.max(2)
|
|
# max1 = torch.topk(scores, k=1, dim=1, sorted=False) # scores.max(1)
|
|
# m0, m1 = max0.indices[:, :, 0], max1.indices[:, 0, :]
|
|
# indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
|
|
# # indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
|
|
# mutual0 = indices0 == m1.gather(1, m0)
|
|
# # mutual1 = indices1 == m0.gather(1, m1)
|
|
# max0_exp = max0.values[:, :, 0].exp()
|
|
# zero = max0_exp.new_tensor(0)
|
|
# mscores0 = torch.where(mutual0, max0_exp, zero)
|
|
# # mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
|
|
# valid0 = mscores0 > th
|
|
# # valid1 = mutual1 & valid0.gather(1, m1)
|
|
# # m0 = torch.where(valid0, m0, -1)
|
|
# # m1 = torch.where(valid1, m1, -1)
|
|
# # return m0, m1, mscores0, mscores1
|
|
|
|
# m_indices_0 = indices0[valid0]
|
|
# m_indices_1 = m0[0][m_indices_0]
|
|
|
|
# matches = torch.stack([m_indices_0, m_indices_1], -1)
|
|
# mscores = mscores0[0][m_indices_0]
|
|
# return matches, mscores
|
|
|
|
def filter_matches(scores: torch.Tensor):
|
|
"""obtain matches from a log assignment matrix [BxMxN]"""
|
|
max0 = torch.topk(scores, k=1, dim=2, sorted=False) # scores.max(2)
|
|
max1 = torch.topk(scores, k=1, dim=1, sorted=False) # scores.max(1)
|
|
m0, m1 = max0.indices[:, :, 0], max1.indices[:, 0, :]
|
|
indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
|
|
|
|
mutual0 = indices0 == m1.gather(1, m0)
|
|
max0_exp = max0.values[:, :, 0].exp()
|
|
|
|
zero = max0_exp.new_tensor(0)
|
|
mscores0 = torch.where(mutual0, max0_exp, zero)
|
|
|
|
m_indices_0 = indices0[mutual0]
|
|
m_indices_1 = m0[0][m_indices_0]
|
|
|
|
matches = torch.stack([m_indices_0, m_indices_1], -1)
|
|
mscores = mscores0[0][m_indices_0]
|
|
return matches, mscores
|
|
|
|
|
|
class LightGlue(nn.Module):
|
|
default_conf = {
|
|
"name": "lightglue", # just for interfacing
|
|
"input_dim": 256, # input descriptor dimension (autoselected from weights)
|
|
"descriptor_dim": 256,
|
|
"n_layers": 9,
|
|
"num_heads": 4,
|
|
"filter_threshold": 0.1, # match threshold
|
|
"depth_confidence": -1, # -1 is no early stopping, recommend: 0.95
|
|
"width_confidence": -1, # -1 is no point pruning, recommend: 0.99
|
|
"weights": None,
|
|
}
|
|
|
|
# lighterglue L3
|
|
_l3_conf_xfeat = {
|
|
"name": "xfeat",
|
|
"n_layers": 3,
|
|
"num_heads": 1,
|
|
"input_dim": 64,
|
|
"descriptor_dim": 96,
|
|
"add_scale_ori": False,
|
|
"add_laf": False, # for KeyNetAffNetHardNet
|
|
"scale_coef": 1.0, # to compensate for the SIFT scale bigger than KeyNet
|
|
"flash": True, # enable FlashAttention if available.
|
|
"mp": False, # enable mixed precision
|
|
"depth_confidence": -1, # early stopping, disable with -1
|
|
# "width_confidence": 0.95, # point pruning, disable with -1
|
|
"width_confidence": -1, # disabled because onnx is not supported dynamic control flow
|
|
"filter_threshold": 0.1, # match threshold
|
|
"weights": None,
|
|
}
|
|
|
|
version = "v0.1_arxiv"
|
|
url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
|
|
|
|
features = {
|
|
"superpoint": ("superpoint_lightglue", 256),
|
|
"disk": ("disk_lightglue", 128),
|
|
"xfeat": ("xfeat_lightglue", 64),
|
|
}
|
|
|
|
def __init__(self, features="superpoint", **conf) -> None:
|
|
super().__init__()
|
|
self.conf = {**self.default_conf, **conf}
|
|
if features is not None:
|
|
assert features in self.features
|
|
self.conf["weights"], self.conf["input_dim"] = self.features[features]
|
|
self.conf = conf = SimpleNamespace(**self.conf)
|
|
|
|
if conf.input_dim != conf.descriptor_dim:
|
|
self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
|
|
else:
|
|
self.input_proj = nn.Identity()
|
|
|
|
head_dim = conf.descriptor_dim // conf.num_heads
|
|
self.posenc = LearnableFourierPositionalEncoding(2, head_dim)
|
|
|
|
# 1, 3, 96 for l3 lighterglue with xfeat
|
|
h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
|
|
|
|
self.transformers = nn.ModuleList([TransformerLayer(d, h) for _ in range(n)])
|
|
|
|
self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
|
|
|
|
self.token_confidence = nn.ModuleList(
|
|
[TokenConfidence(d) for _ in range(n - 1)]
|
|
)
|
|
self.register_buffer(
|
|
"confidence_thresholds",
|
|
torch.Tensor([self.confidence_threshold(i) for i in range(n)]),
|
|
)
|
|
|
|
state_dict = None
|
|
if features is not None:
|
|
fname = f"{conf.weights}_{self.version}.pth".replace(".", "-")
|
|
state_dict = torch.hub.load_state_dict_from_url(
|
|
self.url.format(self.version, features), file_name=fname
|
|
)
|
|
elif conf.weights is not None:
|
|
path = Path(__file__).parent
|
|
path = path / "weights/{}.pth".format(self.conf.weights)
|
|
state_dict = torch.load(str(path), map_location="cpu")
|
|
|
|
if state_dict is not None:
|
|
# rename old state dict entries
|
|
for i in range(n):
|
|
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
|
|
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
|
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
|
|
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
|
self.load_state_dict(state_dict, strict=False)
|
|
|
|
print("Loaded LightGlue model")
|
|
|
|
def forward(
|
|
self,
|
|
kpts0: torch.Tensor,
|
|
kpts1: torch.Tensor,
|
|
desc0: torch.Tensor,
|
|
desc1: torch.Tensor,
|
|
):
|
|
b, m, _ = kpts0.shape
|
|
b, n, _ = kpts1.shape
|
|
|
|
desc0 = self.input_proj(desc0)
|
|
desc1 = self.input_proj(desc1)
|
|
|
|
# cache positional embeddings
|
|
encoding0 = self.posenc(kpts0)
|
|
encoding1 = self.posenc(kpts1)
|
|
|
|
for i in range(self.conf.n_layers):
|
|
# self+cross attention
|
|
desc0, desc1 = self.transformers[i](desc0, desc1, encoding0, encoding1)
|
|
if i == self.conf.n_layers - 1:
|
|
continue # no early stopping or adaptive width at last layer
|
|
|
|
scores = self.log_assignment[i](desc0, desc1)
|
|
matches, mscores = filter_matches(scores)
|
|
return matches, mscores
|
|
|
|
def confidence_threshold(self, layer_index: int) -> float:
|
|
"""scaled confidence threshold"""
|
|
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
|
|
return np.clip(threshold, 0, 1)
|
|
|
|
def get_pruning_mask(
|
|
self,
|
|
confidences: Optional[torch.Tensor],
|
|
scores: torch.Tensor,
|
|
layer_index: int,
|
|
) -> torch.Tensor:
|
|
"""mask points which should be removed"""
|
|
keep = scores > (1 - self.conf.width_confidence)
|
|
if confidences is not None: # Low-confidence points are never pruned.
|
|
keep |= confidences <= self.confidence_thresholds[layer_index]
|
|
return keep
|
|
|
|
def check_if_stop(
|
|
self,
|
|
confidences0: torch.Tensor,
|
|
confidences1: torch.Tensor,
|
|
layer_index: int,
|
|
num_points: int,
|
|
) -> torch.Tensor:
|
|
"""evaluate stopping condition"""
|
|
confidences = torch.cat([confidences0, confidences1], -1)
|
|
threshold = self.confidence_thresholds[layer_index]
|
|
ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
|
|
return ratio_confident > self.conf.depth_confidence
|