Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ML Decoder #2045

Draft
wants to merge 50 commits into
base: main
Choose a base branch
from
Draft

Conversation

fffffgggg54
Copy link
Contributor

@fffffgggg54 fffffgggg54 commented Nov 27, 2023

Update ML Decoder's TransformerDecoderLayerOptimal module to comply with what nn.TransformerDecoder expects. Current changes work with resnet50.

add_ml_decoder_head needs to be updated for other models. In my limited testing, the following case works with RegNet:

elif hasattr(model, 'head'):    # ClassifierHead and ConvNext
    if hasattr(model.head, 'flatten'):  # ConvNext case
        model.head.flatten = nn.Identity()
    model.head.global_pool = nn.Identity()
    del model.head.fc
    num_classes = model.num_classes
    num_features = model.num_features
    model.head.fc = MLDecoder(num_classes=num_classes, initial_num_features=num_features)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@fffffgggg54
Copy link
Contributor Author

Currently have 3 variants: original with compatibility changes for pt 2.1+, a version identical to the original but with a performance fix that's giving ~35% training speedup, and a WIP reimplementation that updates styling and changes the decoder block implementation. The last one involves architectural changes that removes some of odd components from #1012 (residual on dropouts and queries, mlp residual location, norm locations). I'm in the process of testing these changes.

@mrT23 do you have any comments? @rwightman are you aware of any pytorch transformer decoder/cross attention implementations that follow the style of this library or that I could reference? I inferred the cross attention impl based on the ViT attn impl and the block structure from the original impl, but I would rather this impl follow something standard (less the the self attn in the decoder), than introduce other odd implementation choices.

Model compatibility is iffy, there are a few with odd architectures (combining multiple feature maps, distillation architectures, etc) that would be a pain to special case and probably won't be used and a few other more prominent architectures that don't work because they use nhwc. Overall seems like it would be difficult to maintain along with #2048. @rwightman Do you think a revised classifierhead that supports additional pooling and head formats would be better? There are quite a few structures that can be placed there (pool->ffn, various ml-decoder-like mechanisms). Since many models already have forward split into forward_features and forward_head, I hope that a change like this will help abstract the design of the head from the model and provide a convenient and unified way to modify/swap the head.

@rwightman
Copy link
Collaborator

@fffffgggg54 curious what your goals are for this impl, what sort of applications, etc.

  1. For cross attn, don't have many examples in this codebase, but prefer to keep the naming of inputs parallel... ie x_q/x_kv, inputs_q/inputs_kv

  2. Making these sorts of pool + head or pool only additions universally applicable to all the timm models is a bit of a pain due to decisions in the past to keep some compatibility with original weights, would have been easier if I'd pulled all head related weights out into their own submodule for ALL models. Still have some ideas to work on here but currently focused on a document ai project.

  3. The most common type of attention pooling / head impl I've been working with in recent months is the attention pooling of the CLIP variety (attention_pool2d.py) and especially SigLIP style with a latent q (attention_pool.py), which isn't too far different from the perceiver resampler. Been fiddling with the latent style with a non singular q seq_len and qk norms for another project...

@fffffgggg54
Copy link
Contributor Author

Goals are primarily performance, compatibility, and consistent styling with newer timm implementations. Legacy version provides support and improved performance and reimplementation attempts to match other timm models and removes nn.TransformerDecoder and nn.MultiheadAttention. Original impl gives me around 2300 img/s, new impls give me 2850-2900 img/s, gap head gives me around 3000 img/s at 100 groups and 1588 classes.

I work almost exclusively with a dataset (danbooru) that poses a multi-label positive-unlabeled problem. I use MLDecoder for this sometimes and recently noticed the pt 2.1 and extensive model compat issues along the slow groupFC impl, odd dropout, etc, prompting both versions. In addition to PU-specific techniques (often in math-heavy papers that are a bit of a headache to read), some researchers focus on aspects of the model, often what comes after the backbone (GNNs, text towers, MLDecoder, activation functions). The labeling scheme of danbooru is set up such that the labels present in an image can be mapped out hierarchically, similar to a scene graph. This is also done internally via label implications. I'm also working on implementing a from-scratch impl of DependencyViT for this, not going well, hopefully can exploit the tree structure for this.

I have a drop-in replacement for ClassifierHead in a notebook right now that works as long as self.head(x) and self.head.reset style call is used (hence #2050, #2051). My plan for the head submodule would be to gradually change models to follow regular ClassifierHead models and maintain compat with original weights either via aliases or modified weight remap fn. I hope that this change will help standardize head impls for current/new models and help with integrating different head cfgs. I'm down to impl this, just not sure of how you'd like it done.

@fffffgggg54
Copy link
Contributor Author

fffffgggg54 commented Dec 26, 2023

Added tests from my own testing script, will fail because there are models that don't work, 95 variants specifically, mostly due to distill/multiple feature maps/wrong input shape. The code to add the head is messy, universal head should fix. MLDecoderHead would basically be how the universal head would have to be configured in order for MLDecoder to work.

@fffffgggg54
Copy link
Contributor Author

Experimental feature. Want to wait to merge this until a universal head is implemented. This and other things I'm working on are a pain to implement/use/add to timm without a universal head.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants