diff --git a/presto-docs/src/main/sphinx/admin/properties.rst b/presto-docs/src/main/sphinx/admin/properties.rst index 0b4a7ccc8393..dd6c2efab9c5 100644 --- a/presto-docs/src/main/sphinx/admin/properties.rst +++ b/presto-docs/src/main/sphinx/admin/properties.rst @@ -884,6 +884,17 @@ This can also be specified on a per-query basis using the ``confidence_based_bro Enable treating ``LOW`` confidence, zero estimations as ``UNKNOWN`` during joins. This can also be specified on a per-query basis using the ``treat-low-confidence-zero-estimation-as-unknown`` session property. +``optimizer.scalar-function-stats-propagation-enabled`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +Enable scalar functions stats propagation using annotations. Annotations define the behavior of the scalar +function's stats characteristics. When set to ``true``, this property enables the stats propagation through annotations. +This can also be specified on a per-query basis using the ``scalar_function_stats_propagation_enabled`` session property. + ``optimizer.retry-query-with-history-based-optimization`` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java index d039bd559574..61f7174a9aaa 100644 --- a/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -371,6 +371,7 @@ public final class SystemSessionProperties public static final String OPTIMIZER_USE_HISTOGRAMS = "optimizer_use_histograms"; public static final String WARN_ON_COMMON_NAN_PATTERNS = "warn_on_common_nan_patterns"; public static final String INLINE_PROJECTIONS_ON_VALUES = "inline_projections_on_values"; + public static final String SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED = "scalar_function_stats_propagation_enabled"; private final List> sessionProperties; @@ -2077,6 +2078,10 @@ public SystemSessionProperties( booleanProperty(INLINE_PROJECTIONS_ON_VALUES, "Whether to evaluate project node on values node", featuresConfig.getInlineProjectionsOnValues(), + false), + booleanProperty(SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED, + "whether or not to respect stats propagation annotation for scalar functions (or UDF)", + featuresConfig.isScalarFunctionStatsPropagationEnabled(), false)); } @@ -3414,4 +3419,9 @@ public static boolean isInlineProjectionsOnValues(Session session) { return session.getSystemProperty(INLINE_PROJECTIONS_ON_VALUES, Boolean.class); } + + public static boolean shouldEnableScalarFunctionStatsPropagation(Session session) + { + return session.getSystemProperty(SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED, Boolean.class); + } } diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsAnnotationProcessor.java b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsAnnotationProcessor.java new file mode 100644 index 000000000000..56fca4c9bf8d --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsAnnotationProcessor.java @@ -0,0 +1,243 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.common.type.FixedWidthType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; +import com.facebook.presto.spi.function.ScalarStatsHeader; +import com.facebook.presto.spi.function.StatsPropagationBehavior; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; + +import java.util.List; +import java.util.Map; + +import static com.facebook.presto.spi.function.StatsPropagationBehavior.NON_NULL_ROW_COUNT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.ROW_COUNT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.SUM_ARGUMENTS; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.UNKNOWN; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS; +import static com.facebook.presto.util.MoreMath.max; +import static com.facebook.presto.util.MoreMath.min; +import static com.facebook.presto.util.MoreMath.minExcludingNaNs; +import static com.facebook.presto.util.MoreMath.nearlyEqual; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.Double.NaN; +import static java.lang.Double.isFinite; +import static java.lang.Double.isNaN; + +public final class ScalarStatsAnnotationProcessor +{ + private ScalarStatsAnnotationProcessor() + { + } + + public static VariableStatsEstimate process( + double outputRowCount, + CallExpression callExpression, + List sourceStats, + ScalarStatsHeader scalarStatsHeader) + { + double nullFraction = scalarStatsHeader.getNullFraction(); + double distinctValuesCount = NaN; + double averageRowSize = NaN; + double maxValue = scalarStatsHeader.getMax(); + double minValue = scalarStatsHeader.getMin(); + for (Map.Entry paramIndexToStatsMap : scalarStatsHeader.getArgumentStats().entrySet()) { + ScalarPropagateSourceStats scalarPropagateSourceStats = paramIndexToStatsMap.getValue(); + boolean propagateAllStats = scalarPropagateSourceStats.propagateAllStats(); + nullFraction = min(firstFiniteValue(nullFraction, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression, + sourceStats.stream().map(VariableStatsEstimate::getNullsFraction).collect(toImmutableList()), + paramIndexToStatsMap.getKey(), + applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.nullFraction()))), 1.0); + distinctValuesCount = firstFiniteValue(distinctValuesCount, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression, + sourceStats.stream().map(VariableStatsEstimate::getDistinctValuesCount).collect(toImmutableList()), + paramIndexToStatsMap.getKey(), + applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.distinctValuesCount()))); + StatsPropagationBehavior averageRowSizeStatsBehaviour = applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.avgRowSize()); + averageRowSize = minExcludingNaNs(firstFiniteValue(averageRowSize, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression, + sourceStats.stream().map(VariableStatsEstimate::getAverageRowSize).collect(toImmutableList()), + paramIndexToStatsMap.getKey(), + averageRowSizeStatsBehaviour)), returnNaNIfTypeWidthUnknown(getReturnTypeWidth(callExpression, averageRowSizeStatsBehaviour))); + maxValue = firstFiniteValue(maxValue, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression, + sourceStats.stream().map(VariableStatsEstimate::getHighValue).collect(toImmutableList()), + paramIndexToStatsMap.getKey(), + applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.maxValue()))); + minValue = firstFiniteValue(minValue, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression, + sourceStats.stream().map(VariableStatsEstimate::getLowValue).collect(toImmutableList()), + paramIndexToStatsMap.getKey(), + applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.minValue()))); + } + if (isNaN(maxValue) || isNaN(minValue)) { + minValue = NaN; + maxValue = NaN; + } + return VariableStatsEstimate.builder() + .setLowValue(minValue) + .setHighValue(maxValue) + .setNullsFraction(nullFraction) + .setAverageRowSize(firstFiniteValue(scalarStatsHeader.getAvgRowSize(), averageRowSize, returnNaNIfTypeWidthUnknown(getReturnTypeWidth(callExpression, UNKNOWN)))) + .setDistinctValuesCount(processDistinctValuesCount(outputRowCount, nullFraction, scalarStatsHeader.getDistinctValuesCount(), distinctValuesCount)).build(); + } + + private static double processDistinctValuesCount(double outputRowCount, double nullFraction, double distinctValuesCountFromConstant, double distinctValuesCount) + { + if (isFinite(distinctValuesCountFromConstant)) { + if (nearlyEqual(distinctValuesCountFromConstant, NON_NULL_ROW_COUNT.getValue(), 0.1)) { + distinctValuesCountFromConstant = outputRowCount * (1 - firstFiniteValue(nullFraction, 0.0)); + } + else if (nearlyEqual(distinctValuesCount, ROW_COUNT.getValue(), 0.1)) { + distinctValuesCountFromConstant = outputRowCount; + } + } + double distinctValuesCountFinal = firstFiniteValue(distinctValuesCountFromConstant, distinctValuesCount); + if (distinctValuesCountFinal > outputRowCount) { + distinctValuesCountFinal = NaN; + } + return distinctValuesCountFinal; + } + + private static double processSingleArgumentStatistic( + double outputRowCount, + double nullFraction, + CallExpression callExpression, + List sourceStats, + int sourceStatsArgumentIndex, + StatsPropagationBehavior operation) + { + // sourceStatsArgumentIndex is index of the argument on which + // ScalarPropagateSourceStats annotation was applied. + double statValue = NaN; + if (operation.isMultiArgumentStat()) { + for (int i = 0; i < sourceStats.size(); i++) { + if (i == 0 && operation.isSourceStatsDependentStats() && isFinite(sourceStats.get(i))) { + statValue = sourceStats.get(i); + } + else { + switch (operation) { + case MAX_TYPE_WIDTH_VARCHAR: + statValue = returnNaNIfTypeWidthUnknown(getTypeWidthVarchar(callExpression.getArguments().get(i).getType())); + break; + case USE_MIN_ARGUMENT: + statValue = min(statValue, sourceStats.get(i)); + break; + case USE_MAX_ARGUMENT: + statValue = max(statValue, sourceStats.get(i)); + break; + case SUM_ARGUMENTS: + statValue = statValue + sourceStats.get(i); + break; + case SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT: + statValue = min(statValue + sourceStats.get(i), outputRowCount); + break; + } + } + } + } + else { + switch (operation) { + case USE_SOURCE_STATS: + statValue = sourceStats.get(sourceStatsArgumentIndex); + break; + case ROW_COUNT: + statValue = outputRowCount; + break; + case NON_NULL_ROW_COUNT: + statValue = outputRowCount * (1 - firstFiniteValue(nullFraction, 0.0)); + break; + case USE_TYPE_WIDTH_VARCHAR: + statValue = returnNaNIfTypeWidthUnknown(getTypeWidthVarchar(callExpression.getArguments().get(sourceStatsArgumentIndex).getType())); + break; + } + } + return statValue; + } + + private static int getTypeWidthVarchar(Type argumentType) + { + if (argumentType instanceof VarcharType) { + if (!((VarcharType) argumentType).isUnbounded()) { + return ((VarcharType) argumentType).getLengthSafe(); + } + } + return -VarcharType.MAX_LENGTH; + } + + private static double returnNaNIfTypeWidthUnknown(int typeWidthValue) + { + if (typeWidthValue <= 0) { + return NaN; + } + return typeWidthValue; + } + + private static int getReturnTypeWidth(CallExpression callExpression, StatsPropagationBehavior operation) + { + if (callExpression.getType() instanceof FixedWidthType) { + return ((FixedWidthType) callExpression.getType()).getFixedSize(); + } + if (callExpression.getType() instanceof VarcharType) { + VarcharType returnType = (VarcharType) callExpression.getType(); + if (!returnType.isUnbounded()) { + return returnType.getLengthSafe(); + } + if (operation == SUM_ARGUMENTS || operation == SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT) { + // since return type is an unbounded varchar and operation is SUM_ARGUMENTS, + // calculating the type width by doing a SUM of each argument's varchar type bounds - if available. + int sum = 0; + for (RowExpression r : callExpression.getArguments()) { + int typeWidth; + if (r instanceof CallExpression) { // argument is another function call + typeWidth = getReturnTypeWidth((CallExpression) r, UNKNOWN); + } + else { + typeWidth = getTypeWidthVarchar(r.getType()); + } + if (typeWidth < 0) { + return -VarcharType.MAX_LENGTH; + } + sum += typeWidth; + } + return sum; + } + } + return -VarcharType.MAX_LENGTH; + } + + // Return first 'finite' value from values, else return values[0] + private static double firstFiniteValue(double... values) + { + checkArgument(values.length > 1); + for (double v : values) { + if (isFinite(v)) { + return v; + } + } + return values[0]; + } + + private static StatsPropagationBehavior applyPropagateAllStats( + boolean propagateAllStats, StatsPropagationBehavior operation) + { + if (operation == UNKNOWN && propagateAllStats) { + return USE_SOURCE_STATS; + } + return operation; + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java index 044785ca0a44..ccd692ff8fca 100644 --- a/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java +++ b/presto-main/src/main/java/com/facebook/presto/cost/ScalarStatsCalculator.java @@ -13,14 +13,19 @@ */ package com.facebook.presto.cost; +import com.facebook.presto.FullConnectorSession; import com.facebook.presto.Session; +import com.facebook.presto.SystemSessionProperties; import com.facebook.presto.common.function.OperatorType; import com.facebook.presto.common.type.Type; import com.facebook.presto.common.type.TypeSignature; +import com.facebook.presto.metadata.BuiltInFunctionHandle; import com.facebook.presto.metadata.Metadata; import com.facebook.presto.spi.ConnectorSession; import com.facebook.presto.spi.WarningCollector; import com.facebook.presto.spi.function.FunctionMetadata; +import com.facebook.presto.spi.function.ScalarStatsHeader; +import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.relation.CallExpression; import com.facebook.presto.spi.relation.ConstantExpression; import com.facebook.presto.spi.relation.InputReferenceExpression; @@ -53,8 +58,11 @@ import javax.inject.Inject; +import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.OptionalDouble; +import java.util.stream.IntStream; import static com.facebook.presto.common.function.OperatorType.DIVIDE; import static com.facebook.presto.common.function.OperatorType.MODULUS; @@ -66,7 +74,9 @@ import static com.facebook.presto.sql.relational.Expressions.isNull; import static com.facebook.presto.util.MoreMath.max; import static com.facebook.presto.util.MoreMath.min; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.lang.Double.NaN; import static java.lang.Double.isFinite; import static java.lang.Double.isNaN; @@ -107,11 +117,15 @@ private class RowExpressionStatsVisitor private final PlanNodeStatsEstimate input; private final ConnectorSession session; private final FunctionResolution resolution = new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()); + private final boolean isStatsPropagationEnabled; public RowExpressionStatsVisitor(PlanNodeStatsEstimate input, ConnectorSession session) { this.input = requireNonNull(input, "input is null"); this.session = requireNonNull(session, "session is null"); + // casting session to FullConnectorSession is not ideal. + this.isStatsPropagationEnabled = + SystemSessionProperties.shouldEnableScalarFunctionStatsPropagation(((FullConnectorSession) session).getSession()); } @Override @@ -136,11 +150,12 @@ public VariableStatsEstimate visitCall(CallExpression call, Void context) return value.accept(this, context); } - // value is not a constant but we can still propagate estimation through cast + // value is not a constant, but we can still propagate estimation through cast if (resolution.isCastFunction(call.getFunctionHandle())) { return computeCastStatistics(call, context); } - return VariableStatsEstimate.unknown(); + + return computeStatsViaAnnotations(call, context, functionMetadata); } @Override @@ -199,10 +214,41 @@ public VariableStatsEstimate visitSpecialForm(SpecialFormExpression specialForm, return VariableStatsEstimate.unknown(); } + private VariableStatsEstimate computeStatsViaAnnotations(CallExpression call, Void context, FunctionMetadata functionMetadata) + { + if (isStatsPropagationEnabled) { + if (functionMetadata.hasStatsHeader() && call.getFunctionHandle() instanceof BuiltInFunctionHandle) { + Signature signature = ((BuiltInFunctionHandle) call.getFunctionHandle()).getSignature().canonicalization(); + Optional statsHeader = functionMetadata.getScalarStatsHeader(signature); + if (statsHeader.isPresent()) { + return computeCallStatistics(call, context, statsHeader.get()); + } + } + } + return VariableStatsEstimate.unknown(); + } + + private VariableStatsEstimate getSourceStats(CallExpression call, Void context, int argumentIndex) + { + checkArgument(argumentIndex < call.getArguments().size(), + format("function argument index: %d >= %d (call argument size) for %s", argumentIndex, call.getArguments().size(), call)); + return call.getArguments().get(argumentIndex).accept(this, context); + } + + private VariableStatsEstimate computeCallStatistics(CallExpression call, Void context, ScalarStatsHeader scalarStatsHeader) + { + requireNonNull(call, "call is null"); + List sourceStatsList = + IntStream.range(0, call.getArguments().size()).mapToObj(argumentIndex -> getSourceStats(call, context, argumentIndex)).collect(toImmutableList()); + VariableStatsEstimate result = + ScalarStatsAnnotationProcessor.process(input.getOutputRowCount(), call, sourceStatsList, scalarStatsHeader); + return result; + } + private VariableStatsEstimate computeCastStatistics(CallExpression call, Void context) { requireNonNull(call, "call is null"); - VariableStatsEstimate sourceStats = call.getArguments().get(0).accept(this, context); + VariableStatsEstimate sourceStats = getSourceStats(call, context, 0); // todo - make this general postprocessing rule. double distinctValuesCount = sourceStats.getDistinctValuesCount(); @@ -236,7 +282,7 @@ private VariableStatsEstimate computeCastStatistics(CallExpression call, Void co private VariableStatsEstimate computeNegationStatistics(CallExpression call, Void context) { requireNonNull(call, "call is null"); - VariableStatsEstimate stats = call.getArguments().get(0).accept(this, context); + VariableStatsEstimate stats = getSourceStats(call, context, 0); if (resolution.isNegateFunction(call.getFunctionHandle())) { return VariableStatsEstimate.buildFrom(stats) .setLowValue(-stats.getHighValue()) @@ -249,14 +295,13 @@ private VariableStatsEstimate computeNegationStatistics(CallExpression call, Voi private VariableStatsEstimate computeArithmeticBinaryStatistics(CallExpression call, Void context) { requireNonNull(call, "call is null"); - VariableStatsEstimate left = call.getArguments().get(0).accept(this, context); - VariableStatsEstimate right = call.getArguments().get(1).accept(this, context); + VariableStatsEstimate left = getSourceStats(call, context, 0); + VariableStatsEstimate right = getSourceStats(call, context, 1); VariableStatsEstimate.Builder result = VariableStatsEstimate.builder() .setAverageRowSize(Math.max(left.getAverageRowSize(), right.getAverageRowSize())) .setNullsFraction(left.getNullsFraction() + right.getNullsFraction() - left.getNullsFraction() * right.getNullsFraction()) .setDistinctValuesCount(min(left.getDistinctValuesCount() * right.getDistinctValuesCount(), input.getOutputRowCount())); - FunctionMetadata functionMetadata = metadata.getFunctionAndTypeManager().getFunctionMetadata(call.getFunctionHandle()); checkState(functionMetadata.getOperatorType().isPresent()); OperatorType operatorType = functionMetadata.getOperatorType().get(); diff --git a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java index ae76fd532db6..8cbadf028402 100644 --- a/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java +++ b/presto-main/src/main/java/com/facebook/presto/metadata/BuiltInTypeAndFunctionNamespaceManager.java @@ -184,6 +184,7 @@ import com.facebook.presto.operator.scalar.MathFunctions; import com.facebook.presto.operator.scalar.MathFunctions.LegacyLogFunction; import com.facebook.presto.operator.scalar.MultimapFromEntriesFunction; +import com.facebook.presto.operator.scalar.ParametricScalar; import com.facebook.presto.operator.scalar.QuantileDigestFunctions; import com.facebook.presto.operator.scalar.Re2JRegexpFunctions; import com.facebook.presto.operator.scalar.Re2JRegexpReplaceLambdaFunction; @@ -1181,6 +1182,19 @@ else if (function instanceof SqlInvokedFunction) { sqlFunction.getVersion(), sqlFunction.getComplexTypeFunctionDescriptor()); } + else if (function instanceof ParametricScalar) { + ParametricScalar sqlFunction = (ParametricScalar) function; + return new FunctionMetadata( + signature.getName(), + signature.getArgumentTypes(), + signature.getReturnType(), + signature.getKind(), + JAVA, + function.isDeterministic(), + function.isCalledOnNullInput(), + sqlFunction.getComplexTypeFunctionDescriptor(), + sqlFunction.getScalarHeader().getSignatureToScalarStatsHeadersMap()); + } else { return new FunctionMetadata( signature.getName(), diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java index 3950c075e9fe..7a09287c341a 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ParametricScalar.java @@ -30,47 +30,62 @@ import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_MISSING; import static com.facebook.presto.util.Failures.checkCondition; +import static com.google.common.base.MoreObjects.toStringHelper; import static java.lang.String.format; import static java.util.Objects.requireNonNull; public class ParametricScalar extends SqlScalarFunction { - private final ScalarHeader details; + private final ScalarHeader scalarHeader; private final ParametricImplementationsGroup implementations; public ParametricScalar( Signature signature, - ScalarHeader details, + ScalarHeader scalarHeader, ParametricImplementationsGroup implementations) { super(signature); - this.details = requireNonNull(details); + this.scalarHeader = requireNonNull(scalarHeader); this.implementations = requireNonNull(implementations); } @Override public SqlFunctionVisibility getVisibility() { - return details.getVisibility(); + return scalarHeader.getVisibility(); + } + + public ScalarHeader getScalarHeader() + { + return scalarHeader; } @Override public boolean isDeterministic() { - return details.isDeterministic(); + return scalarHeader.isDeterministic(); } @Override public boolean isCalledOnNullInput() { - return details.isCalledOnNullInput(); + return scalarHeader.isCalledOnNullInput(); + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("signature", getSignature()) + .add("implementation", implementations) + .add("scalarHeader", scalarHeader).toString(); } @Override public String getDescription() { - return details.getDescription().isPresent() ? details.getDescription().get() : ""; + return scalarHeader.getDescription().isPresent() ? scalarHeader.getDescription().get() : ""; } @VisibleForTesting diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java index b6b5e33fa6c0..0959a6684a01 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/ScalarHeader.java @@ -13,16 +13,24 @@ */ package com.facebook.presto.operator.scalar; +import com.facebook.presto.spi.function.ScalarStatsHeader; +import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlFunctionVisibility; +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableMap; +import java.util.Map; import java.util.Optional; +import static com.google.common.base.MoreObjects.toStringHelper; + public class ScalarHeader { private final Optional description; private final SqlFunctionVisibility visibility; private final boolean deterministic; private final boolean calledOnNullInput; + private final Map signatureToScalarStatsHeaders; public ScalarHeader(Optional description, SqlFunctionVisibility visibility, boolean deterministic, boolean calledOnNullInput) { @@ -30,6 +38,29 @@ public ScalarHeader(Optional description, SqlFunctionVisibility visibili this.visibility = visibility; this.deterministic = deterministic; this.calledOnNullInput = calledOnNullInput; + this.signatureToScalarStatsHeaders = ImmutableMap.of(); + } + + public ScalarHeader(Optional description, SqlFunctionVisibility visibility, boolean deterministic, boolean calledOnNullInput, + Map signatureToScalarStatsHeaders) + { + this.description = description; + this.visibility = visibility; + this.deterministic = deterministic; + this.calledOnNullInput = calledOnNullInput; + this.signatureToScalarStatsHeaders = signatureToScalarStatsHeaders; + } + + @Override + public String toString() + { + return toStringHelper(this) + .add("description:", this.description) + .add("visibility", this.visibility) + .add("deterministic", this.deterministic) + .add("calledOnNullInput", this.calledOnNullInput) + .add("signatureToScalarStatsHeadersMap", Joiner.on(" , ").withKeyValueSeparator(" -> ").join(this.signatureToScalarStatsHeaders)) + .toString(); } public Optional getDescription() @@ -51,4 +82,9 @@ public boolean isCalledOnNullInput() { return calledOnNullInput; } + + public Map getSignatureToScalarStatsHeadersMap() + { + return signatureToScalarStatsHeaders; + } } diff --git a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java index eb2f5a3ae346..1b8ef7892c99 100644 --- a/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java +++ b/presto-main/src/main/java/com/facebook/presto/operator/scalar/annotations/ScalarFromAnnotationsParser.java @@ -17,23 +17,32 @@ import com.facebook.presto.operator.ParametricImplementationsGroup; import com.facebook.presto.operator.annotations.FunctionsParserHelper; import com.facebook.presto.operator.scalar.ParametricScalar; +import com.facebook.presto.operator.scalar.ScalarHeader; import com.facebook.presto.operator.scalar.annotations.ParametricScalarImplementation.SpecializedSignature; import com.facebook.presto.spi.function.CodegenScalarFunction; import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionConstantStats; import com.facebook.presto.spi.function.ScalarOperator; +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; +import com.facebook.presto.spi.function.ScalarStatsHeader; import com.facebook.presto.spi.function.Signature; import com.facebook.presto.spi.function.SqlInvokedScalarFunction; import com.facebook.presto.spi.function.SqlType; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.lang.reflect.Constructor; import java.lang.reflect.Method; +import java.lang.reflect.Parameter; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.IntStream; import static com.facebook.presto.operator.scalar.annotations.OperatorValidator.validateOperator; import static com.facebook.presto.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; @@ -99,11 +108,38 @@ private static List findScalarsInFunctionSetClass(Class< return builder.build(); } + private static Optional getScalarStatsHeader(Method annotated) + { + Optional scalarStatsHeader; + ScalarFunctionConstantStats constantStatsAnnotation = + annotated.getAnnotation(ScalarFunctionConstantStats.class); + List params = + Arrays.stream(annotated.getParameters()) + .filter(param -> param.getAnnotation(SqlType.class) != null) + .collect(Collectors.toList()); + // Map of (function argument position index) -> (ScalarPropagateSourceStats annotation) + ImmutableMap.Builder argumentIndexToStatsAnnotationMapBuilder = new ImmutableMap.Builder<>(); + + IntStream.range(0, params.size()) + .filter(paramIndex -> params.get(paramIndex).getAnnotation(ScalarPropagateSourceStats.class) != null) + .forEachOrdered(paramIndex -> argumentIndexToStatsAnnotationMapBuilder.put(paramIndex, + params.get(paramIndex).getAnnotation(ScalarPropagateSourceStats.class))); + + Map argumentIndexToStatsAnnotation = argumentIndexToStatsAnnotationMapBuilder.build(); + scalarStatsHeader = Optional.ofNullable(constantStatsAnnotation) + .map(statsAnnotation -> new ScalarStatsHeader(statsAnnotation, argumentIndexToStatsAnnotation)); + if (!argumentIndexToStatsAnnotation.isEmpty() && !scalarStatsHeader.isPresent()) { + scalarStatsHeader = Optional.of(new ScalarStatsHeader(argumentIndexToStatsAnnotation)); + } + return scalarStatsHeader; + } + private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods scalar, Optional> constructor) { ScalarImplementationHeader header = scalar.getHeader(); Map signatures = new HashMap<>(); + ImmutableMap.Builder signatureToStatsHeaderMapBuilder = new ImmutableMap.Builder<>(); for (Method method : scalar.getMethods()) { ParametricScalarImplementation implementation = ParametricScalarImplementation.Parser.parseImplementation(header, method, constructor); if (!signatures.containsKey(implementation.getSpecializedSignature())) { @@ -119,6 +155,8 @@ private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods sc ParametricScalarImplementation.Builder builder = signatures.get(implementation.getSpecializedSignature()); builder.addChoices(implementation); } + Optional scalarStatsHeader = getScalarStatsHeader(method); + scalarStatsHeader.ifPresent(statsHeader -> signatureToStatsHeaderMapBuilder.put(implementation.getSignature().canonicalization(), statsHeader)); } ParametricImplementationsGroup.Builder implementationsBuilder = ParametricImplementationsGroup.builder(); @@ -131,7 +169,11 @@ private static SqlScalarFunction parseParametricScalar(ScalarHeaderAndMethods sc header.getOperatorType().ifPresent(operatorType -> validateOperator(operatorType, scalarSignature.getReturnType(), scalarSignature.getArgumentTypes())); - return new ParametricScalar(scalarSignature, header.getHeader(), implementations); + ScalarHeader scalarHeader = header.getHeader(); + ScalarHeader headerWithStats = + new ScalarHeader(scalarHeader.getDescription(), scalarHeader.getVisibility(), scalarHeader.isDeterministic(), + scalarHeader.isCalledOnNullInput(), signatureToStatsHeaderMapBuilder.build()); + return new ParametricScalar(scalarSignature, headerWithStats, implementations); } private static class ScalarHeaderAndMethods diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index 88d1b3248cd7..8bd78582c184 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -293,7 +293,7 @@ public class FeaturesConfig private boolean removeCrossJoinWithSingleConstantRow = true; private CreateView.Security defaultViewSecurityMode = DEFINER; private boolean useHistograms; - + private boolean isScalarFunctionStatsPropagationEnabled; private boolean isInlineProjectionsOnValuesEnabled; private boolean eagerPlanValidationEnabled; @@ -2947,6 +2947,19 @@ public FeaturesConfig setRemoveCrossJoinWithSingleConstantRow(boolean removeCros return this; } + public boolean isScalarFunctionStatsPropagationEnabled() + { + return isScalarFunctionStatsPropagationEnabled; + } + + @Config("optimizer.scalar-function-stats-propagation-enabled") + @ConfigDescription("Respect scalar function statistics annotation for cost-based calculations in the optimizer") + public FeaturesConfig setScalarFunctionStatsPropagationEnabled(boolean scalarFunctionStatsPropagation) + { + this.isScalarFunctionStatsPropagationEnabled = scalarFunctionStatsPropagation; + return this; + } + public boolean isUseHistograms() { return useHistograms; diff --git a/presto-main/src/main/java/com/facebook/presto/util/MoreMath.java b/presto-main/src/main/java/com/facebook/presto/util/MoreMath.java index fd5beebb602b..ccdefea2c1bc 100644 --- a/presto-main/src/main/java/com/facebook/presto/util/MoreMath.java +++ b/presto-main/src/main/java/com/facebook/presto/util/MoreMath.java @@ -15,6 +15,8 @@ import java.util.stream.DoubleStream; +import static java.lang.Double.NaN; +import static java.lang.Double.isFinite; import static java.lang.Double.isNaN; public final class MoreMath @@ -79,6 +81,23 @@ public static double max(double... values) .getAsDouble(); } + /** + * Returns the minimum value of the arguments. Returns NaN if there are no arguments or all arguments are NaN. + */ + public static double minExcludingNaNs(double... values) + { + double min = NaN; + for (double v : values) { + if (isFinite(v)) { + if (isNaN(min)) { + min = v; + } + min = Math.min(min, v); + } + } + return min; + } + public static double rangeMin(double left, double right) { if (isNaN(left)) { diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsAnnotationProcessor.java b/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsAnnotationProcessor.java new file mode 100644 index 000000000000..a52365259c2f --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsAnnotationProcessor.java @@ -0,0 +1,311 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.cost; + +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.TypeSignature; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.metadata.BuiltInFunctionHandle; +import com.facebook.presto.spi.function.ScalarFunctionConstantStats; +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; +import com.facebook.presto.spi.function.ScalarStatsHeader; +import com.facebook.presto.spi.function.Signature; +import com.facebook.presto.spi.function.StatsPropagationBehavior; +import com.facebook.presto.spi.relation.CallExpression; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.VariableReferenceExpression; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.lang.annotation.Annotation; +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createUnboundedVarcharType; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.spi.function.FunctionKind.SCALAR; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.MAX_TYPE_WIDTH_VARCHAR; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.NON_NULL_ROW_COUNT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.ROW_COUNT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.SUM_ARGUMENTS; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.UNKNOWN; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_MAX_ARGUMENT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_MIN_ARGUMENT; +import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS; +import static java.lang.Double.NEGATIVE_INFINITY; +import static java.lang.Double.NaN; +import static java.lang.Double.POSITIVE_INFINITY; +import static org.testng.Assert.assertEquals; + +public class TestScalarStatsAnnotationProcessor +{ + private static final VariableStatsEstimate STATS_ESTIMATE_FINITE = VariableStatsEstimate.builder() + .setLowValue(1.0) + .setHighValue(120.0) + .setNullsFraction(0.1) + .setAverageRowSize(15.0) + .setDistinctValuesCount(23.0) + .build(); + private static final ScalarFunctionConstantStats CONSTANT_STATS_UNKNOWN = createScalarFunctionConstantStatsInstance(NEGATIVE_INFINITY, POSITIVE_INFINITY, NaN, NaN, NaN); + private static final VariableStatsEstimate STATS_ESTIMATE_UNKNOWN = VariableStatsEstimate.unknown(); + private static final List STATS_ESTIMATE_LIST = ImmutableList.of(STATS_ESTIMATE_FINITE, STATS_ESTIMATE_FINITE); + private static final List STATS_ESTIMATE_LIST_WITH_UNKNOWN = ImmutableList.of(STATS_ESTIMATE_FINITE, STATS_ESTIMATE_UNKNOWN); + private static final TypeSignature VARCHAR_TYPE_10 = createVarcharType(10).getTypeSignature(); + private static final List TWO_ARGUMENTS = ImmutableList.of( + new VariableReferenceExpression(Optional.empty(), "x", createVarcharType(10)), + new VariableReferenceExpression(Optional.empty(), "y", createVarcharType(10))); + + @Test + public void testProcessConstantStatsTakePrecedence() + { + Signature signature = new Signature(QualifiedObjectName.valueOf("presto.default.test"), SCALAR, VARCHAR_TYPE_10); + CallExpression callExpression = + new CallExpression("test", new BuiltInFunctionHandle(signature), createVarcharType(10), + ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "y", VARCHAR))); + ScalarStatsHeader scalarStatsHeader = new ScalarStatsHeader( + createScalarFunctionConstantStatsInstance(1, 10, 0.1, 2.3, 25), + ImmutableMap.of(1, createScalarPropagateSourceStatsInstance(true, UNKNOWN, ROW_COUNT, UNKNOWN, UNKNOWN, UNKNOWN))); + VariableStatsEstimate actualStats = ScalarStatsAnnotationProcessor.process(1000, callExpression, STATS_ESTIMATE_LIST, scalarStatsHeader); + VariableStatsEstimate expectedStats = VariableStatsEstimate.builder() + .setLowValue(1) + .setHighValue(10) + .setNullsFraction(0.1) + .setAverageRowSize(2.3) + .setDistinctValuesCount(25) + .build(); + assertEquals(actualStats, expectedStats); + } + + @Test + public void testProcessNaNSourceStats() + { + Signature signature = new Signature(QualifiedObjectName.valueOf("presto.default.test"), SCALAR, VARCHAR_TYPE_10, + VARCHAR_TYPE_10, VARCHAR_TYPE_10); + CallExpression callExpression = + new CallExpression("test", new BuiltInFunctionHandle(signature), createVarcharType(10), TWO_ARGUMENTS); + ScalarStatsHeader scalarStatsHeader = new ScalarStatsHeader( + CONSTANT_STATS_UNKNOWN, + ImmutableMap.of(1, createScalarPropagateSourceStatsInstance(true, USE_SOURCE_STATS, USE_MAX_ARGUMENT, SUM_ARGUMENTS, SUM_ARGUMENTS, NON_NULL_ROW_COUNT))); + VariableStatsEstimate actualStats = ScalarStatsAnnotationProcessor.process(1000, callExpression, STATS_ESTIMATE_LIST_WITH_UNKNOWN, scalarStatsHeader); + VariableStatsEstimate expectedStats = VariableStatsEstimate + .buildFrom(VariableStatsEstimate.unknown()) + .setDistinctValuesCount(1000) + .setAverageRowSize(10.0).build(); + assertEquals(actualStats, expectedStats); + } + + @Test + public void testProcessTypeWidthBoundaryConditions() + { + VariableStatsEstimate statsEstimateLarge = + VariableStatsEstimate.builder() + .setNullsFraction(0.0) + .setAverageRowSize(8.0) + .setDistinctValuesCount(Double.MAX_VALUE - 1) + .build(); + Signature signature = new Signature(QualifiedObjectName.valueOf("presto.default.test"), SCALAR, VARCHAR_TYPE_10, + createVarcharType(VarcharType.MAX_LENGTH).getTypeSignature(), createVarcharType(VarcharType.MAX_LENGTH).getTypeSignature()); + + List largeVarcharArguments = ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "x", createVarcharType(VarcharType.MAX_LENGTH)), + new VariableReferenceExpression(Optional.empty(), "y", createVarcharType(VarcharType.MAX_LENGTH))); + CallExpression callExpression = new CallExpression("test", new BuiltInFunctionHandle(signature), createUnboundedVarcharType(), largeVarcharArguments); + ScalarStatsHeader scalarStatsHeader = new ScalarStatsHeader( + CONSTANT_STATS_UNKNOWN, + ImmutableMap.of(1, createScalarPropagateSourceStatsInstance(false, USE_SOURCE_STATS, SUM_ARGUMENTS, SUM_ARGUMENTS, SUM_ARGUMENTS, MAX_TYPE_WIDTH_VARCHAR))); + VariableStatsEstimate actualStats = + ScalarStatsAnnotationProcessor.process(Double.MAX_VALUE - 1, callExpression, ImmutableList.of(statsEstimateLarge, statsEstimateLarge), scalarStatsHeader); + VariableStatsEstimate expectedStats = VariableStatsEstimate + .builder() + .setNullsFraction(0.0) + .setDistinctValuesCount(VarcharType.MAX_LENGTH) + .setAverageRowSize(16.0).build(); + assertEquals(actualStats, expectedStats); + } + + @Test + public void testProcessTypeWidthBoundaryConditions2() + { + VariableStatsEstimate statsEstimateLarge = + VariableStatsEstimate.builder() + .setLowValue(Double.MIN_VALUE) + .setHighValue(Double.MAX_VALUE) + .setNullsFraction(0.0) + .setAverageRowSize(8.0) + .setDistinctValuesCount(Double.MAX_VALUE - 1) + .build(); + Signature signature = new Signature(QualifiedObjectName.valueOf("presto.default.test"), SCALAR, VARCHAR_TYPE_10, + DoubleType.DOUBLE.getTypeSignature(), DoubleType.DOUBLE.getTypeSignature()); + + List doubleArguments = ImmutableList.of(new VariableReferenceExpression(Optional.empty(), "x", DoubleType.DOUBLE), + new VariableReferenceExpression(Optional.empty(), "y", DoubleType.DOUBLE)); + CallExpression callExpression = new CallExpression("test", new BuiltInFunctionHandle(signature), createUnboundedVarcharType(), doubleArguments); + ScalarStatsHeader scalarStatsHeader = new ScalarStatsHeader( + CONSTANT_STATS_UNKNOWN, + ImmutableMap.of(1, createScalarPropagateSourceStatsInstance(false, + USE_MIN_ARGUMENT, SUM_ARGUMENTS, SUM_ARGUMENTS, SUM_ARGUMENTS, + SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT))); + VariableStatsEstimate actualStats = + ScalarStatsAnnotationProcessor.process(Double.MAX_VALUE - 1, callExpression, ImmutableList.of(statsEstimateLarge, statsEstimateLarge), scalarStatsHeader); + VariableStatsEstimate expectedStats = VariableStatsEstimate + .builder() + .setLowValue(Double.MIN_VALUE) + .setHighValue(POSITIVE_INFINITY) + .setNullsFraction(0.0) + .setAverageRowSize(16.0) + .setDistinctValuesCount(Double.MAX_VALUE - 1).build(); + assertEquals(actualStats, expectedStats); + } + + @Test + public void testProcessConstantStats() + { + Signature signature = new Signature(QualifiedObjectName.valueOf("presto.default.test"), SCALAR, VARCHAR_TYPE_10); + CallExpression callExpression = + new CallExpression("test", new BuiltInFunctionHandle(signature), createVarcharType(10), ImmutableList.of()); + ScalarStatsHeader scalarStatsHeader = new ScalarStatsHeader( + createScalarFunctionConstantStatsInstance(0, 1, 0.1, 8, NON_NULL_ROW_COUNT.getValue()), + ImmutableMap.of()); + VariableStatsEstimate actualStats = ScalarStatsAnnotationProcessor.process(1000, callExpression, STATS_ESTIMATE_LIST, scalarStatsHeader); + VariableStatsEstimate expectedStats = VariableStatsEstimate.builder() + .setLowValue(0) + .setHighValue(1) + .setNullsFraction(0.1) + .setAverageRowSize(8.0) + .setDistinctValuesCount(900) + .build(); + assertEquals(actualStats, expectedStats); + } + + @Test + public void testProcessConstantNDVWithNullFractionFromArgumentStats() + { + Signature signature = new Signature(QualifiedObjectName.valueOf("presto.default.test"), SCALAR, VARCHAR_TYPE_10, VARCHAR_TYPE_10, VARCHAR_TYPE_10); + CallExpression callExpression = new CallExpression("test", new BuiltInFunctionHandle(signature), createVarcharType(10), TWO_ARGUMENTS); + ScalarStatsHeader scalarStatsHeader = new ScalarStatsHeader( + createScalarFunctionConstantStatsInstance(0, 1, NaN, NaN, NON_NULL_ROW_COUNT.getValue()), + ImmutableMap.of(0, createScalarPropagateSourceStatsInstance(false, UNKNOWN, UNKNOWN, SUM_ARGUMENTS, USE_SOURCE_STATS, UNKNOWN))); + VariableStatsEstimate actualStats = ScalarStatsAnnotationProcessor.process(1000, callExpression, STATS_ESTIMATE_LIST, scalarStatsHeader); + VariableStatsEstimate expectedStats = VariableStatsEstimate.builder() + .setLowValue(0) + .setHighValue(1) + .setNullsFraction(0.1) + .setAverageRowSize(10) + .setDistinctValuesCount(900) + .build(); + assertEquals(actualStats, expectedStats); + } + + private static ScalarFunctionConstantStats createScalarFunctionConstantStatsInstance( + double min, double max, double nullFraction, double avgRowSize, + double distinctValuesCount) + { + return new ScalarFunctionConstantStats() + { + @Override + public Class annotationType() + { + return ScalarFunctionConstantStats.class; + } + + @Override + public double minValue() + { + return min; + } + + @Override + public double maxValue() + { + return max; + } + + @Override + public double distinctValuesCount() + { + return distinctValuesCount; + } + + @Override + public double nullFraction() + { + return nullFraction; + } + + @Override + public double avgRowSize() + { + return avgRowSize; + } + }; + } + + private ScalarPropagateSourceStats createScalarPropagateSourceStatsInstance( + Boolean propagateAllStats, + StatsPropagationBehavior minValue, + StatsPropagationBehavior maxValue, + StatsPropagationBehavior avgRowSize, + StatsPropagationBehavior nullFraction, + StatsPropagationBehavior distinctValuesCount) + { + return new ScalarPropagateSourceStats() + { + @Override + public Class annotationType() + { + return ScalarPropagateSourceStats.class; + } + + @Override + public boolean propagateAllStats() + { + return propagateAllStats; + } + + @Override + public StatsPropagationBehavior minValue() + { + return minValue; + } + + @Override + public StatsPropagationBehavior maxValue() + { + return maxValue; + } + + @Override + public StatsPropagationBehavior distinctValuesCount() + { + return distinctValuesCount; + } + + @Override + public StatsPropagationBehavior avgRowSize() + { + return avgRowSize; + } + + @Override + public StatsPropagationBehavior nullFraction() + { + return nullFraction; + } + }; + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java b/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java index c7951ffbeebb..8b39de9c8345 100644 --- a/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java +++ b/presto-main/src/test/java/com/facebook/presto/cost/TestScalarStatsCalculator.java @@ -14,8 +14,18 @@ package com.facebook.presto.cost; import com.facebook.presto.Session; +import com.facebook.presto.common.type.StandardTypes; import com.facebook.presto.common.type.Type; +import com.facebook.presto.metadata.FunctionListBuilder; import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.spi.function.LiteralParameters; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionConstantStats; +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; +import com.facebook.presto.spi.function.SqlFunction; +import com.facebook.presto.spi.function.SqlNullable; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.spi.function.StatsPropagationBehavior; import com.facebook.presto.spi.relation.RowExpression; import com.facebook.presto.spi.relation.VariableReferenceExpression; import com.facebook.presto.sql.TestingRowExpressionTranslator; @@ -34,16 +44,20 @@ import com.facebook.presto.sql.tree.SymbolReference; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.airlift.slice.Slice; import io.airlift.slice.Slices; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; +import java.util.List; import java.util.Map; import java.util.Optional; +import static com.facebook.presto.SystemSessionProperties.SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED; import static com.facebook.presto.common.type.BigintType.BIGINT; import static com.facebook.presto.common.type.DoubleType.DOUBLE; import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; import static com.facebook.presto.common.type.VarcharType.createVarcharType; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; import static com.facebook.presto.sql.ExpressionUtils.rewriteIdentifiersToSymbolReferences; @@ -51,18 +65,57 @@ import static java.lang.Double.NEGATIVE_INFINITY; import static java.lang.Double.POSITIVE_INFINITY; import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertThrows; public class TestScalarStatsCalculator { + public static final Map SESSION_CONFIG = ImmutableMap.of(SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED, "true"); private static final Map DEFAULT_SYMBOL_TYPES = ImmutableMap.of( "a", BIGINT, "x", BIGINT, "y", BIGINT, "all_null", BIGINT); - + private static final PlanNodeStatsEstimate TWO_ARGUMENTS_BIGINT_SOURCE_STATS = PlanNodeStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "x", BIGINT), VariableStatsEstimate.builder() + .setLowValue(-1) + .setHighValue(10) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .build()) + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "y", BIGINT), VariableStatsEstimate.builder() + .setLowValue(-2) + .setHighValue(5) + .setDistinctValuesCount(3) + .setNullsFraction(0.2) + .build()) + .setOutputRowCount(10) + .build(); + + private static final PlanNodeStatsEstimate BIGINT_SOURCE_STATS = PlanNodeStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "x", BIGINT), + VariableStatsEstimate.builder() + .setLowValue(-2) + .setHighValue(5) + .setDistinctValuesCount(3) + .setNullsFraction(0.2) + .build()) + .setOutputRowCount(10) + .build(); + private static final PlanNodeStatsEstimate VARCHAR_SOURCE_STATS_10_ROWS = PlanNodeStatsEstimate.builder() + .addVariableStatistics( + new VariableReferenceExpression(Optional.empty(), "x", createVarcharType(20)), + VariableStatsEstimate.builder() + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .setAverageRowSize(14) + .build()) + .setOutputRowCount(10) + .build(); + + private static final Map TWO_ARGUMENTS_BIGINT_NAME_TO_TYPE_MAP = ImmutableMap.of("x", BIGINT, "y", BIGINT); + private final SqlParser sqlParser = new SqlParser(); private ScalarStatsCalculator calculator; private Session session; - private final SqlParser sqlParser = new SqlParser(); private TestingRowExpressionTranslator translator; @BeforeClass @@ -73,6 +126,235 @@ public void setUp() translator = new TestingRowExpressionTranslator(MetadataManager.createTestMetadataManager()); } + @Test + public void testStatsPropagationForCustomAdd() + { + assertCalculate(SESSION_CONFIG, + expression("custom_add(x, y)"), + TWO_ARGUMENTS_BIGINT_SOURCE_STATS, + TypeProvider.viewOf(TWO_ARGUMENTS_BIGINT_NAME_TO_TYPE_MAP)) + .distinctValuesCount(7) + .lowValue(-3) + .highValue(15) + .nullsFraction(0.3) + .averageRowSize(8.0); + } + + @Test + public void testStatsPropagationForUnknownSourceStats() + { + PlanNodeStatsEstimate statsWithUnknowns = PlanNodeStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "x", BIGINT), VariableStatsEstimate.builder() + .setLowValue(-1) + .setHighValue(10) + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .build()) + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "y", BIGINT), + VariableStatsEstimate.unknown()) + .setOutputRowCount(10) + .build(); + assertCalculate(SESSION_CONFIG, + expression("custom_add(x, y)"), + statsWithUnknowns, + TypeProvider.viewOf(TWO_ARGUMENTS_BIGINT_NAME_TO_TYPE_MAP)) + .lowValueUnknown() + .highValueUnknown() + .distinctValuesCountUnknown() + .nullsFractionUnknown() + .averageRowSize(8.0); + + PlanNodeStatsEstimate varcharStatsUnknown = PlanNodeStatsEstimate.buildFrom(PlanNodeStatsEstimate.unknown()) + .setOutputRowCount(10) + .build(); + assertCalculate(SESSION_CONFIG, + expression("custom_str_len(x)"), + varcharStatsUnknown, + TypeProvider.viewOf(ImmutableMap.of("x", createVarcharType(20)))) + .lowValue(0.0) + .highValue(20.0) + .distinctValuesCountUnknown() + .nullsFractionUnknown() + .averageRowSize(8.0); + } + + @Test + public void testStatsPropagationForCustomStrLen() + { + PlanNodeStatsEstimate varcharStats100Rows = PlanNodeStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "x", createVarcharType(20)), VariableStatsEstimate.builder() + .setDistinctValuesCount(4) + .setNullsFraction(0.1) + .setAverageRowSize(14) + .build()) + .setOutputRowCount(100) + .build(); + + assertCalculate(SESSION_CONFIG, + expression("custom_str_len(x)"), + varcharStats100Rows, + TypeProvider.viewOf(ImmutableMap.of("x", createVarcharType(20)))) + .distinctValuesCount(20.0) + .lowValue(0.0) + .highValue(20.0) + .nullsFraction(0.1) + .averageRowSize(8.0); + assertCalculate(SESSION_CONFIG, + expression("custom_str_len(x)"), + VARCHAR_SOURCE_STATS_10_ROWS, + TypeProvider.viewOf(ImmutableMap.of("x", createVarcharType(20)))) + .lowValue(0.0) + .highValue(20.0) + .distinctValuesCountUnknown() // When computed NDV is > output row count, it is set to unknown. + .nullsFraction(0.1) + .averageRowSize(8.0); + } + + @Test + public void testStatsPropagationForCustomPrng() + { + assertCalculate(SESSION_CONFIG, + expression("custom_prng(x, y)"), + TWO_ARGUMENTS_BIGINT_SOURCE_STATS, + TypeProvider.viewOf(TWO_ARGUMENTS_BIGINT_NAME_TO_TYPE_MAP)) + .lowValue(-1) + .highValue(5) + .distinctValuesCount(10) + .nullsFraction(0.0) + .averageRowSize(8.0); + } + + @Test + public void testStatsPropagationForCustomStringEditDistance() + { + PlanNodeStatsEstimate.Builder sourceStatsBuilder = PlanNodeStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "x", createVarcharType(10)), + VariableStatsEstimate.builder() + .setDistinctValuesCount(4) + .setNullsFraction(0.213) + .setAverageRowSize(9.44) + .build()) + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "y", createVarcharType(20)), + VariableStatsEstimate.builder() + .setDistinctValuesCount(6) + .setNullsFraction(0.4) + .setAverageRowSize(19.333) + .build()); + PlanNodeStatsEstimate sourceStats10Rows = sourceStatsBuilder.setOutputRowCount(10).build(); + PlanNodeStatsEstimate sourceStats100Rows = sourceStatsBuilder.setOutputRowCount(100).build(); + Map referenceNameToVarcharType = ImmutableMap.of("x", createVarcharType(10), "y", createVarcharType(20)); + Map referenceNameToUnboundedVarcharType = ImmutableMap.of("x", createVarcharType(10), "y", VARCHAR); + assertCalculate(SESSION_CONFIG, + expression("custom_str_edit_distance(x, y)"), + sourceStats10Rows, + TypeProvider.viewOf(referenceNameToVarcharType)) + .lowValue(0) + .highValue(20) + .distinctValuesCountUnknown() + .nullsFraction(0.4) + .averageRowSize(8.0); + assertCalculate(SESSION_CONFIG, + expression("custom_str_edit_distance(x, y)"), + sourceStats100Rows, + TypeProvider.viewOf(referenceNameToVarcharType)) + .distinctValuesCount(20) + .lowValue(0) + .highValue(20) + .nullsFraction(0.4) + .averageRowSize(8.0); + assertCalculate(SESSION_CONFIG, + expression("custom_str_edit_distance(x, y)"), + sourceStats100Rows, + TypeProvider.viewOf(referenceNameToUnboundedVarcharType)) + .lowValue(0) + .highValueUnknown() + .distinctValuesCountUnknown() + .nullsFractionUnknown() + .averageRowSize(8.0); + } + + @Test + public void testStatsPropagationForCustomIsNull() + { + assertCalculate(SESSION_CONFIG, + expression("custom_is_null(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT))) + .lowValueUnknown() + .highValueUnknown() + .distinctValuesCount(3.19) + .nullsFraction(0.0) + .averageRowSize(1.0); + assertCalculate(SESSION_CONFIG, + expression("custom_is_null(x)"), + VARCHAR_SOURCE_STATS_10_ROWS, + TypeProvider.viewOf(ImmutableMap.of("x", createVarcharType(10)))) + .lowValueUnknown() + .highValueUnknown() + .distinctValuesCount(2.0) + .nullsFraction(0.0) + .averageRowSize(1.0); + } + + @Test + public void testConstantStatsBoundaryConditions() + { + assertThrows(IllegalArgumentException.class, () -> assertCalculate(SESSION_CONFIG, + expression("custom_is_null2(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT)))); + assertThrows(IllegalArgumentException.class, () -> assertCalculate(SESSION_CONFIG, + expression("custom_is_null3(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT)))); + assertThrows(IllegalArgumentException.class, () -> assertCalculate(SESSION_CONFIG, + expression("custom_is_null4(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT)))); + assertThrows(IllegalArgumentException.class, () -> assertCalculate(SESSION_CONFIG, + expression("custom_is_null5(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT)))); + assertThrows(IllegalArgumentException.class, () -> assertCalculate(SESSION_CONFIG, + expression("custom_is_null6(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT)))); + assertCalculate(SESSION_CONFIG, + expression("custom_is_null7(x)"), + BIGINT_SOURCE_STATS, + TypeProvider.viewOf(ImmutableMap.of("x", BIGINT))) + .averageRowSize(8.0); + } + + @Test + public void testStatsPropagationForSourceStatsBoundaryConditions() + { + PlanNodeStatsEstimate sourceStats = PlanNodeStatsEstimate.builder() + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "x", BIGINT), VariableStatsEstimate.builder() + .setLowValue(NEGATIVE_INFINITY) + .setHighValue(-7) + .setDistinctValuesCount(10) + .setNullsFraction(0.1) + .build()) + .addVariableStatistics(new VariableReferenceExpression(Optional.empty(), "y", BIGINT), VariableStatsEstimate.builder() + .setLowValue(-2) + .setHighValue(-1) + .setDistinctValuesCount(10) + .setNullsFraction(1.0) + .build()) + .setOutputRowCount(10) + .build(); + assertCalculate(SESSION_CONFIG, + expression("custom_add(x, y)"), + sourceStats, + TypeProvider.viewOf(TWO_ARGUMENTS_BIGINT_NAME_TO_TYPE_MAP)) + .distinctValuesCount(10) + .lowValueUnknown() + .highValue(-8) + .nullsFraction(1.0) + .averageRowSize(8.0); + } + @Test public void testLiteral() { @@ -524,8 +806,163 @@ public void testCoalesceExpression() .averageRowSize(2.0); } + private VariableStatsAssertion assertCalculate( + Map sessionConfigs, + Expression scalarExpression, + PlanNodeStatsEstimate inputStatistics, + TypeProvider types) + { + MetadataManager metadata = createTestMetadataManager(); + List functions = new FunctionListBuilder() + .scalars(CustomFunctions.class) + .getFunctions(); + Session.SessionBuilder sessionBuilder = testSessionBuilder(); + for (Map.Entry entry : sessionConfigs.entrySet()) { + sessionBuilder.setSystemProperty(entry.getKey(), entry.getValue()); + } + Session session1 = sessionBuilder.build(); + metadata.getFunctionAndTypeManager().registerBuiltInFunctions(functions); + ScalarStatsCalculator statsCalculator = new ScalarStatsCalculator(metadata); + TestingRowExpressionTranslator translator = new TestingRowExpressionTranslator(metadata); + RowExpression scalarRowExpression = translator.translate(scalarExpression, types); + VariableStatsEstimate rowExpressionVariableStatsEstimate = statsCalculator.calculate(scalarRowExpression, inputStatistics, session1); + return VariableStatsAssertion.assertThat(rowExpressionVariableStatsEstimate); + } + private Expression expression(String sqlExpression) { return rewriteIdentifiersToSymbolReferences(sqlParser.createExpression(sqlExpression)); } + + public static final class CustomFunctions + { + private CustomFunctions() {} + + @ScalarFunction(value = "custom_add", calledOnNullInput = false) + @ScalarFunctionConstantStats(avgRowSize = 8.0) + @SqlType(StandardTypes.BIGINT) + public static long customAdd( + @ScalarPropagateSourceStats( + propagateAllStats = false, + nullFraction = StatsPropagationBehavior.SUM_ARGUMENTS, + distinctValuesCount = StatsPropagationBehavior.SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT, + minValue = StatsPropagationBehavior.SUM_ARGUMENTS, + maxValue = StatsPropagationBehavior.SUM_ARGUMENTS) @SqlType(StandardTypes.BIGINT) long x, + @SqlType(StandardTypes.BIGINT) long y) + { + return x + y; + } + + @ScalarFunction(value = "custom_is_null", calledOnNullInput = true) + @LiteralParameters("x") + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(distinctValuesCount = 2.0, nullFraction = 0.0) + public static boolean customIsNullVarchar(@SqlNullable @SqlType("varchar(x)") Slice slice) + { + return slice == null; + } + + @ScalarFunction(value = "custom_is_null", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(distinctValuesCount = 3.19, nullFraction = 0.0) + public static boolean customIsNullBigint(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + + @ScalarFunction(value = "custom_str_len") + @SqlType(StandardTypes.BIGINT) + @LiteralParameters("x") + @ScalarFunctionConstantStats(minValue = 0) + public static long customStrLength( + @ScalarPropagateSourceStats( + propagateAllStats = false, + nullFraction = StatsPropagationBehavior.USE_SOURCE_STATS, + distinctValuesCount = StatsPropagationBehavior.USE_TYPE_WIDTH_VARCHAR, + maxValue = StatsPropagationBehavior.USE_TYPE_WIDTH_VARCHAR) @SqlType("varchar(x)") Slice value) + { + return value.length(); + } + + @ScalarFunction(value = "custom_str_edit_distance") + @SqlType(StandardTypes.BIGINT) + @LiteralParameters({"x", "y"}) + @ScalarFunctionConstantStats(minValue = 0) + public static long customStrEditDistance( + @ScalarPropagateSourceStats( + propagateAllStats = false, + nullFraction = StatsPropagationBehavior.USE_MAX_ARGUMENT, + distinctValuesCount = StatsPropagationBehavior.MAX_TYPE_WIDTH_VARCHAR, + maxValue = StatsPropagationBehavior.MAX_TYPE_WIDTH_VARCHAR) @SqlType("varchar(x)") Slice str1, + @SqlType("varchar(y)") Slice str2) + { + return 100; + } + + @ScalarFunction(value = "custom_prng", calledOnNullInput = true) + @SqlType(StandardTypes.BIGINT) + @LiteralParameters("x") + @ScalarFunctionConstantStats(nullFraction = 0) + public static long customPrng( + @SqlNullable + @ScalarPropagateSourceStats( + propagateAllStats = false, + distinctValuesCount = StatsPropagationBehavior.ROW_COUNT, + minValue = StatsPropagationBehavior.USE_SOURCE_STATS) @SqlType(StandardTypes.BIGINT) Long min, + @ScalarPropagateSourceStats( + propagateAllStats = false, + maxValue = StatsPropagationBehavior.USE_SOURCE_STATS + ) @SqlNullable @SqlType(StandardTypes.BIGINT) Long max) + { + return (long) ((Math.random() * (max - min)) + min); + } + + @ScalarFunction(value = "custom_is_null2", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(distinctValuesCount = -3.19, nullFraction = 0.0) + public static boolean customIsNullBigint2(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + + @ScalarFunction(value = "custom_is_null3", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(minValue = -3.19, maxValue = -6.19, nullFraction = 0.0) + public static boolean customIsNullBigint3(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + + @ScalarFunction(value = "custom_is_null4", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(nullFraction = 1.1) + public static boolean customIsNullBigint4(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + + @ScalarFunction(value = "custom_is_null5", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(nullFraction = -1) + public static boolean customIsNullBigint5(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + + @ScalarFunction(value = "custom_is_null6", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(avgRowSize = -1) + public static boolean customIsNullBigint6(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + + @ScalarFunction(value = "custom_is_null7", calledOnNullInput = true) + @SqlType(StandardTypes.BOOLEAN) + @ScalarFunctionConstantStats(avgRowSize = 8) + public static boolean customIsNullBigint7(@SqlNullable @SqlType(StandardTypes.BIGINT) Long value) + { + return value == null; + } + } } diff --git a/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStatsAnnotationScalarFunctions.java b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStatsAnnotationScalarFunctions.java new file mode 100644 index 000000000000..a44037772282 --- /dev/null +++ b/presto-main/src/test/java/com/facebook/presto/operator/scalar/TestStatsAnnotationScalarFunctions.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.operator.scalar; + +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.metadata.SqlScalarFunction; +import com.facebook.presto.operator.scalar.annotations.ScalarFromAnnotationsParser; +import com.facebook.presto.spi.function.Description; +import com.facebook.presto.spi.function.LiteralParameters; +import com.facebook.presto.spi.function.ScalarFunction; +import com.facebook.presto.spi.function.ScalarFunctionConstantStats; +import com.facebook.presto.spi.function.ScalarPropagateSourceStats; +import com.facebook.presto.spi.function.ScalarStatsHeader; +import com.facebook.presto.spi.function.Signature; +import com.facebook.presto.spi.function.SqlType; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import io.airlift.slice.Slice; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Map; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +public class TestStatsAnnotationScalarFunctions + extends AbstractTestFunctions +{ + public TestStatsAnnotationScalarFunctions() + { + } + + protected TestStatsAnnotationScalarFunctions(FeaturesConfig config) + { + super(config); + } + + @Description("Functions with stats annotation") + public static class TestScalarFunction + { + @SqlType(StandardTypes.BOOLEAN) + @LiteralParameters("x") + @ScalarFunction + @ScalarFunctionConstantStats(avgRowSize = 2) + public static boolean fun1(@SqlType("varchar(x)") Slice slice) + { + return true; + } + + @SqlType(StandardTypes.BOOLEAN) + @LiteralParameters("x") + @ScalarFunction + @ScalarFunctionConstantStats(avgRowSize = 2, distinctValuesCount = 2.0) + public static boolean fun2( + @ScalarPropagateSourceStats(propagateAllStats = true) @SqlType(StandardTypes.BIGINT) Slice slice) + { + return true; + } + } + + @Test + public void testAnnotations() + { + List sqlScalarFunctions = ScalarFromAnnotationsParser.parseFunctionDefinitions(TestScalarFunction.class); + assertEquals(sqlScalarFunctions.size(), 2); + for (SqlScalarFunction function : sqlScalarFunctions) { + assertTrue(function instanceof ParametricScalar); + ParametricScalar parametricScalar = (ParametricScalar) function; + Signature signature = parametricScalar.getSignature().canonicalization(); + Map scalarStatsHeaderMap = parametricScalar.getScalarHeader().getSignatureToScalarStatsHeadersMap(); + ScalarStatsHeader scalarStatsHeader = scalarStatsHeaderMap.get(signature); + assertEquals(scalarStatsHeader.getAvgRowSize(), 2); + if (function.getSignature().getName().toString().equals("fun2")) { + assertEquals(scalarStatsHeader.getDistinctValuesCount(), 2); + Map argumentStatsActual = scalarStatsHeader.getArgumentStats(); + assertEquals(argumentStatsActual.size(), 1); + assertTrue(argumentStatsActual.get(0).propagateAllStats()); + } + } + } +} diff --git a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index ebe971197a36..c0bf37ab12ee 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -257,7 +257,8 @@ public void testDefaults() .setUseHistograms(false) .setInlineProjectionsOnValues(false) .setEagerPlanValidationEnabled(false) - .setEagerPlanValidationThreadPoolSize(20)); + .setEagerPlanValidationThreadPoolSize(20) + .setScalarFunctionStatsPropagationEnabled(false)); } @Test @@ -464,6 +465,7 @@ public void testExplicitPropertyMappings() .put("optimizer.inline-projections-on-values", "true") .put("eager-plan-validation-enabled", "true") .put("eager-plan-validation-thread-pool-size", "2") + .put("optimizer.scalar-function-stats-propagation-enabled", "true") .build(); FeaturesConfig expected = new FeaturesConfig() @@ -667,7 +669,8 @@ public void testExplicitPropertyMappings() .setUseHistograms(true) .setInlineProjectionsOnValues(true) .setEagerPlanValidationEnabled(true) - .setEagerPlanValidationThreadPoolSize(2); + .setEagerPlanValidationThreadPoolSize(2) + .setScalarFunctionStatsPropagationEnabled(true); assertFullMapping(properties, expected); } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadata.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadata.java index a82a744b32fb..bc769a45833d 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadata.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/FunctionMetadata.java @@ -20,11 +20,13 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import static com.facebook.presto.spi.function.ComplexTypeFunctionDescriptor.defaultFunctionDescriptor; import static com.facebook.presto.spi.function.FunctionVersion.notVersioned; +import static java.util.Collections.emptyMap; import static java.util.Collections.unmodifiableList; import static java.util.Objects.requireNonNull; @@ -42,6 +44,7 @@ public class FunctionMetadata private final boolean calledOnNullInput; private final FunctionVersion version; private final ComplexTypeFunctionDescriptor descriptor; + private final Map signatureToScalarStatsHeaders; public FunctionMetadata( QualifiedObjectName name, @@ -52,7 +55,22 @@ public FunctionMetadata( boolean deterministic, boolean calledOnNullInput) { - this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, deterministic, calledOnNullInput, notVersioned()); + this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, + deterministic, calledOnNullInput, notVersioned(), emptyMap()); + } + + public FunctionMetadata( + QualifiedObjectName name, + List argumentTypes, + TypeSignature returnType, + FunctionKind functionKind, + FunctionImplementationType implementationType, + boolean deterministic, + boolean calledOnNullInput, + Map signatureToScalarStatsHeaders) + { + this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, + deterministic, calledOnNullInput, notVersioned(), signatureToScalarStatsHeaders); } public FunctionMetadata( @@ -65,7 +83,23 @@ public FunctionMetadata( boolean calledOnNullInput, ComplexTypeFunctionDescriptor functionDescriptor) { - this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, deterministic, calledOnNullInput, notVersioned(), functionDescriptor); + this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, + deterministic, calledOnNullInput, notVersioned(), functionDescriptor, emptyMap()); + } + + public FunctionMetadata( + QualifiedObjectName name, + List argumentTypes, + TypeSignature returnType, + FunctionKind functionKind, + FunctionImplementationType implementationType, + boolean deterministic, + boolean calledOnNullInput, + ComplexTypeFunctionDescriptor functionDescriptor, + Map signatureToScalarStatsHeaders) + { + this(name, Optional.empty(), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, + deterministic, calledOnNullInput, notVersioned(), functionDescriptor, signatureToScalarStatsHeaders); } public FunctionMetadata( @@ -80,7 +114,8 @@ public FunctionMetadata( boolean calledOnNullInput, FunctionVersion version) { - this(name, Optional.empty(), argumentTypes, Optional.of(argumentNames), returnType, functionKind, Optional.of(language), implementationType, deterministic, calledOnNullInput, version); + this(name, Optional.empty(), argumentTypes, Optional.of(argumentNames), returnType, functionKind, Optional.of(language), implementationType, deterministic, + calledOnNullInput, version, emptyMap()); } public FunctionMetadata( @@ -97,7 +132,25 @@ public FunctionMetadata( ComplexTypeFunctionDescriptor functionDescriptor) { this(name, Optional.empty(), argumentTypes, Optional.of(argumentNames), returnType, functionKind, Optional.of(language), implementationType, deterministic, - calledOnNullInput, version, functionDescriptor); + calledOnNullInput, version, functionDescriptor, emptyMap()); + } + + public FunctionMetadata( + QualifiedObjectName name, + List argumentTypes, + List argumentNames, + TypeSignature returnType, + FunctionKind functionKind, + Language language, + FunctionImplementationType implementationType, + boolean deterministic, + boolean calledOnNullInput, + FunctionVersion version, + ComplexTypeFunctionDescriptor functionDescriptor, + Map signatureToScalarStatsHeaders) + { + this(name, Optional.empty(), argumentTypes, Optional.of(argumentNames), returnType, functionKind, Optional.of(language), implementationType, deterministic, + calledOnNullInput, version, functionDescriptor, signatureToScalarStatsHeaders); } public FunctionMetadata( @@ -109,7 +162,8 @@ public FunctionMetadata( boolean deterministic, boolean calledOnNullInput) { - this(operatorType.getFunctionName(), Optional.of(operatorType), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, deterministic, calledOnNullInput, notVersioned()); + this(operatorType.getFunctionName(), Optional.of(operatorType), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), + implementationType, deterministic, calledOnNullInput, notVersioned(), emptyMap()); } public FunctionMetadata( @@ -122,7 +176,8 @@ public FunctionMetadata( boolean calledOnNullInput, ComplexTypeFunctionDescriptor functionDescriptor) { - this(operatorType.getFunctionName(), Optional.of(operatorType), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), implementationType, deterministic, calledOnNullInput, notVersioned(), functionDescriptor); + this(operatorType.getFunctionName(), Optional.of(operatorType), argumentTypes, Optional.empty(), returnType, functionKind, Optional.empty(), + implementationType, deterministic, calledOnNullInput, notVersioned(), functionDescriptor, emptyMap()); } private FunctionMetadata( @@ -136,7 +191,8 @@ private FunctionMetadata( FunctionImplementationType implementationType, boolean deterministic, boolean calledOnNullInput, - FunctionVersion version) + FunctionVersion version, + Map signatureToScalarStatsHeaders) { this( name, @@ -150,7 +206,8 @@ private FunctionMetadata( deterministic, calledOnNullInput, version, - defaultFunctionDescriptor()); + defaultFunctionDescriptor(), + signatureToScalarStatsHeaders); } private FunctionMetadata( @@ -165,7 +222,8 @@ private FunctionMetadata( boolean deterministic, boolean calledOnNullInput, FunctionVersion version, - ComplexTypeFunctionDescriptor functionDescriptor) + ComplexTypeFunctionDescriptor functionDescriptor, + Map signatureToScalarStatsHeaders) { this.name = requireNonNull(name, "name is null"); this.operatorType = requireNonNull(operatorType, "operatorType is null"); @@ -185,7 +243,9 @@ private FunctionMetadata( functionDescriptor.getArgumentIndicesContainingMapOrArray(), functionDescriptor.getOutputToInputTransformationFunction(), argumentTypes); + this.signatureToScalarStatsHeaders = signatureToScalarStatsHeaders; } + public FunctionKind getFunctionKind() { return functionKind; @@ -246,6 +306,16 @@ public ComplexTypeFunctionDescriptor getDescriptor() return descriptor; } + public boolean hasStatsHeader() + { + return !signatureToScalarStatsHeaders.isEmpty(); + } + + public Optional getScalarStatsHeader(Signature signature) + { + return Optional.ofNullable(signatureToScalarStatsHeaders.get(signature)); + } + @Override public boolean equals(Object obj) { @@ -267,12 +337,14 @@ public boolean equals(Object obj) Objects.equals(this.deterministic, other.deterministic) && Objects.equals(this.calledOnNullInput, other.calledOnNullInput) && Objects.equals(this.version, other.version) && - Objects.equals(this.descriptor, other.descriptor); + Objects.equals(this.descriptor, other.descriptor) && + Objects.equals(this.signatureToScalarStatsHeaders, other.signatureToScalarStatsHeaders); } @Override public int hashCode() { - return Objects.hash(name, operatorType, argumentTypes, argumentNames, returnType, functionKind, language, implementationType, deterministic, calledOnNullInput, version, descriptor); + return Objects.hash(name, operatorType, argumentTypes, argumentNames, returnType, functionKind, language, implementationType, deterministic, calledOnNullInput, version, + descriptor, signatureToScalarStatsHeaders); } } diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionConstantStats.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionConstantStats.java new file mode 100644 index 000000000000..19dc8dc1828a --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarFunctionConstantStats.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.presto.spi.function; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +/** + * By default, a function is just a “black box” that the database system knows very little about the behavior of. + * However, that means that queries using the function may be executed much less efficiently than they could be. + * It is possible to supply additional knowledge that helps the planner optimize function calls. + * Scalar functions are straight forward to optimize and can have impact on the overall query performance. + * Use this annotation to provide information regarding how this function impacts following query statistics. + *

+ * A function may take one or more input column or a constant as parameters. Precise stats may depend on the input + * parameters. This annotation does not cover all the possible cases and allows constant values for the following fields. + * Value Double.NaN implies unknown. + *

+ */ +@Retention(RUNTIME) +@Target(METHOD) +public @interface ScalarFunctionConstantStats +{ + // Min max value is Infinity if unknown. + double minValue() default Double.NEGATIVE_INFINITY; + double maxValue() default Double.POSITIVE_INFINITY; + + /** + * A constant value for Distinct values count regardless of `input column`'s source stats. + * e.g. a perfectly random generator may result in distinctValuesCount of `ScalarFunctionStatsUtils.ROW_COUNT`. + */ + double distinctValuesCount() default Double.NaN; + + /** + * A constant value for nullFraction, e.g. is_null(Slice) will alter column's null fraction + * value to 0.0, regardless of input column's source stats. + */ + double nullFraction() default Double.NaN; + + /** + * A constant value for `avgRowSize` e.g. a function like md5 may produce a + * constant row size. + */ + double avgRowSize() default Double.NaN; +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarPropagateSourceStats.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarPropagateSourceStats.java new file mode 100644 index 000000000000..7341418e74ee --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarPropagateSourceStats.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target(PARAMETER) +public @interface ScalarPropagateSourceStats +{ + boolean propagateAllStats() default true; + + StatsPropagationBehavior minValue() default StatsPropagationBehavior.UNKNOWN; + StatsPropagationBehavior maxValue() default StatsPropagationBehavior.UNKNOWN; + StatsPropagationBehavior distinctValuesCount() default StatsPropagationBehavior.UNKNOWN; + StatsPropagationBehavior avgRowSize() default StatsPropagationBehavior.UNKNOWN; + StatsPropagationBehavior nullFraction() default StatsPropagationBehavior.UNKNOWN; +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarStatsHeader.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarStatsHeader.java new file mode 100644 index 000000000000..9c65b4cbb0fb --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/ScalarStatsHeader.java @@ -0,0 +1,97 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import java.util.Map; + +public class ScalarStatsHeader +{ + private final Map argumentStatsResolver; + private final double min; + private final double max; + private final double distinctValuesCount; + private final double nullFraction; + private final double avgRowSize; + + private ScalarStatsHeader(Map argumentStatsResolver, + double min, + double max, + double distinctValuesCount, + double nullFraction, + double avgRowSize) + { + this.min = min; + this.max = max; + this.argumentStatsResolver = argumentStatsResolver; + this.distinctValuesCount = distinctValuesCount; + this.nullFraction = nullFraction; + this.avgRowSize = avgRowSize; + } + + public ScalarStatsHeader(ScalarFunctionConstantStats methodConstantStats, Map argumentStatsResolver) + { + this(argumentStatsResolver, + methodConstantStats.minValue(), + methodConstantStats.maxValue(), + methodConstantStats.distinctValuesCount(), + methodConstantStats.nullFraction(), + methodConstantStats.avgRowSize()); + } + + public ScalarStatsHeader(Map argumentStatsResolver) + { + this(argumentStatsResolver, Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, Double.NaN, Double.NaN, Double.NaN); + } + + @Override + public String toString() + { + return String.format("distinctValuesCount: %g , nullFraction: %g, avgRowSize: %g, min: %g, max: %g", + distinctValuesCount, nullFraction, avgRowSize, min, max); + } + + /* + * Get stats annotation for each of the scalar function argument, where key is the index of the position + * of functions' argument and value is the ScalarPropagateSourceStats annotation. + */ + public Map getArgumentStats() + { + return argumentStatsResolver; + } + + public double getMin() + { + return min; + } + + public double getMax() + { + return max; + } + + public double getAvgRowSize() + { + return avgRowSize; + } + + public double getNullFraction() + { + return nullFraction; + } + + public double getDistinctValuesCount() + { + return distinctValuesCount; + } +} diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java index 024ba43035bd..d39d08033ccd 100644 --- a/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/Signature.java @@ -168,6 +168,17 @@ public String toString() "(" + String.join(",", argumentTypes.stream().map(TypeSignature::toString).collect(toList())) + "):" + returnType; } + /* + * Canonical (normalized i.e. erased type size bounds) form of signature instance. + */ + public Signature canonicalization() + { + return new Signature(this.name, this.kind, new TypeSignature(this.returnType.getBase(), emptyList()), + argumentTypes + .stream() + .map(argumentTypeSignature -> new TypeSignature(argumentTypeSignature.getBase(), emptyList())).collect(toList())); + } + /* * similar to T extends MyClass, if Java supported varargs wildcards */ diff --git a/presto-spi/src/main/java/com/facebook/presto/spi/function/StatsPropagationBehavior.java b/presto-spi/src/main/java/com/facebook/presto/spi/function/StatsPropagationBehavior.java new file mode 100644 index 000000000000..1b6c64373d73 --- /dev/null +++ b/presto-spi/src/main/java/com/facebook/presto/spi/function/StatsPropagationBehavior.java @@ -0,0 +1,76 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.spi.function; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; + +import static java.util.Collections.unmodifiableSet; + +public enum StatsPropagationBehavior +{ + /** Use the max value across all arguments to derive the new stats value */ + USE_MAX_ARGUMENT(0), + /** Use the min value across all arguments to derive the new stats value */ + USE_MIN_ARGUMENT(0), + /** Sum the stats value of all arguments to derive the new stats value */ + SUM_ARGUMENTS(0), + /** Sum the stats value of all arguments to derive the new stats value, but upper bounded to row count. */ + SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT(0), + /** Propagate the source stats as-is */ + USE_SOURCE_STATS(0), + // Following stats are independent of source stats. + /** Use the value of output row count. */ + ROW_COUNT(-1), + /** Use the value of row_count * (1 - null_fraction). */ + NON_NULL_ROW_COUNT(-10), + /** use the value of TYPE_WIDTH in varchar(TYPE_WIDTH) */ + USE_TYPE_WIDTH_VARCHAR(0), + /** Take max of type width of arguments with varchar type. */ + MAX_TYPE_WIDTH_VARCHAR(0), + /** Stats are unknown and thus no action is performed. */ + UNKNOWN(0); + /* + * Stats are multi argument when their value is calculated by operating on stats from source stats or other properties of the all the arguments. + */ + private static final Set MULTI_ARGUMENT_STATS = + unmodifiableSet( + new HashSet<>(Arrays.asList(MAX_TYPE_WIDTH_VARCHAR, USE_MAX_ARGUMENT, USE_MIN_ARGUMENT, SUM_ARGUMENTS, SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT))); + private static final Set SOURCE_STATS_DEPENDENT_STATS = + unmodifiableSet( + new HashSet<>(Arrays.asList(USE_MAX_ARGUMENT, USE_MIN_ARGUMENT, SUM_ARGUMENTS, SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT, USE_SOURCE_STATS))); + + private final int value; + + StatsPropagationBehavior(int value) + { + this.value = value; + } + + public int getValue() + { + return this.value; + } + + public boolean isMultiArgumentStat() + { + return MULTI_ARGUMENT_STATS.contains(this); + } + + public boolean isSourceStatsDependentStats() + { + return SOURCE_STATS_DEPENDENT_STATS.contains(this); + } +}