Skip to content

Commit

Permalink
feat: import predictions in __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
antwxne committed Jun 8, 2022
1 parent a65d91c commit 2e0f437
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 11 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="CVAT wrapper",
version="0.0.2",
version="0.0.3",
author="antwxne",
author_email="[email protected]",
description="Python wrapper for CVAT API",
Expand Down
1 change: 1 addition & 0 deletions src/CVAT/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ._put import Put
from ._static import Static
from .data_types import Task
from .Prediction import *


class CVAT(Get, Post, Delete, Patch, Put, Static):
Expand Down
66 changes: 64 additions & 2 deletions src/CVAT/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,77 @@
#!/bin/python3
# Created by [email protected] at 5/23/22
from pathlib import Path
import re
from os import walk
from typing import Optional

from tqdm import tqdm

from src.CVAT import CVAT


def tryint(s):
"""
Return an int if possible, or `s` unchanged.
"""
try:
return int(s)
except ValueError:
return s


def alphanum_key(s):
"""
Turn a string into a list of string and number chunks.
#>>> alphanum_key("z23a")
["z", 23, "a"]
"""
return [tryint(c) for c in re.split('([0-9]+)', s)]


def human_sort(l):
"""
Sort a list in the way that humans expect.
"""
l.sort(key=alphanum_key)


def get_files_from_path(path: str, to_sort: Optional[bool] = True) -> list[str]:
"""
It takes a path and returns a list of files in that path
Args:
path (str): The path to the directory you want to get the files from.
to_sort (Optional[bool]): If True, the files will be sorted in a human-friendly way. Defaults to True
Returns:
A list of files in the path
"""
f = next(walk(path), (None, None, []))[2]
if to_sort:
human_sort(f)
return [f'{path}/{file}' for file in f]


def image_content_from_kili_prediction(prediction: list[dict], directory: str) -> list[str]:
"""
It downloads the images from the Kili API and saves them in the directory you specify
Args:
prediction (list[dict]): list[dict]
directory (str): str = "./images"
Returns:
A list of paths to the images.
"""
paths: list[str] = []
current_files: list[str] = get_files_from_path(directory)
files_without_extension: list[str] = [file.split(".")[0] for file in current_files]
for elem in tqdm(prediction, unit="Image"):
path: str = f'{directory}/{elem["externalId"]}'
paths.append(CVAT.download_image(elem["content"], path))
if path not in files_without_extension:
paths.append(CVAT.download_image(elem["content"], path))
else:
paths.append(current_files[files_without_extension.index(path)])
return paths
15 changes: 7 additions & 8 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from src.CVAT import CVAT
from src.CVAT.data_types import Task
from src.CVAT.utils import image_content_from_kili_prediction
from src.CVAT.Predictions.Interface import IPrediction

API: CVAT = CVAT()

Expand All @@ -19,12 +18,12 @@
directory: str = "data/Images/Fromages"
interface: dict = CVAT.get_json_from_file("../interface_foodvisor.json")
images_path: list[str] = image_content_from_kili_prediction(prediction_json, directory)
project_id: int = API.create_project("TEST_API2", interface=interface)
task: Task = Task("jqefhgfdqf", project_id=project_id)
# project_id: int = API.create_project("TEST_API2", interface=interface)
# task: Task = Task("jqefhgfdqf", project_id=project_id)
# task: Task = API.get_task_by_name("qsfqsfqsf")
task = API.create_task(task)
API.add_local_images_to_task(task=task, images_path=images_path)
prediction: IPrediction = API.get_prediction_from_file(task, "foodvisor",
"./foodvisor_valid_12_2021_Fromages_gmd_predictions.json")
API.upload_predictions(task, prediction)
# task = API.create_task(task)
# API.add_local_images_to_task(task=task, images_path=images_path)
# prediction: IPrediction = API.get_prediction_from_file(task, "foodvisor",
# "./foodvisor_valid_12_2021_Fromages_gmd_predictions.json")
# API.upload_predictions(task, prediction)

0 comments on commit 2e0f437

Please sign in to comment.