Skip to content

Commit

Permalink
Converted the TPCH/TPCDS tests to run per feature.
Browse files Browse the repository at this point in the history
  • Loading branch information
ScrapCodes committed Oct 16, 2024
1 parent b85fc3e commit d2a5953
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,15 @@
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Map;
import java.util.Map.Entry;
import java.util.stream.Stream;

import static com.facebook.presto.SystemSessionProperties.OPTIMIZER_USE_HISTOGRAMS;
import static com.facebook.presto.SystemSessionProperties.SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED;
import static com.facebook.presto.spi.plan.JoinDistributionType.REPLICATED;
import static com.facebook.presto.spi.plan.JoinType.INNER;
import static com.facebook.presto.sql.Optimizer.PlanStage.OPTIMIZED_AND_VALIDATED;
import static com.facebook.presto.testing.TestngUtils.toDataProvider;
import static com.facebook.presto.testing.TestngUtils.toDataProviderFromArray;
import static com.google.common.base.Preconditions.checkState;
import static com.google.common.base.Verify.verify;
import static com.google.common.io.Files.createParentDirs;
Expand All @@ -62,9 +63,11 @@
public abstract class AbstractCostBasedPlanTest
extends BasePlanTest
{
private final Map<String, String> featureToOutputDir =
ImmutableMap.of(OPTIMIZER_USE_HISTOGRAMS, "histogram",
SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED, "scalar_function_stats_propagation");
private static final String NO_FEATURE_ENABLED = "no_feature_enabled";

private final Map<String, String> featuresMap =
ImmutableMap.of(NO_FEATURE_ENABLED, "", "histogram", OPTIMIZER_USE_HISTOGRAMS,
"scalar_function_stats_propagation", SCALAR_FUNCTION_STATS_PROPAGATION_ENABLED);

public AbstractCostBasedPlanTest(LocalQueryRunnerSupplier supplier)
{
Expand All @@ -76,33 +79,29 @@ public AbstractCostBasedPlanTest(LocalQueryRunnerSupplier supplier)
@DataProvider
public Object[][] getQueriesDataProvider()
{
return getQueryResourcePaths()
.collect(toDataProvider());
return featuresMap.keySet().stream().flatMap(feature -> getQueryResourcePaths().map(path -> new String[] {feature, path})).collect(toDataProviderFromArray());
}

@Test(dataProvider = "getQueriesDataProvider")
public void test(String queryResourcePath)
{
assertEquals(generateQueryPlan(read(queryResourcePath)), read(getQueryPlanResourcePath(queryResourcePath)));
}

@Test(dataProvider = "getQueriesDataProvider")
public void featureSpecificPlansMatch(String queryResourcePath)
public void test(String feature, String queryResourcePath)
{
String sql = read(queryResourcePath);
for (Map.Entry<String, String> featureEntry : featureToOutputDir.entrySet()) {
if (!feature.equals(NO_FEATURE_ENABLED)) {
Session featureEnabledSession = Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(featureEntry.getKey(), "true")
.setSystemProperty(featuresMap.get(feature), "true")
.build();
Session featureDisabledSession = Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(featureEntry.getKey(), "false")
.setSystemProperty(featuresMap.get(feature), "false")
.build();
String regularPlan = generateQueryPlan(sql, featureDisabledSession);
String featureEnabledPlan = generateQueryPlan(sql, featureEnabledSession);
if (!regularPlan.equals(featureEnabledPlan)) {
assertEquals(featureEnabledPlan, read(getSpecificPlanResourcePath(featureEntry.getValue(), getQueryPlanResourcePath(queryResourcePath))));
assertEquals(featureEnabledPlan, read(getSpecificPlanResourcePath(feature, getQueryPlanResourcePath(queryResourcePath))));
}
}
else {
assertEquals(generateQueryPlan(sql), read(getQueryPlanResourcePath(queryResourcePath)));
}
}

private String getQueryPlanResourcePath(String queryResourcePath)
Expand All @@ -113,7 +112,7 @@ private String getQueryPlanResourcePath(String queryResourcePath)
private String getSpecificPlanResourcePath(String outDirPath, String regularPlanResourcePath)
{
Path root = Paths.get(regularPlanResourcePath);
return root.getParent().resolve(String.format("%s/%s", outDirPath, root.getFileName())).toString();
return root.getParent().resolve(format("%s/%s", outDirPath, root.getFileName())).toString();
}

private Path getResourceWritePath(String queryResourcePath)
Expand All @@ -133,25 +132,28 @@ public void generate()
.parallel()
.forEach(queryResourcePath -> {
try {
for (Map.Entry<String, String> featureEntry : featureToOutputDir.entrySet()) {
Path queryPlanWritePath = getResourceWritePath(queryResourcePath);
createParentDirs(queryPlanWritePath.toFile());
for (Entry<String, String> featureEntry : featuresMap.entrySet()) {
String sql = read(queryResourcePath);
Session featuredisabledSession = Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(featureEntry.getKey(), "false")
.build();
String regularPlan = generateQueryPlan(sql, featuredisabledSession);
Session featureEnabledSession = Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(featureEntry.getKey(), "true")
.build();

String featureEnabledPlan = generateQueryPlan(sql, featureEnabledSession);
write(regularPlan.getBytes(UTF_8), queryPlanWritePath.toFile());
// write out the feature enabled plan if it differs
if (!regularPlan.equals(featureEnabledPlan)) {
Path featureEnabledPlanWritePath = getResourceWritePath(getSpecificPlanResourcePath(featureEntry.getValue(), queryResourcePath));
createParentDirs(featureEnabledPlanWritePath.toFile());
write(featureEnabledPlan.getBytes(UTF_8), featureEnabledPlanWritePath.toFile());
if (!featureEntry.getKey().equals(NO_FEATURE_ENABLED)) {
Session featureDisabledSession = Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(featureEntry.getValue(), "false")
.build();
String regularPlan = generateQueryPlan(sql, featureDisabledSession);
Session featureEnabledSession = Session.builder(getQueryRunner().getDefaultSession())
.setSystemProperty(featureEntry.getValue(), "true")
.build();
String featureEnabledPlan = generateQueryPlan(sql, featureEnabledSession);
// write out the feature enabled plan if it differs
if (!regularPlan.equals(featureEnabledPlan)) {
Path featureEnabledPlanWritePath = getResourceWritePath(getSpecificPlanResourcePath(featureEntry.getKey(), queryResourcePath));
createParentDirs(featureEnabledPlanWritePath.toFile());
write(featureEnabledPlan.getBytes(UTF_8), featureEnabledPlanWritePath.toFile());
}
}
else {
Path queryPlanWritePath = getResourceWritePath(queryResourcePath);
createParentDirs(queryPlanWritePath.toFile());
write(generateQueryPlan(sql).getBytes(UTF_8), queryPlanWritePath.toFile());
}
System.out.println("Generated expected plan for query: " + queryResourcePath);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.facebook.presto.spi.function.Description;
import com.facebook.presto.spi.function.LiteralParameters;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarPropagateSourceStats;
import com.facebook.presto.spi.function.Signature;
import com.facebook.presto.spi.function.SqlNullable;
import com.facebook.presto.spi.function.SqlType;
Expand Down Expand Up @@ -63,6 +64,7 @@
import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.StandardErrorCode.NUMERIC_VALUE_OUT_OF_RANGE;
import static com.facebook.presto.spi.function.FunctionKind.SCALAR;
import static com.facebook.presto.spi.function.StatsPropagationBehavior.USE_SOURCE_STATS;
import static com.facebook.presto.type.DecimalOperators.modulusScalarFunction;
import static com.facebook.presto.type.DecimalOperators.modulusSignatureBuilder;
import static com.facebook.presto.util.Failures.checkCondition;
Expand Down Expand Up @@ -105,7 +107,8 @@ private MathFunctions() {}
@Description("absolute value")
@ScalarFunction("abs")
@SqlType(StandardTypes.TINYINT)
public static long absTinyint(@SqlType(StandardTypes.TINYINT) long num)
public static long absTinyint(
@ScalarPropagateSourceStats(propagateAllStats = false, nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.TINYINT) long num)
{
checkCondition(num != Byte.MIN_VALUE, NUMERIC_VALUE_OUT_OF_RANGE, "Value -128 is out of range for abs(tinyint)");
return Math.abs(num);
Expand All @@ -114,7 +117,8 @@ public static long absTinyint(@SqlType(StandardTypes.TINYINT) long num)
@Description("absolute value")
@ScalarFunction("abs")
@SqlType(StandardTypes.SMALLINT)
public static long absSmallint(@SqlType(StandardTypes.SMALLINT) long num)
public static long absSmallint(
@ScalarPropagateSourceStats(propagateAllStats = false, nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.SMALLINT) long num)
{
checkCondition(num != Short.MIN_VALUE, NUMERIC_VALUE_OUT_OF_RANGE, "Value -32768 is out of range for abs(smallint)");
return Math.abs(num);
Expand All @@ -123,7 +127,8 @@ public static long absSmallint(@SqlType(StandardTypes.SMALLINT) long num)
@Description("absolute value")
@ScalarFunction("abs")
@SqlType(StandardTypes.INTEGER)
public static long absInteger(@SqlType(StandardTypes.INTEGER) long num)
public static long absInteger(
@ScalarPropagateSourceStats(propagateAllStats = false, nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.INTEGER) long num)
{
checkCondition(num != Integer.MIN_VALUE, NUMERIC_VALUE_OUT_OF_RANGE, "Value -2147483648 is out of range for abs(integer)");
return Math.abs(num);
Expand All @@ -132,7 +137,8 @@ public static long absInteger(@SqlType(StandardTypes.INTEGER) long num)
@Description("absolute value")
@ScalarFunction
@SqlType(StandardTypes.BIGINT)
public static long abs(@SqlType(StandardTypes.BIGINT) long num)
public static long abs(
@ScalarPropagateSourceStats(propagateAllStats = false, nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.BIGINT) long num)
{
checkCondition(num != Long.MIN_VALUE, NUMERIC_VALUE_OUT_OF_RANGE, "Value -9223372036854775808 is out of range for abs(bigint)");
return Math.abs(num);
Expand All @@ -141,7 +147,8 @@ public static long abs(@SqlType(StandardTypes.BIGINT) long num)
@Description("absolute value")
@ScalarFunction
@SqlType(StandardTypes.DOUBLE)
public static double abs(@SqlType(StandardTypes.DOUBLE) double num)
public static double abs(
@ScalarPropagateSourceStats(propagateAllStats = false, nullFraction = USE_SOURCE_STATS) @SqlType(StandardTypes.DOUBLE) double num)
{
return Math.abs(num);
}
Expand All @@ -154,7 +161,8 @@ private Abs() {}

@LiteralParameters({"p", "s"})
@SqlType("decimal(p, s)")
public static long absShort(@SqlType("decimal(p, s)") long arg)
public static long absShort(
@ScalarPropagateSourceStats(propagateAllStats = false, nullFraction = USE_SOURCE_STATS) @SqlType("decimal(p, s)") long arg)
{
return arg > 0 ? arg : -arg;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,16 @@ private TestngUtils() {}
},
builder -> builder.toArray(new Object[][] {}));
}

public static <T> Collector<T, ?, Object[][]> toDataProviderFromArray()
{
return Collector.of(
ArrayList::new,
ArrayList::add,
(left, right) -> {
left.addAll(right);
return left;
},
builder -> builder.toArray(new Object[][] {}));
}
}

0 comments on commit d2a5953

Please sign in to comment.