Skip to content

Commit

Permalink
fix mlm:artifact_type check missing + update corresponding tests/exam…
Browse files Browse the repository at this point in the history
…ples (fixes #42)
  • Loading branch information
fmigneault committed Nov 1, 2024
1 parent 443d32b commit ef82f26
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 44 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- n/a

### Fixed
- n/a
- Fix missing `mlm:artifact_type` property check for a Model Asset definition
(fixes <https://github.com/stac-extensions/mlm/issues/42>).
The `mlm:artifact_type` is now mutually and exclusively required by the corresponding Asset with `mlm:model` role.

## [v1.3.0](https://github.com/stac-extensions/mlm/tree/v1.3.0)

Expand Down
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -600,13 +600,13 @@ In order to provide more context, the following roles are also recommended were

### Model Asset

| Field Name | Type | Description |
|-------------------|-------------------------------------------|--------------------------------------------------------------------------------------------------|
| title | string | Description of the model asset. |
| href | string | URI to the model artifact. |
| type | string | The media type of the artifact (see [Model Artifact Media-Type](#model-artifact-media-type). |
| roles | \[string] | **REQUIRED** Specify `mlm:model`. Can include `["mlm:weights", "mlm:checkpoint"]` as applicable. |
| mlm:artifact_type | [Artifact Type Enum](#artifact-type-enum) | Specifies the kind of model artifact. Typically related to a particular ML framework. |
| Field Name | Type | Description |
|-------------------|---------------------------------|--------------------------------------------------------------------------------------------------|
| title | string | Description of the model asset. |
| href | string | URI to the model artifact. |
| type | string | The media type of the artifact (see [Model Artifact Media-Type](#model-artifact-media-type). |
| roles | \[string] | **REQUIRED** Specify `mlm:model`. Can include `["mlm:weights", "mlm:checkpoint"]` as applicable. |
| mlm:artifact_type | [Artifact Type](#artifact-type) | Specifies the kind of model artifact. Typically related to a particular ML framework. |

Recommended Asset `roles` include `mlm:weights` or `mlm:checkpoint` for model weights that need to be loaded by a
model definition and `mlm:compiled` for models that can be loaded directly without an intermediate model definition.
Expand Down Expand Up @@ -642,7 +642,7 @@ official. In order to validate the specific framework and artifact type employed

[iana-media-type]: https://www.iana.org/assignments/media-types/media-types.xhtml

#### Artifact Type Enum
#### Artifact Type

This value can be used to provide additional details about the specific model artifact being described.
For example, PyTorch offers [various strategies][pytorch-frameworks] for providing model definitions,
Expand Down
1 change: 1 addition & 0 deletions examples/item_bands_expression.json
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@
"mlm:model",
"mlm:weights"
],
"mlm:artifact_type": "torch.save",
"$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.",
"eo:bands": [
{
Expand Down
3 changes: 2 additions & 1 deletion examples/item_basic.json
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@
"type": "text/html",
"roles": [
"mlm:model"
]
],
"mlm:artifact_type": "torch.save"
}
},
"links": [
Expand Down
2 changes: 1 addition & 1 deletion examples/item_eo_and_raster_bands.json
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@
"mlm:model",
"mlm:weights"
],
"mlm:artifact_type": "torch.save",
"$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.",
"eo:bands": [
{
Expand Down Expand Up @@ -546,7 +547,6 @@
"description": "Source code to run the model.",
"type": "text/x-python",
"roles": [
"mlm:model",
"code",
"metadata"
]
Expand Down
1 change: 1 addition & 0 deletions examples/item_eo_bands.json
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@
"mlm:model",
"mlm:weights"
],
"mlm:artifact_type": "torch.save",
"$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.",
"eo:bands": [
{
Expand Down
2 changes: 1 addition & 1 deletion examples/item_eo_bands_summarized.json
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@
"mlm:model",
"mlm:weights"
],
"mlm:artifact_type": "torch.save",
"$comment": "Following 'eo:bands' is required to fulfil schema validation of 'eo' extension.",
"eo:bands": [
{
Expand Down Expand Up @@ -415,7 +416,6 @@
"description": "Source code to run the model.",
"type": "text/x-python",
"roles": [
"mlm:model",
"code",
"metadata"
]
Expand Down
1 change: 1 addition & 0 deletions examples/item_multi_io.json
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@
"mlm:model",
"mlm:weights"
],
"mlm:artifact_type": "torch.save",
"raster:bands": [
{
"name": "B02 - blue",
Expand Down
1 change: 1 addition & 0 deletions examples/item_raster_bands.json
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@
"mlm:model",
"mlm:weights"
],
"mlm:artifact_type": "torch.save",
"raster:bands": [
{
"name": "B01",
Expand Down
89 changes: 66 additions & 23 deletions json-schema/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,12 @@
"$ref": "#/$defs/AnyBandsRef"
},
{
"$comment": "Schema to validate model role requirement.",
"$comment": "Schema to validate that at least one Asset defines a model role.",
"$ref": "#/$defs/AssetModelRoleMinimumOneDefinition"
},
{
"$comment": "Schema to validate that the Asset model properties are mutually exclusive to the model role.",
"$ref": "#/$defs/AssetModelRequiredProperties"
}
]
}
Expand Down Expand Up @@ -369,6 +373,15 @@
"type": "string",
"pattern": "^(0|[1-9]\\d*)\\.(0|[1-9]\\d*)\\.(0|[1-9]\\d*)(?:-((?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*)(?:\\.(?:0|[1-9]\\d*|\\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?(?:\\+([0-9a-zA-Z-]+(?:\\.[0-9a-zA-Z-]+)*))?$"
},
"mlm:artifact_type": {
"type": "string",
"minLength": 1,
"examples": [
"torch.save",
"torch.jit.save",
"torch.export.save"
]
},
"mlm:tasks": {
"type": "array",
"uniqueItems": true,
Expand Down Expand Up @@ -729,6 +742,57 @@
"DataType": {
"$ref": "https://stac-extensions.github.io/raster/v1.1.0/schema.json#/definitions/bands/items/properties/data_type"
},
"HasArtifactType": {
"$comment": "Used to check the artifact type property that is required by a Model Asset annotated by 'mlm:model' role.",
"type": "object",
"required": [
"mlm:artifact_type"
],
"properties": {
"mlm:artifact_type": {
"$ref": "#/$defs/mlm:artifact_type"
}
}
},
"AssetModelRole": {
"$comment": "Used to check the presence of 'mlm:model' role required by a Model Asset.",
"type": "object",
"required": [
"roles"
],
"properties": {
"roles": {
"type": "array",
"contains": {
"const": "mlm:model"
},
"minItems": 1
}
}
},
"AssetModelRequiredProperties": {
"$comment": "Asset containing the model definition must indicate both the 'mlm:model' role and an artifact type.",
"required": [
"assets"
],
"properties": {
"assets": {
"additionalProperties": {
"if": {
"$ref": "#/$defs/AssetModelRole"
},
"then": {
"$ref": "#/$defs/HasArtifactType"
},
"else": {
"not": {
"$ref": "#/$defs/HasArtifactType"
}
}
}
}
}
},
"AssetModelRoleMinimumOneDefinition": {
"$comment": "At least one Asset must provide the model definition indicated by the 'mlm:model' role.",
"required": [
Expand All @@ -739,15 +803,7 @@
"properties": {
"assets": {
"additionalProperties": {
"properties": {
"roles": {
"type": "array",
"items": {
"const": "mlm:model"
},
"minItems": 1
}
}
"$ref": "#/$defs/AssetModelRole"
}
}
}
Expand Down Expand Up @@ -775,19 +831,6 @@
}
]
},
"AssetModelRole": {
"required": [
"roles"
],
"properties": {
"roles": {
"contains": {
"type": "string",
"const": "mlm:model"
}
}
}
},
"ModelBands": {
"description": "List of bands (if any) that compose the input. Band order represents the index position of the bands.",
"$comment": "No 'minItems' here to support model inputs not using any band (other data source).",
Expand Down
1 change: 1 addition & 0 deletions stac_model/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def eurosat_resnet() -> ItemMLModelExtension:
"mlm:weights",
"data",
],
extra_fields={"mlm:artifact_type": "torch.save"}
),
"source_code": pystac.Asset(
title="Model implementation.",
Expand Down
37 changes: 28 additions & 9 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,44 @@ def test_mlm_other_non_mlm_assets_allowed(
["item_basic.json"],
indirect=True,
)
@pytest.mark.parametrize(
["model_asset_extras", "is_valid"],
[
({"roles": ["checkpoint"]}, False),
({"roles": ["checkpoint", "mlm:model"]}, False),
({"roles": ["checkpoint"], "mlm:artifact_type": "test"}, False),
({"roles": ["checkpoint", "mlm:model"], "mlm:artifact_type": "test"}, True),
]
)
def test_mlm_at_least_one_asset_model(
mlm_validator: STACValidator,
mlm_example: Dict[str, JSON],
model_asset_extras: Dict[str, Any],
is_valid: bool,
) -> None:
mlm_data = copy.deepcopy(mlm_example)
mlm_item = pystac.Item.from_dict(mlm_data)
pystac.validation.validate(mlm_item, validator=mlm_validator) # self-check valid beforehand

mlm_data["assets"] = { # needs at least 1 asset with role 'mlm:model'
"model": {
"type": "application/octet-stream; application=pytorch",
"href": "https://example.com/sample/checkpoint.pt",
"roles": ["checkpoint"],
"title": "Model Weights Checkpoint",
}
mlm_model = {
"type": "application/octet-stream; application=pytorch",
"href": "https://example.com/sample/checkpoint.pt",
"title": "Model Weights Checkpoint",
}
with pytest.raises(pystac.errors.STACValidationError):
mlm_item = pystac.Item.from_dict(mlm_data)
mlm_model.update(model_asset_extras)
mlm_data["assets"] = {
"model": mlm_model
}
mlm_item = pystac.Item.from_dict(mlm_data)
if is_valid:
pystac.validation.validate(mlm_item, validator=mlm_validator)
else:
with pytest.raises(pystac.errors.STACValidationError) as exc:
pystac.validation.validate(mlm_item, validator=mlm_validator)
assert exc.value.source[0].schema["$comment"] in [
"At least one Asset must provide the model definition indicated by the 'mlm:model' role.",
"Used to check the artifact type property that is required by a Model Asset annotated by 'mlm:model' role."
]


def test_model_metadata_to_dict(eurosat_resnet):
Expand Down

0 comments on commit ef82f26

Please sign in to comment.