Skip to content

Commit

Permalink
Initial SessionContextExt skeleton
Browse files Browse the repository at this point in the history
relates to apache#1081
  • Loading branch information
milenkovicm committed Oct 17, 2024
1 parent e3af1d4 commit d774263
Show file tree
Hide file tree
Showing 9 changed files with 617 additions and 23 deletions.
7 changes: 6 additions & 1 deletion ballista/client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ edition = "2021"
rust-version = "1.72"

[dependencies]
async-trait = { workspace = true }
ballista-core = { path = "../core", version = "0.12.0" }
ballista-executor = { path = "../executor", version = "0.12.0", optional = true }
ballista-scheduler = { path = "../scheduler", version = "0.12.0", optional = true }
Expand All @@ -40,8 +41,12 @@ sqlparser = { workspace = true }
tempfile = { workspace = true }
tokio = { workspace = true }

[dev-dependencies]
ctor = { version = "0.2" }
env_logger = { workspace = true }

[features]
azure = ["ballista-core/azure"]
default = []
default = ["standalone"]
s3 = ["ballista-core/s3"]
standalone = ["ballista-executor", "ballista-scheduler"]
172 changes: 172 additions & 0 deletions ballista/client/src/extension.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use ballista_core::{
config::BallistaConfig,
serde::protobuf::{
scheduler_grpc_client::SchedulerGrpcClient, CreateSessionParams, KeyValuePair,
},
utils::{create_df_ctx_with_ballista_query_planner, create_grpc_client_connection},
};
use datafusion::{error::DataFusionError, prelude::SessionContext};
use datafusion_proto::protobuf::LogicalPlanNode;

#[async_trait::async_trait]
pub trait SessionContextExt {
#[cfg(feature = "standalone")]
async fn standalone(
config: &BallistaConfig,
) -> datafusion::error::Result<SessionContext>;
// To be added at the later stage
// #[cfg(feature = "standalone")]
// async fn standalone_with_state(
// config: &BallistaConfig,
// session_state: SessionState,
// ) -> datafusion::error::Result<SessionContext>;

async fn remote(
host: &str,
port: u16,
config: &BallistaConfig,
) -> datafusion::error::Result<SessionContext>;
// To be added at the later stage
// async fn remote_with_state(
// host: &str,
// port: u16,
// config: &BallistaConfig,
// session_state: SessionState,
// ) -> datafusion::error::Result<SessionContext>;
}

#[async_trait::async_trait]
impl SessionContextExt for SessionContext {
async fn remote(
host: &str,
port: u16,
config: &BallistaConfig,
) -> datafusion::error::Result<SessionContext> {
let scheduler_url = format!("http://{}:{}", &host, port);
log::info!(
"Connecting to Ballista scheduler at {}",
scheduler_url.clone()
);
let connection = create_grpc_client_connection(scheduler_url.clone())
.await
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;

let limit = config.default_grpc_client_max_message_size();
let mut scheduler = SchedulerGrpcClient::new(connection)
.max_encoding_message_size(limit)
.max_decoding_message_size(limit);

let remote_session_id = scheduler
.create_session(CreateSessionParams {
settings: config
.settings()
.iter()
.map(|(k, v)| KeyValuePair {
key: k.to_owned(),
value: v.to_owned(),
})
.collect::<Vec<_>>(),
})
.await
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
.into_inner()
.session_id;

log::info!(
"Server side SessionContext created with session id: {}",
remote_session_id
);

let ctx = {
create_df_ctx_with_ballista_query_planner::<LogicalPlanNode>(
scheduler_url,
remote_session_id,
&config,
)
};

Ok(ctx)
}

#[cfg(feature = "standalone")]
async fn standalone(config: &BallistaConfig) -> datafusion::error::Result<Self> {
use ballista_core::serde::BallistaCodec;
use datafusion_proto::protobuf::PhysicalPlanNode;

log::info!("Running in local mode. Scheduler will be run in-proc");

let addr = ballista_scheduler::standalone::new_standalone_scheduler()
.await
.map_err(|e| DataFusionError::Configuration(e.to_string()))?;

let scheduler_url = format!("http://localhost:{}", addr.port());
let mut scheduler = loop {
match SchedulerGrpcClient::connect(scheduler_url.clone()).await {
Err(_) => {
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
log::info!("Attempting to connect to in-proc scheduler...");
}
Ok(scheduler) => break scheduler,
}
};

let remote_session_id = scheduler
.create_session(CreateSessionParams {
settings: config
.settings()
.iter()
.map(|(k, v)| KeyValuePair {
key: k.to_owned(),
value: v.to_owned(),
})
.collect::<Vec<_>>(),
})
.await
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?
.into_inner()
.session_id;

log::info!(
"Server side SessionContext created with session id: {}",
remote_session_id
);

let ctx = {
create_df_ctx_with_ballista_query_planner::<LogicalPlanNode>(
scheduler_url,
remote_session_id,
&config,
)
};

let default_codec: BallistaCodec<LogicalPlanNode, PhysicalPlanNode> =
BallistaCodec::default();

let parallelism = std::thread::available_parallelism()
.map(|v| v.get())
.unwrap_or(2);

ballista_executor::new_standalone_executor(scheduler, parallelism, default_codec)
.await
.map_err(|e| DataFusionError::Configuration(e.to_string()))?;

Ok(ctx)
}
}
1 change: 1 addition & 0 deletions ballista/client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
#![doc = include_str!("../README.md")]

pub mod context;
pub mod extension;
pub mod prelude;
99 changes: 99 additions & 0 deletions ballista/client/tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use std::env;
use std::error::Error;
use std::path::PathBuf;

/// Returns the parquet test data directory, which is by default
/// stored in a git submodule rooted at
/// `examples/testdata`.
///
/// The default can be overridden by the optional environment variable
/// `EXAMPLES_TEST_DATA`
///
/// panics when the directory can not be found.
///
/// Example:
/// ```
/// use ballista_examples::test_util;
/// let testdata = test_util::examples_test_data();
/// let filename = format!("{testdata}/aggregate_test_100.csv");
/// assert!(std::path::PathBuf::from(filename).exists());
/// ```
pub fn example_test_data() -> String {
match get_data_dir("EXAMPLES_TEST_DATA", "testdata") {
Ok(pb) => pb.display().to_string(),
Err(err) => panic!("failed to get examples test data dir: {err}"),
}
}

/// Returns a directory path for finding test data.
///
/// udf_env: name of an environment variable
///
/// submodule_dir: fallback path (relative to CARGO_MANIFEST_DIR)
///
/// Returns either:
/// The path referred to in `udf_env` if that variable is set and refers to a directory
/// The submodule_data directory relative to CARGO_MANIFEST_PATH
fn get_data_dir(udf_env: &str, submodule_data: &str) -> Result<PathBuf, Box<dyn Error>> {
// Try user defined env.
if let Ok(dir) = env::var(udf_env) {
let trimmed = dir.trim().to_string();
if !trimmed.is_empty() {
let pb = PathBuf::from(trimmed);
if pb.is_dir() {
return Ok(pb);
} else {
return Err(format!(
"the data dir `{}` defined by env {udf_env} not found",
pb.display()
)
.into());
}
}
}

// The env is undefined or its value is trimmed to empty, let's try default dir.

// env "CARGO_MANIFEST_DIR" is "the directory containing the manifest of your package",
// set by `cargo run` or `cargo test`, see:
// https://doc.rust-lang.org/cargo/reference/environment-variables.html
let dir = env!("CARGO_MANIFEST_DIR");

let pb = PathBuf::from(dir).join(submodule_data);
if pb.is_dir() {
Ok(pb)
} else {
Err(format!(
"env `{udf_env}` is undefined or has empty value, and the pre-defined data dir `{}` not found\n\
HINT: try running `git submodule update --init`",
pb.display(),
).into())
}
}

#[ctor::ctor]
fn init() {
// Enable RUST_LOG logging configuration for test
let _ = env_logger::builder()
.filter_level(log::LevelFilter::Info)
.parse_filters("ballista=debug,ballista_scheduler-rs=debug,ballista_executor=debug,datafusion=debug")
.is_test(true)
.try_init();
}
Loading

0 comments on commit d774263

Please sign in to comment.