Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convolution operator not marked as quantizable when padding is defined in the class instantiation #398

Open
etrommer opened this issue Nov 28, 2024 · 2 comments

Comments

@etrommer
Copy link

etrommer commented Nov 28, 2024

Description of the bug:

A torch.nn.Conv2d class instance is not marked as quantizable when padding is defined in the instantiation.

Minimal working example:

import torch
import ai_edge_torch
from torch.export import export_for_training
from ai_edge_torch.quantize.pt2e_quantizer import (
    PT2EQuantizer,
    get_symmetric_quantization_config,
)
from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
from ai_edge_torch.quantize.quant_config import QuantConfig

def aiet_export(model: torch.nn.Module, fname: str):
    sample_input = torch.rand((1, 8, 8, 3))
    model = ai_edge_torch.to_channel_last_io(model, args=[0], outputs=[0]).eval()
    model = export_for_training(model, (sample_input,)).module()
    aiet_quantizer = PT2EQuantizer().set_global(
        get_symmetric_quantization_config(is_per_channel=False, is_dynamic=False)
    )
    model = prepare_pt2e(model, aiet_quantizer)
    with torch.no_grad():
        model(sample_input)
    model = convert_pt2e(model, fold_quantize=False)
    torch.ao.quantization.move_exported_model_to_eval(model)
    aiet_tc = ai_edge_torch.convert(
        model, (sample_input,), quant_config=QuantConfig(pt2e_quantizer=aiet_quantizer)
    )
    aiet_tc.export(fname)

class TinyConv1(EvalModule):
    def __init__(self):
        super().__init__()
        # without explicit padding definition
        self.conv = torch.nn.Conv2d(3, 2, (3, 3))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv(x)
        return y

class TinyConv2(EvalModule):
    def __init__(self):
        super().__init__()
        # with explicit padding definition
        self.conv = torch.nn.Conv2d(3, 2, (3, 3), padding='same')

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = self.conv(x)
        return y

tc1 = TinyConv1()
aiet_export(tc1, 'tinyconv1.tflite')

tc2 = TinyConv2()
aiet_export(tc2, 'tinyconv2.tflite')

Actual vs expected behavior:

Expected:
Equally quantized models

Actual:
TinyConv1 has quantized parameters:
gh_aiet_tc1

TinyConv2 does not:
gh_aiet_tc2

Any other information you'd like to share?

The cause of this issue is that the test for identifying a conv node in the FX graph when applying quantization is too restrictive:

if n.op != "call_function" or n.target not in [
torch.ops.aten.conv1d.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.convolution.default,
]:

When padding is defined in the instantiation, n.target becomes an instance of torch._ops.OpOverload, so this test fails even though it probably shouldn't.

@etrommer
Copy link
Author

If you accept PRs, I'm happy to open one for this.

@pkgoogle
Copy link
Contributor

pkgoogle commented Dec 2, 2024

Hi @etrommer, this repo is open for contribution, please do so and we will review 😃

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants