diff --git a/4_generate_labels.py b/4_generate_labels.py index 569bce1..c2e7372 100644 --- a/4_generate_labels.py +++ b/4_generate_labels.py @@ -3,9 +3,13 @@ import fiftyone as fo import fiftyone.zoo as foz - if __name__ == '__main__': dataset = fo.load_dataset("play_photos") + + # Clean up from previous runs + if "labeled_dataset" in fo.list_datasets(): + fo.delete_dataset("labeled_dataset") + clip = foz.load_zoo_model( "clip-vit-base32-torch", text_prompt="A photo of a", @@ -14,14 +18,20 @@ # alexnet = foz.load_zoo_model("alexnet-imagenet-torch") # dense201 = foz.load_zoo_model("densenet201-imagenet-torch") # fasterrcnn = foz.load_zoo_model("faster-rcnn-resnet50-fpn-coco-torch") - # yoloseg = foz.load_zoo_model("yolo11x-seg-coco-torch") + #yoloseg = foz.load_zoo_model("yolo11x-seg-coco-torch") - dataset.apply_model(clip, label_field="default_prediction") - dataset.set_values("ground_truth", fo.Classification(label=str(dataset.values("default_prediction.label")))) + dataset.apply_model(clip, label_field="prediction") # dataset.apply_model(dense201, label_field="dense201") # dataset.apply_model(alexnet, label_field="alexnet") # dataset.apply_model(fasterrcnn, label_field="faster_rcnn") - # dataset.apply_model(yoloseg, label_field="yolo_seg") + #dataset.apply_model(yoloseg, label_field="yolo_seg") + + #Alright time to make our dataset with cleaned labels + labeled_dataset = dataset.clone(name="labeled_dataset", persistent=True) + labeled_dataset.rename_sample_field("prediction", "ground_truth") + labeled_dataset.set_field("ground_truth.detections.confidence", None).save() + + # Now time to go to 5_clean_ground_truth session = fo.launch_app(dataset) session.wait() diff --git a/5_clean_ground_truth.py b/5_clean_ground_truth.py new file mode 100644 index 0000000..003bd14 --- /dev/null +++ b/5_clean_ground_truth.py @@ -0,0 +1,4 @@ +import fiftyone as fo + +if __name__ == '__main__': + fo.load_dataset("labeled_dataset") diff --git a/6_fine_tuning.py b/6_fine_tuning.py new file mode 100644 index 0000000..9c34994 --- /dev/null +++ b/6_fine_tuning.py @@ -0,0 +1,119 @@ +import wandb +from plotly.validators.layout.slider import StepsValidator +""" + +Ultralytics YOLOv8*-cls model training script +for generating confidence-based noise labels for a dataset. + +| Copyright 2017-2024, Voxel51, Inc. +| `voxel51.com `_ +| + +Requires `ultralytics` and `fiftyone>=0.25.0` to be installed. +""" +import argparse +import os +import tempfile +import torch +from ultralytics import YOLO +from ultralytics import settings +import fiftyone as fo + +DEFAULT_MODEL_SIZE = "s" +DEFAULT_IMAGE_SIZE = 128 +DEFAULT_EPOCHS = 10 + +wandb.require("core") + + +def get_torch_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + return torch.device("cpu") + + +def train_classifier( + dataset_name=None, + model_size=DEFAULT_MODEL_SIZE, + image_size=DEFAULT_IMAGE_SIZE, + epochs=DEFAULT_EPOCHS, + project_name="mislabel_confidence_noise", + gt_field="ground_truth", + train_split=None, + test_split=None, + **kwargs +): + + settings.update({"wandb": False}) + if dataset_name: + dataset = fo.load_dataset(dataset_name) + train = dataset.match_tags("train") + test = dataset.match_tags("test") + else: + train = train_split + test = test_split + + if model_size is None: + model_size = "s" + elif model_size not in ["n", "s", "m", "l", "x"]: + raise ValueError("model_size must be one of ['n', 's', 'm', 'l', 'x']") + + splits_dict = { + "train": train, + "val": test, + "test": test, + } + + data_dir = tempfile.mkdtemp() + + for key, split in splits_dict.items(): + split_dir = os.path.join(data_dir, key) + os.makedirs(split_dir) + split.export( + export_dir=split_dir, + dataset_type=fo.types.ImageClassificationDirectoryTree, + label_field=gt_field, + export_media="symlink", + ) + + # Load a pre-trained YOLOv8 model for classification + model = YOLO(f"yolov8{model_size}-cls.pt") + + # Train the model + model.train( + data=data_dir, # Path to the dataset + epochs=epochs, # Number of epochs + imgsz=image_size, # Image size + device=get_torch_device(), + project=project_name, + ) + + return model + + +def main(): + if fo.__version__ < "0.25.0": + raise ValueError("Please upgrade to the latest version of FiftyOne") + + # parser = argparse.ArgumentParser() + # parser.add_argument("--dataset_name", type=str, required=True) + # parser.add_argument("--model_size", type=str, default=None) + # parser.add_argument("--image_size", type=int, default=128) + # parser.add_argument("--epochs", type=int, default=10) + # parser.add_argument("--project_name", type=str, default="mislabel_confidence_noise") + # args = parser.parse_args() + + train_classifier( + dataset_name=args.dataset_name, + model_size=args.model_size, + image_size=args.image_size, + epochs=args.epochs, + project_name=args.project_name, + ) + + +if __name__ == "__main__": + main()