Skip to content

Commit

Permalink
feat: Added local testing, optional extraction of version from resour…
Browse files Browse the repository at this point in the history
…ce_path
  • Loading branch information
jmetz committed Feb 8, 2024
1 parent 434d6b1 commit 4c076b6
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 47 deletions.
33 changes: 28 additions & 5 deletions .github/scripts/s3_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass, field
from datetime import timedelta
from pathlib import Path
from typing import Iterator
from typing import Iterator, Optional

# import requests # type: ignore
from loguru import logger # type: ignore
Expand Down Expand Up @@ -39,12 +39,12 @@ def __post_init__(self):
raise Exception("target bucket does not exist: {self.bucket}")
logger.debug("Created S3-Client: {}", self)

def bucket_exists(self, bucket):
def bucket_exists(self, bucket) -> bool:
return self._client.bucket_exists(bucket)

def put(
self, path, file_object, length=-1, content_type="application/octet-stream"
):
) -> None:
# For unknown length (ie without reading file into mem) give `part_size`
part_size = 0
if length == -1:
Expand Down Expand Up @@ -171,7 +171,7 @@ def get_status(self, resource_path: str, version: str) -> dict:
status = json.loads(status_str)
return status

def put_status(self, resource_path: str, version: str, status: dict):
def put_status(self, resource_path: str, version: str, status: dict) -> None:
logger.debug(
"Updating status for {}-{}, with {}", resource_path, version, status
)
Expand All @@ -197,7 +197,7 @@ def get_log(self, resource_path: str, version: str) -> dict:
log = {}
return log

def put_log(self, resource_path: str, version: str, log: dict):
def put_log(self, resource_path: str, version: str, log: dict) -> None:
logger.debug("Updating log for {}-{}, with {}", resource_path, version, log)
contents = json.dumps(log).encode()
file_object = io.BytesIO(contents)
Expand All @@ -209,6 +209,11 @@ def put_log(self, resource_path: str, version: str, log: dict):
content_type="application/json",
)

def get_url_for_file(self, resource_path: str, filename: str, version: Optional[str] = None) -> str:
if version is None:
resource_path, version = version_from_resource_path_or_s3(resource_path, self)
return f"https://{self.host}/{self.bucket}/{self.prefix}/{resource_path}/{version}/files/{filename}"


def create_client() -> Client:
"""
Expand All @@ -234,3 +239,21 @@ def create_client() -> Client:
secret_key=secret_access_key,
)
return client


def version_from_resource_path_or_s3(resource_path, client : Optional[Client] = None) -> tuple[str, str]:
"""
Extract version from resource_path if present
Otherwise try and determine from model folder
"""
parts = resource_path.split("/")
if len(parts) == 2:
resource_path = parts[0]
version = parts[1]
logger.info("Version: {}", version)
else:
if client is None:
client = create_client()
version = client.get_unpublished_version(resource_path)
logger.info("Version detected: {}", version)
return resource_path, version
14 changes: 4 additions & 10 deletions .github/scripts/update_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
from typing import Optional

from loguru import logger
from s3_client import create_client
from s3_client import create_client, version_from_resource_path_or_s3


def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("resource_path", help="Resource name")
parser.add_argument("category", help="Log category")
parser.add_argument("summary", help="Log summary")
parser.add_argument("--version", help="Version")
return parser


Expand All @@ -28,11 +27,10 @@ def main():
resource_path = args.resource_path
category = args.category
summary = args.summary
version = args.version
add_log_entry(resource_path, category, summary, version=version)
add_log_entry(resource_path, category, summary)


def add_log_entry(resource_path, category, summary, version=None):
def add_log_entry(resource_path, category, summary):
timenow = datetime.datetime.now().isoformat()
client = create_client()
logger.info(
Expand All @@ -42,11 +40,7 @@ def add_log_entry(resource_path, category, summary, version=None):
summary,
)

if version is None:
version = client.get_unpublished_version(resource_path)
logger.info("Version detected: {}", version)
else:
logger.info("Version requested: {}", version)
resource_path, version = version_from_resource_path_or_s3(resource_path)
log = client.get_log(resource_path, version)

if category not in log:
Expand Down
15 changes: 4 additions & 11 deletions .github/scripts/update_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@
from typing import Optional

from loguru import logger
from s3_client import create_client

from s3_client import create_client, version_from_resource_path_or_s3

def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("resource_path", help="Model name")
parser.add_argument("status", help="Status")
parser.add_argument("--version", help="Version")
parser.add_argument("--step", help="Step", default=0, type=int)
parser.add_argument("--num_steps", help="Status", default=0, type=int)
return parser
Expand All @@ -27,14 +25,13 @@ def get_args(argv: Optional[list] = None):
def main():
args = get_args()
resource_path = args.resource_path
version = args.version
step = args.step
num_steps = args.num_steps
status = args.status
update_status(resource_path, status, version=version, step=step, num_steps=num_steps)
update_status(resource_path, status, step=step, num_steps=num_steps)


def update_status(resource_path: str, status_text: str, version: Optional[str] = None, step: Optional[int]=None, num_steps: int = 6):
def update_status(resource_path: str, status_text: str, step: Optional[int]=None, num_steps: int = 6):
assert step is None or step <= num_steps
timenow = datetime.datetime.now().isoformat()
client = create_client()
Expand All @@ -46,11 +43,7 @@ def update_status(resource_path: str, status_text: str, version: Optional[str] =
num_steps,
)

if version is None:
version = client.get_unpublished_version(resource_path)
logger.info("Version detected: {}", version)
else:
logger.info("Version requested: {}", version)
resource_path, version = version_from_resource_path_or_s3(resource_path, client)
status = client.get_status(resource_path, version)

if "messages" not in status:
Expand Down
12 changes: 4 additions & 8 deletions .github/scripts/upload_model_to_zenodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from loguru import logger # type: ignore
from packaging.version import parse as parse_version
from ruyaml import YAML # type: ignore
from s3_client import create_client
from s3_client import create_client, version_from_resource_path_or_s3
from update_status import update_status

yaml = YAML(typ="safe")
Expand Down Expand Up @@ -57,8 +57,7 @@ def assert_good_response(response, message, info=None):

def create_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument("--resource_path", help="Model name", required=True)
parser.add_argument("--version", help="Version", nargs="?", default=None)
parser.add_argument("resource_path", help="Resource path")
return parser


Expand All @@ -76,12 +75,9 @@ def main():
params = {"access_token": ACCESS_TOKEN}

client = create_client()
resource_path, version = version_from_resource_path_or_s3(args.resource_path, client)

# TODO: GET THE CURRENT VERSION
if args.version is None:
version = client.get_unpublished_version(args.resource_path)

s3_path = f"{args.resource_path}/{version}/files"
s3_path = f"{resource_path}/{version}/files"

# List the files at the model URL
file_urls = client.get_file_urls(path=s3_path)
Expand Down
26 changes: 16 additions & 10 deletions .github/scripts/validate_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,23 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import requests
import typer
from bioimageio.spec import load_raw_resource_description, validate
from bioimageio.spec.model.raw_nodes import Model, WeightsFormat
from bioimageio.spec.rdf.raw_nodes import RDF_Base
from bioimageio.spec.shared import yaml
from bioimageio.spec.shared.raw_nodes import URI, Dependencies
import requests # type: ignore
import typer # type: ignore
from bioimageio.spec import load_raw_resource_description, validate # type: ignore
from bioimageio.spec.model.raw_nodes import Model, WeightsFormat # type: ignore
from bioimageio.spec.rdf.raw_nodes import RDF_Base # type: ignore
from bioimageio.spec.shared import yaml # type: ignore
from bioimageio.spec.shared.raw_nodes import URI, Dependencies # type: ignore
from marshmallow import missing
from marshmallow.utils import _Missing
from packaging.version import Version
from tqdm import tqdm
from tqdm import tqdm # type: ignore
from update_log import add_log_entry
from s3_client import create_client, version_from_resource_path_or_s3

tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) # silence tqdm


from update_log import add_log_entry


def set_multiple_gh_actions_outputs(outputs: Dict[str, Union[str, Any]]):
Expand Down Expand Up @@ -250,7 +251,12 @@ def prepare_dynamic_test_cases(descr_id: str, rd: RDF_Base) -> List[Dict[str, st
return validation_cases


def validate_format(descr_id: str, source: str):
def validate_format(descr_id: str):

client = create_client()
resource_path, version = version_from_resource_path_or_s3(descr_id, client)
source = client.get_url_for_file(resource_path, "rdf.yaml", version=version)

dynamic_test_cases: List[Dict[str, str]] = []

summaries = [validate(source)]
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/ci_runner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
id: validate
run: |
python .github/scripts/update_status.py "${{ inputs.resource_path }}" "Starting validation" "2"
python .github/scripts/validate_format.py "${{ inputs.resource_path }}" "${{env.S3_HOST}}/${{env.S3_BUCKET}}/${{env.S3_FOLDER}}/${{inputs.resource_path}}/files/rdf.yaml"
python .github/scripts/validate_format.py "${{ inputs.resource_path }}"
- run: |
python .github/scripts/update_status.py "${{ inputs.resource_path }}" "Starting additional tests" "3"
if: steps.validate.outputs.has_dynamic_test_cases == 'yes'
Expand Down Expand Up @@ -80,7 +80,7 @@ jobs:
run: pip install typer bioimageio.spec
- name: dynamic validation
shell: bash -l {0}
run: python scripts/test_dynamically.py "${{env.S3_HOST}}/${{env.S3_BUCKET}}/${{env.S3_FOLDER}}/${{inputs.resource_path}}/files/rdf.yaml" ${{ matrix.weight_format }} --create-env-outcome ${{ steps.create_env.outcome }} --${{ contains(inputs.deploy_to, 'gh-pages') && 'no-ignore' || 'ignore' }}-rdf-source-field-in-validation
run: python scripts/test_dynamically.py "https://${{env.S3_HOST}}/${{env.S3_BUCKET}}/${{env.S3_FOLDER}}/${{inputs.resource_path}}/files/rdf.yaml" ${{ matrix.weight_format }} --create-env-outcome ${{ steps.create_env.outcome }} --${{ contains(inputs.deploy_to, 'gh-pages') && 'no-ignore' || 'ignore' }}-rdf-source-field-in-validation
timeout-minutes: 60

conclude:
Expand Down
2 changes: 1 addition & 1 deletion .local/test_zenodo_upload.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ set -o allexport
source $SCRIPT_DIR/.env
set +o allexport

python $SCRIPT_DIR/../.github/scripts/upload_model_to_zenodo.py --resource_path=willing-pig
python $SCRIPT_DIR/../.github/scripts/upload_model_to_zenodo.py "willing-pig"

0 comments on commit 4c076b6

Please sign in to comment.