Skip to content

Commit

Permalink
[feat] Add huggingface/datasets integration (#2454)
Browse files Browse the repository at this point in the history
* [feat] Add huggingface/datasets integration

* [fix] Apply black formatting

* [fix] Make get value actions safe

* [fix] Fix flake8 errors

* [fix] Check if keys exist in dataset info

* [fix] Update class name
  • Loading branch information
tamohannes authored Jan 13, 2023
1 parent 81b4ece commit 4e04e35
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Fix plotly and matplotlib compatibility (tmynn)
- Add Stable-Baselines3 integration (tmynn)
- Add Acme integration (tmynn)
- Add huggingface/datasets integration (tmynn)

### Fixes

Expand Down
2 changes: 2 additions & 0 deletions aim/hf_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Alias to SDK Hugging Face Datasets interface
from aim.sdk.objects.plugins.hf_datasets_metadata import HFDataset # noqa F401
73 changes: 73 additions & 0 deletions aim/sdk/objects/plugins/hf_datasets_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from datasets import DatasetDict
from aim.storage.object import CustomObject
from logging import getLogger

logger = getLogger(__name__)


@CustomObject.alias("hf_datasets.metadata")
class HFDataset(CustomObject):
AIM_NAME = "hf_datasets.metadata"

def __init__(self, dataset: DatasetDict):
super().__init__()
self.storage["dataset"] = {
"source": "huggingface_datasets",
"meta": self._get_ds_meta(dataset),
}

def _get_ds_meta(self, dataset: DatasetDict):
dataset_info = vars(dataset[list(dataset.keys())[0]]._info)

return {
"description": dataset_info.get("description"),
"citation": dataset_info.get("citation"),
"homepage": dataset_info.get("homepage"),
"license": dataset_info.get("license"),
"features": self._get_features(dataset_info),
"post_processed": str(dataset_info.get("post_processed")),
"supervised_keys": str(dataset_info.get("supervised_keys")),
"task_templates": self._get_task_templates(dataset_info),
"builder_name": dataset_info.get("builder_name"),
"config_name": dataset_info.get("config_name"),
"version": str(dataset_info.get("version")),
"splits": self._get_splits(dataset_info),
"download_checksums": dataset_info.get("download_checksums"),
"download_size": dataset_info.get("download_size"),
"post_processing_size": dataset_info.get("post_processing_size"),
"dataset_size": dataset_info.get("dataset_size"),
"size_in_bytes": dataset_info.get("size_in_bytes"),
}

def _get_features(self, dataset_info):
try:
if dataset_info.get("features"):
return [
{feature: str(dataset_info.get("features")[feature])}
for feature in dataset_info.get("features").keys()
]
except LookupError:
logger.warning("Failed to get features information")

def _get_task_templates(self, dataset_info):
try:
if dataset_info.get("task_templates"):
return [str(template) for template in dataset_info.get("task_templates")]
except LookupError:
logger.warning("Failed to get task templates information")

def _get_splits(self, dataset_info):
try:
if dataset_info.get("splits"):
return [
{
subset: {
"num_bytes": dataset_info.get("splits")[subset].num_bytes,
"num_examples": dataset_info.get("splits")[subset].num_examples,
"dataset_name": dataset_info.get("splits")[subset].dataset_name,
}
}
for subset in dataset_info.get("splits")
]
except LookupError:
logger.warning("Failed to get splits information")
19 changes: 19 additions & 0 deletions docs/source/quick_start/supported_types.md
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,22 @@ If we apply our previous code snippet on the same repo - we can observe the same
]
}
```
### Logging huggingface/datasets dataset info with Aim
Aim provides wrapper object for `datasets`. It allows to store the dataset info as a `Run`
parameter and retrieve it later just as any other `Run` param. Here is an example of using Aim to log dataset info:
```python
from datasets import load_dataset
from aim import Run
from aim.hf_dataset import HFDataset
# create dataset object
dataset = load_dataset('rotten_tomatoes')
# store dataset metadata
run = Run()
run['datasets_info'] = HFDataset(dataset)
```
33 changes: 33 additions & 0 deletions tests/integrations/test_hf_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest

from tests.base import TestBase
from tests.utils import is_package_installed


class TestHFDatasetsIntegration(TestBase):
@pytest.mark.skipif(
not is_package_installed("datasets"),
reason="'datasets' is not installed. skipping.",
)
def test_datasets_as_run_param(self):
from datasets import load_dataset

from aim.sdk.objects.plugins.hf_datasets_metadata import HFDataset
from aim.sdk import Run

# create dataset object
dataset = load_dataset("rotten_tomatoes")

# log dataset metadata
# log dataset metadata
run = Run(repo=".hf_datasets", system_tracking_interval=None)
run["datasets_info"] = HFDataset(dataset)

# get dataset metadata
ds_object = run["datasets_info"]
ds_dict = run.get("datasets_info", resolve_objects=True)

self.assertTrue(isinstance(ds_object, HFDataset))
self.assertTrue(isinstance(ds_dict, dict))
self.assertIn("meta", ds_dict["dataset"].keys())
self.assertIn("source", ds_dict["dataset"].keys())

0 comments on commit 4e04e35

Please sign in to comment.