Skip to content

Commit

Permalink
Help from Jacob
Browse files Browse the repository at this point in the history
  • Loading branch information
thesteve0 committed Nov 27, 2024
1 parent 117021c commit 35446a4
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 5 deletions.
20 changes: 15 additions & 5 deletions 4_generate_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@
import fiftyone as fo
import fiftyone.zoo as foz


if __name__ == '__main__':
dataset = fo.load_dataset("play_photos")

# Clean up from previous runs
if "labeled_dataset" in fo.list_datasets():
fo.delete_dataset("labeled_dataset")

clip = foz.load_zoo_model(
"clip-vit-base32-torch",
text_prompt="A photo of a",
Expand All @@ -14,14 +18,20 @@
# alexnet = foz.load_zoo_model("alexnet-imagenet-torch")
# dense201 = foz.load_zoo_model("densenet201-imagenet-torch")
# fasterrcnn = foz.load_zoo_model("faster-rcnn-resnet50-fpn-coco-torch")
# yoloseg = foz.load_zoo_model("yolo11x-seg-coco-torch")
#yoloseg = foz.load_zoo_model("yolo11x-seg-coco-torch")

dataset.apply_model(clip, label_field="default_prediction")
dataset.set_values("ground_truth", fo.Classification(label=str(dataset.values("default_prediction.label"))))
dataset.apply_model(clip, label_field="prediction")
# dataset.apply_model(dense201, label_field="dense201")
# dataset.apply_model(alexnet, label_field="alexnet")
# dataset.apply_model(fasterrcnn, label_field="faster_rcnn")
# dataset.apply_model(yoloseg, label_field="yolo_seg")
#dataset.apply_model(yoloseg, label_field="yolo_seg")

#Alright time to make our dataset with cleaned labels
labeled_dataset = dataset.clone(name="labeled_dataset", persistent=True)
labeled_dataset.rename_sample_field("prediction", "ground_truth")
labeled_dataset.set_field("ground_truth.detections.confidence", None).save()

# Now time to go to 5_clean_ground_truth

session = fo.launch_app(dataset)
session.wait()
4 changes: 4 additions & 0 deletions 5_clean_ground_truth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import fiftyone as fo

if __name__ == '__main__':
fo.load_dataset("labeled_dataset")
119 changes: 119 additions & 0 deletions 6_fine_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import wandb
from plotly.validators.layout.slider import StepsValidator
"""
Ultralytics YOLOv8*-cls model training script
for generating confidence-based noise labels for a dataset.
| Copyright 2017-2024, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
Requires `ultralytics` and `fiftyone>=0.25.0` to be installed.
"""
import argparse
import os
import tempfile
import torch
from ultralytics import YOLO
from ultralytics import settings
import fiftyone as fo

DEFAULT_MODEL_SIZE = "s"
DEFAULT_IMAGE_SIZE = 128
DEFAULT_EPOCHS = 10

wandb.require("core")


def get_torch_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")


def train_classifier(
dataset_name=None,
model_size=DEFAULT_MODEL_SIZE,
image_size=DEFAULT_IMAGE_SIZE,
epochs=DEFAULT_EPOCHS,
project_name="mislabel_confidence_noise",
gt_field="ground_truth",
train_split=None,
test_split=None,
**kwargs
):

settings.update({"wandb": False})
if dataset_name:
dataset = fo.load_dataset(dataset_name)
train = dataset.match_tags("train")
test = dataset.match_tags("test")
else:
train = train_split
test = test_split

if model_size is None:
model_size = "s"
elif model_size not in ["n", "s", "m", "l", "x"]:
raise ValueError("model_size must be one of ['n', 's', 'm', 'l', 'x']")

splits_dict = {
"train": train,
"val": test,
"test": test,
}

data_dir = tempfile.mkdtemp()

for key, split in splits_dict.items():
split_dir = os.path.join(data_dir, key)
os.makedirs(split_dir)
split.export(
export_dir=split_dir,
dataset_type=fo.types.ImageClassificationDirectoryTree,
label_field=gt_field,
export_media="symlink",
)

# Load a pre-trained YOLOv8 model for classification
model = YOLO(f"yolov8{model_size}-cls.pt")

# Train the model
model.train(
data=data_dir, # Path to the dataset
epochs=epochs, # Number of epochs
imgsz=image_size, # Image size
device=get_torch_device(),
project=project_name,
)

return model


def main():
if fo.__version__ < "0.25.0":
raise ValueError("Please upgrade to the latest version of FiftyOne")

# parser = argparse.ArgumentParser()
# parser.add_argument("--dataset_name", type=str, required=True)
# parser.add_argument("--model_size", type=str, default=None)
# parser.add_argument("--image_size", type=int, default=128)
# parser.add_argument("--epochs", type=int, default=10)
# parser.add_argument("--project_name", type=str, default="mislabel_confidence_noise")
# args = parser.parse_args()

train_classifier(
dataset_name=args.dataset_name,
model_size=args.model_size,
image_size=args.image_size,
epochs=args.epochs,
project_name=args.project_name,
)


if __name__ == "__main__":
main()

0 comments on commit 35446a4

Please sign in to comment.