From 24cf86fa1be86d1e09695e7514486239da26520a Mon Sep 17 00:00:00 2001 From: Ryan Soley Date: Fri, 12 Apr 2024 13:49:48 -0400 Subject: [PATCH] adds a `TagContainer` to handle key-value tags (#429) * add `TagContainer` * test `TagContainer` * linting & formatting * handle multiple keys * use new tags for parent/child management * update tagging notebook * more tests * fix test * linting & formatting * update return type --- notebooks/logging-examples/tagging.ipynb | 625 ++++++++++++++------ rubicon_ml/client/experiment.py | 13 +- rubicon_ml/client/mixin.py | 6 +- rubicon_ml/client/utils/tags.py | 27 + tests/unit/client/test_experiment_client.py | 8 + tests/unit/client/utils/test_tags.py | 32 +- 6 files changed, 531 insertions(+), 180 deletions(-) diff --git a/notebooks/logging-examples/tagging.ipynb b/notebooks/logging-examples/tagging.ipynb index c8717397..bacb3c60 100644 --- a/notebooks/logging-examples/tagging.ipynb +++ b/notebooks/logging-examples/tagging.ipynb @@ -2,265 +2,550 @@ "cells": [ { "cell_type": "markdown", - "id": "4768d0d8", + "id": "0f32303d-101b-48cc-9e18-3c7466b3ea2f", "metadata": {}, "source": [ "# Tagging\n", "\n", - "Sometimes we might want to tag **experiments** and objects with distinct values to organize\n", - "and filter them later on. For example, tags could be used to differentiate between\n", - "the type of model or classifier used during the **experiment** (i.e. `linear regression`\n", - "or `random forest`). Besides, **experiments**, ``rubicon_ml`` can tag artifacts, dataframes, \n", - "features, metrics, and parameters. \n", + "Tags can be used to group and indentify specific rubicon-ml entities by shared characteristics.\n", + "Any rubicon-ml entity can be tagged when logged with any number of tags. Later, tags can be leveraged\n", + "to query rubicon-ml logs during retrieval.\n", "\n", - "Below, we'll see examples of tagging functionality." - ] - }, - { - "cell_type": "markdown", - "id": "bfc54ec5", - "metadata": {}, - "source": [ - "### Adding tags when logging\n", - "By utilizing the tags parameter:" + "In general, a tag is any arbitrary string. rubicon-ml provides additonal functionality for tags that\n", + "follow a ``:`` format.\n", + "\n", + "## Logging with tags\n", + "\n", + "First, create a ``Rubicon`` entrypoint." ] }, { "cell_type": "code", "execution_count": 1, - "id": "59a475f6", + "id": "10393959-57dd-4a3d-8b1e-d7fffb24b421", "metadata": {}, "outputs": [], "source": [ "from rubicon_ml import Rubicon\n", - "import pandas as pd\n", "\n", "rubicon = Rubicon(persistence=\"memory\")\n", - "project = rubicon.get_or_create_project(\"Tagging\")\n", - "\n", - "#logging experiments with tags\n", - "experiment1 = project.log_experiment(name=\"experiment1\", tags=[\"odd_num_exp\"])\n", - "experiment2 = project.log_experiment(name=\"experiment2\", tags=[\"even_num_exp\"])\n", - "\n", - "#logging artifacts, dataframes, features, metrics and parameters with tags\n", - "first_artifact = experiment1.log_artifact(data_bytes=b\"bytes\", name=\"data_path\", tags=[\"data\"])\n", - "\n", - "confusion_matrix = pd.DataFrame([[5, 0, 0], [0, 5, 1], [0, 0, 4]], columns=[\"x\", \"y\", \"z\"])\n", - "first_dataframe = experiment1.log_dataframe(confusion_matrix, tags=[\"three_column\"])\n", - "\n", - "first_feature = experiment1.log_feature(\"year\", tags=[\"time\"])\n", - "\n", - "first_metric = experiment1.log_metric(\"accuracy\", .8, tags=[\"scalar\"])\n", - "\n", - "#can add multiple tags at logging (works for all objects)\n", - "first_parameter = experiment1.log_parameter(\"n_estimators\", tags=[\"tag1\", \"tag2\"])" + "project = rubicon.create_project(\"tagging\")" ] }, { "cell_type": "markdown", - "id": "61d0926a", + "id": "67307879-6526-4382-a503-ee94c49e9e74", "metadata": {}, "source": [ - "### Viewing tags\n", - "Use the .tags attribute to view tags associated with an object:" + "Now we'll log three experiments with tags \"a\" and \"b\"." ] }, { "cell_type": "code", "execution_count": 2, - "id": "064edb9c", + "id": "ae4c1a0b-6172-4250-ba55-c745d84c0813", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "['odd_num_exp']\n", - "['even_num_exp']\n", - "['data']\n", - "['three_column']\n", - "['time']\n", - "['scalar']\n", - "['tag1', 'tag2']\n" + "`experiment_a` ID: 09b2ff25-5152-4e82-9532-c1a09de65409, tags: ['a']\n", + "`experiment_b` ID: 932f21ff-e839-437f-a7e8-6c05a1186294, tags: ['b']\n", + "`experiment_c` ID: ef64c07f-a7ba-4248-bde2-a4a323a09428, tags: ['a', 'b']\n" ] } ], "source": [ - "print(experiment1.tags)\n", - "print(experiment2.tags)\n", - "print(first_artifact.tags)\n", - "print(first_dataframe.tags)\n", - "print(first_feature.tags)\n", - "print(first_metric.tags)\n", - "print(first_parameter.tags)" + "experiment_a = project.log_experiment(tags=[\"a\"])\n", + "experiment_b = project.log_experiment(tags=[\"b\"])\n", + "experiment_c = project.log_experiment(tags=[\"a\", \"b\"])\n", + "\n", + "print(f\"`experiment_a` ID: {experiment_a.id}, tags: {experiment_a.tags}\")\n", + "print(f\"`experiment_b` ID: {experiment_b.id}, tags: {experiment_b.tags}\")\n", + "print(f\"`experiment_c` ID: {experiment_c.id}, tags: {experiment_c.tags}\")" ] }, { "cell_type": "markdown", - "id": "86bda4bf", + "id": "df6289a5-4681-44df-953a-f1c1f0d39b87", "metadata": {}, "source": [ - "### Adding tags to existing objects\n", - "Use the object's add_tags() method. Works the same for all taggable objects. Here's an example:" + "Any other entity logged to an experiment can also be tagged." ] }, { "cell_type": "code", "execution_count": 3, - "id": "2e5bc3eb", + "id": "71913363-5b98-4e95-9cae-eafa05793a2c", + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "\n", + "artifact = experiment_a.log_artifact(\n", + " data_bytes=b\"artifact\", name=\"artifact\", tags=[\"c\"]\n", + ")\n", + "dataframe = experiment_a.log_dataframe(\n", + " df=pd.DataFrame([[0], [1]]), tags=[\"d\"]\n", + ")\n", + "feature = experiment_a.log_feature(name=\"var_0\", tags=[\"e\"])\n", + "parameter = experiment_a.log_parameter(name=\"input\", value=0, tags=[\"f\"])\n", + "metric = experiment_a.log_metric(name=\"output\", value=1, tags=[\"g\"])" + ] + }, + { + "cell_type": "markdown", + "id": "179f8230-414f-4881-9dc9-7b42f3de23b1", + "metadata": {}, + "source": [ + "## Retrieving with tags\n", + "\n", + "Each of the retrieval functions on a project or experiment (``experiments``, ``metrics``, etc.)\n", + "accept the ``tags`` and ``qtype`` (\"or\" or \"and\", default \"or\") arguments to filter retrieval.\n", + "\n", + "First, grab all the experiments with tag \"a\"." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "8db675d1-b5c1-4471-8f27-ad4741f8f130", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "['linear regression', 'odd_num_exp']\n", - "['even_num_exp', 'random forrest']\n", - "['data', 'added_tag']\n", - "['added_tag', 'three_column']\n", - "['time', 'added_tag']\n", - "['added_tag', 'scalar']\n", - "['added_tag2', 'tag1', 'tag2', 'added_tag1']\n" - ] + "data": { + "text/plain": [ + "['09b2ff25-5152-4e82-9532-c1a09de65409',\n", + " 'ef64c07f-a7ba-4248-bde2-a4a323a09428']" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "experiment1.add_tags([\"linear regression\"])\n", - "experiment2.add_tags([\"random forrest\"])\n", - "first_artifact.add_tags([\"added_tag\"])\n", - "first_dataframe.add_tags([\"added_tag\"])\n", - "first_feature.add_tags([\"added_tag\"])\n", - "first_metric.add_tags([\"added_tag\"])\n", - "\n", - "#can add multiple tags (works for all objects)\n", - "first_parameter.add_tags([\"added_tag1\", \"added_tag2\"])\n", - "\n", - "\n", - "print(experiment1.tags)\n", - "print(experiment2.tags)\n", - "print(first_artifact.tags)\n", - "print(first_dataframe.tags)\n", - "print(first_feature.tags)\n", - "print(first_metric.tags)\n", - "print(first_parameter.tags)" + "[e.id for e in project.experiments(tags=[\"a\"])]" ] }, { "cell_type": "markdown", - "id": "2527eb1a", + "id": "59118c51-be5f-4d8a-89be-c18dcbf411eb", "metadata": {}, "source": [ - "### Removing tags from existing objects\n", - "Use the object's remove_tags() method. Works the same for all taggable objects. Here's an example:" + "Next, get each experiment with tag \"b\". Note that the final experiment is the same as the last\n", + "output since it has both tags \"a\" and \"b\"." ] }, { "cell_type": "code", - "execution_count": 4, - "id": "356a6089", + "execution_count": 5, + "id": "e3e93248-30f5-4c60-bdea-752b25cc9057", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "['odd_num_exp']\n", - "['even_num_exp']\n", - "['data']\n", - "['three_column']\n", - "['time']\n", - "['scalar']\n", - "['tag2', 'tag1']\n" - ] + "data": { + "text/plain": [ + "['932f21ff-e839-437f-a7e8-6c05a1186294',\n", + " 'ef64c07f-a7ba-4248-bde2-a4a323a09428']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "experiment1.remove_tags([\"linear regression\"])\n", - "experiment2.remove_tags([\"random forrest\"])\n", - "first_artifact.remove_tags([\"added_tag\"])\n", - "first_dataframe.remove_tags([\"added_tag\"])\n", - "first_feature.remove_tags([\"added_tag\"])\n", - "first_metric.remove_tags([\"added_tag\"])\n", - "\n", - "#can remove multiple tags (works for all objects)\n", - "first_parameter.remove_tags([\"added_tag2\", \"added_tag1\"])\n", - "\n", - "print(experiment1.tags)\n", - "print(experiment2.tags)\n", - "print(first_artifact.tags)\n", - "print(first_dataframe.tags)\n", - "print(first_feature.tags)\n", - "print(first_metric.tags)\n", - "print(first_parameter.tags)" + "[e.id for e in project.experiments(tags=[\"b\"])]" ] }, { "cell_type": "markdown", - "id": "c8edad59", + "id": "ae7007b0-425f-45f0-8fa2-675ecafedb80", "metadata": {}, "source": [ - "### Retreiving objects by their tags\n", - "After logging objects, here's how we can include tags as a paramter to filter our results. We can specify the `qtype` parameter to change the search type to \"and\" from \"or\" (default). Here this is only shown with experiments, but works for any taggable object when doing parentObject.retrievalObjects():" + "Querying with multiple tags uses a logical _or_ to return results by default." ] }, { "cell_type": "code", - "execution_count": 5, - "id": "43cfa3ed", + "execution_count": 6, + "id": "e5664866-bef5-48a0-acc3-281d40a618d6", "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "old experiments: experiment1, experiment2\n", - "\n", - "new experiments: experiment3\n", - "\n", - "odd experiments: experiment1, experiment3\n", - "\n", - "same experiments: experiment1, experiment3\n", - "\n", - "expected experiment: experiment3\n", - "\n" - ] + "data": { + "text/plain": [ + "['09b2ff25-5152-4e82-9532-c1a09de65409',\n", + " '932f21ff-e839-437f-a7e8-6c05a1186294',\n", + " 'ef64c07f-a7ba-4248-bde2-a4a323a09428']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "experiment1.add_tags([\"old_exp\"])\n", - "experiment2.add_tags([\"old_exp\"])\n", - "experiment3 = project.log_experiment(name=\"experiment3\", tags=[\"odd_num_exp\",\"new_exp\"])\n", - "\n", - "#want just old experiments\n", - "old_experiments = project.experiments(tags=[\"old_exp\"])\n", - "\n", - "#want just new experiments\n", - "new_experiments = project.experiments(tags=[\"new_exp\"])\n", - "\n", - "#want just the odd number experiments\n", - "odd_experiments = project.experiments(tags=[\"odd_num_exp\"])\n", - "\n", - "#this will return the same result as above since qtype=\"or\" by default\n", - "same_experiments = project.experiments(tags=[\"odd_num_exp\", \"new_exp\"])\n", - "\n", - "#this will return just experiment3\n", - "expected_experiment = project.experiments(tags=[\"odd_num_exp\", \"new_exp\"], qtype=\"and\")\n", - "\n", - "\n", - "#getting both the old experiments 1 and 2\n", - "print(\"old experiments: \" + str(old_experiments[0].name) + \", \" + str(old_experiments[1].name) + \"\\n\")\n", + "[e.id for e in project.experiments(tags=[\"a\", \"b\"])]" + ] + }, + { + "cell_type": "markdown", + "id": "03298720-9017-4314-8c86-855ab8be8e32", + "metadata": {}, + "source": [ + "This can be switched to a logical _and_ with the ``qtype`` argument." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "76dec4ef-8ced-4f5a-945c-6a0e88ea8f42", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['ef64c07f-a7ba-4248-bde2-a4a323a09428']" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[e.id for e in project.experiments(tags=[\"a\", \"b\"], qtype=\"and\")]" + ] + }, + { + "cell_type": "markdown", + "id": "42e9652c-9264-4cdd-b230-0fd19345100e", + "metadata": {}, + "source": [ + "## Updating tags\n", "\n", - "#getting just the new experiment 3\n", - "print(\"new experiments: \" + str(new_experiments[0].name) + \"\\n\")\n", + "Tags can be update later, after logging as well." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1fd16d81-e94c-408e-9275-623a35f48af8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['a', 'b']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment_c.tags" + ] + }, + { + "cell_type": "markdown", + "id": "f0564014-5aa0-40ca-861d-9ff485e95b8c", + "metadata": {}, + "source": [ + "`add_tags` adds any number of new tags to an existing entity. Each entity that allows\n", + "tagging will have both the ``add_tags`` and ``remove_tags`` functions." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b8c55a1d-9955-4bc2-9384-117b3444a968", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['i', 'h', 'a', 'b']" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment_c.add_tags([\"h\", \"i\"])\n", + "experiment_c.tags" + ] + }, + { + "cell_type": "markdown", + "id": "7f4fe05e-d4cf-41ad-b3bc-ec78e0207c6b", + "metadata": {}, + "source": [ + "Removal works similarly." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "4fe6eba0-d58e-4bb9-81ed-886bfc611506", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['i', 'h']" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment_c.remove_tags([\"a\", \"b\"])\n", + "experiment_c.tags" + ] + }, + { + "cell_type": "markdown", + "id": "7815fb91-f310-415c-9590-d26d0c9fbda8", + "metadata": {}, + "source": [ + "Now, the same query from above for an experiment with tags \"a\" and \"b\" returns no results." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "a9be6033-775c-4d11-b9d6-55364474eb87", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[e.id for e in project.experiments(tags=[\"a\", \"b\"], qtype=\"and\")]" + ] + }, + { + "cell_type": "markdown", + "id": "1394c855-6fd4-426b-9fce-dae44caee95b", + "metadata": {}, + "source": [ + "## Key-value tags\n", "\n", - "#getting both odd experiments 1 and 3\n", - "print(\"odd experiments: \" + str(odd_experiments[0].name) + \", \" + str(odd_experiments[1].name) + \"\\n\")\n", + "rubicon-ml provides extended support for tags that follow the ``:`` format." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "d68b29d3-436b-476b-9cc3-cf8d4c7e14ac", + "metadata": {}, + "outputs": [], + "source": [ + "experiment_d = project.log_experiment(tags=[\"j:k\"])\n", + "experiment_e = project.log_experiment(tags=[\"l:m\", \"l:n\"])" + ] + }, + { + "cell_type": "markdown", + "id": "aa040516-3cbf-4dbd-8100-0f0f6f4d23b6", + "metadata": {}, + "source": [ + "The list returned by the `tags` property of any entity can be indexed into like a\n", + "regular list to retrieve the full tags, just like with normal tags." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "3b9ff6f5-b90b-4870-8ffc-0ffe74a8c0f8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'j:k'" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment_d.tags[0]" + ] + }, + { + "cell_type": "markdown", + "id": "cc7f9837-bedf-48d8-ac18-2479a0bc2df4", + "metadata": {}, + "source": [ + "But it also supports string indexing, like a dictionary. To retrieve the value of a\n", + "key-value tag, just index into the `tags` property with its key." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "622925d8-26fd-4961-b7d5-0625544d736c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'k'" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment_d.tags[\"j\"]" + ] + }, + { + "cell_type": "markdown", + "id": "726f8cd6-2f8c-44d4-a264-a7106bda49e0", + "metadata": {}, + "source": [ + "If there are multiple keys, a list containing each value will be returned." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "07328cb3-9947-4fe9-8359-244530016a7e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['m', 'n']" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment_e.tags[\"l\"]" + ] + }, + { + "cell_type": "markdown", + "id": "26f96aa5-cc31-4adf-86c5-d1b608741259", + "metadata": {}, + "source": [ + "### Managing experiment relationships\n", "\n", - "#again getting both experiments 1 and 3\n", - "print(\"same experiments: \" + str(same_experiments[0].name) + \", \" + str(same_experiments[1].name) + \"\\n\")\n", + "A common use for key-value tags is managing relationships between experiments. rubicon-ml\n", + "has built-in support for managing such relationships in this manner." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "ab2c5a29-b8ee-45e7-b4b9-6dc22c3abdd7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['a',\n", + " 'child:f080134a-b118-4ac4-b400-25fc097366a8',\n", + " 'child:9443daa7-0bee-4af8-8f31-79a85017bcd5']" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment_a.add_child_experiment(experiment_d)\n", + "experiment_a.add_child_experiment(experiment_e)\n", "\n", - "#getting just experiment 3\n", - "print(\"expected experiment: \" + str(expected_experiment[0].name) + \"\\n\")" + "experiment_a.tags" + ] + }, + { + "cell_type": "markdown", + "id": "7c3367cc-77f3-4c20-b7ab-5029e64b8452", + "metadata": {}, + "source": [ + "The experiment IDs themselves can be retrieved by indexing into the tags with the \"child\" key." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "b3c04d92-6f52-4c19-ab0e-6907b4f8ab17", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['f080134a-b118-4ac4-b400-25fc097366a8',\n", + " '9443daa7-0bee-4af8-8f31-79a85017bcd5']" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment_a.tags[\"child\"]" + ] + }, + { + "cell_type": "markdown", + "id": "5aae42a9-2ebd-4444-98ff-0127f3ed737c", + "metadata": {}, + "source": [ + "From there, we can use the IDs grab the entire experiments from the original project." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "876a3358-2277-47b4-910f-1a1db89bda9d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[,\n", + " ]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[project.experiment(id=exp_id) for exp_id in experiment_a.tags[\"child\"]]" ] } ], @@ -280,7 +565,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.14" } }, "nbformat": 4, diff --git a/rubicon_ml/client/experiment.py b/rubicon_ml/client/experiment.py index 8f1185e4..e555b1b0 100644 --- a/rubicon_ml/client/experiment.py +++ b/rubicon_ml/client/experiment.py @@ -451,14 +451,15 @@ def _get_experiments_from_tags(self, tag_key: str): list of rubicon_ml.client.Experiment The experiments with `experiment_id`s in this experiment's tags. """ - experiments = [] + try: + experiment_ids = self.tags[tag_key] + except KeyError: + return [] - for tag in self.tags: - if f"{tag_key}:" in tag: - experiment_id = tag.split(":")[-1] - experiments.append(self.project.experiment(id=experiment_id)) + if not isinstance(experiment_ids, list): + experiment_ids = [experiment_ids] - return experiments + return [self.project.experiment(id=exp_id) for exp_id in experiment_ids] def get_child_experiments(self) -> List[Experiment]: """Get the experiments that are tagged as children of this experiment. diff --git a/rubicon_ml/client/mixin.py b/rubicon_ml/client/mixin.py index 797e6e3c..48c29bc5 100644 --- a/rubicon_ml/client/mixin.py +++ b/rubicon_ml/client/mixin.py @@ -16,7 +16,7 @@ from rubicon_ml import client, domain from rubicon_ml.client.utils.exception_handling import failsafe -from rubicon_ml.client.utils.tags import filter_children +from rubicon_ml.client.utils.tags import TagContainer, filter_children from rubicon_ml.domain import Artifact as ArtifactDomain from rubicon_ml.exceptions import RubiconException @@ -687,7 +687,7 @@ def _update_tags(self, tag_data): self._domain.remove_tags(tag.get("removed_tags", [])) @property - def tags(self) -> List[str]: + def tags(self) -> TagContainer: """Get this client object's tags.""" project_name, experiment_id, entity_identifier = self._get_taggable_identifiers() return_err = None @@ -704,7 +704,7 @@ def tags(self) -> List[str]: else: self._update_tags(tag_data) - return self._domain.tags + return TagContainer(self._domain.tags) self._raise_rubicon_exception(return_err) diff --git a/rubicon_ml/client/utils/tags.py b/rubicon_ml/client/utils/tags.py index bb444bce..059e10c6 100644 --- a/rubicon_ml/client/utils/tags.py +++ b/rubicon_ml/client/utils/tags.py @@ -1,6 +1,33 @@ from typing import List +class TagContainer(list): + """List-based container for tags that allows indexing into tags + with colons in them by string, like a dictionary. + """ + + def __getitem__(self, index_or_key): + if isinstance(index_or_key, str): + values = [] + + for tag in self: + key_value = tag.split(":", 1) + + if len(key_value) > 1 and key_value[0] == index_or_key: + values.append(key_value[1]) + + if len(values) == 0: + raise KeyError(index_or_key) + elif len(values) == 1: + return values[0] + else: + return values + else: + item = super().__getitem__(index_or_key) + + return TagContainer(item) if isinstance(item, list) else item + + def has_tag_requirements(tags: List[str], required_tags: List[str], qtype: str) -> bool: """Returns True if `tags` meets the requirements based on the values of `required_tags` and `qtype`. False otherwise. diff --git a/tests/unit/client/test_experiment_client.py b/tests/unit/client/test_experiment_client.py index 90a70fc6..cbdbf1c6 100644 --- a/tests/unit/client/test_experiment_client.py +++ b/tests/unit/client/test_experiment_client.py @@ -440,3 +440,11 @@ def test_get_parent_experiments(project_client): parent.add_child_experiment(child) assert child.get_parent_experiments()[0].id == parent.id + + +def test_get_relative_experiments_none(project_client): + project = project_client + experiment = project.log_experiment() + + assert experiment.get_child_experiments() == [] + assert experiment.get_parent_experiments() == [] diff --git a/tests/unit/client/utils/test_tags.py b/tests/unit/client/utils/test_tags.py index f48b4bf8..873058dd 100644 --- a/tests/unit/client/utils/test_tags.py +++ b/tests/unit/client/utils/test_tags.py @@ -1,4 +1,6 @@ -from rubicon_ml.client.utils.tags import has_tag_requirements +import pytest + +from rubicon_ml.client.utils.tags import TagContainer, has_tag_requirements def test_or_single_success(): @@ -27,3 +29,31 @@ def test_and_single_failure(): def test_and_multiple_failure(): assert not has_tag_requirements(["x", "y", "z"], ["a", "z"], "and") + + +def test_tag_container(): + tags = TagContainer(["a", "b:c", "d:e", "d:f"]) + + assert tags[0] == "a" + assert tags[1] == "b:c" + assert tags[2] == "d:e" + assert tags[3] == "d:f" + assert tags["b"] == "c" + assert tags["d"] == ["e", "f"] + + +def test_tag_container_nested(): + tags = TagContainer(["a", "b:c", "d:e", "d:f"]) + + assert tags[1:] == ["b:c", "d:e", "d:f"] + assert tags[1:2]["b"] == "c" + assert tags[2:4]["d"] == ["e", "f"] + + +def test_tag_container_errors(): + tags = TagContainer([]) + + with pytest.raises(KeyError) as error: + tags["missing"] + + assert "KeyError('missing')" in str(error)