Skip to content

Commit

Permalink
impl limit pushdown for MemoryExec
Browse files Browse the repository at this point in the history
  • Loading branch information
zjregee committed Feb 5, 2025
1 parent e8d9b62 commit 02efd60
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 28 deletions.
117 changes: 99 additions & 18 deletions datafusion/physical-plan/src/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ pub struct MemoryExec {
cache: PlanProperties,
/// if partition sizes should be displayed
show_sizes: bool,
/// Maximum number of rows to return
fetch: Option<usize>,
}

impl fmt::Debug for MemoryExec {
Expand All @@ -74,6 +76,7 @@ impl fmt::Debug for MemoryExec {
.field("schema", &self.schema)
.field("projection", &self.projection)
.field("sort_information", &self.sort_information)
.field("fetch", &self.fetch)
.finish()
}
}
Expand All @@ -100,16 +103,20 @@ impl DisplayAs for MemoryExec {
format!(", {}", constraints)
};

let limit = self
.fetch
.map_or(String::new(), |limit| format!(", limit={}", limit));

if self.show_sizes {
write!(
f,
"MemoryExec: partitions={}, partition_sizes={partition_sizes:?}{output_ordering}{constraints}",
"MemoryExec: partitions={}, partition_sizes={limit}{partition_sizes:?}{output_ordering}{constraints}",
partition_sizes.len(),
)
} else {
write!(
f,
"MemoryExec: partitions={}{output_ordering}{constraints}",
"MemoryExec: partitions={}{limit}{output_ordering}{constraints}",
partition_sizes.len(),
)
}
Expand Down Expand Up @@ -154,11 +161,14 @@ impl ExecutionPlan for MemoryExec {
partition: usize,
_context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
Ok(Box::pin(MemoryStream::try_new(
self.partitions[partition].clone(),
Arc::clone(&self.projected_schema),
self.projection.clone(),
)?))
Ok(Box::pin(
MemoryStream::try_new(
self.partitions[partition].clone(),
Arc::clone(&self.projected_schema),
self.projection.clone(),
)?
.with_fetch(self.fetch),
))
}

/// We recompute the statistics dynamically from the arrow metadata as it is pretty cheap to do so
Expand Down Expand Up @@ -193,6 +203,23 @@ impl ExecutionPlan for MemoryExec {
})
.transpose()
}

fn fetch(&self) -> Option<usize> {
self.fetch
}

fn with_fetch(&self, limit: Option<usize>) -> Option<Arc<dyn ExecutionPlan>> {
Some(Arc::new(Self {
partitions: self.partitions.clone(),
schema: self.schema.clone(),
projected_schema: self.projected_schema.clone(),
projection: self.projection.clone(),
sort_information: self.sort_information.clone(),
cache: self.cache.clone(),
show_sizes: self.show_sizes,
fetch: limit,
}))
}
}

impl MemoryExec {
Expand All @@ -219,6 +246,7 @@ impl MemoryExec {
sort_information: vec![],
cache,
show_sizes: true,
fetch: None,
})
}

Expand Down Expand Up @@ -314,6 +342,7 @@ impl MemoryExec {
sort_information: vec![],
cache,
show_sizes: true,
fetch: None,
})
}

Expand Down Expand Up @@ -462,6 +491,8 @@ pub struct MemoryStream {
projection: Option<Vec<usize>>,
/// Index into the data
index: usize,
/// The remaining number of rows to return
fetch: Option<usize>,
}

impl MemoryStream {
Expand All @@ -477,6 +508,7 @@ impl MemoryStream {
schema,
projection,
index: 0,
fetch: None,
})
}

Expand All @@ -485,6 +517,12 @@ impl MemoryStream {
self.reservation = Some(reservation);
self
}

/// Set the number of rows to produce
pub(super) fn with_fetch(mut self, fetch: Option<usize>) -> Self {
self.fetch = fetch;
self
}
}

impl Stream for MemoryStream {
Expand All @@ -494,20 +532,35 @@ impl Stream for MemoryStream {
mut self: std::pin::Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
Poll::Ready(if self.index < self.data.len() {
self.index += 1;
let batch = &self.data[self.index - 1];
if self.index >= self.data.len() {
return Poll::Ready(None);
}

self.index += 1;
let batch = &self.data[self.index - 1];

// return just the columns requested
let batch = match self.projection.as_ref() {
Some(columns) => batch.project(columns)?,
None => batch.clone(),
};

if self.fetch.is_none() {
return Poll::Ready(Some(Ok(batch)));
}

// return just the columns requested
let batch = match self.projection.as_ref() {
Some(columns) => batch.project(columns)?,
None => batch.clone(),
};
let fetch = self.fetch.unwrap();
if fetch == 0 {
return Poll::Ready(None);
}

Some(Ok(batch))
let batch = if batch.num_rows() > fetch {
batch.slice(0, fetch)
} else {
None
})
batch
};
self.fetch = Some(fetch - batch.num_rows());
Poll::Ready(Some(Ok(batch)))
}

fn size_hint(&self) -> (usize, Option<usize>) {
Expand Down Expand Up @@ -859,7 +912,9 @@ mod tests {
use crate::test::{self, make_partition};

use arrow_schema::{DataType, Field};
use datafusion_common::assert_batches_eq;
use datafusion_common::stats::{ColumnStatistics, Precision};
use futures::StreamExt;

#[tokio::test]
async fn values_empty_case() -> Result<()> {
Expand Down Expand Up @@ -944,4 +999,30 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn exec_with_limit() -> Result<()> {
let task_ctx = Arc::new(TaskContext::default());
let batch = make_partition(7);
let schema = batch.schema();
let batches = vec![batch.clone(), batch];

let exec = MemoryExec::try_new_from_batches(schema, batches).unwrap();
assert_eq!(exec.fetch(), None);

let exec = exec.with_fetch(Some(4)).unwrap();
assert_eq!(exec.fetch(), Some(4));

let mut it = exec.execute(0, task_ctx)?;
let mut results = vec![];
while let Some(batch) = it.next().await {
results.push(batch?);
}

let expected = [
"+---+", "| i |", "+---+", "| 0 |", "| 1 |", "| 2 |", "| 3 |", "+---+",
];
assert_batches_eq!(expected, &results);
Ok(())
}
}
12 changes: 4 additions & 8 deletions datafusion/sqllogictest/test_files/limit.slt
Original file line number Diff line number Diff line change
Expand Up @@ -739,10 +739,8 @@ physical_plan
01)ProjectionExec: expr=[a@2 as a, b@3 as b, a@0 as a, b@1 as b]
02)--GlobalLimitExec: skip=0, fetch=10
03)----CrossJoinExec
04)------GlobalLimitExec: skip=0, fetch=1
05)--------MemoryExec: partitions=1, partition_sizes=[1]
06)------GlobalLimitExec: skip=0, fetch=10
07)--------MemoryExec: partitions=1, partition_sizes=[1]
04)------MemoryExec: partitions=1, partition_sizes=, limit=1[1]
05)------MemoryExec: partitions=1, partition_sizes=, limit=10[1]


query IIII
Expand All @@ -765,10 +763,8 @@ logical_plan
physical_plan
01)GlobalLimitExec: skip=0, fetch=2
02)--CrossJoinExec
03)----GlobalLimitExec: skip=0, fetch=2
04)------MemoryExec: partitions=1, partition_sizes=[1]
05)----GlobalLimitExec: skip=0, fetch=2
06)------MemoryExec: partitions=1, partition_sizes=[1]
03)----MemoryExec: partitions=1, partition_sizes=, limit=2[1]
04)----MemoryExec: partitions=1, partition_sizes=, limit=2[1]

statement ok
drop table testSubQueryLimit;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,4 @@ statement ok
drop table test_substr_base;

statement ok
drop table test_datetime_base;
drop table test_datetime_base;
2 changes: 1 addition & 1 deletion datafusion/sqllogictest/test_files/string/large_string.slt
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ statement ok
drop table test_substr_base;

statement ok
drop table test_datetime_base;
drop table test_datetime_base;

0 comments on commit 02efd60

Please sign in to comment.