-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpenumbral.py
39 lines (26 loc) · 1.04 KB
/
penumbral.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn.functional as F
def map_psi(x, r):
x_x = x[..., :-1]
x_y = F.sigmoid(x[..., -1])
return x_x * x_y.unsqueeze(-1) * r, x_y * r
def penumbral(q, k, r=1, gamma=1, eps=1e-6):
q_x, q_y = map_psi(q, r)
k_x, k_y = map_psi(k, r)
q_y = q_y.unsqueeze(2)
k_y = k_y.unsqueeze(1)
x_q_y = torch.sqrt(r**2 - q_y**2 + eps)
x_k_y = torch.sqrt(r**2 - k_y**2 + eps)
pairwise_dist = torch.cdist(q_x, k_x)
lca_height = torch.maximum(torch.maximum(q_y**2, k_y**2),
r**2 - ((x_q_y + x_k_y - pairwise_dist) / 2)**2)
lca_height_outcone = ((pairwise_dist**2 + k_y**2 - q_y**2) /
(2 * pairwise_dist + eps))**2 + q_y**2
exists_cone = torch.logical_or(pairwise_dist <= x_q_y,
(pairwise_dist - x_q_y)**2 + k_y**2 <= r**2)
return -gamma * torch.where(exists_cone, lca_height, lca_height_outcone)
# [b, n, d]
q = torch.randn(100, 10, 2)
k = torch.randn(100, 10, 2)
# [100, 10, 10]
print(penumbral(q, k).shape)