diff --git a/captum/optim/models/__init__.py b/captum/optim/models/__init__.py index 0f809d5ef5..a349ccfd8d 100755 --- a/captum/optim/models/__init__.py +++ b/captum/optim/models/__init__.py @@ -6,6 +6,10 @@ replace_layers, skip_layers, ) +from ._image.clip_resnet50x4_image import CLIP_ResNet50x4Image # noqa: F401 +from ._image.clip_resnet50x4_image import clip_resnet50x4_image # noqa: F401 +from ._image.clip_resnet50x4_text import CLIP_ResNet50x4Text # noqa: F401 +from ._image.clip_resnet50x4_text import clip_resnet50x4_text # noqa: F401 from ._image.inception5h_classes import INCEPTION5H_CLASSES # noqa: F401 from ._image.inception_v1 import InceptionV1, googlenet # noqa: F401 from ._image.inception_v1_places365 import ( # noqa: F401 @@ -16,6 +20,7 @@ INCEPTIONV1_PLACES365_CLASSES, ) + __all__ = [ "RedirectedReluLayer", "SkipLayer", @@ -29,4 +34,8 @@ "InceptionV1Places365", "googlenet_places365", "INCEPTIONV1_PLACES365_CLASSES", + "CLIP_ResNet50x4Image", + "clip_resnet50x4_image", + "CLIP_ResNet50x4Text", + "clip_resnet50x4_text", ] diff --git a/captum/optim/models/_image/clip_resnet50x4_image.py b/captum/optim/models/_image/clip_resnet50x4_image.py new file mode 100644 index 0000000000..14c3cc4ed0 --- /dev/null +++ b/captum/optim/models/_image/clip_resnet50x4_image.py @@ -0,0 +1,382 @@ +from typing import Any, Optional, Type +from warnings import warn + +import torch +import torch.nn as nn +from captum.optim.models._common import RedirectedReluLayer, SkipLayer + +GS_SAVED_WEIGHTS_URL = ( + "https://pytorch.s3.amazonaws.com/models/captum/clip_resnet50x4_image.pt" +) + + +def clip_resnet50x4_image( + pretrained: bool = False, + progress: bool = True, + model_path: Optional[str] = None, + **kwargs: Any, +) -> "CLIP_ResNet50x4Image": + """ + The visual portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable + Visual Models From Natural Language Supervision': https://arxiv.org/abs/2103.00020 + + This model can be combined with the CLIP ResNet 50x4 Text model to create the full + CLIP ResNet 50x4 model. + + Note that the model was trained on inputs with a shape of: [B, 3, 288, 288]. + + Example:: + + >>> model = opt.models.clip_resnet50x4_image(pretrained=True) + >>> output = model(torch.zeros(1, 3, 288, 288)) + + See here for more details: + https://github.com/openai/CLIP + https://github.com/mlfoundations/open_clip + + Args: + + pretrained (bool, optional): If ``True``, returns a pre-trained model. + Default: ``False`` + progress (bool, optional): If ``True``, displays a progress bar of the download + to stderr. + Default: ``True`` + model_path (str, optional): Optional path for the model file. + Default: ``None`` + replace_relus_with_redirectedrelu (bool, optional): If ``True``, return + pretrained model with Redirected ReLU in place of ReLU layers. + Default: *``True``* when ``pretrained`` is ``True`` otherwise *``False``* + use_linear_modules_only (bool, optional): If ``True``, return model + with all nonlinear layers replaced with linear equivalents. + Default: ``False`` + transform_input (bool, optional): If ``True``, preprocesses the input according + to the method with which it was trained. + Default: *``True``* when ``pretrained`` is ``True`` otherwise *``False``* + use_attnpool (bool, optional): Whether or not to use the final + ``AttentionPool2d`` layer in the forward function. If set to ``True``, + model inputs are required to have a shape of: [B, 3, 288, 288] or + [3, 288, 288]. + Default: ``False`` + + Returns: + model (CLIP_ResNet50x4Image): An instance of a CLIP ResNet 50x4 model's + image portion. + """ + if pretrained: + if "transform_input" not in kwargs: + kwargs["transform_input"] = True + if "replace_relus_with_redirectedrelu" not in kwargs: + kwargs["replace_relus_with_redirectedrelu"] = True + if "use_linear_modules_only" not in kwargs: + kwargs["use_linear_modules_only"] = False + if "use_attnpool" not in kwargs: + kwargs["use_attnpool"] = False + + model = CLIP_ResNet50x4Image(**kwargs) + + if model_path is None: + state_dict = torch.hub.load_state_dict_from_url( + GS_SAVED_WEIGHTS_URL, progress=progress, check_hash=False + ) + else: + state_dict = torch.load(model_path, map_location="cpu") + model.load_state_dict(state_dict) + return model + + return CLIP_ResNet50x4Image(**kwargs) + + +class CLIP_ResNet50x4Image(nn.Module): + """ + The visual portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable + Visual Models From Natural Language Supervision': https://arxiv.org/abs/2103.00020 + """ + + __constants__ = ["transform_input", "use_attnpool"] + + def __init__( + self, + transform_input: bool = False, + replace_relus_with_redirectedrelu: bool = False, + use_linear_modules_only: bool = False, + use_attnpool: bool = True, + ) -> None: + """ + Args: + + replace_relus_with_redirectedrelu (bool, optional): If ``True``, return + model with Redirected ReLU in place of ReLU layers. + Default: False + use_linear_modules_only (bool, optional): If ``True``, return model with + all nonlinear layers replaced with linear equivalents. + Default: ``False`` + transform_input (bool, optional): If ``True``, preprocesses the input + according to the method with which it was trained on. + Default: ``False`` + use_attnpool (bool, optional): Whether or not to use the final + ``AttentionPool2d`` layer in the forward function. If set to ``True``, + model inputs are required to have a shape of: [B, 3, 288, 288] or + [3, 288, 288]. + Default: ``True`` + """ + super().__init__() + if use_linear_modules_only: + activ = SkipLayer + else: + if replace_relus_with_redirectedrelu: + activ = RedirectedReluLayer + else: + activ = nn.ReLU + + self.transform_input = transform_input + self.use_attnpool = use_attnpool + + # Stem layers + self.conv1 = nn.Conv2d(3, 40, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(40) + self.relu1 = activ() + self.conv2 = nn.Conv2d(40, 40, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(40) + self.relu2 = activ() + self.conv3 = nn.Conv2d(40, 80, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(80) + self.relu3 = activ() + self.avgpool = nn.AvgPool2d(2) + + # Residual layers + self.layer1 = self._build_layer(80, 80, blocks=4, stride=1, activ=activ) + self.layer2 = self._build_layer(320, 160, blocks=6, stride=2, activ=activ) + self.layer3 = self._build_layer(640, 320, blocks=10, stride=2, activ=activ) + self.layer4 = self._build_layer(1280, 640, blocks=6, stride=2, activ=activ) + + # Attention Pooling + self.attnpool = AttentionPool2d(9, 2560, out_features=640, num_heads=40) + + def _build_layer( + self, + inplanes: int = 80, + planes: int = 80, + blocks: int = 4, + stride: int = 1, + activ: Type[nn.Module] = nn.ReLU, + ) -> nn.Module: + """ + Residual layer creation helper function. + + Args: + + inplanes (int, optional): The number of input channels / features to use + for the first layer. + Default: ``80`` + planes (int, optional): The number of output channels / features to use + for the first layer. This variable is then multiplied by 4 to get the + number of input channels / features to use for the subsequent layers. + Default: ``80`` + blocks (int, optional): The number of Bottleneck layers to create. + Default: ``4`` + stride (int, optional): The stride value to use for the Bottleneck layers. + Default: ``1`` + activ (type of nn.Module, optional): The nn.Module class type to use for + activation layers. + Default: ``nn.ReLU`` + + Returns: + residual_layer (nn.Sequential): A full residual layer instance. + """ + layers = [Bottleneck(inplanes, planes, stride, activ=activ)] + for _ in range(blocks - 1): + layers += [Bottleneck(planes * 4, planes, activ=activ)] + return nn.Sequential(*layers) + + def _transform_input(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to normalize the values of. + + Returns: + x (torch.Tensor): A normalized tensor. + """ + assert x.dim() == 3 or x.dim() == 4 + if self.transform_input: + if x.min() < 0.0 or x.max() > 1.0: + warn("Model input has values outside of the range [0, 1].") + x = x.unsqueeze(0) if x.dim() == 3 else x + x = x - torch.tensor( + [0.48145466, 0.4578275, 0.40821073], device=x.device + ).view(3, 1, 1) + x = x / torch.tensor( + [0.26862954, 0.26130258, 0.27577711], device=x.device + ).view(3, 1, 1) + return x + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to run through the model. + + Returns: + x (torch.Tensor): The model output. + """ + x = self._transform_input(x) + + # Stem layers + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.relu3(self.bn3(self.conv3(x))) + x = self.avgpool(x) + + # Residual layers + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + # Attention Pooling + if self.use_attnpool: + x = self.attnpool(x) + return x + + +class Bottleneck(nn.Module): + def __init__( + self, + inplanes: int = 80, + planes: int = 80, + stride: int = 1, + activ: Type[nn.Module] = nn.ReLU, + ) -> None: + """ + Args: + + inplanes (int, optional): The number of input channels / features to use + for the first layer. + Default: ``80`` + planes (int, optional): The number of output channels / features to use + for the subsequent layers. + Default: ``80`` + stride (int, optional): The stride value to use for the Bottleneck layers. + Default: ``1`` + activ (type of nn.Module, optional): The nn.Module class type to use for + activation layers. + Default: ``nn.ReLU`` + """ + super().__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.relu1 = activ() + + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.relu2 = activ() + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu3 = activ() + + if stride > 1 or inplanes != planes * 4: + self.downsample = nn.Sequential( + nn.AvgPool2d(stride), + nn.Conv2d(inplanes, planes * 4, kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(planes * 4), + ) + else: + self.downsample = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to run through the module. + + Returns: + x (torch.Tensor): The module output. + """ + assert x.dim() == 4 + if self.downsample is not None: + identity = self.downsample(x) + else: + identity = x.clone() + + x = self.relu1(self.bn1(self.conv1(x))) + x = self.relu2(self.bn2(self.conv2(x))) + x = self.avgpool(x) + + x = self.bn3(self.conv3(x)) + identity + x = self.relu3(x) + return x + + +class AttentionPool2d(nn.Module): + def __init__( + self, + spacial_size: int = 9, + in_features: int = 2560, + out_features: int = 640, + num_heads: int = 40, + ) -> None: + """ + Args: + + spacial_size (int, optional): The desired size to user for the positional + embedding. + Default: ``9`` + in_features (int, optional): The desired input size for the nn.Linear + layers. + Default: ``2560`` + out_features (int, optional): The desired output size for the nn.Linear + layers. + Default: ``640`` + num_heads (int, optional): The number of heads to use. + Default: ``40`` + """ + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_size**2 + 1, in_features) / in_features**0.5 + ) + self.k_proj = nn.Linear(in_features, in_features) + self.q_proj = nn.Linear(in_features, in_features) + self.v_proj = nn.Linear(in_features, in_features) + self.c_proj = nn.Linear(in_features, out_features) + self.num_heads = num_heads + + @torch.jit.ignore + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to run through the module. + + Returns: + x (torch.Tensor): The module output. + """ + assert x.dim() == 4 + x = x.reshape(*x.shape[:2], -1).permute(2, 0, 1) + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) + x = x + self.positional_embedding[:, None, :] + return torch.nn.functional.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0.0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + )[0][0] diff --git a/captum/optim/models/_image/clip_resnet50x4_text.py b/captum/optim/models/_image/clip_resnet50x4_text.py new file mode 100644 index 0000000000..8fdbcc5179 --- /dev/null +++ b/captum/optim/models/_image/clip_resnet50x4_text.py @@ -0,0 +1,195 @@ +import math +from typing import Any, Optional + +import torch +import torch.nn as nn + + +GS_SAVED_WEIGHTS_URL = ( + "https://pytorch.s3.amazonaws.com/models/captum/clip_resnet50x4_text.pt" +) + + +def clip_resnet50x4_text( + pretrained: bool = False, + progress: bool = True, + model_path: Optional[str] = None, + **kwargs: Any, +) -> "CLIP_ResNet50x4Text": + """ + The text portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable + Visual Models From Natural Language Supervision': https://arxiv.org/abs/2103.00020 + + This model can be combined with the CLIP ResNet 50x4 Image model to create the full + CLIP ResNet 50x4 model. + + Example:: + + >>> model = opt.models.clip_resnet50x4_text(pretrained=True) + >>> clip_tokenizer = opt.transforms.CLIPTokenizer(pretrained_merges=True) + >>> tokenized_input = clip_tokenizer("Some example text.") + >>> output = model(tokenized_input) + + See here for more details: + https://github.com/openai/CLIP + https://github.com/mlfoundations/open_clip + + Args: + + pretrained (bool, optional): If ``True``, returns a pre-trained model. + Default: ``False`` + progress (bool, optional): If ``True``, displays a progress bar of the download + to stderr. + Default: ``True`` + model_path (str, optional): Optional path for the model file. + Default: ``None`` + width (int, optional): The desired width size to use for the model. + Default: ``640`` + num_heads (int, optional): The number of heads to use for the model. + Default: ``10`` + num_residual_layers (int, optional): The number of residual layers to use for + each residual attention block in the model. + Default: ``12`` + content_length (int, optional): The expected size of text inputs to the model. + Default: ``77`` + vocab_size (int, optional): The size of the vocab used to train the model. + Default: ``49408`` + + Returns: + model (CLIP_ResNet50x4Text): An instance of a CLIP ResNet 50x4 model's text + portion. + """ + if pretrained: + model = CLIP_ResNet50x4Text(**kwargs) + + if model_path is None: + state_dict = torch.hub.load_state_dict_from_url( + GS_SAVED_WEIGHTS_URL, progress=progress, check_hash=False + ) + else: + state_dict = torch.load(model_path, map_location="cpu") + model.load_state_dict(state_dict) + return model + + return CLIP_ResNet50x4Text(**kwargs) + + +class CLIP_ResNet50x4Text(nn.Module): + """ + The text portion of OpenAI's ResNet 50x4 CLIP model from 'Learning Transferable + Visual Models From Natural Language Supervision': https://arxiv.org/abs/2103.00020 + """ + + def __init__( + self, + width: int = 640, + num_heads: int = 10, + num_residual_layers: int = 12, + content_length: int = 77, + vocab_size: int = 49408, + ) -> None: + """ + Args: + + width (int, optional): The desired width size to use for the model. + Default: ``640`` + num_heads (int, optional): The num number of heads to use for the model. + Default: ``10`` + num_residual_layers (int, optional): The number of residual layers to use + for each residual attention block. + Default: ``12`` + content_length (int, optional): The expected size of text inputs to the + model. + Default: ``77`` + vocab_size (int, optional): The size of the vocab used to train the model. + Default: ``49408`` + """ + super().__init__() + self.transformer = nn.Sequential( + *[ + ResidualAttentionBlock(width, num_heads, content_length) + for _ in range(num_residual_layers) + ] + ) + self.token_embedding = nn.Embedding(vocab_size, width) + self.positional_embedding = nn.Parameter(torch.empty(content_length, width)) + self.ln_final = nn.LayerNorm(width) + self.text_projection = nn.Parameter(torch.empty(width, width)) + + # logit_scale is only used when combining Text & Image models + self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1 / 0.07)) + + def forward(self, text: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to run through the model. + + Returns: + x (torch.Tensor): The model output. + """ + x = self.token_embedding(text) + x = x + self.positional_embedding.to(device=x.device, dtype=x.dtype) + x = self.transformer(x.permute(1, 0, 2)).permute(1, 0, 2) + x = self.ln_final(x) + x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] + return x @ self.text_projection.to(device=x.device, dtype=x.dtype) + + +class QuickGELU(nn.Module): + """ + OpenAI's models use a slightly different GELU than PyTorch's default GELU. + """ + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to run through the module. + + Returns: + x (torch.Tensor): The module output. + """ + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__( + self, width: int = 640, num_heads: int = 10, content_length: int = 77 + ) -> None: + """ + Args: + + width (int, optional): The desired width size to use. + Default: ``640`` + num_heads (int, optional): The num number of heads to use. + Default: ``10`` + content_length (int, optional): The desired ``content_length`` to use. + Default: ``77`` + """ + super().__init__() + self.attn = nn.MultiheadAttention(width, num_heads) + self.ln_1 = nn.LayerNorm(width) + self.mlp = nn.Sequential( + nn.Linear(width, width * 4), QuickGELU(), nn.Linear(width * 4, width) + ) + self.ln_2 = nn.LayerNorm(width) + self.attn_mask = ( + torch.empty(content_length, content_length).fill_(float("-inf")).triu_(1) + ) + + def attention(self, x: torch.Tensor) -> torch.Tensor: + attn_mask = self.attn_mask.to(device=x.device, dtype=x.dtype) + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + + x (torch.Tensor): An input tensor to run through the module. + + Returns: + x (torch.Tensor): The module output. + """ + x = x + self.attention(self.ln_1(x)) + return x + self.mlp(self.ln_2(x)) diff --git a/tests/optim/models/test_clip_resnet50x4_image.py b/tests/optim/models/test_clip_resnet50x4_image.py new file mode 100644 index 0000000000..ab5f22e52c --- /dev/null +++ b/tests/optim/models/test_clip_resnet50x4_image.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +import unittest + +import torch +from captum.optim.models import clip_resnet50x4_image +from captum.optim.models._common import RedirectedReluLayer, SkipLayer +from packaging import version +from tests.helpers.basic import BaseTest, assertTensorAlmostEqual +from tests.optim.helpers.models import check_layer_in_model + + +class TestCLIPResNet50x4Image(BaseTest): + def test_load_clip_resnet50x4_image_with_redirected_relu(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping load pretrained CLIP ResNet 50x4 Image due to insufficient" + + " Torch version." + ) + model = clip_resnet50x4_image( + pretrained=True, replace_relus_with_redirectedrelu=True + ) + self.assertTrue(check_layer_in_model(model, RedirectedReluLayer)) + + def test_load_clip_resnet50x4_image_no_redirected_relu(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping load pretrained CLIP ResNet 50x4 Image RedirectedRelu test" + + " due to insufficient Torch version." + ) + model = clip_resnet50x4_image( + pretrained=True, replace_relus_with_redirectedrelu=False + ) + self.assertFalse(check_layer_in_model(model, RedirectedReluLayer)) + self.assertTrue(check_layer_in_model(model, torch.nn.ReLU)) + + def test_load_clip_resnet50x4_image_linear(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping load pretrained CLIP ResNet 50x4 Image linear test due to" + + " insufficient Torch version." + ) + model = clip_resnet50x4_image(pretrained=True, use_linear_modules_only=True) + self.assertFalse(check_layer_in_model(model, RedirectedReluLayer)) + self.assertFalse(check_layer_in_model(model, torch.nn.ReLU)) + self.assertTrue(check_layer_in_model(model, SkipLayer)) + + def test_clip_resnet50x4_image_transform(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping CLIP ResNet 50x4 Image internal transform test due to" + + " insufficient Torch version." + ) + x = torch.randn(1, 3, 288, 288).clamp(0, 1) + model = clip_resnet50x4_image(pretrained=True) + output = model._transform_input(x) + expected_output = x.clone() - torch.tensor( + [0.48145466, 0.4578275, 0.40821073] + ).view(3, 1, 1) + expected_output = expected_output / torch.tensor( + [0.26862954, 0.26130258, 0.27577711] + ).view(3, 1, 1) + assertTensorAlmostEqual(self, output, expected_output, 0) + + def test_clip_resnet50x4_image_transform_warning(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping CLIP ResNet 50x4 Image internal transform warning test due" + + " to insufficient Torch version." + ) + x = torch.stack( + [torch.ones(3, 288, 288) * -1, torch.ones(3, 288, 288) * 2], dim=0 + ) + model = clip_resnet50x4_image(pretrained=True) + with self.assertWarns(UserWarning): + model._transform_input(x) + + def test_clip_resnet50x4_image_load_and_forward(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping basic pretrained CLIP ResNet 50x4 Image forward test due to" + + " insufficient Torch version." + ) + x = torch.zeros(1, 3, 288, 288) + model = clip_resnet50x4_image(pretrained=True, use_attnpool=True) + output = model(x) + self.assertEqual(list(output.shape), [1, 640]) + self.assertTrue(model.use_attnpool) + + def test_untrained_clip_resnet50x4_image_load_and_forward(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping basic untrained CLIP ResNet 50x4 Image forward test due to" + + " insufficient Torch version." + ) + x = torch.zeros(1, 3, 288, 288) + model = clip_resnet50x4_image(pretrained=False, use_attnpool=True) + output = model(x) + self.assertEqual(list(output.shape), [1, 640]) + self.assertTrue(model.use_attnpool) + + def test_clip_resnet50x4_image_warning(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping pretrained CLIP ResNet 50x4 Image transform input" + + " warning test due to insufficient Torch version." + ) + x = torch.stack( + [torch.ones(3, 288, 288) * -1, torch.ones(3, 288, 288) * 2], dim=0 + ) + model = clip_resnet50x4_image(pretrained=True) + with self.assertWarns(UserWarning): + _ = model._transform_input(x) + + def test_clip_resnet50x4_image_use_attnpool_false(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping basic pretrained CLIP ResNet 50x4 Image use_attnpool" + + " forward due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 288, 288) + model = clip_resnet50x4_image(pretrained=True, use_attnpool=False) + output = model(x) + self.assertEqual(list(output.shape), [1, 2560, 9, 9]) + self.assertFalse(model.use_attnpool) + + def test_clip_resnet50x4_image_use_attnpool_false_size_128(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping basic pretrained CLIP ResNet 50x4 Image use_attnpool" + + " forward with 128x128 input due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 128, 128) + model = clip_resnet50x4_image(pretrained=True, use_attnpool=False) + output = model(x) + self.assertEqual(list(output.shape), [1, 2560, 4, 4]) + self.assertFalse(model.use_attnpool) + + def test_clip_resnet50x4_image_forward_cuda(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping pretrained CLIP ResNet 50x4 Image forward CUDA test due to" + + " insufficient Torch version." + ) + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping pretrained CLIP ResNet 50x4 Image forward CUDA test due to" + + " not supporting CUDA." + ) + x = torch.zeros(1, 3, 288, 288).cuda() + model = clip_resnet50x4_image(pretrained=True, use_attnpool=True).cuda() + output = model(x) + + self.assertTrue(output.is_cuda) + self.assertEqual(list(output.shape), [1, 640]) + self.assertTrue(model.use_attnpool) + + def test_clip_resnet50x4_image_jit_module_no_redirected_relu(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.8.0"): + raise unittest.SkipTest( + "Skipping pretrained CLIP ResNet 50x4 Image load & JIT module with" + + " no redirected relu test due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 288, 288) + model = clip_resnet50x4_image( + pretrained=True, replace_relus_with_redirectedrelu=False, use_attnpool=True + ) + jit_model = torch.jit.script(model) + output = jit_model(x) + self.assertEqual(list(output.shape), [1, 640]) + self.assertTrue(model.use_attnpool) + + def test_clip_resnet50x4_image_jit_module_with_redirected_relu(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.8.0"): + raise unittest.SkipTest( + "Skipping pretrained CLIP ResNet 50x4 Image load & JIT module with" + + " redirected relu test due to insufficient Torch version." + ) + x = torch.zeros(1, 3, 288, 288) + model = clip_resnet50x4_image( + pretrained=True, replace_relus_with_redirectedrelu=True, use_attnpool=True + ) + jit_model = torch.jit.script(model) + output = jit_model(x) + self.assertEqual(list(output.shape), [1, 640]) + self.assertTrue(model.use_attnpool) diff --git a/tests/optim/models/test_clip_resnet50x4_text.py b/tests/optim/models/test_clip_resnet50x4_text.py new file mode 100644 index 0000000000..3d7f9d7cd5 --- /dev/null +++ b/tests/optim/models/test_clip_resnet50x4_text.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +import unittest + +import torch +from captum.optim.models import clip_resnet50x4_text +from packaging import version +from tests.helpers.basic import BaseTest, assertTensorAlmostEqual + + +class TestCLIPResNet50x4Text(BaseTest): + def test_clip_resnet50x4_text_logit_scale(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping basic pretrained CLIP ResNet 50x4 Text logit scale test due" + + " to insufficient Torch version." + ) + model = clip_resnet50x4_text(pretrained=True) + expected_logit_scale = torch.tensor(4.605170249938965) + assertTensorAlmostEqual(self, model.logit_scale, expected_logit_scale) + + def test_clip_resnet50x4_text_load_and_forward(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping basic pretrained CLIP ResNet 50x4 Text forward test due to" + + " insufficient Torch version." + ) + # Start & End tokens: 49405, 49406 + x = torch.cat([torch.tensor([49405, 49406]), torch.zeros(77 - 2)]) + x = x[None, :].long() + model = clip_resnet50x4_text(pretrained=True) + output = model(x) + self.assertEqual(list(output.shape), [1, 640]) + + def test_clip_resnet50x4_text_forward_cuda(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.6.0"): + raise unittest.SkipTest( + "Skipping pretrained CLIP ResNet 50x4 Text forward CUDA test due to" + + " insufficient Torch version." + ) + if not torch.cuda.is_available(): + raise unittest.SkipTest( + "Skipping pretrained CLIP ResNet 50x4 Text forward CUDA test due to" + + " not supporting CUDA." + ) + x = torch.cat([torch.tensor([49405, 49406]), torch.zeros(77 - 2)]).cuda() + x = x[None, :].long() + model = clip_resnet50x4_text(pretrained=True).cuda() + output = model(x) + + self.assertTrue(output.is_cuda) + self.assertEqual(list(output.shape), [1, 640]) + + def test_clip_resnet50x4_text_jit_module(self) -> None: + if version.parse(torch.__version__) <= version.parse("1.8.0"): + raise unittest.SkipTest( + "Skipping pretrained CLIP ResNet 50x4 Text load & JIT module" + + " test due to insufficient Torch version." + ) + x = torch.cat([torch.tensor([49405, 49406]), torch.zeros(77 - 2)]) + x = x[None, :].long() + model = clip_resnet50x4_text(pretrained=True) + jit_model = torch.jit.script(model) + output = jit_model(x) + self.assertEqual(list(output.shape), [1, 640])