Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError occurs when parameter "layer_scale" is used in torch #2103

Open
nalnez13 opened this issue Dec 30, 2024 · 1 comment
Open

ValueError occurs when parameter "layer_scale" is used in torch #2103

nalnez13 opened this issue Dec 30, 2024 · 1 comment
Assignees

Comments

@nalnez13
Copy link

If I define a parameter with the same name as "layer_scale" in the pytorch nn.Module, as shown in the following code, a ValueError occurs.

class ConvEncoder(nn.Module):
    """
    Implementation of ConvEncoder with 3*3 and 1*1 convolutions.
    Input: tensor with shape [B, C, H, W]
    Output: tensor with shape [B, C, H, W]
    """

    def __init__(
        self, dim, hidden_dim=64, kernel_size=3, drop_path=0.0, use_layer_scale=True
    ):
        super().__init__()
        self.dwconv = nn.Conv2d(
            dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, groups=dim
        )
        self.norm = nn.BatchNorm2d(dim)
        self.pwconv1 = nn.Conv2d(dim, hidden_dim, kernel_size=1)
        self.act = nn.GELU()
        self.pwconv2 = nn.Conv2d(hidden_dim, dim, kernel_size=1)
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
        self.use_layer_scale = use_layer_scale
        if use_layer_scale:
            self.layer_scale = nn.Parameter(
                torch.ones(dim).unsqueeze(-1).unsqueeze(-1), requires_grad=True
            )
        self.apply(self._init_weights)

In "adaptor/pytorch" Line 4174

        for node in model.graph.nodes:
            if node.op == "get_attr":
                if prefix:
                    sub_name = prefix + "--" + node.target
                else:
                    sub_name = node.target
                if not hasattr(model, node.target):
                    continue
                if "scale" in node.target: #### This condition is not suitable
                    tune_cfg["get_attr"][sub_name] = float(getattr(model, node.target))
                elif "zero_point" in node.target:
                    tune_cfg["get_attr"][sub_name] = int(getattr(model, node.target))
                else:
                    pass
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/utils/utility.py", line 347, in fi
    res = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 3658, in quantize
    self._get_scale_zeropoint(q_model._model, q_model.q_config)
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4217, in _get_scale_zeropoint
    self._get_sub_module_scale_zeropoint(model, tune_cfg)
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4199, in _get_sub_module_scale_zeropoint
    self._get_sub_module_scale_zeropoint(module, tune_cfg, op_name)
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4199, in _get_sub_module_scale_zeropoint
    self._get_sub_module_scale_zeropoint(module, tune_cfg, op_name)
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4199, in _get_sub_module_scale_zeropoint
    self._get_sub_module_scale_zeropoint(module, tune_cfg, op_name)
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4197, in _get_sub_module_scale_zeropoint
    self._get_module_scale_zeropoint(module, tune_cfg, op_name)
  File "/root/.pyenv/versions/3.11.11/lib/python3.11/site-packages/neural_compressor/adaptor/pytorch.py", line 4175, in _get_module_scale_zeropoint
    tune_cfg["get_attr"][sub_name] = float(getattr(model, node.target))
                                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: only one element tensors can be converted to Python scalars

When I check the string "node.target" and the tensor value, it treats the layer_scale like a scale of quantization, as follows.

feature_extractor.patch_embed--0_input_scale_0 tensor(0.0355)
feature_extractor.network.0.0--dwconv_input_scale_0 tensor(0.0338)
feature_extractor.network.0.0--pwconv2_input_scale_0 tensor(0.2263)
feature_extractor.network.0.0--layer_scale Parameter containing:
tensor([[[ 0.0336]],

        [[ 0.0153]],

        [[ 0.0214]],

        [[ 0.0068]],

        [[ 0.0229]],

        [[ 0.0136]],

        [[ 0.0491]],

        [[ 0.0202]],

        [[ 0.0420]],

        [[ 0.0495]],

        [[ 0.0060]],

        [[ 0.0225]],

        [[ 0.0311]],

        [[ 0.0303]],

        [[ 0.0556]],

        [[ 0.0290]],

        [[ 0.0222]],

        [[ 0.0153]],

        [[ 0.0332]],

        [[ 0.0667]],

        [[ 0.0168]],

        [[ 0.0416]],

        [[ 0.0258]],

        [[ 0.0200]],

        [[ 0.0259]],

        [[ 0.0044]],

        [[ 0.0514]],

        [[ 0.0190]],

        [[ 0.0545]],

        [[ 0.0119]],

        [[ 0.0220]],

        [[ 0.0481]],

        [[ 0.0115]],

        [[ 0.0707]],

        [[ 0.0299]],

        [[ 0.0105]],

        [[ 0.0266]],

        [[ 0.0156]],

        [[ 0.0380]],

        [[ 0.0160]],

        [[ 0.0521]],

        [[ 0.0094]],

        [[-0.0133]],

        [[ 0.0585]],

        [[ 0.0216]],

        [[ 0.0102]],

        [[ 0.0297]],

        [[ 0.0104]]], requires_grad=True)

Modifying the conditional statement as below fixes the problem, but it doesn't seem to be a perfect way.

if "scale" in node.target and "layer_scale" not in node.target:
@xin3he xin3he self-assigned this Dec 30, 2024
@xin3he
Copy link
Contributor

xin3he commented Dec 30, 2024

Hi @nalnez13 , Thanks for the root cause.
I agree there should be a better way to distinguish module specific scale names from input_scales. See below for the case where I prefer to use _input_zero_point_ and _input_scale_.
https://github.com/pytorch/pytorch/blob/2ed4d65af0a1993c0df7b081f4088d0f3614283e/torch/ao/ns/fx/graph_passes.py#L211

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants