diff --git a/crates/proof-of-sql/src/base/database/accessor.rs b/crates/proof-of-sql/src/base/database/accessor.rs index 45817b013..d9d7dcb50 100644 --- a/crates/proof-of-sql/src/base/database/accessor.rs +++ b/crates/proof-of-sql/src/base/database/accessor.rs @@ -1,6 +1,7 @@ use crate::base::{ commitment::Commitment, - database::{Column, ColumnRef, ColumnType, TableRef}, + database::{Column, ColumnRef, ColumnType, Table, TableOptions, TableRef}, + map::{IndexMap, IndexSet}, scalar::Scalar, }; use alloc::vec::Vec; @@ -85,6 +86,28 @@ pub trait CommitmentAccessor: MetadataAccessor { pub trait DataAccessor: MetadataAccessor { /// Return the data span in the table (not the full-table data) fn get_column(&self, column: ColumnRef) -> Column; + + /// Creates a new [`Table`] from a [`TableRef`] and [`ColumnRef`]s. + /// + /// Columns are retrieved from the [`DataAccessor`] using the provided [`TableRef`] and [`ColumnRef`]s. + /// The only reason why [`table_ref` is needed is because `column_refs` can be empty. + /// # Panics + /// Column length mismatches can occur in theory. In practice, this should not happen. + fn get_table(&self, table_ref: TableRef, column_refs: &IndexSet) -> Table { + if column_refs.is_empty() { + let input_length = self.get_length(table_ref); + Table::::try_new_with_options( + IndexMap::default(), + TableOptions::new(Some(input_length)), + ) + } else { + Table::::try_from_iter(column_refs.into_iter().map(|column_ref| { + let column = self.get_column(*column_ref); + (column_ref.column_id(), column) + })) + } + .expect("Failed to create table from table and column references") + } } /// Access tables and their schemas in a database. diff --git a/crates/proof-of-sql/src/base/database/mod.rs b/crates/proof-of-sql/src/base/database/mod.rs index a8cc8fcd2..82047c320 100644 --- a/crates/proof-of-sql/src/base/database/mod.rs +++ b/crates/proof-of-sql/src/base/database/mod.rs @@ -65,9 +65,9 @@ mod owned_table_test; pub mod owned_table_utility; mod table; -pub use table::Table; #[cfg(test)] pub(crate) use table::TableError; +pub use table::{Table, TableOptions}; #[cfg(test)] mod table_test; pub mod table_utility; diff --git a/crates/proof-of-sql/src/base/database/table.rs b/crates/proof-of-sql/src/base/database/table.rs index 29aebaf36..f2ee2648c 100644 --- a/crates/proof-of-sql/src/base/database/table.rs +++ b/crates/proof-of-sql/src/base/database/table.rs @@ -1,19 +1,44 @@ -use super::{Column, ColumnRef, DataAccessor, TableRef}; -use crate::base::{ - map::{IndexMap, IndexSet}, - scalar::Scalar, -}; -use alloc::vec; -use bumpalo::Bump; +use super::Column; +use crate::base::{map::IndexMap, scalar::Scalar}; use proof_of_sql_parser::Identifier; use snafu::Snafu; +/// Options for creating a table. +/// Inspired by [`RecordBatchOptions`](https://docs.rs/arrow/latest/arrow/record_batch/struct.RecordBatchOptions.html) +#[derive(Debug, Clone, Copy)] +pub struct TableOptions { + /// The number of rows in the table. Mostly useful for tables without columns. + pub row_count: Option, +} + +impl Default for TableOptions { + fn default() -> Self { + Self { row_count: None } + } +} + +impl TableOptions { + /// Creates a new [`TableOptions`]. + #[must_use] + pub fn new(row_count: Option) -> Self { + Self { row_count } + } +} + /// An error that occurs when working with tables. #[derive(Snafu, Debug, PartialEq, Eq)] pub enum TableError { /// The columns have different lengths. #[snafu(display("Columns have different lengths"))] ColumnLengthMismatch, + + /// At least one column has length different from the provided row count. + #[snafu(display("Column has length different from the provided row count"))] + ColumnLengthMismatchWithSpecifiedRowCount, + + /// The table is empty and there is no specified row count. + #[snafu(display("Table is empty and no row count is specified"))] + EmptyTableWithoutSpecifiedRowCount, } /// A table of data, with schema included. This is simply a map from `Identifier` to `Column`, /// where columns order matters. @@ -23,56 +48,61 @@ pub enum TableError { #[derive(Debug, Clone, Eq)] pub struct Table<'a, S: Scalar> { table: IndexMap>, - num_rows: usize, + row_count: usize, } impl<'a, S: Scalar> Table<'a, S> { - /// Creates a new [`Table`]. + /// Creates a new [`Table`] with the given columns and default [`TableOptions`]. pub fn try_new(table: IndexMap>) -> Result { - if table.is_empty() { - // `EmptyExec` should have one row for queries such as `SELECT 1`. - return Ok(Self { table, num_rows: 1 }); - } - let num_rows = table[0].len(); - if table.values().any(|column| column.len() != num_rows) { - Err(TableError::ColumnLengthMismatch) - } else { - Ok(Self { table, num_rows }) + Self::try_new_with_options(table, TableOptions::default()) + } + + /// Creates a new [`Table`] with the given columns and with [`TableOptions`]. + pub fn try_new_with_options( + table: IndexMap>, + options: TableOptions, + ) -> Result { + match (table.is_empty(), options.row_count) { + (true, None) => Err(TableError::EmptyTableWithoutSpecifiedRowCount), + (true, Some(row_count)) => Ok(Self { + table, + row_count, + }), + (false, None) => { + let row_count = table[0].len(); + if table.values().any(|column| column.len() != row_count) { + Err(TableError::ColumnLengthMismatch) + } else { + Ok(Self { table, row_count }) + } + } + (false, Some(row_count)) => { + if table.values().any(|column| column.len() != row_count) { + Err(TableError::ColumnLengthMismatchWithSpecifiedRowCount) + } else { + Ok(Self { + table, + row_count, + }) + } + } } } - /// Creates a new [`Table`]. + + /// Creates a new [`Table`] from an iterator of `(Identifier, Column)` pairs with default [`TableOptions`]. pub fn try_from_iter)>>( iter: T, ) -> Result { - Self::try_new(IndexMap::from_iter(iter)) + Self::try_from_iter_with_options(iter, TableOptions::default()) } - /// Creates a new [`Table`] from a [`DataAccessor`], [`TableRef`] and [`ColumnRef`]s. - /// - /// Columns are retrieved from the [`DataAccessor`] using the provided [`ColumnRef`]s. - /// # Panics - /// Missing columns or column length mismatches can occur if the accessor doesn't - /// contain the necessary columns. In practice, this should not happen. - pub(crate) fn from_columns( - column_refs: &IndexSet, - table_ref: TableRef, - accessor: &'a dyn DataAccessor, - alloc: &'a Bump, - ) -> Self { - if column_refs.is_empty() { - // TODO: Currently we have to have non-empty column references to have a non-empty table - // to evaluate `ProofExpr`s on. Once we restrict [`DataAccessor`] to [`TableExec`] - // and use input `DynProofPlan`s we should no longer need this. - let input_length = accessor.get_length(table_ref); - let bogus_vec = vec![true; input_length]; - let bogus_col = Column::Boolean(alloc.alloc_slice_copy(&bogus_vec)); - Table::<'a, S>::try_from_iter(core::iter::once(("bogus".parse().unwrap(), bogus_col))) - } else { - Table::<'a, S>::try_from_iter(column_refs.into_iter().map(|column_ref| { - let column = accessor.get_column(*column_ref); - (column_ref.column_id(), column) - })) - } - .expect("Failed to create table from column references") + + /// Creates a new [`Table`] from an iterator of `(Identifier, Column)` pairs with [`TableOptions`]. + pub fn try_from_iter_with_options)>>( + iter: T, + options: TableOptions, + ) -> Result { + Self::try_new_with_options(IndexMap::from_iter(iter), options) } + /// Number of columns in the table. #[must_use] pub fn num_columns(&self) -> usize { @@ -81,7 +111,7 @@ impl<'a, S: Scalar> Table<'a, S> { /// Number of rows in the table. #[must_use] pub fn num_rows(&self) -> usize { - self.num_rows + self.row_count } /// Whether the table has no columns. #[must_use] diff --git a/crates/proof-of-sql/src/base/database/table_test.rs b/crates/proof-of-sql/src/base/database/table_test.rs index 179dff057..f61c67efc 100644 --- a/crates/proof-of-sql/src/base/database/table_test.rs +++ b/crates/proof-of-sql/src/base/database/table_test.rs @@ -1,6 +1,6 @@ use crate::base::{ - database::{table_utility::*, Column, Table, TableError}, - map::IndexMap, + database::{table_utility::*, Column, Table, TableError, TableOptions}, + map::{indexmap, IndexMap}, scalar::test_scalar::TestScalar, }; use bumpalo::Bump; @@ -10,13 +10,104 @@ use proof_of_sql_parser::{ }; #[test] -fn we_can_create_a_table_with_no_columns() { - let table = Table::::try_new(IndexMap::default()).unwrap(); +fn we_can_create_a_table_with_no_columns_specifying_row_count() { + let table = + Table::::try_new_with_options(IndexMap::default(), TableOptions::new(Some(1))) + .unwrap(); assert_eq!(table.num_columns(), 0); assert_eq!(table.num_rows(), 1); + + let table = + Table::::try_new_with_options(IndexMap::default(), TableOptions::new(Some(0))) + .unwrap(); + assert_eq!(table.num_columns(), 0); + assert_eq!(table.num_rows(), 0); +} + +#[test] +fn we_can_create_a_table_with_default_options() { + let table = Table::::try_new(indexmap! { + "a".parse().unwrap() => Column::BigInt(&[0, 1]), + "b".parse().unwrap() => Column::Int128(&[0, 1]), + }) + .unwrap(); + assert_eq!(table.num_columns(), 2); + assert_eq!(table.num_rows(), 2); + + let table = Table::::try_new(indexmap! { + "a".parse().unwrap() => Column::BigInt(&[]), + "b".parse().unwrap() => Column::Int128(&[]), + }) + .unwrap(); + assert_eq!(table.num_columns(), 2); + assert_eq!(table.num_rows(), 0); +} + +#[test] +fn we_can_create_a_table_with_specified_row_count() { + let table = Table::::try_new_with_options( + indexmap! { + "a".parse().unwrap() => Column::BigInt(&[0, 1]), + "b".parse().unwrap() => Column::Int128(&[0, 1]), + }, + TableOptions::new(Some(2)), + ) + .unwrap(); + assert_eq!(table.num_columns(), 2); + assert_eq!(table.num_rows(), 2); + + let table = Table::::try_new_with_options( + indexmap! { + "a".parse().unwrap() => Column::BigInt(&[]), + "b".parse().unwrap() => Column::Int128(&[]), + }, + TableOptions::new(Some(0)), + ) + .unwrap(); + assert_eq!(table.num_columns(), 2); + assert_eq!(table.num_rows(), 0); +} + +#[test] +fn we_cannot_create_a_table_with_differing_column_lengths() { + assert!(matches!( + Table::::try_from_iter([ + ("a".parse().unwrap(), Column::BigInt(&[0])), + ("b".parse().unwrap(), Column::BigInt(&[])), + ]), + Err(TableError::ColumnLengthMismatch) + )); } + #[test] -fn we_can_create_an_empty_table() { +fn we_cannot_create_a_table_with_column_length_different_from_specified_row_count() { + assert!(matches!( + Table::::try_from_iter_with_options( + [ + ("a".parse().unwrap(), Column::BigInt(&[0])), + ("b".parse().unwrap(), Column::BigInt(&[1])), + ], + TableOptions::new(Some(0)) + ), + Err(TableError::ColumnLengthMismatchWithSpecifiedRowCount) + )); +} + +#[test] +fn we_cannot_create_a_table_with_no_columns_without_specified_row_count() { + assert!(matches!( + Table::::try_from_iter_with_options([], TableOptions::new(None)), + Err(TableError::EmptyTableWithoutSpecifiedRowCount) + )); + + assert!(matches!( + Table::::try_new(IndexMap::default()), + Err(TableError::EmptyTableWithoutSpecifiedRowCount) + )); +} + +#[test] +fn we_can_create_an_empty_table_with_some_columns() { let alloc = Bump::new(); let borrowed_table = table::([ borrowed_bigint("bigint", [0; 0], &alloc), @@ -193,14 +284,3 @@ fn we_get_inequality_between_tables_with_differing_data() { assert_ne!(table_a, table_b); } - -#[test] -fn we_cannot_create_a_table_with_differing_column_lengths() { - assert!(matches!( - Table::::try_from_iter([ - ("a".parse().unwrap(), Column::BigInt(&[0])), - ("b".parse().unwrap(), Column::BigInt(&[])), - ]), - Err(TableError::ColumnLengthMismatch) - )); -} diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs index 3cd87dac0..b8380c87f 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec.rs @@ -3,7 +3,7 @@ use crate::{ base::{ database::{ filter_util::filter_columns, Column, ColumnField, ColumnRef, DataAccessor, OwnedTable, - Table, TableRef, + TableRef, }, map::{IndexMap, IndexSet}, proof::ProofError, @@ -143,8 +143,7 @@ impl ProverEvaluate for FilterExec { accessor: &'a dyn DataAccessor, ) -> Vec> { let column_refs = self.get_column_references(); - let used_table = - Table::<'a, S>::from_columns(&column_refs, self.table.table_ref, accessor, alloc); + let used_table = accessor.get_table(self.table.table_ref, &column_refs); // 1. selection let selection_column: Column<'a, S> = self.where_clause.result_evaluate(alloc, &used_table); let selection = selection_column @@ -176,8 +175,7 @@ impl ProverEvaluate for FilterExec { accessor: &'a dyn DataAccessor, ) -> Vec> { let column_refs = self.get_column_references(); - let used_table = - Table::<'a, S>::from_columns(&column_refs, self.table.table_ref, accessor, alloc); + let used_table = accessor.get_table(self.table.table_ref, &column_refs); // 1. selection let selection_column: Column<'a, S> = self.where_clause diff --git a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs index 56f7fac9e..07f18b4bf 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/filter_exec_test_dishonest_prover.rs @@ -3,7 +3,7 @@ use crate::{ base::{ database::{ filter_util::*, owned_table_utility::*, Column, DataAccessor, OwnedTableTestAccessor, - Table, TestAccessor, + TestAccessor, }, proof::ProofError, scalar::Scalar, @@ -39,8 +39,7 @@ impl ProverEvaluate for DishonestFilterExec { accessor: &'a dyn DataAccessor, ) -> Vec> { let column_refs = self.get_column_references(); - let used_table = - Table::<'a, S>::from_columns(&column_refs, self.table.table_ref, accessor, alloc); + let used_table = accessor.get_table(self.table.table_ref, &column_refs); // 1. selection let selection_column: Column<'a, S> = self.where_clause.result_evaluate(alloc, &used_table); let selection = selection_column @@ -75,8 +74,7 @@ impl ProverEvaluate for DishonestFilterExec { accessor: &'a dyn DataAccessor, ) -> Vec> { let column_refs = self.get_column_references(); - let used_table = - Table::<'a, S>::from_columns(&column_refs, self.table.table_ref, accessor, alloc); + let used_table = accessor.get_table(self.table.table_ref, &column_refs); // 1. selection let selection_column: Column<'a, S> = self.where_clause diff --git a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs index a3ad2e8a0..0773da752 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/group_by_exec.rs @@ -5,7 +5,7 @@ use crate::{ group_by_util::{ aggregate_columns, compare_indexes_by_owned_columns, AggregatedColumns, }, - Column, ColumnField, ColumnRef, ColumnType, DataAccessor, OwnedTable, Table, TableRef, + Column, ColumnField, ColumnRef, ColumnType, DataAccessor, OwnedTable, TableRef, }, map::{IndexMap, IndexSet}, proof::ProofError, @@ -202,8 +202,7 @@ impl ProverEvaluate for GroupByExec { accessor: &'a dyn DataAccessor, ) -> Vec> { let column_refs = self.get_column_references(); - let used_table = - Table::<'a, S>::from_columns(&column_refs, self.table.table_ref, accessor, alloc); + let used_table = accessor.get_table(self.table.table_ref, &column_refs); // 1. selection let selection_column: Column<'a, S> = self.where_clause.result_evaluate(alloc, &used_table); @@ -251,8 +250,7 @@ impl ProverEvaluate for GroupByExec { accessor: &'a dyn DataAccessor, ) -> Vec> { let column_refs = self.get_column_references(); - let used_table = - Table::<'a, S>::from_columns(&column_refs, self.table.table_ref, accessor, alloc); + let used_table = accessor.get_table(self.table.table_ref, &column_refs); // 1. selection let selection_column: Column<'a, S> = self.where_clause diff --git a/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs b/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs index 326158c76..6c90b2a47 100644 --- a/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs +++ b/crates/proof-of-sql/src/sql/proof_plans/projection_exec.rs @@ -1,6 +1,6 @@ use crate::{ base::{ - database::{Column, ColumnField, ColumnRef, DataAccessor, OwnedTable, Table, TableRef}, + database::{Column, ColumnField, ColumnRef, DataAccessor, OwnedTable, TableRef}, map::{IndexMap, IndexSet}, proof::ProofError, scalar::Scalar, @@ -91,8 +91,7 @@ impl ProverEvaluate for ProjectionExec { accessor: &'a dyn DataAccessor, ) -> Vec> { let column_refs = self.get_column_references(); - let used_table = - Table::<'a, S>::from_columns(&column_refs, self.table.table_ref, accessor, alloc); + let used_table = accessor.get_table(self.table.table_ref, &column_refs); let columns: Vec<_> = self .aliased_results .iter() @@ -116,8 +115,7 @@ impl ProverEvaluate for ProjectionExec { accessor: &'a dyn DataAccessor, ) -> Vec> { let column_refs = self.get_column_references(); - let used_table = - Table::<'a, S>::from_columns(&column_refs, self.table.table_ref, accessor, alloc); + let used_table = accessor.get_table(self.table.table_ref, &column_refs); // 1. Evaluate result expressions let res: Vec<_> = self .aliased_results