Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Verification of HuggingFace models #21

Open
freejen opened this issue Mar 7, 2022 · 1 comment
Open

Verification of HuggingFace models #21

freejen opened this issue Mar 7, 2022 · 1 comment

Comments

@freejen
Copy link

freejen commented Mar 7, 2022

Hi 😄

I'm trying to run auto_LiRPA on a transformer model for token classification (named entity recognition) from the Hugging Face library, but I have run into some problems.

Here's the code I am trying to run:

import torch
import torch.nn as nn
from transformers import AutoModelForTokenClassification, AutoTokenizer, AutoConfig
from auto_LiRPA import BoundedModule


class TransformerFromEmbeddings(nn.Module):
    """A wrapper around a Hugging face model that ."""

    def __init__(self, model):
        super(TransformerFromEmbeddings, self).__init__()
        self.model = model

    def forward(self, inputs_embeds, attention_mask):
        model_out = self.model(attention_mask=attention_mask, 
            inputs_embeds=inputs_embeds)
        return model_out.logits


def verify_huggingface_model(text):
    bert_config = AutoConfig.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
    bert_config.hidden_act='relu'
    model = AutoModelForTokenClassification.from_config(bert_config)
    model.eval()

    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    
    inputs = tokenizer(text, return_tensors="pt")

    embedding_layer =  model.get_input_embeddings()
    batch_embeddings = embedding_layer(inputs["input_ids"])
    attention_mask = inputs["attention_mask"]
    embeds_model = TransformerFromEmbeddings(model)

    # Runs as expected
    outputs = embeds_model(batch_embeddings, attention_mask)
    predictions = torch.argmax(outputs, dim=2)
    tokens = inputs.tokens()
    for token, prediction in zip(tokens, predictions[0].numpy()):
        print((token, model.config.id2label[prediction]))
    print()

    # Throws an exception
    bounded_model = BoundedModule(embeds_model, (batch_embeddings, attention_mask))
    

if __name__ == "__main__":
    text = (
        "Hugging Face Inc. is a company based in New York City. Its headquarters are in DUMBO, "
        "therefore very close to the Manhattan Bridge."
    )
    verify_huggingface_model(text)

As you can see, I try to create a BoundedModule at the end of the verify_huggingface_model function, but this step produces an error. In the "Errors" section of this issue, I will present 2 errors I have run into. I was able to fix the first one, but not the second one.

I also note that I have created the TransformerFromEmbeddings class in an attempt to make the Transformer model from the Hugging Face library compatible with auto_LiRPA.

Erros

Running the script with an unmodified auto_LiRPA library produces the following error:

Traceback (most recent call last):
  File "c:\users\a-sjenko\documents\code\huggingfaceverification\auto_lirpa\auto_LiRPA\bound_general.py", line 265, in forward
    l.batch_dim = l.infer_batch_dim(self.init_batch_size, *inp_batch_dim)
  File "c:\users\a-sjenko\documents\code\huggingfaceverification\auto_lirpa\auto_LiRPA\operators\logical.py", line 46, in infer_batch_dim
    return BoundMul.infer_batch_dim(batch_size, *x[1:])
NameError: name 'BoundMul' is not defined

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File ".\verify_transformer.py", line 54, in <module>
    verify_huggingface_model(text)
  File ".\verify_transformer.py", line 46, in verify_huggingface_model
    bounded_model = BoundedModule(embeds_model, (batch_embeddings, attention_mask))
  File "c:\users\a-sjenko\documents\code\huggingfaceverification\auto_lirpa\auto_LiRPA\bound_general.py", line 80, in __init__
    self._convert(model, global_input)
  File "c:\users\a-sjenko\documents\code\huggingfaceverification\auto_lirpa\auto_LiRPA\bound_general.py", line 599, in _convert
    self.forward(*global_input)  # running means/vars changed
  File "c:\users\a-sjenko\documents\code\huggingfaceverification\auto_lirpa\auto_LiRPA\bound_general.py", line 269, in forward
    l, l.name, l.forward_value.shape, inp_batch_dim))
Exception: Fail to infer the batch dimension of (BoundWhere())[/419]: forward_value shape torch.Size([2]), input batch dimensions [-1, -1, -1]

The error was created in this function:
https://github.com/KaidiXu/auto_LiRPA/blob/214710e457213a03b24650c089424b8706abe01a/auto_LiRPA/operators/logical.py#L45-L46

It seems that BoundMul isn't imported into "operators/logical.py". To fix this issue, I have added from .bivariate import BoundMul to "operators/logical.py"

With this addition, I no longer get the previous error, but now I get a new one:

Traceback (most recent call last):
  File "c:\users\a-sjenko\documents\code\huggingfaceverification\auto_lirpa\auto_LiRPA\bound_general.py", line 265, in forward
    l.batch_dim = l.infer_batch_dim(self.init_batch_size, *inp_batch_dim)
  File "c:\users\a-sjenko\documents\code\huggingfaceverification\auto_lirpa\auto_LiRPA\operators\shape.py", line 517, in infer_batch_dim
    raise NotImplementedError('forward_value shape {}'.format(self.forward_value.shape))
NotImplementedError: forward_value shape torch.Size([1, 32])

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File ".\verify_transformer.py", line 53, in <module>
    verify_huggingface_model(text)
  File ".\verify_transformer.py", line 45, in verify_huggingface_model
    bounded_model = BoundedModule(embeds_model, (batch_embeddings, attention_mask))
  File "c:\users\a-sjenko\documents\code\huggingfaceverification\auto_lirpa\auto_LiRPA\bound_general.py", line 80, in __init__
    self._convert(model, global_input)
  File "c:\users\a-sjenko\documents\code\huggingfaceverification\auto_lirpa\auto_LiRPA\bound_general.py", line 599, in _convert
    self.forward(*global_input)  # running means/vars changed
  File "c:\users\a-sjenko\documents\code\huggingfaceverification\auto_lirpa\auto_LiRPA\bound_general.py", line 269, in forward
    l, l.name, l.forward_value.shape, inp_batch_dim))
Exception: Fail to infer the batch dimension of (BoundExpand())[/420]: forward_value shape torch.Size([1, 32]), input batch dimensions [-1, -1]

The error mentions something with the shape of [1, 32]. It might be referring to the attention_mask tensor in my code, which has that shape. It might be worth noting that the shape of the input embeddings (batch_embeddings variable) in my code is [1, 32, 1024].

I have looked into your example script for transformers and seen that the mask you use is a 4-dimensional tensor, with a shape of [32, 1, 1, 32], for a batch of input embedings with the shape [32, 32, 64]. I have tried modifying the shape of the attention mask in my code to follow the same pattern as the attention mask from your example (I have reshaped my mask to [1, 1, 1, 32]), but then I get an error from the Hugging Face library, because it expects my mask to have the shape [1, 32].

Questions

  1. How can I solve this problem?
  2. Also, I have noticed that the classes you are using in the language examples seem to have been copied from the Hugging Face library. Is this correct and how much did you have to modify the copied classes to make them compatible with auto_LiRPA?

My setup

I am working on Microsoft Windows Server 2019, and have run the following list of commands to set everything up:

conda create --name auto_lirpa_dev python=3.7 -y
conda activate auto_lirpa_dev
conda install pytorch torchvision torchaudio cpuonly -c pytorch-lts -y
conda install -c huggingface transformers -y

git clone https://github.com/KaidiXu/auto_LiRPA
cd auto_LiRPA
python setup.py develop
cd ..
@huanzhang12
Copy link
Member

Thank you for providing such a detailed report! We will take a look at your problem and get back to you as soon as possible.
@shizhouxing Can you take a look at the example and see if it is easy to fix?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants