Skip to content

Commit

Permalink
[Enhancement] Estimate numWorkers of connector scan nodes for query q…
Browse files Browse the repository at this point in the history
…ueue (#56053)

Signed-off-by: zihe.liu <[email protected]>
  • Loading branch information
ZiheLiu authored Feb 19, 2025
1 parent 56c20ed commit 1c19057
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 19 deletions.
7 changes: 7 additions & 0 deletions fe/fe-core/src/main/java/com/starrocks/planner/ScanNode.java
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,13 @@ public boolean isRunningAsConnectorOperator() {
public void setScanSampleStrategy(RemoteFilesSampleStrategy strategy) {
}

public boolean isConnectorScanNode() {
return this instanceof HdfsScanNode || this instanceof IcebergScanNode ||
this instanceof HudiScanNode || this instanceof DeltaLakeScanNode ||
this instanceof FileTableScanNode || this instanceof PaimonScanNode ||
this instanceof OdpsScanNode || this instanceof IcebergMetadataScanNode;
}

protected String explainColumnDict(String prefix) {
StringBuilder output = new StringBuilder();
if (!appliedDictStringColumns.isEmpty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,7 @@

package com.starrocks.qe.scheduler.assignment;

import com.starrocks.planner.DeltaLakeScanNode;
import com.starrocks.planner.FileTableScanNode;
import com.starrocks.planner.HdfsScanNode;
import com.starrocks.planner.HudiScanNode;
import com.starrocks.planner.IcebergMetadataScanNode;
import com.starrocks.planner.IcebergScanNode;
import com.starrocks.planner.OdpsScanNode;
import com.starrocks.planner.OlapScanNode;
import com.starrocks.planner.PaimonScanNode;
import com.starrocks.planner.ScanNode;
import com.starrocks.planner.SchemaScanNode;
import com.starrocks.qe.BackendSelector;
Expand Down Expand Up @@ -68,10 +60,7 @@ public static BackendSelector create(ScanNode scanNode,

if (scanNode instanceof SchemaScanNode) {
return new NormalBackendSelector(scanNode, locations, assignment, workerProvider, false);
} else if (scanNode instanceof HdfsScanNode || scanNode instanceof IcebergScanNode ||
scanNode instanceof HudiScanNode || scanNode instanceof DeltaLakeScanNode ||
scanNode instanceof FileTableScanNode || scanNode instanceof PaimonScanNode
|| scanNode instanceof OdpsScanNode || scanNode instanceof IcebergMetadataScanNode) {
} else if (scanNode.isConnectorScanNode()) {
return new HDFSBackendSelector(scanNode, locations, assignment, workerProvider,
sessionVariable.getForceScheduleLocal(),
sessionVariable.getHDFSBackendSelectorScanRangeShuffle(), useIncrementalScanRanges);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import com.starrocks.planner.PlanFragment;
import com.starrocks.planner.PlanFragmentId;
import com.starrocks.planner.PlanNode;
import com.starrocks.planner.ScanNode;
import com.starrocks.qe.ConnectContext;
import com.starrocks.qe.DefaultCoordinator;
import com.starrocks.sql.optimizer.cost.feature.CostPredictor;
Expand Down Expand Up @@ -55,7 +56,11 @@ public int estimateSlots(QueryQueueOptions opts, ConnectContext context, Default
if (CostPredictor.getServiceBasedCostPredictor().isAvailable() && coord.getPredictedCost() > 0) {
memCost = coord.getPredictedCost();
} else {
memCost = (long) context.getAuditEventBuilder().build().planMemCosts;
// The estimate of planMemCosts is typically an underestimation, often several orders of magnitude smaller than
// the actual memory usage, whereas planCpuCosts tends to be relatively larger.
// Therefore, the maximum value between the two is used as the estimate for memory.
memCost = (long) Math.max(context.getAuditEventBuilder().build().planMemCosts,
context.getAuditEventBuilder().build().planCpuCosts);
}
long numSlotsPerWorker = memCost / opts.v2().getMemBytesPerSlot();
numSlotsPerWorker = Math.max(numSlotsPerWorker, 0);
Expand All @@ -72,7 +77,7 @@ public int estimateSlots(QueryQueueOptions opts, ConnectContext context, Default
public static class ParallelismBasedSlotsEstimator implements SlotEstimator {
@Override
public int estimateSlots(QueryQueueOptions opts, ConnectContext context, DefaultCoordinator coord) {
Map<PlanFragmentId, FragmentContext> fragmentContexts = collectFragmentContexts(coord);
Map<PlanFragmentId, FragmentContext> fragmentContexts = collectFragmentContexts(opts, coord);
int numSlots = fragmentContexts.values().stream()
.mapToInt(fragmentContext -> estimateFragmentSlots(opts, fragmentContext))
.max().orElse(1);
Expand Down Expand Up @@ -117,13 +122,14 @@ private static int estimateNumSlotsBySourceNode(QueryQueueOptions opts, PlanNode
return (int) (sourceNode.getCardinality() / opts.v2().getNumRowsPerSlot());
}

private static Map<PlanFragmentId, FragmentContext> collectFragmentContexts(DefaultCoordinator coord) {
private static Map<PlanFragmentId, FragmentContext> collectFragmentContexts(QueryQueueOptions opts,
DefaultCoordinator coord) {
PlanFragment rootFragment = coord.getExecutionDAG().getRootFragment().getPlanFragment();
PlanNode rootNode = rootFragment.getPlanRoot();

Map<PlanFragmentId, FragmentContext> contexts = Maps.newHashMap();
collectFragmentSourceNodes(rootNode, contexts);
calculateFragmentWorkers(rootFragment, contexts);
calculateFragmentWorkers(opts, rootFragment, contexts);

return contexts;
}
Expand All @@ -138,8 +144,9 @@ private static void collectFragmentSourceNodes(PlanNode node, Map<PlanFragmentId
node.getChildren().forEach(child -> collectFragmentSourceNodes(child, contexts));
}

private static void calculateFragmentWorkers(PlanFragment fragment, Map<PlanFragmentId, FragmentContext> contexts) {
fragment.getChildren().forEach(child -> calculateFragmentWorkers(child, contexts));
private static void calculateFragmentWorkers(QueryQueueOptions opts, PlanFragment fragment,
Map<PlanFragmentId, FragmentContext> contexts) {
fragment.getChildren().forEach(child -> calculateFragmentWorkers(opts, child, contexts));

FragmentContext context = contexts.get(fragment.getFragmentId());
if (context == null) {
Expand All @@ -154,6 +161,10 @@ private static void calculateFragmentWorkers(PlanFragment fragment, Map<PlanFrag
.map(TScanRangeLocation::getBackend_id)
.collect(Collectors.toSet())
.size();
} else if (leftMostNode instanceof ScanNode && ((ScanNode) leftMostNode).isConnectorScanNode()) {
// TODO: get the actual number of files for connector scan nodes.
int numWorkers = (int) leftMostNode.getCardinality() / opts.v2().getNumRowsPerSlot();
context.numWorkers = Math.max(1, Math.min(numWorkers, opts.v2().getNumWorkers()));
} else if (fragment.isGatherFragment()) {
context.numWorkers = 1;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,24 @@

package com.starrocks.qe.scheduler.slot;

import com.starrocks.connector.hive.MockedHiveMetadata;
import com.starrocks.planner.PlanNode;
import com.starrocks.qe.DefaultCoordinator;
import com.starrocks.qe.scheduler.SchedulerTestBase;
import com.starrocks.sql.optimizer.statistics.Statistics;
import com.starrocks.sql.plan.ConnectorPlanTestBase;
import org.junit.BeforeClass;
import org.junit.Test;

import static org.assertj.core.api.Assertions.assertThat;

public class SlotEstimatorTest extends SchedulerTestBase {
@BeforeClass
public static void beforeClass() throws Exception {
SchedulerTestBase.beforeClass();
ConnectorPlanTestBase.mockCatalog(connectContext, MockedHiveMetadata.MOCKED_HIVE_CATALOG_NAME);
}

@Test
public void testDefaultSlotEstimator() {
SlotEstimatorFactory.DefaultSlotEstimator estimator = new SlotEstimatorFactory.DefaultSlotEstimator();
Expand Down Expand Up @@ -65,7 +74,6 @@ public void testParallelismBasedSlotsEstimator() throws Exception {
QueryQueueOptions opts = new QueryQueueOptions(true,
new QueryQueueOptions.V2(4, numWorkers, numCoresPerWorker, 64L * 1024 * 1024 * 1024, numRowsPerWorker, 100));


{
DefaultCoordinator coordinator = getScheduler("SELECT * FROM lineitem");
connectContext.getAuditEventBuilder().setPlanCpuCosts(100 * 10000);
Expand Down Expand Up @@ -152,6 +160,33 @@ public void testParallelismBasedSlotsEstimator() throws Exception {
}
}

@Test
public void testParallelismBasedSlotsEstimatorForConnectorScan() throws Exception {
SlotEstimatorFactory.ParallelismBasedSlotsEstimator estimator = new SlotEstimatorFactory.ParallelismBasedSlotsEstimator();
final int numWorkers = 3;
final int numCoresPerWorker = 16;
final int numRowsPerWorker = 4096;
final int dop = numCoresPerWorker / 2;
QueryQueueOptions opts = new QueryQueueOptions(true,
new QueryQueueOptions.V2(4, numWorkers, numCoresPerWorker, 64L * 1024 * 1024 * 1024, numRowsPerWorker, 100));

String sql = "SELECT /*+SET_VAR(pipeline_dop=8)*/ " +
"count(1) FROM hive0.tpch.lineitem t1 join [shuffle] hive0.tpch.lineitem t2 on t1.l_orderkey = t2.l_orderkey";

{
DefaultCoordinator coordinator = getScheduler(sql);
setNodeCardinality(coordinator, 1, numWorkers * numRowsPerWorker * dop);
connectContext.getAuditEventBuilder().setPlanCpuCosts(100 * 10000);
assertThat(estimator.estimateSlots(opts, connectContext, coordinator)).isEqualTo(dop * numWorkers);
}
{
DefaultCoordinator coordinator = getScheduler(sql);
setNodeCardinality(coordinator, 1, numWorkers * numRowsPerWorker * dop * 10);
connectContext.getAuditEventBuilder().setPlanCpuCosts(100 * 10000);
assertThat(estimator.estimateSlots(opts, connectContext, coordinator)).isEqualTo(dop * numWorkers);
}
}

@Test
public void testMaxSlotsEstimator() {
SlotEstimator estimator1 = (opts, context, coord) -> 1;
Expand Down

0 comments on commit 1c19057

Please sign in to comment.