-
Notifications
You must be signed in to change notification settings - Fork 483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flux Autoencoder #2098
Merged
Merged
Flux Autoencoder #2098
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
7f9641a
flux autoencoder builder
calvinpelletier bfc93fc
flux autoencoder
calvinpelletier 0e6611e
convert flux autoencoder weights
calvinpelletier e7504b6
flux autoencoder unit test
calvinpelletier aa8b751
Merge remote-tracking branch 'origin/main' into flux_ae
calvinpelletier ddd2e7d
Merge remote-tracking branch 'origin/main' into flux_ae
calvinpelletier 794b1bd
addressing comments
calvinpelletier 346c51e
fmt
calvinpelletier c78f485
moving encoder/decoder construction to the model builder
calvinpelletier File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import pytest | ||
import torch | ||
|
||
from torchtune.models.flux import flux_1_autoencoder | ||
from torchtune.training.seed import set_seed | ||
|
||
BSZ = 32 | ||
CH_IN = 3 | ||
RESOLUTION = 16 | ||
CH_MULTS = [1, 2] | ||
CH_Z = 4 | ||
RES_Z = RESOLUTION // len(CH_MULTS) | ||
|
||
|
||
@pytest.fixture(autouse=True) | ||
def random(): | ||
set_seed(0) | ||
|
||
|
||
class TestFluxAutoencoder: | ||
@pytest.fixture | ||
def model(self): | ||
model = flux_1_autoencoder( | ||
resolution=RESOLUTION, | ||
ch_in=CH_IN, | ||
ch_out=3, | ||
ch_base=32, | ||
ch_mults=CH_MULTS, | ||
ch_z=CH_Z, | ||
n_layers_per_resample_block=2, | ||
scale_factor=1.0, | ||
shift_factor=0.0, | ||
) | ||
|
||
for param in model.parameters(): | ||
param.data.uniform_(0, 0.1) | ||
|
||
return model | ||
|
||
@pytest.fixture | ||
def img(self): | ||
return torch.randn(BSZ, CH_IN, RESOLUTION, RESOLUTION) | ||
|
||
@pytest.fixture | ||
def z(self): | ||
return torch.randn(BSZ, CH_Z, RES_Z, RES_Z) | ||
|
||
def test_forward(self, model, img): | ||
actual = model(img) | ||
assert actual.shape == (BSZ, CH_IN, RESOLUTION, RESOLUTION) | ||
|
||
actual = torch.mean(actual, dim=(0, 2, 3)) | ||
expected = torch.tensor([0.4286, 0.4276, 0.4054]) | ||
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) | ||
|
||
def test_backward(self, model, img): | ||
y = model(img) | ||
loss = y.mean() | ||
loss.backward() | ||
|
||
def test_encode(self, model, img): | ||
actual = model.encode(img) | ||
assert actual.shape == (BSZ, CH_Z, RES_Z, RES_Z) | ||
|
||
actual = torch.mean(actual, dim=(0, 2, 3)) | ||
expected = torch.tensor([0.6150, 0.7959, 0.7178, 0.7011]) | ||
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) | ||
|
||
def test_decode(self, model, z): | ||
actual = model.decode(z) | ||
assert actual.shape == (BSZ, CH_IN, RESOLUTION, RESOLUTION) | ||
|
||
actual = torch.mean(actual, dim=(0, 2, 3)) | ||
expected = torch.tensor([0.4246, 0.4241, 0.4014]) | ||
torch.testing.assert_close(actual, expected, atol=1e-4, rtol=1e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
from ._model_builders import flux_1_autoencoder | ||
|
||
__all__ = [ | ||
"flux_1_autoencoder", | ||
] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are from the small dummy model, did you run these same tests on the full model with weights against the flux codebase?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, these are the results from the full test against the flux codebase: