Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch to match dtype and device of problematic layers to same as input #324

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dinov2/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
self.sample_drop_ratio = drop_path

def forward(self, x: Tensor) -> Tensor:
self.to(x, non_blocking=True)
def attn_residual_func(x: Tensor) -> Tensor:
return self.ls1(self.attn(self.norm1(x)))

Expand Down
1 change: 1 addition & 0 deletions dinov2/layers/dino_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _init_weights(self, m):
nn.init.constant_(m.bias, 0)

def forward(self, x):
self.mlp.to(x, non_blocking=True)
x = self.mlp(x)
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
Expand Down
1 change: 1 addition & 0 deletions dinov2/layers/patch_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def forward(self, x: Tensor) -> Tensor:
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"

self.proj.to(x, non_blocking=True)
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
Expand Down
6 changes: 4 additions & 2 deletions dinov2/models/vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ def prepare_tokens_with_masks(self, x, masks=None):
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)

x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)
x = torch.cat((self.cls_token.to(x, non_blocking=True).expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h).to(x, non_blocking=True)

if self.register_tokens is not None:
x = torch.cat(
Expand Down Expand Up @@ -249,8 +249,10 @@ def forward_features_list(self, x_list, masks_list):

def forward_features(self, x, masks=None):
if isinstance(x, list):
self.norm.to(x[0], non_blocking=True)
return self.forward_features_list(x, masks)

self.norm.to(x, non_blocking=True)
x = self.prepare_tokens_with_masks(x, masks)

for blk in self.blocks:
Expand Down