diff --git a/config/dICP_config.yaml b/config/dICP_config.yaml index 05928df..0afab47 100644 --- a/config/dICP_config.yaml +++ b/config/dICP_config.yaml @@ -8,7 +8,9 @@ dICP: functionality: - gumbel: False # If true, use Gumbel-Softmax trick for nearest neighour + gumbel: False # If true, use Gumbel-Softmax trick for nearest neighour + gumbel_eps: 1.0e-10 # Epsilon for Gumbel-Softmax trick + gumbel_tau: 0.1 # Temperature for Gumbel-Softmax trick # Not yet implemented # svd: False # If true, use SVD to solve pt2pt problem, no effect for pt2pl diff --git a/dICP/ICP.py b/dICP/ICP.py index dcc2e12..420ac69 100644 --- a/dICP/ICP.py +++ b/dICP/ICP.py @@ -37,8 +37,10 @@ def load_config(file_path): self.diff = differentiable use_gumbel = self.config['dICP']['functionality']['gumbel'] + eps = self.config['dICP']['functionality']['gumbel_eps'] + tau = self.config['dICP']['functionality']['gumbel_tau'] - self.nn = nn(self.diff, use_gumbel=use_gumbel) + self.nn = nn(self.diff, use_gumbel=use_gumbel, eps=eps, tau=tau) def icp(self, source, target, T_init, weight=None, trim_dist=None, loss_fn=None, dim=3): return self.dICP(source, target, T_init, weight, trim_dist, loss_fn, dim) diff --git a/dICP/nn.py b/dICP/nn.py index 155778a..5e8c3ff 100644 --- a/dICP/nn.py +++ b/dICP/nn.py @@ -2,9 +2,11 @@ import torch.nn.functional as F class nn: - def __init__(self, differentiable=True, use_gumbel=True): + def __init__(self, differentiable=True, use_gumbel=True, eps=1e-20, tau=0.1): self.differentiable = differentiable self.use_gumbel = use_gumbel + self.eps = eps + self.tau = tau def find_nn(self, x, y): x_use, y_use = self.__handle_dimensions(x, y) @@ -37,6 +39,36 @@ def __diff_nn(self, x, y): return neighbors + # Nearest neighbour using Gumbel-Softmax trick, not yet integrated + def __diff_nn_gumbel(self, x, y): + """ + Computes the differentiable nearest neighbor of all entries in source x + to the target point cloud y using softmax. + :param x: Source points (N, n, 3). + :param y: Target points (N, m, 3/6). + """ + # Expand x and y to have an additional dimension for broadcasting + x_use = x.unsqueeze(2) # shape: (N, n, 1, 3) + y_use = y.unsqueeze(1) # shape: (N, 1, m, 3/6) + + # If y has 6 elements, then normals are included, in this case extract first 3 for operations + # Compute the squared Euclidean distances between x and each point in y + distances = torch.sum((x_use - y_use[:,:,:,:3])**2, dim=3) # shape: (N, n, m) + + # Apply the Gumbel-Softmax trick to obtain a differentiable approximation of the argmax operation + logits = -distances + U = torch.rand(logits.shape, device=logits.device) # sample from uniform distribution + eps = self.eps + noise = -torch.log(-torch.log(U + eps) + eps) # sample from Gumbel distribution + tau = self.tau # temperature + noisy_logits = (logits + noise) / tau # divide by temperature, shape: (N, n, m) + probs = torch.softmax(noisy_logits, dim=2) # shape: (N, n, m) + + # Compute the weighted average of the points in y using the probabilities + neighbor = probs @ y + + return neighbor + def __non_diff_nn(self, x, y): """ Computes the differentiable nearest neighbor of all entries in source x @@ -90,34 +122,4 @@ def __handle_dimensions(self, x, y): assert y_use.shape[2] == 3 or y_use.shape[2] == 6, "y must have 3 or 6 elements in the second dimension." - return x_use, y_use - - # Nearest neighbour using Gumbel-Softmax trick, not yet integrated - def __diff_nn_gumbel(self, x, y): - """ - Computes the differentiable nearest neighbor of all entries in source x - to the target point cloud y using softmax. - :param x: Source points (N, n, 3). - :param y: Target points (N, m, 3/6). - """ - # Expand x and y to have an additional dimension for broadcasting - x_use = x.unsqueeze(2) # shape: (N, n, 1, 3) - y_use = y.unsqueeze(1) # shape: (N, 1, m, 3/6) - - # If y has 6 elements, then normals are included, in this case extract first 3 for operations - # Compute the squared Euclidean distances between x and each point in y - distances = torch.sum((x_use - y_use[:,:,:,:3])**2, dim=3) # shape: (N, n, m) - - # Apply the Gumbel-Softmax trick to obtain a differentiable approximation of the argmax operation - logits = -distances - U = torch.rand(logits.shape, device=logits.device) # sample from uniform distribution - eps = 1e-20 - noise = -torch.log(-torch.log(U + eps) + eps) # sample from Gumbel distribution - tau = 0.1 # temperature - noisy_logits = (logits + noise) / tau # divide by temperature, shape: (N, n, m) - probs = torch.softmax(noisy_logits, dim=2) # shape: (N, n, m) - - # Compute the weighted average of the points in y using the probabilities - neighbor = probs @ y - - return neighbor \ No newline at end of file + return x_use, y_use \ No newline at end of file