Skip to content

Unofficial Pytorch implementation of Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization [Huang+, ICCV2017]

Notifications You must be signed in to change notification settings

jackdaw213/Artistic-style-transfer

Repository files navigation

Artistic-style-transfer

Unofficial Pytorch implementation of Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization [Huang+, ICCV2017]

Original Lua implementation from the author can be found here.

This implementation uses Nvidia DALI and AMP to accelerate the training process, with WanDB employed for monitoring.

Prerequisites

  1. Clone this repository

    git clone https://github.com/jackdaw213/Artistic-style-transfer
    cd Artistic-style-transfer
  2. Install Conda and create an environment

    conda create -n artistic_style_transfer python=3.12
  3. Install all dependencies from requirements.txt

    conda activate artistic_style_transfer
    pip install nvidia-pyindex
    pip install -r requirements.txt

This should prepare the Conda environment for both training and testing (pretrained model available below)

Train

  1. Download the COCO dataset for content images and the Wikiart dataset for style images. Extract the files and organize them into the 'data' folder, with subfolders 'train_content', 'val_content', 'train_style', and 'val_style'.

  2. Preprocess the dataset

    WikiArt dataset contains corrupted JPEG images (file ends prematurely) and images with 105x pixel counts of a 4K image. This step should remove MOST of the corrupted images and resize any images with pixel counts higher than 3840 * 2160.

    python preprocess.py
    preprocess.py [-h]
                  [--train_style TRAIN_STYLE_FOLDER]
                  [--val_style VAL_STYLE_FOLDER]
    
  3. Train the model.

    python train.py --enable_dali --enable_amp --enable_wandb
    train.py [-h]
             [--epochs EPOCHS]
             [--batch_size BATCH_SIZE]
             [--num_workers NUM_WORKERS]
             [--train_dir_content TRAIN_DIR_CONTENT]
             [--val_dir_content VAL_DIR_CONTENT]
             [--train_dir_style TRAIN_DIR_STYLE]
             [--val_dir_style VAL_DIR_STYLE]
             [--optimizer OPTIMIZER]
             [--learning_rate LEARNING_RATE]
             [--momentum MOMENTUM]
             [--resume_id RESUME_ID]
             [--checkpoint_freq CHECKPOINT_FREQ]
             [--amp_dtype AMP_DTYPE]
             [--enable_dali]
             [--enable_amp]
             [--enable_wandb]
    

    The model was trained on an RTX 3080 10G for 10 epoches.

    Training setup Batch size GPU memory usage Training time
    DALI 4 6GB 3.8 hours
    DALI + AMP 8 6.5GB 2.2 hours
    DataLoader 8 9GB 4.4 hours
    DataLoader + AMP 8 4GB 2.4 hours

    WARNING: Nvidia DALI only supports Nvidia GPUs. BFloat16 is supported only on RTX 3000/Ampere GPUs and above, while GPU Direct Storage (GDS) is supported only on server-class GPUs. Using Float16 might cause NaN loss during training, whereas BFloat16 does not.

Test

  1. Download the pretrained model here and put it in the model folder

  2. Generate the output image using the command bellow.

    python test -c content_image_path -s style_image_path
    test.py [-h] 
            [--content CONTENT] 
            [--style STYLE]
            [--model MODEL_PATH] 
    

Result

image image image image image

References

About

Unofficial Pytorch implementation of Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization [Huang+, ICCV2017]

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages