Skip to content

Commit

Permalink
Moved Gumbel-Softmax parameters to config
Browse files Browse the repository at this point in the history
  • Loading branch information
lisusdaniil committed Jul 9, 2024
1 parent 37f14c4 commit db95563
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 34 deletions.
4 changes: 3 additions & 1 deletion config/dICP_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ dICP:


functionality:
gumbel: False # If true, use Gumbel-Softmax trick for nearest neighour
gumbel: False # If true, use Gumbel-Softmax trick for nearest neighour
gumbel_eps: 1.0e-10 # Epsilon for Gumbel-Softmax trick
gumbel_tau: 0.1 # Temperature for Gumbel-Softmax trick
# Not yet implemented
# svd: False # If true, use SVD to solve pt2pt problem, no effect for pt2pl

Expand Down
4 changes: 3 additions & 1 deletion dICP/ICP.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ def load_config(file_path):
self.diff = differentiable

use_gumbel = self.config['dICP']['functionality']['gumbel']
eps = self.config['dICP']['functionality']['gumbel_eps']
tau = self.config['dICP']['functionality']['gumbel_tau']

self.nn = nn(self.diff, use_gumbel=use_gumbel)
self.nn = nn(self.diff, use_gumbel=use_gumbel, eps=eps, tau=tau)

def icp(self, source, target, T_init, weight=None, trim_dist=None, loss_fn=None, dim=3):
return self.dICP(source, target, T_init, weight, trim_dist, loss_fn, dim)
Expand Down
66 changes: 34 additions & 32 deletions dICP/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import torch.nn.functional as F

class nn:
def __init__(self, differentiable=True, use_gumbel=True):
def __init__(self, differentiable=True, use_gumbel=True, eps=1e-20, tau=0.1):
self.differentiable = differentiable
self.use_gumbel = use_gumbel
self.eps = eps
self.tau = tau

def find_nn(self, x, y):
x_use, y_use = self.__handle_dimensions(x, y)
Expand Down Expand Up @@ -37,6 +39,36 @@ def __diff_nn(self, x, y):

return neighbors

# Nearest neighbour using Gumbel-Softmax trick, not yet integrated
def __diff_nn_gumbel(self, x, y):
"""
Computes the differentiable nearest neighbor of all entries in source x
to the target point cloud y using softmax.
:param x: Source points (N, n, 3).
:param y: Target points (N, m, 3/6).
"""
# Expand x and y to have an additional dimension for broadcasting
x_use = x.unsqueeze(2) # shape: (N, n, 1, 3)
y_use = y.unsqueeze(1) # shape: (N, 1, m, 3/6)

# If y has 6 elements, then normals are included, in this case extract first 3 for operations
# Compute the squared Euclidean distances between x and each point in y
distances = torch.sum((x_use - y_use[:,:,:,:3])**2, dim=3) # shape: (N, n, m)

# Apply the Gumbel-Softmax trick to obtain a differentiable approximation of the argmax operation
logits = -distances
U = torch.rand(logits.shape, device=logits.device) # sample from uniform distribution
eps = self.eps
noise = -torch.log(-torch.log(U + eps) + eps) # sample from Gumbel distribution
tau = self.tau # temperature
noisy_logits = (logits + noise) / tau # divide by temperature, shape: (N, n, m)
probs = torch.softmax(noisy_logits, dim=2) # shape: (N, n, m)

# Compute the weighted average of the points in y using the probabilities
neighbor = probs @ y

return neighbor

def __non_diff_nn(self, x, y):
"""
Computes the differentiable nearest neighbor of all entries in source x
Expand Down Expand Up @@ -90,34 +122,4 @@ def __handle_dimensions(self, x, y):

assert y_use.shape[2] == 3 or y_use.shape[2] == 6, "y must have 3 or 6 elements in the second dimension."

return x_use, y_use

# Nearest neighbour using Gumbel-Softmax trick, not yet integrated
def __diff_nn_gumbel(self, x, y):
"""
Computes the differentiable nearest neighbor of all entries in source x
to the target point cloud y using softmax.
:param x: Source points (N, n, 3).
:param y: Target points (N, m, 3/6).
"""
# Expand x and y to have an additional dimension for broadcasting
x_use = x.unsqueeze(2) # shape: (N, n, 1, 3)
y_use = y.unsqueeze(1) # shape: (N, 1, m, 3/6)

# If y has 6 elements, then normals are included, in this case extract first 3 for operations
# Compute the squared Euclidean distances between x and each point in y
distances = torch.sum((x_use - y_use[:,:,:,:3])**2, dim=3) # shape: (N, n, m)

# Apply the Gumbel-Softmax trick to obtain a differentiable approximation of the argmax operation
logits = -distances
U = torch.rand(logits.shape, device=logits.device) # sample from uniform distribution
eps = 1e-20
noise = -torch.log(-torch.log(U + eps) + eps) # sample from Gumbel distribution
tau = 0.1 # temperature
noisy_logits = (logits + noise) / tau # divide by temperature, shape: (N, n, m)
probs = torch.softmax(noisy_logits, dim=2) # shape: (N, n, m)

# Compute the weighted average of the points in y using the probabilities
neighbor = probs @ y

return neighbor
return x_use, y_use

0 comments on commit db95563

Please sign in to comment.