Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model and checkpoint mismatch for BEVFusion with pre-trained weights in BEVFusion #3081

Closed
3 tasks done
ZhanghaiFA opened this issue Feb 18, 2025 · 2 comments
Closed
3 tasks done

Comments

@ZhanghaiFA
Copy link

Prerequisite

Task

I'm using the official example scripts/configs for the officially supported tasks/models/datasets.

Branch

main branch https://github.com/open-mmlab/mmdetection3d

Environment

mmcv 2.0.0rc4
mmdet 3.0.0
mmengine 0.10.6
torch 1.10.0+cu111
torchaudio 0.10.0+rocm4.1
torchvision 0.11.0+cu111

Reproduces the problem - code sample

init_cfg=dict(
type='Pretrained',
checkpoint= ...pretrain/swint-nuimages-pretrained.pth

    )),

Reproduces the problem - command or script

bash tools/dist_train.sh projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py 1 --cfg-options load_from='/home/zhf/mmdet_bev_mit/mmdetection3d/pretrain/bevfusion_lidar_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d-2628f933.pth' model.img_backbone.init_cfg.checkpoint='/home/zhf/mmdet_bev_mit/mmdetection3d/pretrain/swint-nuimages-pretrained.pth'

Reproduces the problem - error message

I am currently using BEVFusion for multimodal training.I encountered a model mismatch error while loading the pre-trained weights for BEVFusion. The issue arises when I try to load the checkpoint from the official repository (bevfusion_lidar_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d-2628f933.pth and swint-nuimages-pretrained.pth), and I am receiving the following error message:

The model and loaded state dict do not match exactly.

size mismatch for pts_middle_encoder.conv_input.0.weight: copying a param with shape torch.Size([16, 3, 3, 3, 5]) from checkpoint, the shape in current model is torch.Size([3, 3, 3, 5, 16]).
...
missing keys in source state_dict: img_backbone.patch_embed.projection.weight, img_backbone.patch_embed.projection.bias, img_backbone.patch_embed.norm.weight, img_backbone.patch_embed.norm.bias, img_backbone.stages.0.blocks.0.norm1.weight, img_backbone.stages.0.blocks.0.norm1.bias, img_backbone.stages.0.blocks.0.attn.w_msa.relative_position_bias_table, img_backbone.stages.0.blocks.0.attn.w_msa.relative_position_index, img_backbone.stages.0.blocks.0.attn.w_msa.qkv.weight, img_backbone.stages.0.blocks.0.attn.w_msa.qkv.bias, img_backbone.stages.0.blocks.0.attn.w_msa.proj.weight, img_backbone.stages.0.blocks.0.attn.w_msa.proj.bias, img_backbone.stages.0.blocks.0.norm2.weight, img_backbone.stages.0.blocks.0.norm2.bias, img_backbone.stages.0.blocks.0.ffn.layers.0.0.weight, img_backbone.stages.0.blocks.0.ffn.layers.0.0.bias, img_backbone.stages.0.blocks.0.ffn.layers.1.weight, img_backbone.stages.0.blocks.0.ffn.layers.1.bias, img_backbone.stages.0.blocks.1.norm1.weight, img_backbone.stages.0.blocks.1.norm1.bias, img_backbone.stages.0.blocks.1.attn.w_msa.relative_position_bias_table, img_backbone.stages.0.blocks.1.attn.w_msa.relative_position_index, img_backbone.stages.0.blocks.1.attn.w_msa.qkv.weight, img_backbone.stages.0.blocks.1.attn.w_msa.qkv.bias, img_backbone.stages.0.blocks.1.attn.w_msa.proj.weight, img_backbone.stages.0.blocks.1.attn.w_msa.proj.bias, img_backbone.stages.0.blocks.1.norm2.weight, img_backbone.stages.0.blocks.1.norm2.bias, img_backbone.stages.0.blocks.1.ffn.layers.0.0.weight, img_backbone.stages.0.blocks.1.ffn.layers.0.0.bias, img_backbone.stages.0.blocks.1.ffn.layers.1.weight, img_backbone.stages.0.blocks.1.ffn.layers.1.bias, img_backbone.stages.0.downsample.norm.weight, img_backbone.stages.0.downsample.norm.bias, img_backbone.stages.0.downsample.reduction.weight, img_backbone.stages.1.blocks.0.norm1.weight, img_backbone.stages.1.blocks.0.norm1.bias, img_backbone.stages.1.blocks.0.attn.w_msa.rela
(many more size mismatch errors)
...
missing keys in source state_dict: img_backbone.patch_embed.projection.weight, img_backbone.patch_embed.projection.bias, ...

Additional information

I wonder if this is normal.

@AinaraC
Copy link

AinaraC commented Feb 19, 2025

Hi!

I encountered a similar error when using MVX-Net pre-trained weights. It seems to be an issue with the layers' shape --> torch.Size([16, 3, 3, 3, 5]) from checkpoint, the shape in current model is torch.Size([3, 3, 3, 5, 16]).

Try the following code to solve it:

import torch

path = ''/home/zhf/mmdet_bev_mit/mmdetection3d/pretrain/swint-nuimages-pretrained.pth'
model = torch.load(path)

def transpose_weights(model, layer_names):
    for layer in layer_names:
        if layer in model['state_dict']:
            weight = model['state_dict'][layer]
            # Transpose weights from [N, C, D, H, W] to [C, D, H, W, N] (from torch.Size([16, 3, 3, 3, 5] to  torch.Size([3, 3, 3, 5, 16])
            weight = weight.permute(1, 2, 3, 4, 0) 
            model['state_dict'][layer] = weight 
   
# Fill with BEVFusion's layers, an example:
layer_names = [
pts_middle_encoder.conv_input.0.weight
]

transpose_weights(model, layer_names)
torch.save(model, './bevfusion_fixed.pth')

Don´t forget to complete layer_names with your model layers.
I hope it helps! Let me know if it works.

@ZhanghaiFA
Copy link
Author

Hi!

I encountered a similar error when using MVX-Net pre-trained weights. It seems to be an issue with the layers' shape --> torch.Size([16, 3, 3, 3, 5]) from checkpoint, the shape in current model is torch.Size([3, 3, 3, 5, 16]).

Try the following code to solve it:

import torch

path = ''/home/zhf/mmdet_bev_mit/mmdetection3d/pretrain/swint-nuimages-pretrained.pth'
model = torch.load(path)

def transpose_weights(model, layer_names):
    for layer in layer_names:
        if layer in model['state_dict']:
            weight = model['state_dict'][layer]
            # Transpose weights from [N, C, D, H, W] to [C, D, H, W, N] (from torch.Size([16, 3, 3, 3, 5] to  torch.Size([3, 3, 3, 5, 16])
            weight = weight.permute(1, 2, 3, 4, 0) 
            model['state_dict'][layer] = weight 
   
# Fill with BEVFusion's layers, an example:
layer_names = [
pts_middle_encoder.conv_input.0.weight
]

transpose_weights(model, layer_names)
torch.save(model, './bevfusion_fixed.pth')

Don´t forget to complete layer_names with your model layers. I hope it helps! Let me know if it works.

Thank you for your suggestion, but unfortunately, it didn't resolve the issue. However, your input did help me find the root cause. I first printed out the names and shapes of the weights in both the pre-trained model and my model. I found that while the shapes were indeed the same, I was surprised to discover that the weights themselves were also identical. I had originally thought that the pre-trained model wasn’t properly loading the weights, but it turns out it was actually working.

This led me to think about why the warning was still showing up. My model uses two pre-trained models: one for the camera and one for the radar. Upon investigating the radar model’s weight loading, I found that the warning was originating from the radar pre-trained model. It turns out that the radar model didn’t have the camera network’s weights, so even though the warning mentioned a mismatch with the camera network, the actual issue was related to the radar model.

The error occurs in the load_state_dict method of the mmcv.runner.checkpoint.py file, specifically in this function:
module._load_from_state_dict(state_dict, prefix, local_metadata, True,
all_missing_keys, unexpected_keys,
err_msg)

This method calls _load_from_state_dict and forces strict=True, which causes warnings to appear. However, these warnings don’t actually affect the loading of the weights.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants