Skip to content

Commit

Permalink
fix(sql): table aliases (#11988)
Browse files Browse the repository at this point in the history
Co-authored-by: cory.grinstead <[email protected]>
  • Loading branch information
universalmind303 and cgrins authored Oct 24, 2023
1 parent 60adaef commit 58511d4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 9 deletions.
21 changes: 18 additions & 3 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub struct SQLContext {
pub(crate) table_map: PlHashMap<String, LazyFrame>,
pub(crate) function_registry: Arc<dyn FunctionRegistry>,
cte_map: RefCell<PlHashMap<String, LazyFrame>>,
aliases: RefCell<PlHashMap<String, String>>,
}

impl Default for SQLContext {
Expand All @@ -33,6 +34,7 @@ impl Default for SQLContext {
function_registry: Arc::new(DefaultFunctionRegistry {}),
table_map: Default::default(),
cte_map: Default::default(),
aliases: Default::default(),
}
}
}
Expand Down Expand Up @@ -113,6 +115,7 @@ impl SQLContext {
let res = self.execute_statement(ast.get(0).unwrap());
// Every execution should clear the CTE map.
self.cte_map.borrow_mut().clear();
self.aliases.borrow_mut().clear();
res
}

Expand All @@ -139,9 +142,16 @@ impl SQLContext {
self.cte_map.borrow_mut().insert(name.to_owned(), lf);
}

fn get_table_from_current_scope(&self, name: &str) -> Option<LazyFrame> {
pub(super) fn get_table_from_current_scope(&self, name: &str) -> Option<LazyFrame> {
let table_name = self.table_map.get(name).cloned();
table_name.or_else(|| self.cte_map.borrow().get(name).cloned())
table_name
.or_else(|| self.cte_map.borrow().get(name).cloned())
.or_else(|| {
self.aliases
.borrow()
.get(name)
.and_then(|alias| self.table_map.get(alias).cloned())
})
}

pub(crate) fn execute_statement(&mut self, stmt: &Statement) -> PolarsResult<LazyFrame> {
Expand Down Expand Up @@ -574,7 +584,12 @@ impl SQLContext {
let tbl_name = name.0.get(0).unwrap().value.as_str();
if let Some(lf) = self.get_table_from_current_scope(tbl_name) {
match alias {
Some(alias) => Ok((alias.to_string(), lf)),
Some(alias) => {
self.aliases
.borrow_mut()
.insert(alias.name.value.clone(), tbl_name.to_string());
Ok((alias.to_string(), lf))
},
None => Ok((tbl_name.to_string(), lf)),
}
} else {
Expand Down
15 changes: 9 additions & 6 deletions crates/polars-sql/src/sql_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,15 @@ impl SqlExprVisitor<'_> {
fn visit_compound_identifier(&self, idents: &[Ident]) -> PolarsResult<Expr> {
match idents {
[tbl_name, column_name] => {
let lf = self.ctx.table_map.get(&tbl_name.value).ok_or_else(|| {
polars_err!(
ComputeError: "no table named '{}' found",
tbl_name
)
})?;
let lf = self
.ctx
.get_table_from_current_scope(&tbl_name.value)
.ok_or_else(|| {
polars_err!(
ComputeError: "no table or alias named '{}' found",
tbl_name
)
})?;

let schema = lf.schema()?;
if let Some((_, name, _)) = schema.get_full(&column_name.value) {
Expand Down
13 changes: 13 additions & 0 deletions crates/polars-sql/tests/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,19 @@ fn create_ctx() -> SQLContext {
ctx
}

#[test]
fn tbl_alias() {
let mut ctx = create_ctx();
let sql = r#"
SELECT
tbl.a,
tbl.b,
FROM df as tbl
"#;
let actual = ctx.execute(sql);
assert!(actual.is_ok());
}

#[test]
fn trailing_commas_allowed() {
let mut ctx = create_ctx();
Expand Down

0 comments on commit 58511d4

Please sign in to comment.