Skip to content

Commit

Permalink
final working version with readme updated
Browse files Browse the repository at this point in the history
  • Loading branch information
samueleruffino99 committed Aug 2, 2024
1 parent c5c3ef9 commit 15e2706
Show file tree
Hide file tree
Showing 146 changed files with 21,332 additions and 10,372 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
/outputs
__pycache__
/wandb
/checkpoints
/checkpoints
/xxx
/Output_i2vgenxl_train_segmentation_nuscenes_defaultcontrolnet_360x640_8_frames
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,26 @@ We list the inference scripts for different tasks mentioned in our paper as foll



## Video Generation with Segmentation Control (NuScenes)

<br>
<img width="800" src="assets/nuscenes.gif"/>
<br>

We currently implemented segmentation control on **I2VGen-XL**. Run this on a local pretrained checkpoint (no HF).

Notice that in these bash files, you can set the segmentation type that you want to use ("ade" or "odise"), depending which ControlNet you are using, in this way the Prescan segmentation will be converted to the correct format.

NB: for inference w/o extracted condition, segmentation_type "ade" is only support, since the deafult segmentor extarct the maps in ade format.

Inference (w/ extracted condition) | Inference (w/o extracted condition) |
|---------|--------|
| [command](/inference_scripts/i2vgenxl/i2vgenxl_inference_segmentation_nuscenes.sh) | [command](/inference_scripts/i2vgenxl/i2vgenxl_inference_extract_segmentation_from_raw_frames_nuscenes.sh)

NB: If you want to add more data to test on, place raw frames and segmentation into ```evaluation/frames/raw_input``` and ```evaluation/frames/segmentation```, along with a caption.csv with the corresponding folder name specified in it (look for other examples to see a reference).



## Video Generation with Multi-Condition Control

<br>
Expand Down
3 changes: 0 additions & 3 deletions assets/evaluation/frames/captions_AD_real.json

This file was deleted.

3 changes: 0 additions & 3 deletions assets/evaluation/frames/captions_AD_simul.json

This file was deleted.

3 changes: 3 additions & 0 deletions assets/evaluation/frames/captions_segm_nuscenes.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"nuscenes_scene_0061_prescan": "A realistic driving scene."
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Diff not rendered.
Binary file added assets/nuscenes.gif
45 changes: 45 additions & 0 deletions assets/train_guideline.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,51 @@

# 🚅 How To Train

## AD Training
### Step 1: Download Training Data
Download nuScenes dataset and place it into ```/path/to/nuscenes```
### Step 2: Prepare Data in Specified Format
#### Generate per-scene resized (and optionally FOV-adjusted) frames
Run this command to save resized (and optionally adjusted) CAM_FRONT frames for every scene.
```bash
python -m data.nuscenes.generate_scene_frames_folders --dataroot /path/to/nuscenes --version v1.0-trainval --cam-type CAM_FRONT --output-path /path/to/train/data --scale-factor 0.4 --save-adjusted-fov --fov-from 120 --fov-to 94
# Example
python -m data.nuscenes.generate_scene_frames_folders --dataroot /mnt/d/AD/datasets/nuscenes --version v1.0-trainval --cam-type CAM_FRONT --output-path /mnt/d/z004x7dn/datasets/nuscenes/scenes_frames --scale-factor 0.4 --save-adjusted-fov --fov-from 120 --fov-to 94
```
#### Generate training segments and captions
Run this command to generate training segments (.mp4) and captions (csv file in ./sample_data)
```bash
python -m data.nuscenes.data_preparation --dataroot /path/to/nuscenes --version v1.0-trainval --cam-type CAM_FRONT --json-path /path/to/nuscenes/predictions/mllm/results_nusc_mllm.json --input-path /path/to/nuscenes/scenes_frames --output-path /path/to/nuscenes/scenes_videos_segments --use-adjusted-fov --generate-segments --augment-captions --csv-filename video_captions_nuscenes.csv --segment-length 16
# Example
python -m data.nuscenes.data_preparation --dataroot /mnt/d/AD/datasets/nuscenes --version v1.0-trainval --cam-type CAM_FRONT --json-path /mnt/d/z004x7dn/datasets/nuscenes/predictions/mllm/results_nusc_mllm.json --input-path /mnt/d/z004x7dn/datasets/nuscenes/scenes_frames --output-path /mnt/d/z004x7dn/datasets/nuscenes/scenes_videos_segments --use-adjusted-fov --generate-segments --augment-captions --csv-filename video_captions_nuscenes.csv --segment-length 16
```

### Step 3: Run Training
Here is the command we used to start training on I2VGENXL with segmenentation map as control condition on nuscenes driving scenes. Training scripts on I2VGen-XL and SVD are roughly the same.

```
sh train_scripts/i2vgenxl/i2vgenxl_train_segmentation_nuscenes.sh
```

Specifically, in the training scripts:

`--yaml_file`: The configuration file for all hyper-parameters related to **training**.

The rest of the hyper-parameters in the training script are for **evaluation**, which can help you monitor the training process better.

`--save_n_steps`: Save the trained adapter checkpoints every n training steps.

`--save_starting_step`: Save the trained adapter checkpoints after such training steps.

`--validate_every_steps`: Perform evaluation every x training steps. The evaluation data are placed under ```./assets/evaluation```. If you prefer to evaluate different samples, you can replace them by following the same file structure.

`--num_inference_steps`: The number of inference steps during inference. We can just set it as the same value as the default inference steps of the backbone model.

```--extract_control_conditions```: If you already have condition image/frames extracted from evaluation image/video (see Inference Data Structure section above), you can set it as ```False```. Otherwise, if you haven't extracted control conditions and only have the raw image/frames, you can set it as ```True```, and our code can automatically extract the control conditions from the evaluation image/frames. The default setting is ```False```.

```--control_guidance_end```: As mentioned above, this is the most important parameter that balances generated image/video quality with control strength. But since we want to see if the training code working or not, we recommend just setting it as 1.0 to give control across all inference steps. You can adjust it to a lower value later after you have a trained model.

## Ctrl-Adapter
### Step 1: Download Training Data

- For Ctrl-Adapter training on image backbones (e.g., SDXL), we use 300k images from the [LAION POP](https://laion.ai/blog/laion-pop/) dataset. You can download a subset from this dataset [here](https://huggingface.co/datasets/Ejafa/ye-pop).
Expand Down
4 changes: 3 additions & 1 deletion configs/i2vgenxl_train_segmentation_nuscenes.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
DATA_PATH: xxx ##### ACTION NEEDED: set this path before training #####

# the folder where training data is stored
train_data_path: /mnt/d/nuscenes/scenes_videos_segments ##### ACTION NEEDED: path for the training image folder #####
train_data_path: /mnt/d/z004x7dn/datasets/nuscenes/scenes_videos_segments ##### ACTION NEEDED: path for the training image folder #####

# the csv file for training prompts
train_prompt_path: sample_data/video_captions_nuscenes.csv ##### ACTION NEEDED: path for the training prompt csv file #####

# batch size per gpu. we use batch size of 1 by default.
# you can adjust it based on your GPU memory available
train_batch_size: 1
num_workers: 1

# this parameter represents that we resize the input image to 512 * 512 before giving to SDv1.5 ControlNet
# adding support for different resolutions is left for future work
Expand All @@ -22,6 +23,7 @@ height: 360 #144
width: 640 #256

# for video generation models (e.g., I2VGen-XL), need to set both of the following as the default settings in the backbone model
n_ref_frames: 1
n_sample_frames: 8 #16
output_fps: 12 # 12 Hz from Nuscenes dataset

Expand Down
12 changes: 5 additions & 7 deletions data/nuscenes/data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ def process_videos(input_dir, output_dir, segment_length=16):

def main():
parser = argparse.ArgumentParser(description='Prepare Nuscenes data for training the model.')
parser.add_argument('--dataroot', type=str, default='/mnt/d/nuscenes', help='Path to Nuscenes dataset.')
parser.add_argument('--dataroot', type=str, default='/mnt/d/AD/datasets/nuscenes', help='Path to Nuscenes dataset.')
parser.add_argument('--version', type=str, default='v1.0-trainval', help='Nuscenes dataset version.')
parser.add_argument('--json-filename', default='results_nusc_mllm.json', type=str, help='Filename of the json file containing the annotation from mllm LLaVA.')
parser.add_argument('--input-dir', default='scenes_frames', type=str, help='Folder containing the frames per each scene.')
parser.add_argument('--output-dir', default='scenes_videos_segments', type=str, help='folder to save the video segments.')
parser.add_argument('--json-path', default='/mnt/d/z004x7dn/datasets/nuscenes/predictions/mllm/results_nusc_mllm.json', type=str, help='Filename of the json file containing the annotation from mllm LLaVA.')
parser.add_argument('--input-path', default='/mnt/d/z004x7dn/datasets/nuscenes/scenes_frames', type=str, help='Folder containing the frames per each scene.')
parser.add_argument('--output-path', default='/mnt/d/z004x7dn/datasets/nuscenes/scenes_videos_segments', type=str, help='folder to save the video segments.')
parser.add_argument('--cam-type', type=str, default='CAM_FRONT', help='Camera type to extract images.')
parser.add_argument('--use-adjusted-fov', default=True, action='store_true', help='Use adjusted field of view.')
parser.add_argument('--generate-segments', default=False, action='store_true', help='Generate video segments.')
Expand All @@ -136,9 +136,7 @@ def main():
parser.add_argument('--debugpy', action='store_true', help='Enable debugpy for remote debugging')
args = parser.parse_args()
cam_type_foldername = args.cam_type + "_adj_fov" if args.use_adjusted_fov else args.cam_type
args.input_path = os.path.join(args.dataroot, 'scenes_frames', cam_type_foldername)
args.output_path = os.path.join(args.dataroot, args.output_dir)
args.json_path = os.path.join(args.dataroot, args.version, 'predictions', 'mllm', args.json_filename)
args.input_path = os.path.join(args.input_path, cam_type_foldername)
args.csv_path = os.path.join('sample_data', args.csv_filename)

if args.debugpy:
Expand Down
5 changes: 2 additions & 3 deletions data/nuscenes/generate_scene_frames_folders.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def process_nuscenes_data(nusc, data_dir, scale_factor=0.5, cam_type="CAM_FRONT"
# Main function to setup nuScenes SDK and path
def main():
parser = argparse.ArgumentParser(description='Generate front view images from nuScenes dataset')
parser.add_argument('--dataroot', type=str, default="/mnt/d/nuscenes", help='Path to nuScenes dataset')
parser.add_argument('--dataroot', type=str, default="/mnt/d/AD/datasets/nuscenes", help='Path to nuScenes dataset')
parser.add_argument('--version', type=str, default="v1.0-trainval", help='nuScenes dataset version')
parser.add_argument('--cam-type', type=str, default="CAM_FRONT", help='Camera type to extract images')
parser.add_argument('--output-dir', type=str, default="scenes_frames", help='Path to output directory')
parser.add_argument('--output-path', type=str, default="/mnt/d/z004x7dn/datasets/nuscenes/scenes_frames", help='Path to output directory')
parser.add_argument('--scale-factor', type=float, default=0.4, help='Scale factor for resizing images, 0.4 correspond to divide by 2.5 each dimension.')
parser.add_argument('--save-adjusted-fov', action='store_true', help='Save images with adjusted field of view')
parser.add_argument('--fov-from', type=float, default=120, help='Initial field of view in degrees')
Expand All @@ -88,7 +88,6 @@ def main():
print("Waiting for debugger attach")
debugpy.wait_for_client()

args.output_path = os.path.join(args.dataroot, args.output_dir)
nusc = NuScenes(version=args.version, dataroot=args.dataroot, verbose=True)
process_nuscenes_data(nusc, args.output_path, args.scale_factor, args.cam_type, args.save_adjusted_fov, args.fov_from, args.fov_to)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ def __call__(
)

height, width = images.shape[-2:]
print("this is the height and the with", height, width)
# print("this is the height and the with", height, width)
elif isinstance(controlnet, MultiControlNetModel):

images = []
Expand Down
33 changes: 24 additions & 9 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,12 +367,12 @@ def inference_main(inference_args):

#SAM
# change segmentation model based on the segmentation_type used
# if inference_args.segmentation_type == 'odise':
# model_paths['segmentation'] = "JaspervanLeuven/controlnet_rect"
# elif inference_args.segmentation_type == 'ade':
# model_paths['segmentation'] = "lllyasviel/control_v11p_sd15_seg"
# else:
# model_paths['segmentation'] = "lllyasviel/control_v11p_sd15_seg"
if inference_args.segmentation_type == 'odise':
model_paths['segmentation'] = "JaspervanLeuven/controlnet_rect"
elif inference_args.segmentation_type == 'ade':
model_paths['segmentation'] = "lllyasviel/control_v11p_sd15_seg"
else:
model_paths['segmentation'] = "lllyasviel/control_v11p_sd15_seg"

for control_type, model_path in model_paths.items():
if (len(inference_args.control_types) == 1 and control_type in inference_args.control_types) or (len(inference_args.control_types) > 1): # single-condition control
Expand Down Expand Up @@ -448,8 +448,11 @@ def inference_main(inference_args):
# print images_pil[0].size and type
images_pil = [center_crop_and_resize(img, output_size=(inference_args.width, inference_args.height)) for img in images_pil]
images_pil = images_pil[:inference_args.n_sample_frames]


# convert to rgb if not
if images_pil[0].mode != 'RGB':
images_pil = [img.convert('RGB') for img in images_pil]

# load or extract condition images
if inference_args.extract_control_conditions:
all_conditioning_images_pil = []
Expand All @@ -468,15 +471,26 @@ def inference_main(inference_args):
for cond_dir in condition_input_dir:
condition_images_path = os.path.join(cond_dir, sample)
#SAM prescan segmentation mapping
allowed_extensions = ('.png', '.jpg', '.jpeg', '.bmp', '.gif')
if 'segmentation' in cond_dir and inference_args.segmentation_type is not None:
if os.path.isdir(condition_images_path):
condition_frames = sorted(os.listdir(condition_images_path))[:inference_args.n_sample_frames]
# if segmenttion from prescn (ImageSegmentationSensor_1_00001.png)
if "ImageSegmentationSensor" in os.listdir(condition_images_path)[0]:
condition_frames = sorted([f for f in os.listdir(condition_images_path) if f.lower().endswith(allowed_extensions)], key=lambda x: int(x.split('_')[-1].split('.')[0]))[::2]
condition_frames = condition_frames[:inference_args.n_sample_frames]
else:
condition_frames = sorted([f for f in os.listdir(condition_images_path) if f.lower().endswith(allowed_extensions)])[:inference_args.n_sample_frames]
conditioning_images_pil = [map_rgb(Image.open(os.path.join(condition_images_path, frame)), mapping_type=inference_args.segmentation_type) for frame in condition_frames]
else:
conditioning_images_pil = [map_rgb(Image.open(condition_images_path), mapping_type=inference_args.segmentation_type)]
else:
if os.path.isdir(condition_images_path):
condition_frames = sorted(os.listdir(condition_images_path))[:inference_args.n_sample_frames]
# if segmenttion from prescn (ImageSegmentationSensor_1_00001.png)
if "ImageSegmentationSensor" in os.listdir(condition_images_path)[0]:
condition_frames = sorted([f for f in os.listdir(condition_images_path) if f.lower().endswith(allowed_extensions)], key=lambda x: int(x.split('_')[-1].split('.')[0]))[::2]
condition_frames = condition_frames[:inference_args.n_sample_frames]
else:
condition_frames = sorted([f for f in os.listdir(condition_images_path) if f.lower().endswith(allowed_extensions)])[:inference_args.n_sample_frames]
conditioning_images_pil = [Image.open(os.path.join(condition_images_path, frame)) for frame in condition_frames]
else:
conditioning_images_pil = [Image.open(condition_images_path)]
Expand Down Expand Up @@ -511,6 +525,7 @@ def inference_main(inference_args):
if inference_args.model_name == 'i2vgenxl':
num_frames = inference_args.n_sample_frames if 'n_sample_frames' in inference_args else 16 # default
target_fps = inference_args.output_fps if 'output_fps' in inference_args else 16 # default
print(num_frames, target_fps)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
i2vgenxl_outputs = pipe(
prompt=prompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@ python inference.py \
--model_name "i2vgenxl" \
--control_types "segmentation" \
--segmentation_type "odise" \
--local_checkpoint_path "checkpoints/adapter_jasper" \
--local_checkpoint_path "checkpoints/adapter_odise" \
--eval_input_type "frames" \
--evaluation_input_folder "assets/evaluation/frames" \
--global_step 70000 \
--n_sample_frames 8 \
--n_sample_frames 16 \
--output_fps 12 \
--n_ref_frames 1 \
--num_inference_steps 50 \
--control_guidance_end 0.8 \
--use_size_512 false \
--height 360 \
--width 640 \
--evaluation_prompt_file "captions_AD_simul.json"
--evaluation_prompt_file "captions_segm_nuscenes.json"


Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
python inference.py \
--model_name "i2vgenxl" \
--control_types "segmentation" \
--segmentation_type "ade" \
--local_checkpoint_path "checkpoints/adapter_ade_360x640" \
--eval_input_type "frames" \
--evaluation_input_folder "assets/evaluation/frames" \
--global_step 100000 \
--n_sample_frames 16 \
--extract_control_conditions True \
--output_fps 12 \
--n_ref_frames 1 \
--num_inference_steps 50 \
--control_guidance_end 0.8 \
--use_size_512 false \
--height 360 \
--width 640 \
--evaluation_prompt_file "captions_segm_nuscenes.json"


Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@ python inference.py \
--model_name "i2vgenxl" \
--control_types "segmentation" \
--segmentation_type "odise" \
--local_checkpoint_path "checkpoints/adapter_jasper" \
--local_checkpoint_path "checkpoints/adapter_ade_360x640" \
--eval_input_type "frames" \
--evaluation_input_folder "assets/evaluation/frames" \
--global_step 70000 \
--global_step 100000 \
--n_sample_frames 16 \
--output_fps 12 \
--n_ref_frames 1 \
--num_inference_steps 50 \
--control_guidance_end 0.8 \
--use_size_512 false \
--height 160 \
--width 256 \
--evaluation_prompt_file "captions_AD_real.json"
--height 360 \
--width 640 \
--evaluation_prompt_file "captions_segm_nuscenes.json"


This file was deleted.

Loading

0 comments on commit 15e2706

Please sign in to comment.