Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with perturbations on some elements in x_L and x_U #67

Open
nnjnjn opened this issue Mar 16, 2024 · 1 comment
Open

Error with perturbations on some elements in x_L and x_U #67

nnjnjn opened this issue Mar 16, 2024 · 1 comment

Comments

@nnjnjn
Copy link

nnjnjn commented Mar 16, 2024

I extracted a small subnet from ResNet18 as my target network. I specified the upper and lower bounds of the perturbation and used auto_LiRPA to calculate the bounds. However, some errors occurred. Below are my errors and code. What I don't understand is that when I use torch.rand_like to generate bound1 and bound2, no errors occur.

import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
from auto_LiRPA import BoundedModule, BoundedTensor
from auto_LiRPA.perturbations import PerturbationLpNorm

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )
        self.show = False
        
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.flatten = nn.Flatten()
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avgpool(out)
        out = self.flatten(out)
        out = self.linear(out)
        return out

    def split(self):
        return nn.Sequential(
            self.conv1, 
            self.bn1, 
            nn.ReLU(),
            self.layer1, self.layer2, self.layer3, self.layer4[0]), nn.Sequential(self.layer4[1],
            self.avgpool,
            self.flatten,
            self.linear)

def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

def test():
    net = ResNet18()
    net.eval()
    split_before = net.split()[0]
    split_after = net.split()[1]
    h = split_before(torch.randn(1, 3, 32, 32))
   
    device = 'cuda:0'
    lirpa_model = BoundedModule(split_after, torch.empty_like(h), device=device)

    bound1 = torch.zeros_like(h).to(device)
    bound2 = torch.ones_like(h).to(device)
    ptb = PerturbationLpNorm(x_L=bound1, x_U=bound2)
    true_input = BoundedTensor(h, ptb)

    required_A = defaultdict(set)
    required_A[lirpa_model.output_name[0]].add(lirpa_model.input_name[0])
    lb, ub, A_dict = lirpa_model.compute_bounds(x=(true_input,), method='backward', return_A=True,
                                                needed_A_dict=required_A)
    print(lb)
    print(ub)

if __name__ == "__main__":
    test()

And my error:

Traceback (most recent call last):
    lb, ub, A_dict = lirpa_model.compute_bounds(x=(true_input,), method='backward', return_A=True,
  File "/home/hdu/.conda/envs/vere/lib/python3.9/site-packages/auto_LiRPA-0.4.0-py3.9.egg/auto_LiRPA/bound_general.py", line 1206, in compute_bounds
    return self._compute_bounds_main(C=C,
  File "/home/hdu/.conda/envs/vere/lib/python3.9/site-packages/auto_LiRPA-0.4.0-py3.9.egg/auto_LiRPA/bound_general.py", line 1303, in _compute_bounds_main
    self.check_prior_bounds(final)
  File "/home/hdu/.conda/envs/vere/lib/python3.9/site-packages/auto_LiRPA-0.4.0-py3.9.egg/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/home/hdu/.conda/envs/vere/lib/python3.9/site-packages/auto_LiRPA-0.4.0-py3.9.egg/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/home/hdu/.conda/envs/vere/lib/python3.9/site-packages/auto_LiRPA-0.4.0-py3.9.egg/auto_LiRPA/bound_general.py", line 800, in check_prior_bounds
    self.check_prior_bounds(n)
  File "/home/hdu/.conda/envs/vere/lib/python3.9/site-packages/auto_LiRPA-0.4.0-py3.9.egg/auto_LiRPA/bound_general.py", line 804, in check_prior_bounds
    self.compute_intermediate_bounds(
  File "/home/hdu/.conda/envs/vere/lib/python3.9/site-packages/auto_LiRPA-0.4.0-py3.9.egg/auto_LiRPA/bound_general.py", line 904, in compute_intermediate_bounds
    node.lower, node.upper, _ = self.backward_general(
  File "/home/hdu/.conda/envs/vere/lib/python3.9/site-packages/auto_LiRPA-0.4.0-py3.9.egg/auto_LiRPA/backward_bound.py", line 326, in backward_general
    lb, ub = concretize(self, batch_size, output_dim, lb, ub,
  File "/home/hdu/.conda/envs/vere/lib/python3.9/site-packages/auto_LiRPA-0.4.0-py3.9.egg/auto_LiRPA/backward_bound.py", line 686, in concretize
    lb = lb + roots[i].perturbation.concretize(
RuntimeError: The size of tensor a (4) must match the size of tensor b (8192) at non-singleton dimension 3
@nnjnjn
Copy link
Author

nnjnjn commented Mar 16, 2024

Besides, I did some simple debugging and found that the variables lower_b and upper_b in the function backward_general in backward_bound.py seems to be different shape, when I use torch.zeros_like(h), torch.ones_like(h) and torch.rand_like to generate bound.

I also found that if I set the bias arguments of the convolutional layers in BasicBlock to True, the error seems to disappear. Can you explain why this is?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant