Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add a new image_size parameter in train_dalle and generate #310

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions dalle_pytorch/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ def download(url, filename = None, root = CACHE_PATH):
# pretrained Discrete VAE from OpenAI

class OpenAIDiscreteVAE(nn.Module):
def __init__(self):
def __init__(self, image_size=256):
super().__init__()

self.enc = load_model(download(OPENAI_VAE_ENCODER_PATH))
self.dec = load_model(download(OPENAI_VAE_DECODER_PATH))

self.num_layers = 3
self.image_size = 256
self.image_size = image_size
self.num_tokens = 8192

@torch.no_grad()
Expand Down Expand Up @@ -142,7 +142,7 @@ def instantiate_from_config(config):
return get_obj_from_str(config["target"])(**config.get("params", dict()))

class VQGanVAE(nn.Module):
def __init__(self, vqgan_model_path=None, vqgan_config_path=None):
def __init__(self, image_size=256, vqgan_model_path=None, vqgan_config_path=None):
super().__init__()

if vqgan_model_path is None:
Expand Down Expand Up @@ -170,7 +170,7 @@ def __init__(self, vqgan_model_path=None, vqgan_config_path=None):
# f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
f = config.model.params.ddconfig.resolution / config.model.params.ddconfig.attn_resolutions[0]
self.num_layers = int(log(f)/log(2))
self.image_size = 256
self.image_size = image_size
self.num_tokens = config.model.params.n_embed
self.is_gumbel = isinstance(self.model, GumbelVQ)

Expand Down
13 changes: 8 additions & 5 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
parser.add_argument('--top_k', type = float, default = 0.9, required = False,
help='top k filter threshold')

parser.add_argument('--image_size', type = int, default = 256, required = False,
help='image size')

parser.add_argument('--outputs_dir', type = str, default = './outputs', required = False,
help='output directory')

Expand Down Expand Up @@ -81,12 +84,14 @@ def exists(val):

dalle_params.pop('vae', None) # cleanup later

IMAGE_SIZE = args.image_size

if vae_params is not None:
vae = DiscreteVAE(**vae_params)
vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:])
elif not args.taming:
vae = OpenAIDiscreteVAE()
vae = OpenAIDiscreteVAE(IMAGE_SIZE)
else:
vae = VQGanVAE(args.vqgan_model_path, args.vqgan_config_path)
vae = VQGanVAE(IMAGE_SIZE, args.vqgan_model_path, args.vqgan_config_path)


dalle = DALLE(vae = vae, **dalle_params).cuda()
Expand All @@ -95,8 +100,6 @@ def exists(val):

# generate images

image_size = vae.image_size

texts = args.text.split('|')

for text in tqdm(texts):
Expand Down
18 changes: 9 additions & 9 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@

model_group.add_argument('--loss_img_weight', default = 7, type = int, help = 'Image loss weight')

model_group.add_argument('--image_size', default = 256, type = int, help = 'Image size')

model_group.add_argument('--attn_types', default = 'full', type = str, help = 'comma separated list of attention types. attention type can be: full or sparse or axial_row or axial_col or conv_like.')

args = parser.parse_args()
Expand Down Expand Up @@ -173,6 +175,7 @@ def cp_path_to_dir(cp_path, tag):
SAVE_EVERY_N_STEPS = args.save_every_n_steps
KEEP_N_CHECKPOINTS = args.keep_n_checkpoints

IMAGE_SIZE = args.image_size
MODEL_DIM = args.dim
TEXT_SEQ_LEN = args.text_seq_len
DEPTH = args.depth
Expand Down Expand Up @@ -242,17 +245,16 @@ def cp_path_to_dir(cp_path, tag):
scheduler_state = loaded_obj.get('scheduler_state')

if vae_params is not None:
vae = DiscreteVAE(**vae_params)
vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:])
else:
if args.taming:
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
vae = VQGanVAE(IMAGE_SIZE, VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
else:
vae = OpenAIDiscreteVAE()
vae = OpenAIDiscreteVAE(IMAGE_SIZE)

dalle_params = dict(
**dalle_params
)
IMAGE_SIZE = vae.image_size
resume_epoch = loaded_obj.get('epoch', 0)
else:
if exists(VAE_PATH):
Expand All @@ -268,19 +270,17 @@ def cp_path_to_dir(cp_path, tag):

vae_params, weights = loaded_obj['hparams'], loaded_obj['weights']

vae = DiscreteVAE(**vae_params)
vae = DiscreteVAE(IMAGE_SIZE, **vae_params[1:])
vae.load_state_dict(weights)
else:
if distr_backend.is_root_worker():
print('using pretrained VAE for encoding images to tokens')
vae_params = None

if args.taming:
vae = VQGanVAE(VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
vae = VQGanVAE(IMAGE_SIZE, VQGAN_MODEL_PATH, VQGAN_CONFIG_PATH)
else:
vae = OpenAIDiscreteVAE()

IMAGE_SIZE = vae.image_size
vae = OpenAIDiscreteVAE(IMAGE_SIZE)

dalle_params = dict(
num_text_tokens=tokenizer.vocab_size,
Expand Down