Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto-generate aggregation classes #1918

Merged
merged 2 commits into from
Oct 4, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
feedback
miguelgrinberg committed Oct 4, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 0b685407b08d5f42b0f6185ecd94f4f770ea7cbc
37 changes: 33 additions & 4 deletions elasticsearch_dsl/aggs.py
Original file line number Diff line number Diff line change
@@ -2788,6 +2788,39 @@ def __init__(
super().__init__(path=path, **kwargs)


class RandomSampler(Bucket[_R]):
"""
A single bucket aggregation that randomly includes documents in the
aggregated results. Sampling provides significant speed improvement at
the cost of accuracy.

:arg probability: (required) The probability that a document will be
included in the aggregated data. Must be greater than 0, less than
0.5, or exactly 1. The lower the probability, the fewer documents
are matched.
:arg seed: The seed to generate the random sampling of documents. When
a seed is provided, the random subset of documents is the same
between calls.
:arg shard_seed: When combined with seed, setting shard_seed ensures
100% consistent sampling over shards where data is exactly the
same.
"""

name = "random_sampler"

def __init__(
self,
*,
probability: Union[float, "DefaultType"] = DEFAULT,
seed: Union[int, "DefaultType"] = DEFAULT,
shard_seed: Union[int, "DefaultType"] = DEFAULT,
**kwargs: Any,
):
super().__init__(
probability=probability, seed=seed, shard_seed=shard_seed, **kwargs
)


class Sampler(Bucket[_R]):
"""
A filtering aggregation used to limit any sub aggregations' processing
@@ -3696,7 +3729,3 @@ def __init__(

def result(self, search: "SearchBase[_R]", data: Any) -> AttrDict[Any]:
return FieldBucketData(self, search, data)


class RandomSampler(Bucket[_R]):
name = "random_sampler"
7 changes: 7 additions & 0 deletions tests/test_aggs.py
Original file line number Diff line number Diff line change
@@ -220,6 +220,7 @@ def test_filters_correctly_identifies_the_hash() -> None:


def test_bucket_sort_agg() -> None:
# test the dictionary (type ignored) and fully typed alterantives
bucket_sort_agg = aggs.BucketSort(sort=[{"total_sales": {"order": "desc"}}], size=3) # type: ignore
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For several examples in this file I kept the original dict based solution with ignored typing errors, and right below I've added a correctly typed version, so that we make sure we are backwards compatible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This deserves to be an in-file comment IMO, because it's tempting to remove in a refactoring.

assert bucket_sort_agg.to_dict() == {
"bucket_sort": {"sort": [{"total_sales": {"order": "desc"}}], "size": 3}
@@ -251,6 +252,7 @@ def test_bucket_sort_agg() -> None:


def test_bucket_sort_agg_only_trnunc() -> None:
# test the dictionary (type ignored) and fully typed alterantives
bucket_sort_agg = aggs.BucketSort(**{"from": 1, "size": 1, "_expand__to_dot": False}) # type: ignore
assert bucket_sort_agg.to_dict() == {"bucket_sort": {"from": 1, "size": 1}}
bucket_sort_agg = aggs.BucketSort(from_=1, size=1, _expand__to_dot=False)
@@ -265,20 +267,23 @@ def test_bucket_sort_agg_only_trnunc() -> None:


def test_geohash_grid_aggregation() -> None:
# test the dictionary (type ignored) and fully typed alterantives
a = aggs.GeohashGrid(**{"field": "centroid", "precision": 3}) # type: ignore
assert {"geohash_grid": {"field": "centroid", "precision": 3}} == a.to_dict()
a = aggs.GeohashGrid(field="centroid", precision=3)
assert {"geohash_grid": {"field": "centroid", "precision": 3}} == a.to_dict()


def test_geohex_grid_aggregation() -> None:
# test the dictionary (type ignored) and fully typed alterantives
a = aggs.GeohexGrid(**{"field": "centroid", "precision": 3}) # type: ignore
assert {"geohex_grid": {"field": "centroid", "precision": 3}} == a.to_dict()
a = aggs.GeohexGrid(field="centroid", precision=3)
assert {"geohex_grid": {"field": "centroid", "precision": 3}} == a.to_dict()


def test_geotile_grid_aggregation() -> None:
# test the dictionary (type ignored) and fully typed alterantives
a = aggs.GeotileGrid(**{"field": "centroid", "precision": 3}) # type: ignore
assert {"geotile_grid": {"field": "centroid", "precision": 3}} == a.to_dict()
a = aggs.GeotileGrid(field="centroid", precision=3)
@@ -318,6 +323,7 @@ def test_variable_width_histogram_aggregation() -> None:


def test_ip_prefix_aggregation() -> None:
# test the dictionary (type ignored) and fully typed alterantives
a = aggs.IPPrefix(**{"field": "ipv4", "prefix_length": 24}) # type: ignore
assert {"ip_prefix": {"field": "ipv4", "prefix_length": 24}} == a.to_dict()
a = aggs.IPPrefix(field="ipv4", prefix_length=24)
@@ -501,6 +507,7 @@ def test_adjancecy_matrix_aggregation() -> None:


def test_top_metrics_aggregation() -> None:
# test the dictionary (type ignored) and fully typed alterantives
a = aggs.TopMetrics(metrics={"field": "m"}, sort={"s": "desc"}) # type: ignore
assert {
"top_metrics": {"metrics": {"field": "m"}, "sort": {"s": "desc"}}
3 changes: 0 additions & 3 deletions utils/templates/aggs.py.tpl
Original file line number Diff line number Diff line change
@@ -318,6 +318,3 @@ class {{ k.name }}({{ k.parent if k.parent else parent }}[_R]):

{% endif %}
{% endfor %}
{# the following aggregation is in technical preview and does not exist in the specification #}
class RandomSampler(Bucket[_R]):
name = "random_sampler"