Skip to content

Commit

Permalink
Corrected typing hints for the FunctionScore query (#1960)
Browse files Browse the repository at this point in the history
Fixes #1957
  • Loading branch information
miguelgrinberg authored Jan 7, 2025
1 parent 5a2cc86 commit 5d2bccd
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 77 deletions.
8 changes: 2 additions & 6 deletions elasticsearch_dsl/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,9 +612,9 @@ class FunctionScore(Query):

name = "function_score"
_param_defs = {
"functions": {"type": "score_function", "multi": True},
"query": {"type": "query"},
"filter": {"type": "query"},
"functions": {"type": "score_function", "multi": True},
}

def __init__(
Expand All @@ -623,11 +623,7 @@ def __init__(
boost_mode: Union[
Literal["multiply", "replace", "sum", "avg", "max", "min"], "DefaultType"
] = DEFAULT,
functions: Union[
Sequence["types.FunctionScoreContainer"],
Sequence[Dict[str, Any]],
"DefaultType",
] = DEFAULT,
functions: Union[Sequence[ScoreFunction], "DefaultType"] = DEFAULT,
max_boost: Union[float, "DefaultType"] = DEFAULT,
min_score: Union[float, "DefaultType"] = DEFAULT,
query: Union[Query, "DefaultType"] = DEFAULT,
Expand Down
70 changes: 1 addition & 69 deletions elasticsearch_dsl/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from elastic_transport.client_utils import DEFAULT, DefaultType

from elasticsearch_dsl import Query, function
from elasticsearch_dsl import Query
from elasticsearch_dsl.document_base import InstrumentedField
from elasticsearch_dsl.utils import AttrDict

Expand Down Expand Up @@ -688,74 +688,6 @@ def __init__(
super().__init__(kwargs)


class FunctionScoreContainer(AttrDict[Any]):
"""
:arg exp: Function that scores a document with a exponential decay,
depending on the distance of a numeric field value of the document
from an origin.
:arg gauss: Function that scores a document with a normal decay,
depending on the distance of a numeric field value of the document
from an origin.
:arg linear: Function that scores a document with a linear decay,
depending on the distance of a numeric field value of the document
from an origin.
:arg field_value_factor: Function allows you to use a field from a
document to influence the score. It’s similar to using the
script_score function, however, it avoids the overhead of
scripting.
:arg random_score: Generates scores that are uniformly distributed
from 0 up to but not including 1. In case you want scores to be
reproducible, it is possible to provide a `seed` and `field`.
:arg script_score: Enables you to wrap another query and customize the
scoring of it optionally with a computation derived from other
numeric field values in the doc using a script expression.
:arg filter:
:arg weight:
"""

exp: Union[function.DecayFunction, DefaultType]
gauss: Union[function.DecayFunction, DefaultType]
linear: Union[function.DecayFunction, DefaultType]
field_value_factor: Union[function.FieldValueFactorScore, DefaultType]
random_score: Union[function.RandomScore, DefaultType]
script_score: Union[function.ScriptScore, DefaultType]
filter: Union[Query, DefaultType]
weight: Union[float, DefaultType]

def __init__(
self,
*,
exp: Union[function.DecayFunction, DefaultType] = DEFAULT,
gauss: Union[function.DecayFunction, DefaultType] = DEFAULT,
linear: Union[function.DecayFunction, DefaultType] = DEFAULT,
field_value_factor: Union[
function.FieldValueFactorScore, DefaultType
] = DEFAULT,
random_score: Union[function.RandomScore, DefaultType] = DEFAULT,
script_score: Union[function.ScriptScore, DefaultType] = DEFAULT,
filter: Union[Query, DefaultType] = DEFAULT,
weight: Union[float, DefaultType] = DEFAULT,
**kwargs: Any,
):
if exp is not DEFAULT:
kwargs["exp"] = exp
if gauss is not DEFAULT:
kwargs["gauss"] = gauss
if linear is not DEFAULT:
kwargs["linear"] = linear
if field_value_factor is not DEFAULT:
kwargs["field_value_factor"] = field_value_factor
if random_score is not DEFAULT:
kwargs["random_score"] = random_score
if script_score is not DEFAULT:
kwargs["script_score"] = script_score
if filter is not DEFAULT:
kwargs["filter"] = filter
if weight is not DEFAULT:
kwargs["weight"] = weight
super().__init__(kwargs)


class FuzzyQuery(AttrDict[Any]):
"""
:arg value: (required) Term you wish to find in the provided field.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,33 @@ def test_function_score_to_dict() -> None:
assert d == q.to_dict()


def test_function_score_class_based_to_dict() -> None:
q = query.FunctionScore(
query=query.Match(title="python"),
functions=[
function.RandomScore(),
function.FieldValueFactor(
field="comment_count",
filter=query.Term(tags="python"),
),
],
)

d = {
"function_score": {
"query": {"match": {"title": "python"}},
"functions": [
{"random_score": {}},
{
"filter": {"term": {"tags": "python"}},
"field_value_factor": {"field": "comment_count"},
},
],
}
}
assert d == q.to_dict()


def test_function_score_with_single_function() -> None:
d = {
"function_score": {
Expand Down
6 changes: 6 additions & 0 deletions utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ def get_python_type(self, schema_type, for_response=False):
):
# QueryContainer maps to the DSL's Query class
return "Query", {"type": "query"}
elif (
type_name["namespace"] == "_types.query_dsl"
and type_name["name"] == "FunctionScoreContainer"
):
# FunctionScoreContainer maps to the DSL's ScoreFunction class
return "ScoreFunction", {"type": "score_function"}
elif (
type_name["namespace"] == "_types.aggregations"
and type_name["name"] == "Buckets"
Expand Down
1 change: 0 additions & 1 deletion utils/templates/query.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ class {{ k.name }}({{ parent }}):
shortcut property. Until the code generator can support shortcut
properties directly that solution is added here #}
"filter": {"type": "query"},
"functions": {"type": "score_function", "multi": True},
{% endif %}
}
{% endif %}
Expand Down
2 changes: 1 addition & 1 deletion utils/templates/types.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ from typing import Any, Dict, Literal, Mapping, Sequence, Union
from elastic_transport.client_utils import DEFAULT, DefaultType

from elasticsearch_dsl.document_base import InstrumentedField
from elasticsearch_dsl import function, Query
from elasticsearch_dsl import Query
from elasticsearch_dsl.utils import AttrDict

PipeSeparatedFlags = str
Expand Down

0 comments on commit 5d2bccd

Please sign in to comment.