Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flux Autoencoder #2098

Merged
merged 9 commits into from
Jan 8, 2025
5 changes: 5 additions & 0 deletions tests/torchtune/models/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
81 changes: 81 additions & 0 deletions tests/torchtune/models/flux/test_flux_autoencoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import pytest
import torch

from torchtune.models.flux import flux_1_autoencoder
from torchtune.training.seed import set_seed

BSZ = 32
CH_IN = 3
RESOLUTION = 16
CH_MULTS = [1, 2]
CH_Z = 4
RES_Z = RESOLUTION // len(CH_MULTS)


@pytest.fixture(autouse=True)
def random():
set_seed(0)


class TestFluxAutoencoder:
@pytest.fixture
def model(self):
model = flux_1_autoencoder(
resolution=RESOLUTION,
ch_in=CH_IN,
ch_out=3,
ch_base=32,
ch_mults=CH_MULTS,
ch_z=CH_Z,
n_layers_per_resample_block=2,
scale_factor=1.0,
shift_factor=0.0,
)

for param in model.parameters():
param.data.uniform_(0, 0.1)

return model

@pytest.fixture
def img(self):
return torch.randn(BSZ, CH_IN, RESOLUTION, RESOLUTION)

@pytest.fixture
def z(self):
return torch.randn(BSZ, CH_Z, RES_Z, RES_Z)

def test_forward(self, model, img):
actual = model(img)
assert actual.shape == (BSZ, CH_IN, RESOLUTION, RESOLUTION)

actual = torch.mean(actual, dim=(0, 2, 3))
expected = torch.tensor([0.4286, 0.4276, 0.4054])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are from the small dummy model, did you run these same tests on the full model with weights against the flux codebase?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, these are the results from the full test against the flux codebase:

Parity: 3.36e-5

Speed: ~10% faster

torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

def test_backward(self, model, img):
y = model(img)
loss = y.mean()
loss.backward()

def test_encode(self, model, img):
actual = model.encode(img)
assert actual.shape == (BSZ, CH_Z, RES_Z, RES_Z)

actual = torch.mean(actual, dim=(0, 2, 3))
expected = torch.tensor([0.6150, 0.7959, 0.7178, 0.7011])
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)

def test_decode(self, model, z):
actual = model.decode(z)
assert actual.shape == (BSZ, CH_IN, RESOLUTION, RESOLUTION)

actual = torch.mean(actual, dim=(0, 2, 3))
expected = torch.tensor([0.4246, 0.4241, 0.4014])
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4)
10 changes: 10 additions & 0 deletions torchtune/models/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from ._model_builders import flux_1_autoencoder

__all__ = [
"flux_1_autoencoder",
]
Loading
Loading