Skip to content

Commit

Permalink
Phase 1. Implementation for RFC-0005: Scalar function stats propagation.
Browse files Browse the repository at this point in the history
1. Support for annotating functions with both constant stats and propagating source stats.
2. Added tests for the same.
3. Added Scalar stats calculation based on annotation and tests for the same.

Not added SQLInvokedScalarFunctions.
Not annotated builtin functions, as that is covered in next implementation phase.
Not added C++ changes as this phase only covers Java side of changes.

Added documentation for the new properties and ...
 1. Previously, if any of the source stats were missing, we would still compute the max/min/sum of argument stats etc..
  now we propagate NaNs if any one of the arguments' stats are missing.

2. For distinct values count, upper bounding it to row count is as good as unknown. Therefore, the approach here is, when distinctValuesCount is greater than row count and is provided via annotation we set it to unknown.
A function developer has full control here, for example developer can choose to upper bound or not by selecting the appropriate StatsPropagationBehavior value.

 3. For average row size,
    a) If average row size is provided via ScalarFunctionConstantStats annotation, then we allow even if the size is greater than functions return type width.
    b) If average row size is provided via one of the StatsPropagationBehavior values, then we upper bound it to functions return type width - if available.
    If both (a) and (b) is unknown, then we default it to functions return type width if available.

This way the function developer has greater control.

Added new behaviour SUM_ARGUMENTS_UPPER_BOUND_ROW_COUNT which would upper bound the values to row count, so that summing distinct values count not exceed row counts.
  • Loading branch information
ScrapCodes committed Oct 15, 2024
1 parent 27eb666 commit 9d77026
Show file tree
Hide file tree
Showing 20 changed files with 1,672 additions and 31 deletions.
11 changes: 11 additions & 0 deletions presto-docs/src/main/sphinx/admin/properties.rst
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,17 @@ This can also be specified on a per-query basis using the ``confidence_based_bro
Enable treating ``LOW`` confidence, zero estimations as ``UNKNOWN`` during joins. This can also be specified
on a per-query basis using the ``treat-low-confidence-zero-estimation-as-unknown`` session property.

``optimizer.scalar-function-stats-propagation-enabled``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^


* **Type:** ``boolean``
* **Default value:** ``false``

Enable scalar functions stats propagation using annotations. Annotations define the behavior of the scalar
function's stats characteristics. When set to ``true``, this property enables the stats propagation through annotations.
This can also be specified on a per-query basis using the ``scalar_function_stats_propagation_enabled`` session property.

``optimizer.retry-query-with-history-based-optimization``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ public final class SystemSessionProperties
public static final String OPTIMIZER_USE_HISTOGRAMS = "optimizer_use_histograms";
public static final String WARN_ON_COMMON_NAN_PATTERNS = "warn_on_common_nan_patterns";
public static final String INLINE_PROJECTIONS_ON_VALUES = "inline_projections_on_values";
public static final String SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED = "scalar_function_stats_propagation_enabled";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -2077,6 +2078,10 @@ public SystemSessionProperties(
booleanProperty(INLINE_PROJECTIONS_ON_VALUES,
"Whether to evaluate project node on values node",
featuresConfig.getInlineProjectionsOnValues(),
false),
booleanProperty(SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED,
"whether or not to respect stats propagation annotation for scalar functions (or UDF)",
featuresConfig.isScalarFunctionStatsPropagationEnabled(),
false));
}

Expand Down Expand Up @@ -3414,4 +3419,9 @@ public static boolean isInlineProjectionsOnValues(Session session)
{
return session.getSystemProperty(INLINE_PROJECTIONS_ON_VALUES, Boolean.class);
}

public static boolean shouldEnableScalarFunctionStatsPropagation(Session session)
{
return session.getSystemProperty(SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED, Boolean.class);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.facebook.presto.cost;

import com.facebook.presto.common.type.FixedWidthType;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.common.type.VarcharType;
import com.facebook.presto.spi.function.ScalarPropagateSourceStats;
import com.facebook.presto.spi.function.ScalarStatsHeader;
import com.facebook.presto.spi.function.StatsPropagationBehavior;
import com.facebook.presto.spi.relation.CallExpression;
import com.facebook.presto.spi.relation.RowExpression;

import java.util.List;
import java.util.Map;

import static com.facebook.presto.spi.function.StatsPropagationBehavior.NON_NULL_ROW_COUNT;
import static com.facebook.presto.spi.function.StatsPropagationBehavior.ROW_COUNT;
import static com.facebook.presto.spi.function.StatsPropagationBehavior.SUM_ARGUMENTS;
import static com.facebook.presto.spi.function.StatsPropagationBehavior.SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT;
import static com.facebook.presto.spi.function.StatsPropagationBehavior.UNKNOWN;
import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS;
import static com.facebook.presto.util.MoreMath.max;
import static com.facebook.presto.util.MoreMath.min;
import static com.facebook.presto.util.MoreMath.minExcludingNaNs;
import static com.facebook.presto.util.MoreMath.nearlyEqual;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static java.lang.Double.NaN;
import static java.lang.Double.isFinite;
import static java.lang.Double.isNaN;

public final class ScalarStatsAnnotationProcessor
{
private ScalarStatsAnnotationProcessor()
{
}

public static VariableStatsEstimate process(
double outputRowCount,
CallExpression callExpression,
List<VariableStatsEstimate> sourceStats,
ScalarStatsHeader scalarStatsHeader)
{
double nullFraction = scalarStatsHeader.getNullFraction();
double distinctValuesCount = NaN;
double averageRowSize = NaN;
double maxValue = scalarStatsHeader.getMax();
double minValue = scalarStatsHeader.getMin();
for (Map.Entry<Integer, ScalarPropagateSourceStats> paramIndexToStatsMap : scalarStatsHeader.getArgumentStats().entrySet()) {
ScalarPropagateSourceStats scalarPropagateSourceStats = paramIndexToStatsMap.getValue();
boolean propagateAllStats = scalarPropagateSourceStats.propagateAllStats();
nullFraction = min(firstFiniteValue(nullFraction, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
sourceStats.stream().map(VariableStatsEstimate::getNullsFraction).collect(toImmutableList()),
paramIndexToStatsMap.getKey(),
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.nullFraction()))), 1.0);
distinctValuesCount = firstFiniteValue(distinctValuesCount, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
sourceStats.stream().map(VariableStatsEstimate::getDistinctValuesCount).collect(toImmutableList()),
paramIndexToStatsMap.getKey(),
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.distinctValuesCount())));
StatsPropagationBehavior averageRowSizeStatsBehaviour = applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.avgRowSize());
averageRowSize = minExcludingNaNs(firstFiniteValue(averageRowSize, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
sourceStats.stream().map(VariableStatsEstimate::getAverageRowSize).collect(toImmutableList()),
paramIndexToStatsMap.getKey(),
averageRowSizeStatsBehaviour)), returnNaNIfTypeWidthUnknown(getReturnTypeWidth(callExpression, averageRowSizeStatsBehaviour)));
maxValue = firstFiniteValue(maxValue, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
sourceStats.stream().map(VariableStatsEstimate::getHighValue).collect(toImmutableList()),
paramIndexToStatsMap.getKey(),
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.maxValue())));
minValue = firstFiniteValue(minValue, processSingleArgumentStatistic(outputRowCount, nullFraction, callExpression,
sourceStats.stream().map(VariableStatsEstimate::getLowValue).collect(toImmutableList()),
paramIndexToStatsMap.getKey(),
applyPropagateAllStats(propagateAllStats, scalarPropagateSourceStats.minValue())));
}
if (isNaN(maxValue) || isNaN(minValue)) {
minValue = NaN;
maxValue = NaN;
}
return VariableStatsEstimate.builder()
.setLowValue(minValue)
.setHighValue(maxValue)
.setNullsFraction(nullFraction)
.setAverageRowSize(firstFiniteValue(scalarStatsHeader.getAvgRowSize(), averageRowSize, returnNaNIfTypeWidthUnknown(getReturnTypeWidth(callExpression, UNKNOWN))))
.setDistinctValuesCount(processDistinctValuesCount(outputRowCount, nullFraction, scalarStatsHeader.getDistinctValuesCount(), distinctValuesCount)).build();
}

private static double processDistinctValuesCount(double outputRowCount, double nullFraction, double distinctValuesCountFromConstant, double distinctValuesCount)
{
if (isFinite(distinctValuesCountFromConstant)) {
if (nearlyEqual(distinctValuesCountFromConstant, NON_NULL_ROW_COUNT.getValue(), 0.1)) {
distinctValuesCountFromConstant = outputRowCount * (1 - firstFiniteValue(nullFraction, 0.0));
}
else if (nearlyEqual(distinctValuesCount, ROW_COUNT.getValue(), 0.1)) {
distinctValuesCountFromConstant = outputRowCount;
}
}
double distinctValuesCountFinal = firstFiniteValue(distinctValuesCountFromConstant, distinctValuesCount);
if (distinctValuesCountFinal > outputRowCount) {
distinctValuesCountFinal = NaN;
}
return distinctValuesCountFinal;
}

private static double processSingleArgumentStatistic(
double outputRowCount,
double nullFraction,
CallExpression callExpression,
List<Double> sourceStats,
int sourceStatsArgumentIndex,
StatsPropagationBehavior operation)
{
// sourceStatsArgumentIndex is index of the argument on which
// ScalarPropagateSourceStats annotation was applied.
double statValue = NaN;
if (operation.isMultiArgumentStat()) {
for (int i = 0; i < sourceStats.size(); i++) {
if (i == 0 && operation.isSourceStatsDependentStats() && isFinite(sourceStats.get(i))) {
statValue = sourceStats.get(i);
}
else {
switch (operation) {
case MAX_TYPE_WIDTH_VARCHAR:
statValue = returnNaNIfTypeWidthUnknown(getTypeWidthVarchar(callExpression.getArguments().get(i).getType()));
break;
case USE_MIN_ARGUMENT:
statValue = min(statValue, sourceStats.get(i));
break;
case USE_MAX_ARGUMENT:
statValue = max(statValue, sourceStats.get(i));
break;
case SUM_ARGUMENTS:
statValue = statValue + sourceStats.get(i);
break;
case SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT:
statValue = min(statValue + sourceStats.get(i), outputRowCount);
break;
}
}
}
}
else {
switch (operation) {
case USE_SOURCE_STATS:
statValue = sourceStats.get(sourceStatsArgumentIndex);
break;
case ROW_COUNT:
statValue = outputRowCount;
break;
case NON_NULL_ROW_COUNT:
statValue = outputRowCount * (1 - firstFiniteValue(nullFraction, 0.0));
break;
case USE_TYPE_WIDTH_VARCHAR:
statValue = returnNaNIfTypeWidthUnknown(getTypeWidthVarchar(callExpression.getArguments().get(sourceStatsArgumentIndex).getType()));
break;
}
}
return statValue;
}

private static int getTypeWidthVarchar(Type argumentType)
{
if (argumentType instanceof VarcharType) {
if (!((VarcharType) argumentType).isUnbounded()) {
return ((VarcharType) argumentType).getLengthSafe();
}
}
return -VarcharType.MAX_LENGTH;
}

private static double returnNaNIfTypeWidthUnknown(int typeWidthValue)
{
if (typeWidthValue <= 0) {
return NaN;
}
return typeWidthValue;
}

private static int getReturnTypeWidth(CallExpression callExpression, StatsPropagationBehavior operation)
{
if (callExpression.getType() instanceof FixedWidthType) {
return ((FixedWidthType) callExpression.getType()).getFixedSize();
}
if (callExpression.getType() instanceof VarcharType) {
VarcharType returnType = (VarcharType) callExpression.getType();
if (!returnType.isUnbounded()) {
return returnType.getLengthSafe();
}
if (operation == SUM_ARGUMENTS || operation == SUM_ARGUMENTS_UPPER_BOUNDED_TO_ROW_COUNT) {
// since return type is an unbounded varchar and operation is SUM_ARGUMENTS,
// calculating the type width by doing a SUM of each argument's varchar type bounds - if available.
int sum = 0;
for (RowExpression r : callExpression.getArguments()) {
int typeWidth;
if (r instanceof CallExpression) { // argument is another function call
typeWidth = getReturnTypeWidth((CallExpression) r, UNKNOWN);
}
else {
typeWidth = getTypeWidthVarchar(r.getType());
}
if (typeWidth < 0) {
return -VarcharType.MAX_LENGTH;
}
sum += typeWidth;
}
return sum;
}
}
return -VarcharType.MAX_LENGTH;
}

// Return first 'finite' value from values, else return values[0]
private static double firstFiniteValue(double... values)
{
checkArgument(values.length > 1);
for (double v : values) {
if (isFinite(v)) {
return v;
}
}
return values[0];
}

private static StatsPropagationBehavior applyPropagateAllStats(
boolean propagateAllStats, StatsPropagationBehavior operation)
{
if (operation == UNKNOWN && propagateAllStats) {
return USE_SOURCE_STATS;
}
return operation;
}
}
Loading

0 comments on commit 9d77026

Please sign in to comment.