diff --git a/notebooks/logging-examples/tagging.ipynb b/notebooks/logging-examples/tagging.ipynb index bacb3c60..fcd24f21 100644 --- a/notebooks/logging-examples/tagging.ipynb +++ b/notebooks/logging-examples/tagging.ipynb @@ -50,16 +50,16 @@ "name": "stdout", "output_type": "stream", "text": [ - "`experiment_a` ID: 09b2ff25-5152-4e82-9532-c1a09de65409, tags: ['a']\n", - "`experiment_b` ID: 932f21ff-e839-437f-a7e8-6c05a1186294, tags: ['b']\n", - "`experiment_c` ID: ef64c07f-a7ba-4248-bde2-a4a323a09428, tags: ['a', 'b']\n" + "`experiment_a` ID: 7e08c0bf-7f88-46d1-89de-4e0da7e2a448, tags: ['tag_a']\n", + "`experiment_b` ID: 75e54061-5910-43a0-a036-c5a6bdd77ca1, tags: ['other_tag_a', 'tag_b']\n", + "`experiment_c` ID: cc09ea5c-18df-48b1-888e-e692f5d9e71a, tags: ['tag_a', 'tag_b']\n" ] } ], "source": [ - "experiment_a = project.log_experiment(tags=[\"a\"])\n", - "experiment_b = project.log_experiment(tags=[\"b\"])\n", - "experiment_c = project.log_experiment(tags=[\"a\", \"b\"])\n", + "experiment_a = project.log_experiment(tags=[\"tag_a\"])\n", + "experiment_b = project.log_experiment(tags=[\"other_tag_a\", \"tag_b\"])\n", + "experiment_c = project.log_experiment(tags=[\"tag_a\", \"tag_b\"])\n", "\n", "print(f\"`experiment_a` ID: {experiment_a.id}, tags: {experiment_a.tags}\")\n", "print(f\"`experiment_b` ID: {experiment_b.id}, tags: {experiment_b.tags}\")\n", @@ -84,14 +84,14 @@ "import pandas as pd\n", "\n", "artifact = experiment_a.log_artifact(\n", - " data_bytes=b\"artifact\", name=\"artifact\", tags=[\"c\"]\n", + " data_bytes=b\"artifact\", name=\"artifact\", tags=[\"tag_c\"]\n", ")\n", "dataframe = experiment_a.log_dataframe(\n", - " df=pd.DataFrame([[0], [1]]), tags=[\"d\"]\n", + " df=pd.DataFrame([[0], [1]]), tags=[\"tag_d\"]\n", ")\n", - "feature = experiment_a.log_feature(name=\"var_0\", tags=[\"e\"])\n", - "parameter = experiment_a.log_parameter(name=\"input\", value=0, tags=[\"f\"])\n", - "metric = experiment_a.log_metric(name=\"output\", value=1, tags=[\"g\"])" + "feature = experiment_a.log_feature(name=\"var_0\", tags=[\"tag_e\"])\n", + "parameter = experiment_a.log_parameter(name=\"input\", value=0, tags=[\"tag_f\"])\n", + "metric = experiment_a.log_metric(name=\"output\", value=1, tags=[\"tag_g\"])" ] }, { @@ -116,8 +116,8 @@ { "data": { "text/plain": [ - "['09b2ff25-5152-4e82-9532-c1a09de65409',\n", - " 'ef64c07f-a7ba-4248-bde2-a4a323a09428']" + "[\"7e08c0bf-7f88-46d1-89de-4e0da7e2a448: ['tag_a']\",\n", + " \"cc09ea5c-18df-48b1-888e-e692f5d9e71a: ['tag_a', 'tag_b']\"]" ] }, "execution_count": 4, @@ -126,7 +126,7 @@ } ], "source": [ - "[e.id for e in project.experiments(tags=[\"a\"])]" + "[f\"{e.id}: {e.tags}\" for e in project.experiments(tags=[\"tag_a\"])]" ] }, { @@ -147,8 +147,8 @@ { "data": { "text/plain": [ - "['932f21ff-e839-437f-a7e8-6c05a1186294',\n", - " 'ef64c07f-a7ba-4248-bde2-a4a323a09428']" + "[\"75e54061-5910-43a0-a036-c5a6bdd77ca1: ['other_tag_a', 'tag_b']\",\n", + " \"cc09ea5c-18df-48b1-888e-e692f5d9e71a: ['tag_a', 'tag_b']\"]" ] }, "execution_count": 5, @@ -157,7 +157,7 @@ } ], "source": [ - "[e.id for e in project.experiments(tags=[\"b\"])]" + "[f\"{e.id}: {e.tags}\" for e in project.experiments(tags=[\"tag_b\"])]" ] }, { @@ -177,9 +177,9 @@ { "data": { "text/plain": [ - "['09b2ff25-5152-4e82-9532-c1a09de65409',\n", - " '932f21ff-e839-437f-a7e8-6c05a1186294',\n", - " 'ef64c07f-a7ba-4248-bde2-a4a323a09428']" + "[\"7e08c0bf-7f88-46d1-89de-4e0da7e2a448: ['tag_a']\",\n", + " \"75e54061-5910-43a0-a036-c5a6bdd77ca1: ['other_tag_a', 'tag_b']\",\n", + " \"cc09ea5c-18df-48b1-888e-e692f5d9e71a: ['tag_a', 'tag_b']\"]" ] }, "execution_count": 6, @@ -188,7 +188,7 @@ } ], "source": [ - "[e.id for e in project.experiments(tags=[\"a\", \"b\"])]" + "[f\"{e.id}: {e.tags}\" for e in project.experiments(tags=[\"tag_a\", \"tag_b\"])]" ] }, { @@ -208,7 +208,7 @@ { "data": { "text/plain": [ - "['ef64c07f-a7ba-4248-bde2-a4a323a09428']" + "[\"cc09ea5c-18df-48b1-888e-e692f5d9e71a: ['tag_a', 'tag_b']\"]" ] }, "execution_count": 7, @@ -217,7 +217,70 @@ } ], "source": [ - "[e.id for e in project.experiments(tags=[\"a\", \"b\"], qtype=\"and\")]" + "[f\"{e.id}: {e.tags}\" for e in project.experiments(tags=[\"tag_a\", \"tag_b\"], qtype=\"and\")]" + ] + }, + { + "cell_type": "markdown", + "id": "6099dacf-51f1-4f08-98e1-a62d6e19d7f9", + "metadata": {}, + "source": [ + "### Wildcards\n", + "\n", + "Retrieval by tags also supports wildcards (``*``) while querying." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8a3ee1b5-0340-481a-8d9e-151caff8e7f8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[\"7e08c0bf-7f88-46d1-89de-4e0da7e2a448: ['tag_a']\",\n", + " \"75e54061-5910-43a0-a036-c5a6bdd77ca1: ['other_tag_a', 'tag_b']\",\n", + " \"cc09ea5c-18df-48b1-888e-e692f5d9e71a: ['tag_a', 'tag_b']\"]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[f\"{e.id}: {e.tags}\" for e in project.experiments(tags=[\"*_a\"])]" + ] + }, + { + "cell_type": "markdown", + "id": "a5339a13-f08e-4680-be9b-0bb2bf68607c", + "metadata": {}, + "source": [ + "Multiple wildcards can be used in a single query. A single wildcard character will match any number of\n", + "characters in the tag." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "bb81395a-1b9a-48f9-887f-67f549813226", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[\"75e54061-5910-43a0-a036-c5a6bdd77ca1: ['other_tag_a', 'tag_b']\"]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[f\"{e.id}: {e.tags}\" for e in project.experiments(tags=[\"*_*_*\"])]" ] }, { @@ -232,17 +295,17 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "id": "1fd16d81-e94c-408e-9275-623a35f48af8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['a', 'b']" + "['tag_a', 'tag_b']" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -262,23 +325,23 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "id": "b8c55a1d-9955-4bc2-9384-117b3444a968", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['i', 'h', 'a', 'b']" + "['tag_i', 'tag_h', 'tag_b', 'tag_a']" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "experiment_c.add_tags([\"h\", \"i\"])\n", + "experiment_c.add_tags([\"tag_h\", \"tag_i\"])\n", "experiment_c.tags" ] }, @@ -292,23 +355,23 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "id": "4fe6eba0-d58e-4bb9-81ed-886bfc611506", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['i', 'h']" + "['tag_i', 'tag_h']" ] }, - "execution_count": 10, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "experiment_c.remove_tags([\"a\", \"b\"])\n", + "experiment_c.remove_tags([\"tag_a\", \"tag_b\"])\n", "experiment_c.tags" ] }, @@ -322,7 +385,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "id": "a9be6033-775c-4d11-b9d6-55364474eb87", "metadata": {}, "outputs": [ @@ -332,13 +395,13 @@ "[]" ] }, - "execution_count": 11, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "[e.id for e in project.experiments(tags=[\"a\", \"b\"], qtype=\"and\")]" + "[f\"{e.id}: {e.tags}\" for e in project.experiments(tags=[\"tag_a\", \"tag_b\"], qtype=\"and\")]" ] }, { @@ -353,13 +416,13 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "id": "d68b29d3-436b-476b-9cc3-cf8d4c7e14ac", "metadata": {}, "outputs": [], "source": [ - "experiment_d = project.log_experiment(tags=[\"j:k\"])\n", - "experiment_e = project.log_experiment(tags=[\"l:m\", \"l:n\"])" + "experiment_d = project.log_experiment(tags=[\"tag_j:k\"])\n", + "experiment_e = project.log_experiment(tags=[\"tag_j:l\", \"tag_m:n\", \"tag_m:o\"])" ] }, { @@ -373,17 +436,17 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "id": "3b9ff6f5-b90b-4870-8ffc-0ffe74a8c0f8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'j:k'" + "'tag_j:k'" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -403,7 +466,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "id": "622925d8-26fd-4961-b7d5-0625544d736c", "metadata": {}, "outputs": [ @@ -413,13 +476,13 @@ "'k'" ] }, - "execution_count": 14, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "experiment_d.tags[\"j\"]" + "experiment_d.tags[\"tag_j\"]" ] }, { @@ -432,23 +495,53 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "id": "07328cb3-9947-4fe9-8359-244530016a7e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['m', 'n']" + "['n', 'o']" ] }, - "execution_count": 15, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "experiment_e.tags[\"tag_m\"]" + ] + }, + { + "cell_type": "markdown", + "id": "a85a3307-eb62-4bd1-b0eb-9367f6edf164", + "metadata": {}, + "source": [ + "Combine key-value tags and wildcards to examine the value of _\"tag_j\"_ on every experiment that has one." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "334e5096-b816-4981-b6c2-6e741a323a3e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['0962a69f-1db5-4f42-884c-60ad179bdb5c: k',\n", + " 'd5c9d775-e546-4b02-93ac-f5cee6d17aea: l']" + ] + }, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "experiment_e.tags[\"l\"]" + "[f\"{e.id}: {e.tags['tag_j']}\" for e in project.experiments(tags=[\"tag_j:*\"])]" ] }, { @@ -464,19 +557,19 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 19, "id": "ab2c5a29-b8ee-45e7-b4b9-6dc22c3abdd7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['a',\n", - " 'child:f080134a-b118-4ac4-b400-25fc097366a8',\n", - " 'child:9443daa7-0bee-4af8-8f31-79a85017bcd5']" + "['child:0962a69f-1db5-4f42-884c-60ad179bdb5c',\n", + " 'child:d5c9d775-e546-4b02-93ac-f5cee6d17aea',\n", + " 'tag_a']" ] }, - "execution_count": 16, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -493,23 +586,25 @@ "id": "7c3367cc-77f3-4c20-b7ab-5029e64b8452", "metadata": {}, "source": [ - "The experiment IDs themselves can be retrieved by indexing into the tags with the \"child\" key." + "Now let's say we've only been given `experiment_a` and we don't know anything about its children or how they were logged.\n", + "\n", + "The child experiment IDs themselves can be retrieved by indexing into the tags with the \"child\" key." ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 20, "id": "b3c04d92-6f52-4c19-ab0e-6907b4f8ab17", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "['f080134a-b118-4ac4-b400-25fc097366a8',\n", - " '9443daa7-0bee-4af8-8f31-79a85017bcd5']" + "['0962a69f-1db5-4f42-884c-60ad179bdb5c',\n", + " 'd5c9d775-e546-4b02-93ac-f5cee6d17aea']" ] }, - "execution_count": 17, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -523,23 +618,23 @@ "id": "5aae42a9-2ebd-4444-98ff-0127f3ed737c", "metadata": {}, "source": [ - "From there, we can use the IDs grab the entire experiments from the original project." + "From there, we can use the IDs grab the complete child experiments from the original project." ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 21, "id": "876a3358-2277-47b4-910f-1a1db89bda9d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[,\n", - " ]" + "[,\n", + " ]" ] }, - "execution_count": 18, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } diff --git a/rubicon_ml/client/utils/tags.py b/rubicon_ml/client/utils/tags.py index 059e10c6..ab64d229 100644 --- a/rubicon_ml/client/utils/tags.py +++ b/rubicon_ml/client/utils/tags.py @@ -1,12 +1,14 @@ -from typing import List +import re +from typing import List, Optional, Union class TagContainer(list): - """List-based container for tags that allows indexing into tags - with colons in them by string, like a dictionary. + """List-based container for tags. + + Allows indexing into tags with colons in them by string, like a dictionary. """ - def __getitem__(self, index_or_key): + def __getitem__(self, index_or_key: Union[int, str]): if isinstance(index_or_key, str): values = [] @@ -29,33 +31,38 @@ def __getitem__(self, index_or_key): def has_tag_requirements(tags: List[str], required_tags: List[str], qtype: str) -> bool: - """Returns True if `tags` meets the requirements based on - the values of `required_tags` and `qtype`. False otherwise. + """Determines if the `required_tags` are in `tags`. + + Returns True if all `required_tags` are in `tags` and `qtype` is "and" or if + any `required_tags` are in `tags` and `qtype` is "or". The tags in + `required_tags` may contain contain wildcard (*) characters. """ - has_tag_requirements = False + qtype_func = any if qtype == "or" else all - tag_intersection = set(required_tags).intersection(set(tags)) - if qtype == "or": - if len(tag_intersection) > 0: - has_tag_requirements = True - if qtype == "and": - if len(tag_intersection) == len(required_tags): - has_tag_requirements = True + if any(["*" in tag for tag in required_tags]): - return has_tag_requirements + def _wildcard_match(pattern, tag): + return re.match(f"^{pattern.replace('*', '.*')}$", tag) is not None + return qtype_func( + [ + any([_wildcard_match(required_tag, tag) for tag in tags]) + for required_tag in required_tags + ] + ) + else: + return qtype_func(tag in tags for tag in required_tags) -def filter_children(children, tags, qtype, name): - """Filters the provided rubicon objects by `tags` using - query type `qtype` and by `name`. - """ - filtered_children = children +def filter_children(children, tags: List[str], qtype: str, name: Optional[str]): + """Return the children in `children` with the given tags or name. + + If both are provided, children are first filtered by tags and then by names. + """ if len(tags) > 0: - filtered_children = [ - c for c in filtered_children if has_tag_requirements(c.tags, tags, qtype) - ] + children = [c for c in children if has_tag_requirements(c.tags, tags, qtype)] + if name is not None: - filtered_children = [c for c in filtered_children if c.name == name] + children = [c for c in children if c.name == name] - return filtered_children + return children diff --git a/tests/unit/client/utils/test_tags.py b/tests/unit/client/utils/test_tags.py index 873058dd..1ee1f450 100644 --- a/tests/unit/client/utils/test_tags.py +++ b/tests/unit/client/utils/test_tags.py @@ -3,32 +3,45 @@ from rubicon_ml.client.utils.tags import TagContainer, has_tag_requirements -def test_or_single_success(): - assert has_tag_requirements(["x", "y", "z"], ["y"], "or") - - -def test_or_multiple_success(): - assert has_tag_requirements(["x", "y", "z"], ["a", "y"], "or") - - -def test_or_single_failure(): - assert not has_tag_requirements(["x", "y", "z"], ["a"], "or") - - -def test_and_single_success(): - assert has_tag_requirements(["x", "y", "z"], ["y"], "and") - - -def test_and_multiple_success(): - assert has_tag_requirements(["x", "y", "z"], ["y", "z"], "and") - - -def test_and_single_failure(): - assert not has_tag_requirements(["x", "y", "z"], ["a"], "and") - - -def test_and_multiple_failure(): - assert not has_tag_requirements(["x", "y", "z"], ["a", "z"], "and") +@pytest.mark.parametrize( + ["required_tags", "qtype"], + [ + (["pre_y_post"], "or"), + (["missing", "pre_y_post"], "or"), + (["pre_y_post"], "and"), + (["pre_y_post", "pre_z"], "and"), + (["x_post", "pre_y_post"], "and"), + (["*y*"], "or"), + (["missing", "*y*"], "or"), + (["*y*"], "and"), + (["*y*", "*z"], "and"), + (["x*", "*y*"], "and"), + (["*_*_*", "*_*"], "and"), + ], +) +def test_has_tags(required_tags, qtype): + tags = ["x_post", "pre_y_post", "pre_z"] + + assert has_tag_requirements(tags, required_tags, qtype) + + +@pytest.mark.parametrize( + ["required_tags", "qtype"], + [ + (["missing"], "or"), + (["missing"], "and"), + (["missing", "pre_z"], "and"), + (["*y"], "or"), + (["y*"], "or"), + (["*y"], "and"), + (["y*"], "and"), + (["missing", "*z"], "and"), + ], +) +def test_not_has_tags(required_tags, qtype): + tags = ["x_post", "pre_y_post", "pre_z"] + + assert not has_tag_requirements(tags, required_tags, qtype) def test_tag_container():