diff --git a/Dockerfile b/Dockerfile index 1f25cdb..e49837d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,4 +1,4 @@ -FROM python:3.9.16-slim-bullseye +FROM python:3.9.18-slim-bullseye WORKDIR /work RUN apt-get update -y && \ diff --git a/basemodels/manifest/manifest.py b/basemodels/manifest/manifest.py index 73974a9..45fc25f 100644 --- a/basemodels/manifest/manifest.py +++ b/basemodels/manifest/manifest.py @@ -1,9 +1,11 @@ import json import uuid +from datetime import datetime + import requests from requests.exceptions import RequestException from typing_extensions import Literal -from typing import Dict, Union, List, Optional +from typing import Dict, Union, List, Optional, Any from enum import Enum from uuid import UUID, uuid4 from .data.groundtruth import validate_groundtruth_entry @@ -84,7 +86,7 @@ def to_primitive(self): # Helper function for using in the unittest def check(self, return_new=False): - self.__class__.validate(self) + self.__class__.validate(self.dict()) out_dict, _, validation_error = validate_model(self.__class__, self.__dict__) if validation_error: raise validation_error @@ -280,6 +282,7 @@ class Manifest(Model): task_bid_price: float oracle_stake: Decimal expiration_date: Optional[int] + start_date: Optional[int] requester_accuracy_target: float = 0.1 manifest_smart_bounty_addr: Optional[str] hmtoken_addr: Optional[str] @@ -323,6 +326,26 @@ class Manifest(Model): ##### Validators + @root_validator + def validate(cls, values: Dict[str, Any]) -> Dict[str, Any]: + start_date = values["start_date"] + expiration_date = values["expiration_date"] + if not start_date and not expiration_date: + # Timestamps are not passed + return values + has_both_dates = bool(start_date) and bool(expiration_date) + if not has_both_dates: + raise ValueError("You must specify both start_date and expiration_date") + + if not start_date < expiration_date: + raise ValueError("start_date must be before expiration_date") + + duration = datetime.utcfromtimestamp(expiration_date) - datetime.utcfromtimestamp(start_date) + if 7 < duration.days: + raise ValueError("Max job duration is 7 days.") + + return values + @validator("requester_min_repeats") def validate_min_repeats(cls, v, values): """min repeats are required to be at least 4 if ilmc""" diff --git a/tests/test_manifest_validation.py b/tests/test_manifest_validation.py index dfccdc9..8ead9e7 100644 --- a/tests/test_manifest_validation.py +++ b/tests/test_manifest_validation.py @@ -1,19 +1,18 @@ #!/usr/bin/env python3 -from pydantic import ValidationError -from typing import Any - +import json import logging -import basemodels -from uuid import uuid4 +import unittest from copy import deepcopy +from datetime import datetime +from typing import Any +from uuid import uuid4 -from basemodels.manifest.data.taskdata import TaskDataEntry - -import unittest import httpretty -import json +from pydantic import ValidationError +import basemodels +from basemodels.manifest.data.taskdata import TaskDataEntry from basemodels.manifest.restricted_audience import RestrictedAudience CALLBACK_URL = "http://google.com/webback" @@ -56,18 +55,18 @@ def validate_func(model): test_models = basemodels -def a_manifest( - number_of_tasks=100, - bid_amount=1.0, - oracle_stake=0.05, - expiration_date=0, - minimum_trust=0.1, - request_type=IMAGE_LABEL_BINARY, - request_config=None, - job_mode="batch", - multi_challenge_manifests=None, - is_verification=None, -) -> Any: +def get_data( + number_of_tasks=100, + bid_amount=1.0, + oracle_stake=0.05, + expiration_date=0, + minimum_trust=0.1, + request_type=IMAGE_LABEL_BINARY, + request_config=None, + job_mode="batch", + multi_challenge_manifests=None, + is_verification=None, +) -> dict: internal_config = {"exchange": {"a": 1, "b": "c"}} model = { "requester_restricted_answer_set": { @@ -104,7 +103,34 @@ def a_manifest( if is_verification is not None: model.update({"is_verification": is_verification}) - manifest = create_manifest(model) + return model + + +def a_manifest( + number_of_tasks=100, + bid_amount=1.0, + oracle_stake=0.05, + expiration_date=0, + minimum_trust=0.1, + request_type=IMAGE_LABEL_BINARY, + request_config=None, + job_mode="batch", + multi_challenge_manifests=None, + is_verification=None, +) -> Any: + data = get_data( + number_of_tasks=number_of_tasks, + bid_amount=bid_amount, + oracle_stake=oracle_stake, + expiration_date=expiration_date, + minimum_trust=minimum_trust, + request_type=request_type, + request_config=request_config, + job_mode=job_mode, + multi_challenge_manifests=multi_challenge_manifests, + is_verification=is_verification, + ) + manifest = create_manifest(data) validate_func(manifest)() return manifest @@ -498,6 +524,60 @@ def test_realistic_multi_challenge_example(self): # print(model.to_primitive()) self.assertTrue(validate_func(model)() is None) + def test_both_timestamps_are_required(self): + """validates both start_date & expiration_date must be passed at the same time.""" + # Given + data = get_data() + del data["expiration_date"] + data["start_date"] = int(datetime(2023, 10, 20).timestamp()) + + # When/Then + with self.assertRaises(ValueError): + basemodels.Manifest(**data) + + # When/Then + del data["start_date"] + data["expiration_date"] = int(datetime(2023, 10, 21).timestamp()) + with self.assertRaises(ValueError): + basemodels.Manifest(**data) + + # When/Then + data["start_date"] = int(datetime(2023, 10, 20).timestamp()) + data["expiration_date"] = int(datetime(2023, 10, 21).timestamp()) + basemodels.Manifest(**data) + + def test_timestamps_validation(self): + """validates start_date must be before expiration_date.""" + # Given + data = get_data() + data["start_date"] = int(datetime(2023, 10, 22).timestamp()) + data["expiration_date"] = int(datetime(2023, 10, 21).timestamp()) + + # When/Then + with self.assertRaises(ValueError): + basemodels.Manifest(**data) + + # When/Then + data["start_date"] = int(datetime(2023, 10, 20).timestamp()) + data["expiration_date"] = int(datetime(2023, 10, 21).timestamp()) + basemodels.Manifest(**data) + + def test_timestamps_duration_validation(self): + """validates max duration.""" + # Given + data = get_data() + data["start_date"] = datetime(2023, 10, 20).timestamp() + data["expiration_date"] = datetime(2023, 10, 28).timestamp() + + # When/Then + with self.assertRaises(ValueError): + basemodels.Manifest(**data) + + # When/Then + data["start_date"] = datetime(2023, 10, 20).timestamp() + data["expiration_date"] = datetime(2023, 10, 27).timestamp() + basemodels.Manifest(**data) + def test_webhook(self): """Test that webhook is correct""" webhook = {