From 6717d5fd250af859de0385eaaf09c0e96db2a343 Mon Sep 17 00:00:00 2001 From: Jackson Newhouse Date: Thu, 7 Dec 2023 15:56:00 -0800 Subject: [PATCH] check logical plan output schemas for supported types. --- arroyo-sql/src/pipeline.rs | 2 + arroyo-sql/src/test.rs | 17 ++++++++ arroyo-sql/src/types.rs | 88 +++++++++++++++++++++++++++++++++++++- 3 files changed, 106 insertions(+), 1 deletion(-) diff --git a/arroyo-sql/src/pipeline.rs b/arroyo-sql/src/pipeline.rs index 0059f6557..3a6f1dc08 100644 --- a/arroyo-sql/src/pipeline.rs +++ b/arroyo-sql/src/pipeline.rs @@ -437,6 +437,8 @@ impl<'a> SqlPipelineBuilder<'a> { } pub fn insert_sql_plan(&mut self, plan: &LogicalPlan) -> Result { + // Check that output types are supported by Arroyo to avoid compile-time errors + let _output_struct_def: StructDef = plan.schema().clone().try_into()?; match plan { LogicalPlan::Projection(projection) => self.insert_projection(projection), LogicalPlan::Filter(filter) => self.insert_filter(filter), diff --git a/arroyo-sql/src/test.rs b/arroyo-sql/src/test.rs index 1265ea68f..c7361a7dd 100644 --- a/arroyo-sql/src/test.rs +++ b/arroyo-sql/src/test.rs @@ -185,6 +185,23 @@ async fn test_no_inserting_updates_into_non_updating() { .unwrap_err(); } +#[tokio::test] +async fn durations_error_out() { + let schema_provider = get_test_schema_provider(); + let sql = "create table nexmark with ( + connector = 'nexmark', + event_rate = '5' + ); + + select bid.datetime - DATE '2023-12-03' + from nexmark + group by 1; + "; + let _ = parse_and_get_program(sql, schema_provider, SqlConfig::default()) + .await + .unwrap_err(); +} + #[tokio::test] async fn test_no_aggregates_in_window() { let schema_provider = get_test_schema_provider(); diff --git a/arroyo-sql/src/types.rs b/arroyo-sql/src/types.rs index ce179d56f..d9024afd9 100644 --- a/arroyo-sql/src/types.rs +++ b/arroyo-sql/src/types.rs @@ -23,7 +23,7 @@ use crate::avro; use arroyo_rpc::api_types::connections::{ FieldType, PrimitiveType, SourceField, SourceFieldType, StructType, }; -use datafusion_common::ScalarValue; +use datafusion_common::{DFField, DFSchemaRef, ScalarValue}; use proc_macro2::{Ident, TokenStream}; use quote::{format_ident, quote}; use regex::Regex; @@ -43,6 +43,18 @@ pub struct StructPair { pub right: StructDef, } +impl TryFrom for StructDef { + type Error = anyhow::Error; + fn try_from(schema: DFSchemaRef) -> Result { + let struct_fields: Vec = schema + .fields() + .iter() + .map(|field| field.try_into()) + .collect::>>()?; + Ok(Self::for_fields(struct_fields)) + } +} + impl StructDef { pub fn for_fields(fields: Vec) -> Self { Self { @@ -569,6 +581,18 @@ pub struct StructField { pub original_type: Option, } +impl TryFrom<&DFField> for StructField { + type Error = anyhow::Error; + + fn try_from(value: &DFField) -> Result { + Ok(StructField::new( + value.name().to_string(), + value.qualifier().map(|qualifier| qualifier.to_string()), + TypeDef::try_from_arrow(value.data_type(), value.is_nullable())?, + )) + } +} + impl StructField { pub fn new(name: String, alias: Option, data_type: TypeDef) -> Self { if let TypeDef::DataType(DataType::Struct(_), _) = &data_type { @@ -985,6 +1009,68 @@ impl TypeDef { TypeDef::DataType(data_type, _) => TypeDef::DataType(data_type.clone(), nullity), } } + pub fn try_from_arrow(data_type: &DataType, nullable: bool) -> Result { + match data_type { + DataType::Null + | DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Binary + | DataType::LargeBinary + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Timestamp(_, None) => Ok(TypeDef::DataType(data_type.clone(), nullable)), + + DataType::Timestamp(_, Some(_)) + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::FixedSizeBinary(_) + | DataType::Union(_, _) + | DataType::Dictionary(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Map(_, _) + | DataType::RunEndEncoded(_, _) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + | DataType::Interval(_) => bail!("{:?} not supported as struct type", data_type), + DataType::Struct(fields) => Ok(TypeDef::StructDef( + StructDef::for_fields( + fields + .iter() + .map(|field| { + Ok(StructField::new( + field.name().to_string(), + None, + Self::try_from_arrow(field.data_type(), field.is_nullable())?, + )) + }) + .collect::>>()?, + ), + nullable, + )), + DataType::List(field) => { + let TypeDef::DataType(..) = + Self::try_from_arrow(field.data_type(), field.is_nullable())? + else { + bail!("List contains unsupported data type {:?}", field); + }; + Ok(TypeDef::DataType(data_type.clone(), nullable)) + } + } + } } impl StructField {