diff --git a/elasticsearch_dsl/query.py b/elasticsearch_dsl/query.py index dfb5518b..868610da 100644 --- a/elasticsearch_dsl/query.py +++ b/elasticsearch_dsl/query.py @@ -612,9 +612,9 @@ class FunctionScore(Query): name = "function_score" _param_defs = { + "functions": {"type": "score_function", "multi": True}, "query": {"type": "query"}, "filter": {"type": "query"}, - "functions": {"type": "score_function", "multi": True}, } def __init__( @@ -623,11 +623,7 @@ def __init__( boost_mode: Union[ Literal["multiply", "replace", "sum", "avg", "max", "min"], "DefaultType" ] = DEFAULT, - functions: Union[ - Sequence["types.FunctionScoreContainer"], - Sequence[Dict[str, Any]], - "DefaultType", - ] = DEFAULT, + functions: Union[Sequence[ScoreFunction], "DefaultType"] = DEFAULT, max_boost: Union[float, "DefaultType"] = DEFAULT, min_score: Union[float, "DefaultType"] = DEFAULT, query: Union[Query, "DefaultType"] = DEFAULT, diff --git a/elasticsearch_dsl/types.py b/elasticsearch_dsl/types.py index ebb18bb2..57521543 100644 --- a/elasticsearch_dsl/types.py +++ b/elasticsearch_dsl/types.py @@ -19,7 +19,7 @@ from elastic_transport.client_utils import DEFAULT, DefaultType -from elasticsearch_dsl import Query, function +from elasticsearch_dsl import Query from elasticsearch_dsl.document_base import InstrumentedField from elasticsearch_dsl.utils import AttrDict @@ -688,74 +688,6 @@ def __init__( super().__init__(kwargs) -class FunctionScoreContainer(AttrDict[Any]): - """ - :arg exp: Function that scores a document with a exponential decay, - depending on the distance of a numeric field value of the document - from an origin. - :arg gauss: Function that scores a document with a normal decay, - depending on the distance of a numeric field value of the document - from an origin. - :arg linear: Function that scores a document with a linear decay, - depending on the distance of a numeric field value of the document - from an origin. - :arg field_value_factor: Function allows you to use a field from a - document to influence the score. It’s similar to using the - script_score function, however, it avoids the overhead of - scripting. - :arg random_score: Generates scores that are uniformly distributed - from 0 up to but not including 1. In case you want scores to be - reproducible, it is possible to provide a `seed` and `field`. - :arg script_score: Enables you to wrap another query and customize the - scoring of it optionally with a computation derived from other - numeric field values in the doc using a script expression. - :arg filter: - :arg weight: - """ - - exp: Union[function.DecayFunction, DefaultType] - gauss: Union[function.DecayFunction, DefaultType] - linear: Union[function.DecayFunction, DefaultType] - field_value_factor: Union[function.FieldValueFactorScore, DefaultType] - random_score: Union[function.RandomScore, DefaultType] - script_score: Union[function.ScriptScore, DefaultType] - filter: Union[Query, DefaultType] - weight: Union[float, DefaultType] - - def __init__( - self, - *, - exp: Union[function.DecayFunction, DefaultType] = DEFAULT, - gauss: Union[function.DecayFunction, DefaultType] = DEFAULT, - linear: Union[function.DecayFunction, DefaultType] = DEFAULT, - field_value_factor: Union[ - function.FieldValueFactorScore, DefaultType - ] = DEFAULT, - random_score: Union[function.RandomScore, DefaultType] = DEFAULT, - script_score: Union[function.ScriptScore, DefaultType] = DEFAULT, - filter: Union[Query, DefaultType] = DEFAULT, - weight: Union[float, DefaultType] = DEFAULT, - **kwargs: Any, - ): - if exp is not DEFAULT: - kwargs["exp"] = exp - if gauss is not DEFAULT: - kwargs["gauss"] = gauss - if linear is not DEFAULT: - kwargs["linear"] = linear - if field_value_factor is not DEFAULT: - kwargs["field_value_factor"] = field_value_factor - if random_score is not DEFAULT: - kwargs["random_score"] = random_score - if script_score is not DEFAULT: - kwargs["script_score"] = script_score - if filter is not DEFAULT: - kwargs["filter"] = filter - if weight is not DEFAULT: - kwargs["weight"] = weight - super().__init__(kwargs) - - class FuzzyQuery(AttrDict[Any]): """ :arg value: (required) Term you wish to find in the provided field. diff --git a/tests/test_query.py b/tests/test_query.py index 35fe957a..5683ca91 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -562,6 +562,33 @@ def test_function_score_to_dict() -> None: assert d == q.to_dict() +def test_function_score_class_based_to_dict() -> None: + q = query.FunctionScore( + query=query.Match(title="python"), + functions=[ + function.RandomScore(), + function.FieldValueFactor( + field="comment_count", + filter=query.Term(tags="python"), + ), + ], + ) + + d = { + "function_score": { + "query": {"match": {"title": "python"}}, + "functions": [ + {"random_score": {}}, + { + "filter": {"term": {"tags": "python"}}, + "field_value_factor": {"field": "comment_count"}, + }, + ], + } + } + assert d == q.to_dict() + + def test_function_score_with_single_function() -> None: d = { "function_score": { diff --git a/utils/generator.py b/utils/generator.py index 653f943b..2bbb34e5 100644 --- a/utils/generator.py +++ b/utils/generator.py @@ -200,6 +200,12 @@ def get_python_type(self, schema_type, for_response=False): ): # QueryContainer maps to the DSL's Query class return "Query", {"type": "query"} + elif ( + type_name["namespace"] == "_types.query_dsl" + and type_name["name"] == "FunctionScoreContainer" + ): + # FunctionScoreContainer maps to the DSL's ScoreFunction class + return "ScoreFunction", {"type": "score_function"} elif ( type_name["namespace"] == "_types.aggregations" and type_name["name"] == "Buckets" diff --git a/utils/templates/query.py.tpl b/utils/templates/query.py.tpl index ca95f5a0..71ef4ff7 100644 --- a/utils/templates/query.py.tpl +++ b/utils/templates/query.py.tpl @@ -174,7 +174,6 @@ class {{ k.name }}({{ parent }}): shortcut property. Until the code generator can support shortcut properties directly that solution is added here #} "filter": {"type": "query"}, - "functions": {"type": "score_function", "multi": True}, {% endif %} } {% endif %} diff --git a/utils/templates/types.py.tpl b/utils/templates/types.py.tpl index 0571e068..7f203ad7 100644 --- a/utils/templates/types.py.tpl +++ b/utils/templates/types.py.tpl @@ -20,7 +20,7 @@ from typing import Any, Dict, Literal, Mapping, Sequence, Union from elastic_transport.client_utils import DEFAULT, DefaultType from elasticsearch_dsl.document_base import InstrumentedField -from elasticsearch_dsl import function, Query +from elasticsearch_dsl import Query from elasticsearch_dsl.utils import AttrDict PipeSeparatedFlags = str