From f9523476353bfeb7bbe836fbc50ccb3f51d99858 Mon Sep 17 00:00:00 2001 From: Francis Charette Migneault Date: Mon, 17 Feb 2025 21:47:20 -0500 Subject: [PATCH 1/3] add HuggingFace SafeTensors to MLM Artifact Types Best Practices (fixes #68) --- CHANGELOG.md | 4 +++- README.md | 18 +++++++++--------- best-practices.md | 21 +++++++++++++-------- json-schema/schema.json | 5 ++++- 4 files changed, 29 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d9fad41..417668a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased](https://github.com/stac-extensions/mlm/tree/main) ### Added -- n/a +- Add [`huggingface/safetensors`](https://github.com/huggingface/safetensors) + recommendations for ``mlm:artifact_type`` and corresponding ``mlm:framework`` values + (fixes [#68](https://github.com/stac-extensions/mlm/issues/68)). ### Changed - n/a diff --git a/README.md b/README.md index a9053c8..2433d8a 100644 --- a/README.md +++ b/README.md @@ -121,7 +121,7 @@ The fields in the table below can be used in these parts of STAC documents: | mlm:name [\[1\]][1] | string | **REQUIRED** A name for the model. This can include, but must be distinct, from simply naming the model architecture. If there is a publication or other published work related to the model, use the official name of the model. | | mlm:architecture | [Model Architecture](#model-architecture) string | **REQUIRED** A generic and well established architecture name of the model. | | mlm:tasks | \[[Task Enum](#task-enum)] | **REQUIRED** Specifies the Machine Learning tasks for which the model can be used for. If multi-tasks outputs are provided by distinct model heads, specify all available tasks under the main properties and specify respective tasks in each [Model Output Object](#model-output-object). | -| mlm:framework | string | Framework used to train the model (ex: PyTorch, TensorFlow). | +| mlm:framework | string | Framework used to train the model (ex: PyTorch, TensorFlow). Typically, this will align with the applied `mlm:artifact_type` of the [Model Asset](#model-asset). | | mlm:framework_version | string | The `framework` library version. Some models require a specific version of the machine learning `framework` to run. | | mlm:memory_size | integer | The in-memory size of the model on the accelerator during inference (bytes). | | mlm:total_parameters | integer | Total number of model parameters, including trainable and non-trainable parameters. | @@ -661,14 +661,14 @@ In order to provide more context, the following roles are also recommended were ### Model Asset -| Field Name | Type | Description | -|-------------------|---------------------------------|--------------------------------------------------------------------------------------------------| -| title | string | Description of the model asset. | -| href | string | URI to the model artifact. | -| type | string | The media type of the artifact (see [Model Artifact Media-Type](#model-artifact-media-type). | -| roles | \[string] | **REQUIRED** Specify `mlm:model`. Can include `["mlm:weights", "mlm:checkpoint"]` as applicable. | -| mlm:artifact_type | [Artifact Type](./best-practices.md#framework-specific-artifact-types) | Specifies the kind of model artifact, any string is allowed. Typically related to a particular ML framework, see [Best Practices - Framework Specific Artifact Types](./best-practices.md#framework-specific-artifact-types) for **RECOMMENDED** values. This field is **REQUIRED** if the `mlm:model` role is specified. | -| mlm:compile_method | [Compile Method](#compile-method) \| null | Describes the method used to compile the ML model either when the model is saved or at model runtime prior to inference. | +| Field Name | Type | Description | +|--------------------|------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| title | string | Description of the model asset. | +| href | string | URI to the model artifact. | +| type | string | The media type of the artifact (see [Model Artifact Media-Type](#model-artifact-media-type). | +| roles | \[string] | **REQUIRED** Specify `mlm:model`. Can include `["mlm:weights", "mlm:checkpoint"]` as applicable. | +| mlm:artifact_type | [Artifact Type](./best-practices.md#framework-specific-artifact-types) | Specifies the kind of model artifact, any string is allowed. Typically related to a particular ML framework, see [Best Practices - Framework Specific Artifact Types](./best-practices.md#framework-specific-artifact-types) for **RECOMMENDED** values. This field is **REQUIRED** if the `mlm:model` role is specified. | +| mlm:compile_method | [Compile Method](#compile-method) \| null | Describes the method used to compile the ML model either when the model is saved or at model runtime prior to inference. | Recommended Asset `roles` include `mlm:weights` or `mlm:checkpoint` for model weights that need to be loaded by a model definition and `mlm:compiled` for models that can be loaded directly without an intermediate model definition. diff --git a/best-practices.md b/best-practices.md index 7484b45..9807697 100644 --- a/best-practices.md +++ b/best-practices.md @@ -301,14 +301,15 @@ permitted, as these values are not validated by the schema. Note that the names framework-specific definitions to help the users understand how the model artifact was created, although these exact names are not strictly required either. -| Artifact Type | Description | -|--------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| `torch.save` | A [serialized python pickle object][pytorch-save] (i.e.: `.pt`) which can represent a model or state_dict. | -| `torch.jit.save` | A [`TorchScript`][pytorch-jit-script] model artifact obtained with one or more of the graph export options Torchscript Tracing and Torchscript Scripting. | -| `torch.export.save` | A model artifact storing an [ExportedProgram][exported-program] obtained by [`torch.export.export`][pytorch-export] (i.e.: `.pt2`). | -| `tf.keras.Model.save` | Saves a [.keras model file][keras-model], a unified zip archive format containing the architecture, weights, optimizer, losses, and metrics. | -| `tf.keras.Model.save_weights` | A [.weights.h5][keras-save-weights] file containing only model weights for use by Tensorflow or Keras. | -| `tf.keras.Model.export` | [TF Saved Model][tf-saved-model] is the [recommended format][tf-keras-recommended] by the Tensorflow team for whole model saving/loading for inference. See the docs for [different save methods][keras-methods] in TF and Keras. | +| Artifact Type | Description | +|------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `torch.save` | A [serialized python pickle object][pytorch-save] (i.e.: `.pt`) which can represent a model or state_dict. | +| `torch.jit.save` | A [`TorchScript`][pytorch-jit-script] model artifact obtained with one or more of the graph export options TorchScript Tracing and Scripting. | +| `torch.export.save` | A model artifact storing an [ExportedProgram][exported-program] obtained by [`torch.export.export`][pytorch-export] (i.e.: `.pt2`). | +| `tf.keras.Model.save` | Saves a [.keras model file][keras-model], a unified zip archive format containing the architecture, weights, optimizer, losses, and metrics. | +| `tf.keras.Model.save_weights` | A [.weights.h5][keras-save-weights] file containing only model weights for use by Tensorflow or Keras. | +| `tf.keras.Model.export` | [TF Saved Model][tf-saved-model] is the [recommended format][tf-keras-recommended] by the Tensorflow team for whole model saving/loading for inference. See the docs for [different save methods][keras-methods] in TF and Keras. | +| `safetensors.{framework}.{method}` | Model weights saved as [HuggingFace SafeTensors][hf-st], where `{framework}` matches the [`mlm:framework`][mlm-framework] of a [*supported framework*][hf-st-support] and `{method}` matches the applicable method from SafeTensors. For example, a PyTorch model saved this way would indicate [`safetensors.torch.save_file`][hf-st-torch]. | [exported-program]: https://pytorch.org/docs/main/export.html#serialization [pytorch-aot-inductor]: https://pytorch.org/docs/main/torch.compiler_aot_inductor.html @@ -321,3 +322,7 @@ names are not strictly required either. [tf-keras-recommended]: https://www.tensorflow.org/guide/saved_model#creating_a_savedmodel_from_keras [keras-methods]: https://keras.io/2.16/api/models/model_saving_apis/ [keras-model]: https://keras.io/api/models/model_saving_apis/model_saving_and_loading/ +[hf-st]: https://github.com/huggingface/safetensors +[hf-st-support]: https://huggingface.co/docs/safetensors/index#featured-projects +[hf-st-torch]: https://huggingface.co/docs/safetensors/api/torch#safetensors.torch.save_file +[mlm-framework]: CHANGELOG.md#item-properties-and-collection-fields diff --git a/json-schema/schema.json b/json-schema/schema.json index 5a76347..c26307e 100644 --- a/json-schema/schema.json +++ b/json-schema/schema.json @@ -464,7 +464,10 @@ "torch.export.save", "tf.keras.Model.save", "tf.keras.Model.save_weights", - "tf.saved_model.export(format='tf_saved_model')" + "tf.keras.Model.export", + "safetensors.torch.save_file", + "safetensors.tensorflow.save_file", + "safetensors.paddle.save_file" ] }, "mlm:compile_method": { From ae921c90d26d1ed84b496b4f5b84e79e6a80d3d9 Mon Sep 17 00:00:00 2001 From: Francis Charette Migneault Date: Mon, 17 Feb 2025 21:51:23 -0500 Subject: [PATCH 2/3] fix md link --- best-practices.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/best-practices.md b/best-practices.md index 9807697..14d1235 100644 --- a/best-practices.md +++ b/best-practices.md @@ -325,4 +325,4 @@ names are not strictly required either. [hf-st]: https://github.com/huggingface/safetensors [hf-st-support]: https://huggingface.co/docs/safetensors/index#featured-projects [hf-st-torch]: https://huggingface.co/docs/safetensors/api/torch#safetensors.torch.save_file -[mlm-framework]: CHANGELOG.md#item-properties-and-collection-fields +[mlm-framework]: README.md#item-properties-and-collection-fields From 489c9f7b9bc97484020df8b1b48f430e3b685865 Mon Sep 17 00:00:00 2001 From: Francis Charette Migneault Date: Tue, 18 Feb 2025 13:31:58 -0500 Subject: [PATCH 3/3] add flax framework and its safetensors backend to JSON schema examples --- CHANGELOG.md | 4 +++- README.md | 1 + json-schema/schema.json | 2 ++ 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7973d8..51e3bd9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,8 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add [`huggingface/safetensors`](https://github.com/huggingface/safetensors) - recommendations for ``mlm:artifact_type`` and corresponding ``mlm:framework`` values + recommendations for `mlm:artifact_type` and corresponding ``mlm:framework`` values (fixes [#68](https://github.com/stac-extensions/mlm/issues/68)). +- Add [`Flax`](https://github.com/google/flax) to the list of `mlm:framework` and + the corresponding `mlm:artifact_type` SafeTensors backend in the JSON schema examples. - Add [`Paddle`](https://github.com/PaddlePaddle/Paddle) to the list of `mlm:framework` (fixes [#69](https://github.com/stac-extensions/mlm/issues/69)). diff --git a/README.md b/README.md index 800d7d0..b51c59c 100644 --- a/README.md +++ b/README.md @@ -238,6 +238,7 @@ to use common names when applicable. Below are a few notable entries. - [`rgee`](https://github.com/r-spatial/rgee) - [`spatialRF`](https://github.com/BlasBenito/spatialRF) - [`JAX`](https://github.com/jax-ml/jax) +- [`Flax`](https://github.com/google/flax) - [`MXNet`](https://github.com/apache/mxnet) - [`Caffe`](https://github.com/BVLC/caffe) - [`PyMC`](https://github.com/pymc-devs/pymc) diff --git a/json-schema/schema.json b/json-schema/schema.json index bc5cbfa..5a70356 100644 --- a/json-schema/schema.json +++ b/json-schema/schema.json @@ -436,6 +436,7 @@ "rgee", "spatialRF", "JAX", + "Flax", "MXNet", "Caffe", "PyMC", @@ -468,6 +469,7 @@ "tf.keras.Model.export", "safetensors.torch.save_file", "safetensors.tensorflow.save_file", + "safetensors.flax.save_file", "safetensors.paddle.save_file" ] },