Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for >1 batch size in Splatfacto #3582

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nerfstudio/cameras/camera_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def apply_to_raybundle(self, raybundle: RayBundle) -> None:
raybundle.origins = raybundle.origins + correction_matrices[:, :3, 3]
raybundle.directions = torch.bmm(correction_matrices[:, :3, :3], raybundle.directions[..., None]).squeeze()

def apply_to_camera(self, camera: Cameras) -> torch.Tensor:
def apply_to_camera(self, camera: Cameras) -> Float[Tensor, "b 3 4"]:
"""Apply the pose correction to the world-to-camera matrix in a Camera object"""
if self.config.mode == "off":
return camera.camera_to_worlds
Expand Down
42 changes: 22 additions & 20 deletions nerfstudio/data/datamanagers/full_images_datamanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from nerfstudio.data.datasets.base_dataset import InputDataset
from nerfstudio.data.utils.data_utils import identity_collate
from nerfstudio.data.utils.dataloaders import ImageBatchStream, _undistort_image
from nerfstudio.data.utils.nerfstudio_collate import nerfstudio_collate
from nerfstudio.utils.misc import get_orig_class
from nerfstudio.utils.rich_utils import CONSOLE

Expand Down Expand Up @@ -89,6 +90,8 @@ class FullImageDatamanagerConfig(DataManagerConfig):
More details are described here: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader"""
cache_compressed_images: bool = False
"""If True, cache raw image files as byte strings to RAM."""
batch_size: int = 1
"""The batch size for the dataloader."""


class FullImageDatamanager(DataManager, Generic[TDataset]):
Expand Down Expand Up @@ -148,7 +151,7 @@ def __init__(
assert len(self.train_unseen_cameras) > 0, "No data found in dataset"
super().__init__()

def sample_train_cameras(self):
def sample_train_cameras(self) -> List[int]:
"""Return a list of camera indices sampled using the strategy specified by
self.config.train_cameras_sampling_strategy"""
num_train_cameras = len(self.train_dataset)
Expand Down Expand Up @@ -322,9 +325,9 @@ def setup_train(self):
)
self.train_image_dataloader = DataLoader(
self.train_imagebatch_stream,
batch_size=1,
batch_size=self.config.batch_size,
num_workers=self.config.dataloader_num_workers,
collate_fn=identity_collate,
collate_fn=nerfstudio_collate,
)
self.iter_train_image_dataloader = iter(self.train_image_dataloader)

Expand Down Expand Up @@ -380,33 +383,32 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
def get_train_rays_per_batch(self) -> int:
"""Returns resolution of the image returned from datamanager."""
camera = self.train_dataset.cameras[0].reshape(())
return int(camera.width[0].item() * camera.height[0].item())
return int(camera.width[0].item() * camera.height[0].item()) * self.config.batch_size

def next_train(self, step: int) -> Tuple[Cameras, Dict]:
"""Returns the next training batch
Returns a Camera instead of raybundle"""

self.train_count += 1
if self.config.cache_images == "disk":
camera, data = next(self.iter_train_image_dataloader)[0]
return camera, data
cameras, data = next(self.iter_train_image_dataloader)
return cameras, data

image_idx = self.train_unseen_cameras.pop(0)
# Make sure to re-populate the unseen cameras list if we have exhausted it
if len(self.train_unseen_cameras) == 0:
self.train_unseen_cameras = self.sample_train_cameras()
camera_indices = []
for _ in range(self.config.batch_size):
# Make sure to re-populate the unseen cameras list if we have exhausted it
if len(self.train_unseen_cameras) == 0:
self.train_unseen_cameras = self.sample_train_cameras()
camera_indices.append(self.train_unseen_cameras.pop(0))

data = self.cached_train[image_idx]
# We're going to copy to make sure we don't mutate the cached dictionary.
# NOTE: We're going to copy the data to make sure we don't mutate the cached dictionary.
# This can cause a memory leak: https://github.com/nerfstudio-project/nerfstudio/issues/3335
data = data.copy()
data["image"] = data["image"].to(self.device)
data = nerfstudio_collate(
[self.cached_train[i].copy() for i in camera_indices]
) # Note that this must happen before indexing cameras, as it can modify the cameras in the dataset during undistortion
cameras = nerfstudio_collate([self.train_dataset.cameras[i : i + 1].to(self.device) for i in camera_indices])

assert len(self.train_cameras.shape) == 1, "Assumes single batch dimension"
camera = self.train_cameras[image_idx : image_idx + 1].to(self.device)
if camera.metadata is None:
camera.metadata = {}
camera.metadata["cam_idx"] = image_idx
return camera, data
return cameras, data

def next_eval(self, step: int) -> Tuple[Cameras, Dict]:
"""Returns the next evaluation batch
Expand Down
6 changes: 4 additions & 2 deletions nerfstudio/data/utils/nerfstudio_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel, device=elem.device)
storage = elem.untyped_storage()._new_shared(numel, device=str(elem.device))
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == "numpy" and elem_type.__name__ not in ("str_", "string_"):
Expand Down Expand Up @@ -179,7 +179,9 @@ def nerfstudio_collate(batch: Any, extra_mappings: Union[Dict[type, Callable], N

# Create metadata dictionary
if batch[0].metadata is not None:
metadata = {key: op([cam.metadata[key] for cam in batch], dim=0) for key in batch[0].metadata.keys()}
metadata = {
key: op([torch.tensor([cam.metadata[key]]) for cam in batch], dim=0) for key in batch[0].metadata.keys()
}
else:
metadata = None

Expand Down
77 changes: 46 additions & 31 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,28 @@
from nerfstudio.utils.spherical_harmonics import RGB2SH, SH2RGB, num_sh_bases


def resize_image(image: torch.Tensor, d: int):
def resize_image(image: torch.Tensor, d: int) -> torch.Tensor:
"""
Downscale images using the same 'area' method in opencv

:param image shape [H, W, C]
:param image shape [B, H, W, C]
:param d downscale factor (must be 2, 4, 8, etc.)

return downscaled image in shape [H//d, W//d, C]
return downscaled image in shape [B, H//d, W//d, C]
"""
import torch.nn.functional as tf

image = image.to(torch.float32)
weight = (1.0 / (d * d)) * torch.ones((1, 1, d, d), dtype=torch.float32, device=image.device)
return tf.conv2d(image.permute(2, 0, 1)[:, None, ...], weight, stride=d).squeeze(1).permute(1, 2, 0)

B, H, W, C = image.shape
image = image.permute(0, 3, 1, 2) # [B, C, H, W]
image = image.reshape(B * C, 1, H, W) # Combine batch and channel dimensions for Conv2D

downscaled = tf.conv2d(image, weight, stride=d)
downscaled = downscaled.reshape(B, C, downscaled.shape[-2], downscaled.shape[-1])
downscaled = downscaled.permute(0, 2, 3, 1) # [B, H//d, W//d, C]

return downscaled


@torch_compile()
Expand Down Expand Up @@ -482,32 +490,31 @@ def _apply_bilateral_grid(self, rgb: torch.Tensor, cam_idx: int, H: int, W: int)
)
return out["rgb"]

def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
"""Takes in a camera and returns a dictionary of outputs.
def get_outputs(self, cameras: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
"""Takes in cameras and returns a dictionary of outputs.

Args:
camera: The camera(s) for which output images are rendered. It should have
cameras: The camera(s) for which output images are rendered. It should have
all the needed information to compute the outputs.

Returns:
Outputs of model. (ie. rendered colors)
"""
if not isinstance(camera, Cameras):
if not isinstance(cameras, Cameras):
print("Called get_outputs with not a camera")
return {}

if self.training:
assert camera.shape[0] == 1, "Only one camera at a time"
optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)
optimized_camera_to_world = self.camera_optimizer.apply_to_camera(cameras)
else:
optimized_camera_to_world = camera.camera_to_worlds
optimized_camera_to_world = cameras.camera_to_worlds

# cropping
if self.crop_box is not None and not self.training:
crop_ids = self.crop_box.within(self.means).squeeze()
if crop_ids.sum() == 0:
return self.get_empty_outputs(
int(camera.width.item()), int(camera.height.item()), self.background_color
int(cameras.width.item()), int(cameras.height.item()), self.background_color
)
else:
crop_ids = None
Expand All @@ -530,12 +537,16 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
colors_crop = torch.cat((features_dc_crop[:, None, :], features_rest_crop), dim=1)

camera_scale_fac = self._get_downscale_factor()
camera.rescale_output_resolution(1 / camera_scale_fac)
viewmat = get_viewmat(optimized_camera_to_world)
K = camera.get_intrinsics_matrices().cuda()
W, H = int(camera.width.item()), int(camera.height.item())
cameras.rescale_output_resolution(1 / camera_scale_fac)
viewmats = get_viewmat(optimized_camera_to_world)
Ks = cameras.get_intrinsics_matrices().cuda()

W, H = (
int(cameras.width[0]),
int(cameras.height[0]),
) # assume all cameras have the same resolution
self.last_size = (H, W)
camera.rescale_output_resolution(camera_scale_fac) # type: ignore
cameras.rescale_output_resolution(camera_scale_fac) # type: ignore

# apply the compensation of screen space blurring to gaussians
if self.config.rasterize_mode not in ["antialiased", "classic"]:
Expand All @@ -558,8 +569,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
scales=torch.exp(scales_crop),
opacities=torch.sigmoid(opacities_crop).squeeze(-1),
colors=colors_crop,
viewmats=viewmat, # [1, 4, 4]
Ks=K, # [1, 3, 3]
viewmats=viewmats, # [B, 4, 4]
Ks=Ks, # [B, 3, 3]
width=W,
height=H,
packed=False,
Expand All @@ -585,24 +596,28 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:

# apply bilateral grid
if self.config.use_bilateral_grid and self.training:
if camera.metadata is not None and "cam_idx" in camera.metadata:
rgb = self._apply_bilateral_grid(rgb, camera.metadata["cam_idx"], H, W)
if cameras.metadata is not None and "cam_idx" in cameras.metadata:
rgb = self._apply_bilateral_grid(rgb, cameras.metadata["cam_idx"], H, W)

if render_mode == "RGB+ED":
depth_im = render[:, ..., 3:4]
depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max()).squeeze(0)
depth_im = torch.where(alpha > 0, depth_im, depth_im.detach().max())
else:
depth_im = None

if background.shape[0] == 3 and not self.training:
background = background.expand(H, W, 3)

return {
"rgb": rgb.squeeze(0), # type: ignore
"depth": depth_im, # type: ignore
"accumulation": alpha.squeeze(0), # type: ignore
"background": background, # type: ignore
} # type: ignore
outputs = {
"rgb": rgb,
"depth": depth_im,
"accumulation": alpha,
"background": background,
}

if self.training:
return outputs
return {k: v.squeeze(0) if k != "background" else v for k, v in outputs.items()}

def get_gt_img(self, image: torch.Tensor):
"""Compute groundtruth image with iteration dependent downscale factor for evaluation purpose
Expand All @@ -622,7 +637,7 @@ def composite_with_background(self, image, background) -> torch.Tensor:
image: the image to composite
background: the background color
"""
if image.shape[2] == 4:
if image.shape[-1] == 4:
alpha = image[..., -1].unsqueeze(-1).repeat((1, 1, 3))
return alpha * image[..., :3] + (1 - alpha) * background
else:
Expand Down Expand Up @@ -671,7 +686,7 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
pred_img = pred_img * mask

Ll1 = torch.abs(gt_img - pred_img).mean()
simloss = 1 - self.ssim(gt_img.permute(2, 0, 1)[None, ...], pred_img.permute(2, 0, 1)[None, ...])
simloss = 1 - self.ssim(gt_img.permute(0, 3, 1, 2), pred_img.permute(0, 3, 1, 2))
if self.config.use_scale_regularization and self.step % 10 == 0:
scale_exp = torch.exp(self.scales)
scale_reg = (
Expand Down
Loading