Skip to content

Commit

Permalink
more work
Browse files Browse the repository at this point in the history
  • Loading branch information
jacksonrnewhouse committed Dec 7, 2023
1 parent c219b5d commit be2a415
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 17 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

32 changes: 32 additions & 0 deletions arroyo-datastream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,14 @@ pub enum Operator {
name: String,
config: Vec<u8>,
},
ArrowValue {
name: String,
config: Vec<u8>
},
ArrowKey {
name: String,
config: Vec<u8>
}
}

#[derive(Clone, Encode, Decode, Debug, Serialize, Deserialize, PartialEq, Eq)]
Expand Down Expand Up @@ -518,6 +526,8 @@ impl Debug for Operator {
Operator::ArrowProjection { name, config: _ } => {
write!(f, "arrow_projection<{}>", name)
}
Operator::ArrowValue { name, config: _ } => write!(f, "arrow_value<{}>", name),
Operator::ArrowKey { name, config } => write!(f, "arrow_key<{}>", name)
}
}
}
Expand Down Expand Up @@ -1896,6 +1906,20 @@ impl Program {
ProjectionOperator::from_config(#name.to_string(), hex::decode(#hex_string).unwrap()).unwrap())
}
},
Operator::ArrowValue { name, config } => {
let hex_string = hex::encode(config);
quote! {
Box::new(arroyo_worker::arrow::
ProjectionOperator::from_config(#name.to_string(), hex::decode(#hex_string).unwrap()).unwrap())
}
},
Operator::ArrowKey { name, config } => {
let hex_string = hex::encode(config);
quote! {
Box::new(arroyo_worker::arrow::
ProjectionOperator::from_config(#name.to_string(), hex::decode(#hex_string).unwrap()).unwrap())
}
},
};

(node.operator_id.clone(), description, body, node.parallelism)
Expand Down Expand Up @@ -2260,6 +2284,12 @@ impl From<Operator> for GrpcApi::operator::Operator {
config: Some(config),
})
}
Operator::ArrowValue { name, config } => {
GrpcOperator::NamedOperator(GrpcApi::NamedOperator { name, operator: GrpcApi::OperatorName::ArrowValuePlan.into(), config: Some(config) })
},
Operator::ArrowKey { name, config } => {
GrpcOperator::NamedOperator(GrpcApi::NamedOperator { name, operator: GrpcApi::OperatorName::ArrowKeyPlan.into(), config: Some(config) })
},
}
}
}
Expand Down Expand Up @@ -2582,6 +2612,8 @@ impl TryFrom<arroyo_rpc::grpc::api::Operator> for Operator {
name,
config: named_operator.config.unwrap(),
},
GrpcApi::OperatorName::ArrowValuePlan => Operator::ArrowValue { name, config: named_operator.config.unwrap() },
GrpcApi::OperatorName::ArrowKeyPlan => Operator::ArrowKey { name, config: named_operator.config.unwrap() },
}
}
},
Expand Down
7 changes: 7 additions & 0 deletions arroyo-rpc/proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,17 @@ message ProjectionOperator {
bytes output_schema = 5;
}

message ValuePlanOperator {
string name = 1;
bytes physical_plan = 2;
}

enum OperatorName {
STRUCT_TO_RECORD_BATCH = 0;
RECORD_BATCH_TO_STRUCT = 1;
ARROW_PROJECTION = 2;
ARROW_VALUE_PLAN = 3;
ARROW_KEY_PLAN = 4;
}

message WasmUdfs {
Expand Down
73 changes: 63 additions & 10 deletions arroyo-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use arroyo_datastream::{Program, WindowType};

use arroyo_rpc::grpc::api::window;
use datafusion::physical_plan::functions::make_scalar_function;
use datafusion_common::{Column, OwnedTableReference, Result as DFResult};
use datafusion_common::{Column, OwnedTableReference, Result as DFResult, DFField};
pub mod avro;
pub(crate) mod code_gen;
pub mod expressions;
Expand Down Expand Up @@ -56,6 +56,7 @@ use arroyo_rpc::formats::{Format, JsonFormat};
use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError};
use prettyplease::unparse;
use regex::Regex;
use std::borrow::BorrowMut;
use std::collections::HashSet;
use std::fmt::Debug;

Expand Down Expand Up @@ -478,19 +479,63 @@ pub(crate) struct QueryToGraphVisitor {
table_source_to_nodes: HashMap<OwnedTableReference, NodeIndex>,
}

#[derive(Default)]
struct KeyTimestampRewriter {

}

impl TreeNodeRewriter for KeyTimestampRewriter {
type N = LogicalPlan;

fn mutate(&mut self,mut node: Self::N) -> DFResult<Self::N> {
match node {
LogicalPlan::Projection(ref mut projection) => {
projection.schema = add_timestamp_field(projection.schema.clone())?;
projection.expr.push(Expr::Column(Column { relation: None, name: "_timestamp".to_string()}));
},
LogicalPlan::Aggregate(ref mut aggregate) => {
aggregate.schema = add_timestamp_field(aggregate.schema.clone())?;
},
LogicalPlan::Join(ref mut join) => {
join.schema = add_timestamp_field(join.schema.clone())?;
},
LogicalPlan::Union(ref mut union) => {
union.schema = add_timestamp_field(union.schema.clone())?;
},
LogicalPlan::TableScan(ref mut table_scan) => {
table_scan.projected_schema = add_timestamp_field(table_scan.projected_schema.clone())?;
},
LogicalPlan::SubqueryAlias(ref mut subquery_alias) => {
let timestamp_field = DFField::new(Some(subquery_alias.alias.clone()), "_timestamp", DataType::Timestamp(TimeUnit::Nanosecond, None), false);
subquery_alias.schema = Arc::new(subquery_alias.schema.join(&DFSchema::new_with_metadata(vec![timestamp_field], HashMap::new())?)?);
}
_ => {}
}
Ok(node)
}
}

fn add_timestamp_field(schema: DFSchemaRef) -> DFResult<DFSchemaRef> {
let timestamp_field = DFField::new_unqualified("_timestamp", DataType::Timestamp(TimeUnit::Nanosecond, None), false);
Ok(Arc::new(schema.join(&DFSchema::new_with_metadata(vec![timestamp_field], HashMap::new())?)?))
}

#[derive(Debug)]
enum LogicalPlanExtension {
ValueCalculation(LogicalPlan),
KeyCalculation(LogicalPlan),
AggregateCalculation(AggregateCalculation),
Sink,
}

impl LogicalPlanExtension {
// used for finding input TableScans, if the variant already manually crafts its edges, return None.
fn inner_logical_plan(&self) -> Option<&LogicalPlan> {
match self {
LogicalPlanExtension::ValueCalculation(inner_plan)
| LogicalPlanExtension::KeyCalculation(inner_plan) => Some(inner_plan),
LogicalPlanExtension::AggregateCalculation(_) => None,
LogicalPlanExtension::Sink => None,
}
}
fn outgoing_edge(&self) -> DataFusionEdge {
Expand All @@ -511,6 +556,7 @@ impl LogicalPlanExtension {
value_schema: aggregate_calculation.aggregate.schema.clone(),
key_schema: None,
},
LogicalPlanExtension::Sink => unreachable!()
}
}
}
Expand Down Expand Up @@ -767,6 +813,7 @@ pub fn rewrite_experiment(
)?);
};
}
let mut rewriter = QueryToGraphVisitor::default();
for insert in inserts {
let mut plan = match insert {
Insert::InsertQuery {
Expand All @@ -775,13 +822,19 @@ pub fn rewrite_experiment(
} => logical_plan,
Insert::Anonymous { logical_plan } => logical_plan,
};
let mut rewriter = QueryToGraphVisitor::default();
println!("plan {:?}", plan);
println!("plan {:?}\n\n", plan);
let plan_rewrite = plan.rewrite(&mut rewriter).unwrap();
println!("plan rewrite {:?}", plan_rewrite);
rewriter
let extended_plan_node = LogicalPlanExtension::ValueCalculation(plan_rewrite);
let edge = extended_plan_node.outgoing_edge();
let plan_index = rewriter
.local_logical_plan_graph
.add_node(LogicalPlanExtension::ValueCalculation(plan_rewrite));
.add_node(extended_plan_node);

let sink_index = rewriter.local_logical_plan_graph.add_node(LogicalPlanExtension::Sink);
rewriter.local_logical_plan_graph.add_edge(plan_index, sink_index, edge);


let mut edges = vec![];
for (node_index, node) in rewriter.local_logical_plan_graph.node_references() {
let Some(logical_plan) = node.inner_logical_plan() else {
Expand Down Expand Up @@ -811,12 +864,12 @@ pub fn rewrite_experiment(
for (a, b, weight) in edges {
rewriter.local_logical_plan_graph.add_edge(a, b, weight);
}
println!("rewriter: {:?}", rewriter);
println!(
"graph: {:?}",
petgraph::dot::Dot::with_config(&rewriter.local_logical_plan_graph, &[])
);
}
println!("rewriter: {:?}", rewriter);
println!(
"graph: {:?}",
petgraph::dot::Dot::with_config(&rewriter.local_logical_plan_graph, &[])
);
Ok(())
}

Expand Down
21 changes: 19 additions & 2 deletions arroyo-sql/src/plan_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ use datafusion::{
physical_plan::PhysicalExpr,
physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner},
};
use petgraph::graph::{DiGraph, NodeIndex};
use petgraph::{graph::{DiGraph, NodeIndex}, visit::Topo};
use quote::{quote, ToTokens};
use syn::{parse_quote, parse_str, Type};

use crate::expressions::AggregateComputation;
use crate::{expressions::AggregateComputation, QueryToGraphVisitor};
use crate::{
code_gen::{
BinAggregatingContext, CodeGenerator, CombiningContext, JoinListsContext, JoinPairContext,
Expand Down Expand Up @@ -2112,6 +2112,23 @@ impl From<PlanGraph> for DiGraph<StreamNode, StreamEdge> {
}
}

pub (crate) fn get_arrow_program(mut rewriter: QueryToGraphVisitor, schema_provider: ArroyoSchemaProvider) -> Result<CompiledSql> {
let mut topo = Topo::new(&rewriter.local_logical_plan_graph);
let program_graph :DiGraph<StreamNode, StreamEdge> = DiGraph::new();
while let Some(node_index) = topo.next(&rewriter.local_logical_plan_graph) {
let logical_extension = rewriter.local_logical_plan_graph.node_weight(node_index).unwrap();
match logical_extension {
crate::LogicalPlanExtension::ValueCalculation(logical_plan) => {

},
crate::LogicalPlanExtension::KeyCalculation(_) => todo!(),
crate::LogicalPlanExtension::AggregateCalculation(_) => todo!(),
crate::LogicalPlanExtension::Sink => todo!(),
}
}
todo!()
}

pub fn get_program(
mut plan_graph: PlanGraph,
schema_provider: ArroyoSchemaProvider,
Expand Down
12 changes: 7 additions & 5 deletions arroyo-sql/src/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,14 @@ GROUP BY bidder, HOP(INTERVAL '3 second', INTERVAL '10' minute)) WHERE distinct_
#[tokio::test]
async fn test_query_parsing() {
let schema_provider = get_test_schema_provider();
let sql = "WITH bids as (
SELECT bid.auction as auction, bid.price as price, bid.bidder as bidder, bid.extra as extra, bid.datetime as datetime
FROM nexmark where bid is not null)
let sql = "
CREATE TABLE impulse WITH (
connector = 'impulse',
event_rate = '10'
);
SELECT * FROM (
SELECT bidder, hop(interval '3 second', interval '10 minute'), count(*) as bids from bids GROUP BY 1,2)
WHERE bids > 10";
SELECT counter, hop(interval '3 second', interval '10 minute'), count(*) as rows from impulse GROUP BY 1,2)
WHERE rows > 10";
rewrite_experiment(sql.to_string(), schema_provider, SqlConfig::default()).unwrap();
}

Expand Down
1 change: 1 addition & 0 deletions arroyo-worker/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ memchr = "2.6.3"
apache-avro = "0.16.0"
redis = { version = "0.23.3", features = ["default", "tokio-rustls-comp", "cluster-async", "connection-manager"] }

datafusion = "31.0"
datafusion-proto = "31.0.0"
datafusion-expr = "31.0.0"
datafusion-physical-expr = "31.0"
Expand Down
80 changes: 80 additions & 0 deletions arroyo-worker/src/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@ use arroyo_types::KeyValueTimestampRecordBatch;
use arroyo_types::KeyValueTimestampRecordBatchBuilder;
use arroyo_types::RecordBatchBuilder;
use arroyo_types::{Key, Record, RecordBatchData};
use datafusion::execution::context::SessionContext;
use datafusion::physical_plan::ExecutionPlan;
use datafusion_execution::FunctionRegistry;
use datafusion_execution::TaskContext;
use datafusion_execution::runtime_env::RuntimeConfig;
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_expr::AggregateUDF;
use datafusion_expr::ScalarUDF;
use datafusion_expr::WindowUDF;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_proto::physical_plan::AsExecutionPlan;
use datafusion_proto::physical_plan::DefaultPhysicalExtensionCodec;
use datafusion_proto::physical_plan::from_proto::parse_physical_expr;
use datafusion_proto::protobuf::PhysicalExprNode;
use datafusion_proto::protobuf::PhysicalPlanNode;
use futures::StreamExt;
use prost::Message;
use std::collections::HashSet;
use std::marker::PhantomData;
Expand Down Expand Up @@ -120,6 +129,77 @@ impl ProcessFuncTrait for ProjectionOperator {
}
}

pub struct ValueExecutionOperator {
name: String,
execution_plan: Arc<dyn ExecutionPlan>
}

impl ValueExecutionOperator {
pub fn from_config(name: String, config: Vec<u8>) -> Result<Self> {
let proto_config: arroyo_rpc::grpc::api::ValuePlanOperator =
arroyo_rpc::grpc::api::ValuePlanOperator::decode(&mut config.as_slice()).unwrap();

let registry = Registry {};

let plan = PhysicalPlanNode::decode(&mut proto_config.physical_plan.as_slice()).unwrap();
let execution_plan = plan.try_into_physical_plan(&registry,
&RuntimeEnv::new(RuntimeConfig::new())?, &DefaultPhysicalExtensionCodec{})?;

Ok(Self {
name,
execution_plan
})
}
}

#[async_trait::async_trait]
impl ProcessFuncTrait for ValueExecutionOperator {
type InKey = ();
type InT = ();
type OutKey = ();
type OutT = ();

fn name(&self) -> String {
self.name.clone()
}

async fn process_element(&mut self, record: &Record<(), ()>, ctx: &mut Context<(), ()>) {
unimplemented!("only record batches supported");
}

async fn process_record_batch(
&mut self,
record_batch: &RecordBatchData,
ctx: &mut Context<(), ()>,
) {
info!("incoming record batch {:?}", record_batch);
let batch = &record_batch.0;
let mut data: KeyValueTimestampRecordBatch = batch.try_into().unwrap();
let session_context = SessionContext::new();
session_context.register_batch("memory", data.value_batch.clone());
let records = self.execution_plan.execute(0, session_context.task_ctx()).unwrap();
while let Some(batch) = records.next().await {
batch.unwrap();
}
let batch = &record_batch.0;
let mut data: KeyValueTimestampRecordBatch = batch.try_into().unwrap();
let arrays: Vec<_> = self
.exprs
.iter()
.map(|expr| expr.evaluate(&data.value_batch))
.map(|r| r.unwrap().into_array(batch.num_rows()))
.collect();

data.value_batch = if arrays.is_empty() {
let options = RecordBatchOptions::new().with_row_count(Some(batch.num_rows()));
RecordBatch::try_new_with_options(self.output_schema.clone(), arrays, &options).unwrap()
} else {
RecordBatch::try_new(self.output_schema.clone(), arrays).unwrap()
};
ctx.collect_record_batch((&data).into()).await;
}
}

pub struct StructToRecordBatch<K: RecordBatchBuilder, T: RecordBatchBuilder>
where
K::Data: Key,
Expand Down

0 comments on commit be2a415

Please sign in to comment.