Skip to content

Commit

Permalink
feat: support simple window functions (#787)
Browse files Browse the repository at this point in the history
* support basic window function

Signed-off-by: Runji Wang <[email protected]>

* support aggregate functions as window function

Signed-off-by: Runji Wang <[email protected]>

* forbid window functions in WHERE and HAVING clause

Signed-off-by: Runji Wang <[email protected]>

* forbid nested window function

Signed-off-by: Runji Wang <[email protected]>

* ignore window function test in v1

Signed-off-by: Runji Wang <[email protected]>

* fix clippy

Signed-off-by: Runji Wang <[email protected]>

---------

Signed-off-by: Runji Wang <[email protected]>
  • Loading branch information
wangrunji0408 authored Jul 11, 2023
1 parent 0e81535 commit 604b4a1
Show file tree
Hide file tree
Showing 14 changed files with 231 additions and 31 deletions.
31 changes: 30 additions & 1 deletion src/binder_v2/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ impl Binder {
Ok(id)
}

/// Bind a list of expressions.
pub fn bind_exprs(&mut self, exprs: Vec<Expr>) -> Result {
let list = exprs
.into_iter()
.map(|expr| self.bind_expr(expr))
.try_collect()?;
Ok(self.egraph.add(Node::List(list)))
}

fn bind_ident(&mut self, idents: impl IntoIterator<Item = Ident>) -> Result {
let idents = idents
.into_iter()
Expand Down Expand Up @@ -221,9 +230,29 @@ impl Binder {
"first" => Node::First(args[0]),
"last" => Node::Last(args[0]),
"replace" => Node::Replace([args[0], args[1], args[2]]),
"row_number" => Node::RowNumber,
name => todo!("Unsupported function: {}", name),
};
Ok(self.egraph.add(node))
let mut id = self.egraph.add(node);
if let Some(window) = func.over {
id = self.bind_window_function(id, window)?;
}
Ok(id)
}

fn bind_window_function(&mut self, func: Id, window: WindowSpec) -> Result {
if !self.node(func).is_window_function() {
return Err(BindError::NotAgg(self.node(func).to_string()));
}
if !self.overs(func).is_empty() {
return Err(BindError::NestedWindow);
}
let partitionby = self.bind_exprs(window.partition_by)?;
let orderby = self.bind_orderby(window.order_by)?;
if window.window_frame.is_some() {
todo!("support window frame");
}
Ok(self.egraph.add(Node::Over([func, partitionby, orderby])))
}
}

Expand Down
12 changes: 12 additions & 0 deletions src/binder_v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,20 @@ pub enum BindError {
AggInWhere,
#[error("GROUP BY clause cannot contain aggregates")]
AggInGroupBy,
#[error("window function calls cannot be nested")]
NestedWindow,
#[error("WHERE clause cannot contain window functions")]
WindowInWhere,
#[error("HAVING clause cannot contain window functions")]
WindowInHaving,
#[error("column {0} must appear in the GROUP BY clause or be used in an aggregate function")]
ColumnNotInAgg(String),
#[error("ORDER BY items must appear in the select list if DISTINCT is specified")]
OrderKeyNotInDistinct,
#[error("operation on internal table is not supported")]
NotSupportedOnInternalTable,
#[error("{0} is not an aggregate function")]
NotAgg(String),
}

/// The binder resolves all expressions referring to schema objects such as
Expand Down Expand Up @@ -194,6 +202,10 @@ impl Binder {
&self.egraph[id].data.aggs
}

fn overs(&self, id: Id) -> &[Node] {
&self.egraph[id].data.overs
}

fn node(&self, id: Id) -> &Node {
&self.egraph[id].nodes[0]
}
Expand Down
50 changes: 39 additions & 11 deletions src/binder_v2/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ impl Binder {
plan = self.plan_agg(&mut to_rewrite, groupby, plan)?;
let [mut projection, distinct, having, orderby] = to_rewrite;
plan = self.egraph.add(Node::Filter([having, plan]));
plan = self.plan_window(projection, distinct, orderby, plan)?;
plan = self.plan_distinct(distinct, orderby, &mut projection, plan)?;
plan = self.egraph.add(Node::Order([orderby, plan]));
plan = self.egraph.add(Node::Proj([projection, plan]));
Expand Down Expand Up @@ -91,15 +92,27 @@ impl Binder {
///
/// There should be no aggregation in the expression, otherwise an error will be returned.
pub(super) fn bind_where(&mut self, selection: Option<Expr>) -> Result {
let id = self.bind_having(selection)?;
let id = self.bind_selection(selection)?;
if !self.aggs(id).is_empty() {
return Err(BindError::AggInWhere);
}
if !self.overs(id).is_empty() {
return Err(BindError::WindowInWhere);
}
Ok(id)
}

/// Binds the HAVING clause. Returns an expression for condition.
fn bind_having(&mut self, selection: Option<Expr>) -> Result {
let id = self.bind_selection(selection)?;
if !self.overs(id).is_empty() {
return Err(BindError::WindowInHaving);
}
Ok(id)
}

/// Binds a selection. Returns a `true` node if no selection.
fn bind_selection(&mut self, selection: Option<Expr>) -> Result {
Ok(match selection {
Some(expr) => self.bind_expr(expr)?,
None => self.egraph.add(Node::true_()),
Expand All @@ -110,18 +123,15 @@ impl Binder {
///
/// There should be no aggregation in the expressions, otherwise an error will be returned.
fn bind_groupby(&mut self, group_by: Vec<Expr>) -> Result {
let list = (group_by.into_iter())
.map(|key| self.bind_expr(key))
.try_collect()?;
let id = self.egraph.add(Node::List(list));
let id = self.bind_exprs(group_by)?;
if !self.aggs(id).is_empty() {
return Err(BindError::AggInGroupBy);
}
Ok(id)
}

/// Binds the ORDER BY clause. Returns a list of expressions.
fn bind_orderby(&mut self, order_by: Vec<OrderByExpr>) -> Result {
pub(super) fn bind_orderby(&mut self, order_by: Vec<OrderByExpr>) -> Result {
let mut orderby = Vec::with_capacity(order_by.len());
for e in order_by {
let expr = self.bind_expr(e.expr)?;
Expand Down Expand Up @@ -149,11 +159,7 @@ impl Binder {
"VALUES lists must all be the same length".into(),
));
}
let mut bound_row = Vec::with_capacity(column_len);
for expr in row {
bound_row.push(self.bind_expr(expr)?);
}
bound_values.push(self.egraph.add(Node::List(bound_row.into())));
bound_values.push(self.bind_exprs(row)?);
}
let id = self.egraph.add(Node::Values(bound_values.into()));
self.check_type(id)?;
Expand Down Expand Up @@ -265,4 +271,26 @@ impl Binder {
*projection = self.egraph.add(Node::List(projs.into()));
Ok(self.egraph.add(Node::Agg([aggs, distinct, plan])))
}

/// Extracts all over nodes from `projection`, `distinct` and `orderby`.
/// Generates an [`Window`](Node::Window) plan if any over node is found.
/// Otherwise returns the original `plan`.
fn plan_window(&mut self, projection: Id, distinct: Id, orderby: Id, plan: Id) -> Result {
let mut overs = vec![];
overs.extend_from_slice(self.overs(projection));
overs.extend_from_slice(self.overs(distinct));
overs.extend_from_slice(self.overs(orderby));

if overs.is_empty() {
return Ok(plan);
}
let mut list: Vec<_> = overs
.into_iter()
.map(|over| self.egraph.add(over))
.collect();
list.sort();
list.dedup();
let overs = self.egraph.add(Node::List(list.into()));
Ok(self.egraph.add(Node::Window([overs, plan])))
}
}
6 changes: 4 additions & 2 deletions src/executor_v2/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ impl<'a> Evaluator<'a> {
fn init_agg_state(&self) -> DataValue {
use Expr::*;
match self.node() {
RowCount | Count(_) => DataValue::Int32(0),
Over([window, _, _]) => self.next(*window).init_agg_state(),
RowCount | RowNumber | Count(_) => DataValue::Int32(0),
Sum(_) | Min(_) | Max(_) | First(_) | Last(_) => DataValue::Null,
t => panic!("not aggregation: {t}"),
}
Expand Down Expand Up @@ -197,7 +198,8 @@ impl<'a> Evaluator<'a> {
fn agg_append(&self, state: DataValue, value: DataValue) -> DataValue {
use Expr::*;
match self.node() {
RowCount => state.add(DataValue::Int32(1)),
Over([window, _, _]) => self.next(*window).agg_append(state, value),
RowCount | RowNumber => state.add(DataValue::Int32(1)),
Count(_) => state.add(DataValue::Int32(!value.is_null() as _)),
Sum(_) => state.add(value),
Min(_) => state.min(value),
Expand Down
7 changes: 7 additions & 0 deletions src/executor_v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ use self::simple_agg::*;
use self::table_scan::*;
use self::top_n::TopNExecutor;
use self::values::*;
use self::window::*;
use crate::array::DataChunk;
use crate::catalog::RootCatalogRef;
use crate::planner::{Expr, RecExpr, TypeSchemaAnalysis};
Expand Down Expand Up @@ -75,6 +76,7 @@ mod simple_agg;
mod table_scan;
mod top_n;
mod values;
mod window;

/// Join types for generating join code during the compilation.
#[derive(Copy, Clone, Eq, PartialEq)]
Expand Down Expand Up @@ -297,6 +299,11 @@ impl<S: Storage> Builder<S> {
.execute(self.build_id(child))
}
}
Window([exprs, child]) => WindowExecutor {
exprs: self.resolve_column_index(exprs, child),
types: self.plan_types(exprs).to_vec(),
}
.execute(self.build_id(child)),

CreateTable(plan) => CreateTableExecutor {
plan,
Expand Down
33 changes: 33 additions & 0 deletions src/executor_v2/window.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// Copyright 2023 RisingLight Project Authors. Licensed under Apache-2.0.

use super::*;
use crate::array::DataChunkBuilder;

/// The executor of window functions.
pub struct WindowExecutor {
/// A list of over window functions.
///
/// e.g. `(list (over (lag #0) list list))`
pub exprs: RecExpr,
/// The types of window function columns.
pub types: Vec<DataType>,
}

impl WindowExecutor {
#[try_stream(boxed, ok = DataChunk, error = ExecutorError)]
pub async fn execute(self, child: BoxedExecutor) {
let mut states = Evaluator::new(&self.exprs).init_agg_states::<Vec<_>>();

#[for_await]
for chunk in child {
let chunk = chunk?;
let mut builder = DataChunkBuilder::new(&self.types, chunk.cardinality() + 1);
for i in 0..chunk.cardinality() {
Evaluator::new(&self.exprs).agg_list_append(&mut states, chunk.row(i).values());
_ = builder.push_row(states.clone());
}
let window_chunk = builder.take().unwrap();
yield chunk.row_concat(window_chunk);
}
}
}
26 changes: 19 additions & 7 deletions src/planner/explain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,22 +162,29 @@ impl<'a> Explain<'a> {
],
),
Field(field) => Pretty::display(field),
Replace([a, b, c]) => Pretty::childless_record(
"Replace",
vec![
("in", self.expr(a).pretty()),
("from", self.expr(b).pretty()),
("to", self.expr(c).pretty()),
],
),

// aggregations
RowCount => "rowcount".into(),
RowCount | RowNumber => enode.to_string().into(),
Max(a) | Min(a) | Sum(a) | Avg(a) | Count(a) | First(a) | Last(a) => {
let name = enode.to_string();
let v = vec![self.expr(a).pretty()];
Pretty::fieldless_record(name, v)
}

Replace([a, b, c]) => Pretty::childless_record(
"Replace",
Over([f, orderby, partitionby]) => Pretty::simple_record(
"Over",
vec![
("in", self.expr(a).pretty()),
("from", self.expr(b).pretty()),
("to", self.expr(c).pretty()),
("order_by", self.expr(orderby).pretty()),
("partition_by", self.expr(partitionby).pretty()),
],
vec![self.expr(f).pretty()],
),

Exists(a) => {
Expand Down Expand Up @@ -279,6 +286,11 @@ impl<'a> Explain<'a> {
.with_cost(cost),
vec![self.child(child).pretty()],
),
Window([windows, child]) => Pretty::simple_record(
"Window",
vec![("windows", self.expr(windows).pretty())].with_cost(cost),
vec![self.child(child).pretty()],
),
CreateTable(t) => {
let fields = t.pretty_table().with_cost(cost);
Pretty::childless_record("CreateTable", fields)
Expand Down
25 changes: 22 additions & 3 deletions src/planner/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ define_language! {
// functions
"extract" = Extract([Id; 2]), // (extract field expr)
Field(DateTimeField),
"replace" = Replace([Id; 3]), // (replace expr pattern replacement)

// aggregations
"max" = Max(Id),
Expand All @@ -75,8 +76,11 @@ define_language! {
"rowcount" = RowCount,
"first" = First(Id),
"last" = Last(Id),

"replace" = Replace([Id; 3]),
// window functions
"over" = Over([Id; 3]), // (over window_function [partition_key..] [order_key..])
// TODO: support frame clause
// "range" = Range([Id; 2]), // (range start end)
"row_number" = RowNumber,

// subquery related
"exists" = Exists(Id),
Expand All @@ -102,8 +106,10 @@ define_language! {
"right_outer" = RightOuter,
"full_outer" = FullOuter,
"agg" = Agg([Id; 3]), // (agg aggs=[expr..] group_keys=[expr..] child)
// expressions must be agg
// expressions must be aggregate functions
// output = aggs || group_keys
"window" = Window([Id; 2]), // (window [over..] child)
// output = child || exprs
CreateTable(CreateTable),
Drop(BoundDrop),
"insert" = Insert([Id; 3]), // (insert table [column..] child)
Expand Down Expand Up @@ -196,6 +202,19 @@ impl Expr {
_ => return None,
})
}

pub const fn is_aggregate_function(&self) -> bool {
use Expr::*;
matches!(
self,
RowCount | Max(_) | Min(_) | Sum(_) | Avg(_) | Count(_) | First(_) | Last(_)
)
}

pub const fn is_window_function(&self) -> bool {
use Expr::*;
matches!(self, RowNumber) || self.is_aggregate_function()
}
}

trait ExprExt {
Expand Down
24 changes: 19 additions & 5 deletions src/planner/rules/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,24 @@ pub type AggSet = Vec<Expr>;
/// Note: if there is an agg over agg, e.g. `sum(count(a))`, only the upper one will be returned.
pub fn analyze_aggs(enode: &Expr, x: impl Fn(&Id) -> AggSet) -> AggSet {
use Expr::*;
if let RowCount | Max(_) | Min(_) | Sum(_) | Avg(_) | Count(_) | First(_) | Last(_) = enode {
return vec![enode.clone()];
match enode {
_ if enode.is_aggregate_function() => vec![enode.clone()],
Over(_) | Ref(_) => vec![],
// merge the set from all children
_ => enode.children().iter().flat_map(x).collect(),
}
}

/// The data type of over analysis.
pub type OverSet = Vec<Expr>;

/// Returns all over nodes in the tree.
pub fn analyze_overs(enode: &Expr, x: impl Fn(&Id) -> OverSet) -> OverSet {
use Expr::*;
match enode {
Over(_) => vec![enode.clone()],
Ref(_) => vec![],
// merge the set from all children
_ => enode.children().iter().flat_map(x).collect(),
}
// merge the set from all children
// TODO: ignore plan nodes
enode.children().iter().flat_map(x).collect()
}
Loading

0 comments on commit 604b4a1

Please sign in to comment.