Skip to content

Commit

Permalink
Account for nerfacto vs splatfacto spherical harmonics differences
Browse files Browse the repository at this point in the history
  • Loading branch information
brentyi committed Jan 21, 2025
1 parent 43abb0f commit a739636
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 17 deletions.
13 changes: 7 additions & 6 deletions nerfstudio/scripts/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
import open3d as o3d
import torch
import tyro
from scipy.spatial.transform import Rotation as ScR
from typing_extensions import Annotated, Literal

from nerfstudio.cameras.rays import RayBundle
from nerfstudio.data.datamanagers.base_datamanager import VanillaDataManager
from nerfstudio.data.datamanagers.full_images_datamanager import FullImageDatamanager
Expand All @@ -50,6 +47,8 @@
from nerfstudio.utils.eval_utils import eval_setup
from nerfstudio.utils.rich_utils import CONSOLE
from nerfstudio.utils.spherical_harmonics import rotate_spherical_harmonics
from scipy.spatial.transform import Rotation as ScR
from typing_extensions import Annotated, Literal


@dataclass

Check failure on line 54 in nerfstudio/scripts/exporter.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

nerfstudio/scripts/exporter.py:19:1: I001 Import block is un-sorted or un-formatted
Expand Down Expand Up @@ -636,9 +635,11 @@ def main(self) -> None:
dim_sh_all = shs_rest.shape[-1] + 1
shs_coeffs_all = torch.zeros((n, 3, dim_sh_all), device=shs_rest.device)
shs_coeffs_all[:, :, 1:] = shs_rest
# TODO: check output rotation
output_rotation = pipeline.datamanager.train_dataparser_outputs.dataparser_transform[:3, :3]
shs_rest = rotate_spherical_harmonics(output_rotation, shs_coeffs_all)[:, :, 1:]
shs_rest = rotate_spherical_harmonics(
pipeline.datamanager.train_dataparser_outputs.dataparser_transform[:3, :3].T,
shs_coeffs_all,
component_convention="-y,+z,-x",
)[:, :, 1:]

shs_rest = shs_rest.cpu().numpy().reshape((n, -1))

Expand Down
32 changes: 29 additions & 3 deletions nerfstudio/utils/spherical_harmonics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
"""Sphecal Harmonics utils."""

import math
from typing import Literal

import torch
from e3nn.o3 import Irreps
from jaxtyping import Float
from torch import Tensor
from typing_extensions import assert_never

MAX_SH_DEGREE = 4

Expand Down Expand Up @@ -117,12 +119,16 @@ def SH2RGB(sh):
def rotate_spherical_harmonics(
rotation_matrix: Float[Tensor, "3 3"],
coeffs: Float[Tensor, "*batch dim_sh"],
component_convention: Literal["-y,+z,-x", "+y,+z,+x"],
) -> Float[Tensor, "*batch dim_sh"]:
"""Rotates real spherical harmonic coefficients using a given 3x3 rotation matrix.
Args:
rotation_matrix : A 3x3 rotation matrix.
coeffs : SH coefficients
component_convention: Component convention for spherical harmonics.
Nerfstudio (nerfacto) uses +y,+z,+x, while gsplat (splatfacto) uses
-y,+z,-x.
Returns:
The rotated SH coefficients
Expand All @@ -132,11 +138,31 @@ def rotate_spherical_harmonics(
sh_degree = int(math.sqrt(dim_sh)) - 1

# e3nn uses the xyz ordering instead of the standard yzx used in ns, equivalent to a change of basis
R_yzx_to_xyz = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]], dtype=torch.float32)
R_total = (R_yzx_to_xyz.T @ rotation_matrix @ R_yzx_to_xyz).cpu()
if component_convention == "+y,+z,+x":
R_xyz_from_yzx = torch.tensor(
[
[0, 0, 1],
[1, 0, 0],
[0, 1, 0],
],
dtype=torch.float32,
)
rotation_matrix = (R_xyz_from_yzx.T @ rotation_matrix @ R_xyz_from_yzx).cpu()
elif component_convention == "-y,+z,-x":
R_xyz_from_negyznegx = torch.tensor(
[
[0, 0, -1],
[-1, 0, 0],
[0, 1, 0],
],
dtype=torch.float32,
)
rotation_matrix = (R_xyz_from_negyznegx.T @ rotation_matrix @ R_xyz_from_negyznegx).cpu()
else:
assert_never(component_convention)

irreps = Irreps(" + ".join([f"{i}e" for i in range(sh_degree + 1)])) # Even parity spherical harmonics of degree l
D_matrix = irreps.D_from_matrix(R_total).to(coeffs.device) # Construct Wigner D-matrix
D_matrix = irreps.D_from_matrix(rotation_matrix).to(coeffs.device) # Construct Wigner D-matrix

# Multiply last dimension of coeffs (..., dim_sh) with the Wigner D-matrix (dim_sh, dim_sh)
rotated_coeffs = coeffs @ D_matrix.T
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ include = ["nerfstudio*"]
"*" = ["*.cu", "*.json", "py.typed", "setup.bash", "setup.zsh"]

[tool.pytest.ini_options]
addopts = "-n=4 --typeguard-packages=nerfstudio --disable-warnings"
# addopts = "-n=4 --typeguard-packages=nerfstudio --disable-warnings"
testpaths = [
"tests",
]
Expand Down
52 changes: 45 additions & 7 deletions tests/utils/test_spherical_harmonics.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Literal

import numpy as np
import pytest
import torch
from scipy.spatial.transform import Rotation as ScR

from gsplat.cuda._torch_impl import _eval_sh_bases_fast as gsplat_eval_sh_bases
from gsplat.cuda._torch_impl import _spherical_harmonics as gsplat_spherical_harmonics
from nerfstudio.utils.spherical_harmonics import (
components_from_spherical_harmonics,
num_sh_bases,
rotate_spherical_harmonics,
)
from scipy.spatial.transform import Rotation as ScR


@pytest.mark.parametrize("degree", list(range(0, 5)))

Check failure on line 16 in tests/utils/test_spherical_harmonics.py

View workflow job for this annotation

GitHub Actions / build

Ruff (I001)

tests/utils/test_spherical_harmonics.py:1:1: I001 Import block is un-sorted or un-formatted
Expand All @@ -23,7 +26,7 @@ def test_spherical_harmonics_components(degree):


@pytest.mark.parametrize("sh_degree", list(range(0, 4)))
def test_spherical_harmonics_rotation(sh_degree):
def test_spherical_harmonics_rotation_nerfacto(sh_degree):
"""Test if rotating both the view direction and SH coefficients by the same rotation
produces the same color output as the original.
Expand All @@ -43,7 +46,7 @@ def test_spherical_harmonics_rotation(sh_degree):
color_original = (sh_coeffs * y_lm[..., None, :]).sum(dim=-1)

rot_matrix = torch.tensor(ScR.random().as_matrix(), dtype=torch.float32)
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs)
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs, component_convention="+y,+z,+x")
dirs_rotated = (rot_matrix @ dirs.T).T
y_lm_rotated = components_from_spherical_harmonics(sh_degree, dirs_rotated)
color_rotated = (sh_coeffs_rotated * y_lm_rotated[..., None, :]).sum(dim=-1)
Expand All @@ -52,7 +55,42 @@ def test_spherical_harmonics_rotation(sh_degree):


@pytest.mark.parametrize("sh_degree", list(range(0, 4)))
def test_spherical_harmonics_rotation_properties(sh_degree):
def test_spherical_harmonics_rotation_splatfacto(sh_degree):
"""Test if rotating both the view direction and SH coefficients by the same rotation
produces the same color output as the original.
In other words, for any rotation R:
color(dir, coeffs) = color(R @ dir, rotate_sh(R, coeffs))
"""
torch.manual_seed(0)
np.random.seed(0)

N = 1000
num_coeffs = (sh_degree + 1) ** 2
sh_coeffs = torch.rand(N, 3, num_coeffs)
dirs = torch.rand(N, 3)
dirs = dirs / torch.linalg.norm(dirs, dim=-1, keepdim=True)

assert dirs.shape == (N, 3)
y_lm = gsplat_eval_sh_bases(num_coeffs, dirs)
color_original = (sh_coeffs * y_lm[..., None, :]).sum(dim=-1)

Check failure on line 76 in tests/utils/test_spherical_harmonics.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F841)

tests/utils/test_spherical_harmonics.py:76:5: F841 Local variable `color_original` is assigned to but never used

rot_matrix = torch.tensor(ScR.random().as_matrix(), dtype=torch.float32)
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs, component_convention="-y,+z,-x")
dirs_rotated = (rot_matrix @ dirs.T).T
assert dirs_rotated.shape == (N, 3)
y_lm_rotated = gsplat_eval_sh_bases(num_coeffs, dirs_rotated)
color_rotated = (sh_coeffs_rotated * y_lm_rotated[..., None, :]).sum(dim=-1)

Check failure on line 83 in tests/utils/test_spherical_harmonics.py

View workflow job for this annotation

GitHub Actions / build

Ruff (F841)

tests/utils/test_spherical_harmonics.py:83:5: F841 Local variable `color_rotated` is assigned to but never used

torch.testing.assert_close(
gsplat_spherical_harmonics(sh_degree, coeffs=sh_coeffs.swapaxes(-1, -2), dirs=dirs),
gsplat_spherical_harmonics(sh_degree, coeffs=sh_coeffs_rotated.swapaxes(-1, -2), dirs=dirs_rotated),
)


@pytest.mark.parametrize("sh_degree", list(range(0, 4)))
@pytest.mark.parametrize("component_convention", ["+y,+z,+x", "-y,+z,-x"])
def test_spherical_harmonics_rotation_properties(sh_degree: int, component_convention: Literal["+y,+z,+x", "-y,+z,-x"]):
"""Test properties of the SH rotation"""
torch.manual_seed(0)
np.random.seed(0)
Expand All @@ -61,7 +99,7 @@ def test_spherical_harmonics_rotation_properties(sh_degree):
num_coeffs = (sh_degree + 1) ** 2
sh_coeffs = torch.rand(N, 3, num_coeffs)
rot_matrix = torch.tensor(ScR.random().as_matrix(), dtype=torch.float32)
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs)
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs, component_convention)

# Norm preserving
norm_original = torch.norm(sh_coeffs, dim=-1)
Expand All @@ -73,5 +111,5 @@ def test_spherical_harmonics_rotation_properties(sh_degree):

# Identity rotation
rot_matrix = torch.eye(3)
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs)
sh_coeffs_rotated = rotate_spherical_harmonics(rot_matrix, sh_coeffs, ordering)
torch.testing.assert_close(sh_coeffs, sh_coeffs_rotated, rtol=0, atol=1e-6)

0 comments on commit a739636

Please sign in to comment.