Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Opacus for 3D Segmentation #671

Open
Ziva1011 opened this issue Sep 9, 2024 · 2 comments
Open

Opacus for 3D Segmentation #671

Ziva1011 opened this issue Sep 9, 2024 · 2 comments

Comments

@Ziva1011
Copy link

Ziva1011 commented Sep 9, 2024

Dear Opacus,

I've been looking into 3D segmentation models for medical imaging. I use several architectures, half of them from the Monai library.

When I try to run Opacus in combination with these architectures I get the following error:

Traceback (most recent call last):
  File "/vol/aimspace/users/viulapir/Documents/dp-thesis/test2.py", line 473, in main
    training_losses, validation_losses, lr_rates = trainer.run_trainer()
  File "/vol/aimspace/users/viulapir/Documents/dp-thesis/trainer.py", line 48, in run_trainer
    self._train()
  File "/vol/aimspace/users/viulapir/Documents/dp-thesis/trainer.py", line 106, in _train
    self.optimizer.step()  # update the parameters
  File "/u/home/viulapir/.conda/envs/torchsegmentation2/lib/python3.10/site-packages/opacus/optimizers/optimizer.py", line 553, in step
    if self.pre_step():
  File "/u/home/viulapir/.conda/envs/torchsegmentation2/lib/python3.10/site-packages/opacus/optimizers/optimizer.py", line 536, in pre_step
    if self.grad_samples is None or len(self.grad_samples) == 0:
  File "/u/home/viulapir/.conda/envs/torchsegmentation2/lib/python3.10/site-packages/opacus/optimizers/optimizer.py", line 342, in grad_samples
    ret.append(self._get_flat_grad_sample(p))
  File "/u/home/viulapir/.conda/envs/torchsegmentation2/lib/python3.10/site-packages/opacus/optimizers/optimizer.py", line 279, in _get_flat_grad_sample
    raise ValueError(
ValueError: Per sample gradient is not initialized. Not updated in backward pass?

The model that generated this error above was the Vnet (but this happens with other 3D architectures too).

I now that Monai uses batch normalization layers so I used ModuleValidator.fix(model) to convert them to group normalization layers. Even though I read that the Monai models have their weights initialized already, I also tried to initialize the weights using torch.init but it produced the same error.

Simplified version of the code is as below.

model= VNet( spatial_dims=3, in_channels=1, out_channels=3)
model = ModuleValidator.fix(model)
        optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

        privacy_engine = PrivacyEngine(accountant="gdp")
      
        
        model, optimizer, train_dl = privacy_engine.make_private_with_epsilon(
            module=model,
            optimizer=optimizer,
            data_loader=train_dl,
            target_epsilon=8,
            target_delta=1e-2,
            epochs=epochs,
            max_grad_norm=1
        )
        train_dl.collate_fn = wrap_collate_with_empty(
            collate_fn=list_data_collate,
            sample_empty_shapes={x: train_dl.dataset[0][x].shape for x in ["img", "seg"]},
            dtypes={x: train_dl.dataset[0][x].dtype for x in ["img", "seg"]},
        )

Not sure if this a feature you havent implemented, a bug or something I am missing here but there doesnt seem to be any type of initialization that works in combination with Opacus.

Thank you!

@EnayatUllah
Copy link
Contributor

Thanks for raising the issue. Can you use the bug report Colab to reproduce the issue?

@iden-kalemaj
Copy link
Contributor

iden-kalemaj commented Oct 2, 2024

Thanks again for raising this issue.

Please see this colab that reproduces the issue and suggests a fix.

In summary, we observe that the issue is caused due to the fact that grad_sample is not being computed for the parameters of the GroupNorm layers. Opacus requires all grad_samples to be computed for the model parameters that have requires_grad = True (this is the case for the GroupNorm layer parameters).

We are still not certain of the source of the issue, but found that changing the inplace parameter of the ELU activation layers from True to False fixes the issue. Please see the colab for how the parameters can be set to False.

For further context, when debugging with a similar segmentation network to VNet, even training without privacy gives an error during the backward pass due to the inplace=True. Thus the issue is not be limited to Opacus.

Please let us know if this fixes your issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants