This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git
The following commit(s) were added to refs/heads/main by this push:
new cd800952 Executor configuration extended .. (#1099)
cd800952 is described below
commit cd80095239f071f615c9002ce56e1f020a876e13
Author: Marko Milenković <[email protected]>
AuthorDate: Thu Oct 31 19:39:13 2024 +0000
Executor configuration extended .. (#1099)
* Executor configuration accepts SessionState ..
... this way we can configure way more options
* add testcontainer to verify s3 access
* fix codec configuration for executor
* change when and where codecs are created
* revert standalone from default option
* disable testcontainers
* Change config option for query planner override
* expose RuntimeProducer and SessionConfigProducer
so executors can configure runtime per task,
and session config they have
* rename extension removing Ballista prefix
* promote ballista function registry
* add few extra configuration options on executor
... process (logical, and physical codec)
* disable `datafusion.optimizer.enable_round_robin_repartition`
in default configuration as ballista disables them
---
ballista/client/Cargo.toml | 3 +
ballista/client/src/extension.rs | 66 ++++---
ballista/client/tests/common/mod.rs | 89 +++++++++
ballista/client/tests/object_store.rs | 201 +++++++++++++++++++++
ballista/client/tests/remote.rs | 40 ++++
ballista/client/tests/setup.rs | 17 +-
ballista/core/src/config.rs | 13 --
.../core/src/execution_plans/distributed_query.rs | 25 ++-
ballista/core/src/lib.rs | 24 +++
ballista/core/src/object_store_registry/mod.rs | 1 +
ballista/core/src/serde/mod.rs | 2 +-
ballista/core/src/serde/scheduler/from_proto.rs | 37 ++--
ballista/core/src/serde/scheduler/mod.rs | 53 +++++-
ballista/core/src/utils.rs | 113 +++++++++---
ballista/executor/Cargo.toml | 2 +-
ballista/executor/src/bin/main.rs | 5 +
ballista/executor/src/execution_loop.rs | 65 ++-----
ballista/executor/src/executor.rs | 102 +++++++----
ballista/executor/src/executor_process.rs | 67 +++++--
ballista/executor/src/executor_server.rs | 33 ++--
ballista/executor/src/lib.rs | 1 +
ballista/executor/src/standalone.rs | 107 ++++++++++-
ballista/scheduler/Cargo.toml | 2 +-
ballista/scheduler/src/cluster/memory.rs | 18 +-
ballista/scheduler/src/cluster/mod.rs | 2 +-
ballista/scheduler/src/scheduler_server/grpc.rs | 2 +
ballista/scheduler/src/scheduler_server/mod.rs | 2 +-
ballista/scheduler/src/standalone.rs | 36 +++-
ballista/scheduler/src/test_utils.rs | 2 +-
29 files changed, 871 insertions(+), 259 deletions(-)
diff --git a/ballista/client/Cargo.toml b/ballista/client/Cargo.toml
index da61dab9..038c62c8 100644
--- a/ballista/client/Cargo.toml
+++ b/ballista/client/Cargo.toml
@@ -47,9 +47,12 @@ ballista-executor = { path = "../executor", version =
"0.12.0" }
ballista-scheduler = { path = "../scheduler", version = "0.12.0" }
ctor = { version = "0.2" }
env_logger = { workspace = true }
+object_store = { workspace = true, features = ["aws"] }
+testcontainers-modules = { version = "0.11", features = ["minio"] }
[features]
azure = ["ballista-core/azure"]
default = []
s3 = ["ballista-core/s3"]
standalone = ["ballista-executor", "ballista-scheduler"]
+testcontainers = []
diff --git a/ballista/client/src/extension.rs b/ballista/client/src/extension.rs
index 99c8a88f..38931e28 100644
--- a/ballista/client/src/extension.rs
+++ b/ballista/client/src/extension.rs
@@ -15,13 +15,13 @@
// specific language governing permissions and limitations
// under the License.
-pub use ballista_core::utils::BallistaSessionConfigExt;
+pub use ballista_core::utils::SessionConfigExt;
use ballista_core::{
config::BallistaConfig,
serde::protobuf::{
scheduler_grpc_client::SchedulerGrpcClient, CreateSessionParams,
KeyValuePair,
},
- utils::{create_grpc_client_connection, BallistaSessionStateExt},
+ utils::{create_grpc_client_connection, SessionStateExt},
};
use datafusion::{
error::DataFusionError, execution::SessionState, prelude::SessionContext,
@@ -65,6 +65,7 @@ const DEFAULT_SCHEDULER_PORT: u16 = 50050;
/// There are still few limitations on query distribution, thus not all
/// [SessionContext] functionalities are supported.
///
+
#[async_trait::async_trait]
pub trait SessionContextExt {
/// Creates a context for executing queries against a standalone Ballista
scheduler instance
@@ -144,14 +145,8 @@ impl SessionContextExt for SessionContext {
) -> datafusion::error::Result<SessionContext> {
let config = state.ballista_config();
- let codec_logical = state.config().ballista_logical_extension_codec();
- let codec_physical =
state.config().ballista_physical_extension_codec();
-
- let ballista_codec =
- ballista_core::serde::BallistaCodec::new(codec_logical,
codec_physical);
-
let (remote_session_id, scheduler_url) =
- Extension::setup_standalone(config, ballista_codec).await?;
+ Extension::setup_standalone(config, Some(&state)).await?;
let session_state =
state.upgrade_for_ballista(scheduler_url,
remote_session_id.clone())?;
@@ -170,10 +165,8 @@ impl SessionContextExt for SessionContext {
let config = BallistaConfig::new()
.map_err(|e| DataFusionError::Configuration(e.to_string()))?;
- let ballista_codec = ballista_core::serde::BallistaCodec::default();
-
let (remote_session_id, scheduler_url) =
- Extension::setup_standalone(config, ballista_codec).await?;
+ Extension::setup_standalone(config, None).await?;
let session_state =
SessionState::new_ballista_state(scheduler_url,
remote_session_id.clone())?;
@@ -205,14 +198,22 @@ impl Extension {
#[cfg(feature = "standalone")]
async fn setup_standalone(
config: BallistaConfig,
- ballista_codec: ballista_core::serde::BallistaCodec<
- datafusion_proto::protobuf::LogicalPlanNode,
- datafusion_proto::protobuf::PhysicalPlanNode,
- >,
+ session_state: Option<&SessionState>,
) -> datafusion::error::Result<(String, String)> {
- let addr = ballista_scheduler::standalone::new_standalone_scheduler()
- .await
- .map_err(|e| DataFusionError::Configuration(e.to_string()))?;
+ use ballista_core::serde::BallistaCodec;
+
+ let addr = match session_state {
+ None => ballista_scheduler::standalone::new_standalone_scheduler()
+ .await
+ .map_err(|e| DataFusionError::Configuration(e.to_string()))?,
+ Some(session_state) => {
+
ballista_scheduler::standalone::new_standalone_scheduler_from_state(
+ session_state,
+ )
+ .await
+ .map_err(|e| DataFusionError::Configuration(e.to_string()))?
+ }
+ };
let scheduler_url = format!("http://localhost:{}", addr.port());
@@ -243,13 +244,26 @@ impl Extension {
.session_id;
let concurrent_tasks = config.default_standalone_parallelism();
- ballista_executor::new_standalone_executor(
- scheduler,
- concurrent_tasks,
- ballista_codec,
- )
- .await
- .map_err(|e| DataFusionError::Configuration(e.to_string()))?;
+
+ match session_state {
+ None => {
+ ballista_executor::new_standalone_executor(
+ scheduler,
+ concurrent_tasks,
+ BallistaCodec::default(),
+ )
+ .await
+ .map_err(|e| DataFusionError::Configuration(e.to_string()))?;
+ }
+ Some(session_state) => {
+ ballista_executor::new_standalone_executor_from_state::<
+ datafusion_proto::protobuf::LogicalPlanNode,
+ datafusion_proto::protobuf::PhysicalPlanNode,
+ >(scheduler, concurrent_tasks, session_state)
+ .await
+ .map_err(|e| DataFusionError::Configuration(e.to_string()))?;
+ }
+ }
Ok((remote_session_id, scheduler_url))
}
diff --git a/ballista/client/tests/common/mod.rs
b/ballista/client/tests/common/mod.rs
index 02f25d7b..afc32aea 100644
--- a/ballista/client/tests/common/mod.rs
+++ b/ballista/client/tests/common/mod.rs
@@ -23,6 +23,53 @@ use ballista::prelude::BallistaConfig;
use ballista_core::serde::{
protobuf::scheduler_grpc_client::SchedulerGrpcClient, BallistaCodec,
};
+use datafusion::execution::SessionState;
+use object_store::aws::AmazonS3Builder;
+use testcontainers_modules::minio::MinIO;
+use testcontainers_modules::testcontainers::core::{CmdWaitFor, ExecCommand};
+use testcontainers_modules::testcontainers::ContainerRequest;
+use testcontainers_modules::{minio, testcontainers::ImageExt};
+
+pub const REGION: &str = "eu-west-1";
+pub const BUCKET: &str = "ballista";
+pub const ACCESS_KEY_ID: &str = "MINIO";
+pub const SECRET_KEY: &str = "MINIOMINIO";
+
+#[allow(dead_code)]
+pub fn create_s3_store(
+ port: u16,
+) -> std::result::Result<object_store::aws::AmazonS3, object_store::Error> {
+ AmazonS3Builder::new()
+ .with_endpoint(format!("http://localhost:{port}"))
+ .with_region(REGION)
+ .with_bucket_name(BUCKET)
+ .with_access_key_id(ACCESS_KEY_ID)
+ .with_secret_access_key(SECRET_KEY)
+ .with_allow_http(true)
+ .build()
+}
+
+#[allow(dead_code)]
+pub fn create_minio_container() -> ContainerRequest<minio::MinIO> {
+ MinIO::default()
+ .with_env_var("MINIO_ACCESS_KEY", ACCESS_KEY_ID)
+ .with_env_var("MINIO_SECRET_KEY", SECRET_KEY)
+}
+
+#[allow(dead_code)]
+pub fn create_bucket_command() -> ExecCommand {
+ // this is hack to create a bucket without creating s3 client.
+ // this works with current testcontainer (and image) version
'RELEASE.2022-02-07T08-17-33Z'.
+ // (testcontainer does not await properly on latest image version)
+ //
+ // if testcontainer image version change to something newer we should use
"mc mb /data/ballista"
+ // to crate a bucket.
+ ExecCommand::new(vec![
+ "mkdir".to_string(),
+ format!("/data/{}", crate::common::BUCKET),
+ ])
+ .with_cmd_ready_condition(CmdWaitFor::seconds(1))
+}
// /// Remote ballista cluster to be used for local testing.
// static BALLISTA_CLUSTER: tokio::sync::OnceCell<(String, u16)> =
@@ -136,6 +183,48 @@ pub async fn setup_test_cluster() -> (String, u16) {
(host, addr.port())
}
+/// starts a ballista cluster for integration tests
+#[allow(dead_code)]
+pub async fn setup_test_cluster_with_state(session_state: SessionState) ->
(String, u16) {
+ let config = BallistaConfig::builder().build().unwrap();
+ //let default_codec = BallistaCodec::default();
+
+ let addr =
ballista_scheduler::standalone::new_standalone_scheduler_from_state(
+ &session_state,
+ )
+ .await
+ .expect("scheduler to be created");
+
+ let host = "localhost".to_string();
+
+ let scheduler_url = format!("http://{}:{}", host, addr.port());
+
+ let scheduler = loop {
+ match SchedulerGrpcClient::connect(scheduler_url.clone()).await {
+ Err(_) => {
+
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
+ log::info!("Attempting to connect to test scheduler...");
+ }
+ Ok(scheduler) => break scheduler,
+ }
+ };
+
+ ballista_executor::new_standalone_executor_from_state::<
+ datafusion_proto::protobuf::LogicalPlanNode,
+ datafusion_proto::protobuf::PhysicalPlanNode,
+ >(
+ scheduler,
+ config.default_standalone_parallelism(),
+ &session_state,
+ )
+ .await
+ .expect("executor to be created");
+
+ log::info!("test scheduler created at: {}:{}", host, addr.port());
+
+ (host, addr.port())
+}
+
#[ctor::ctor]
fn init() {
// Enable RUST_LOG logging configuration for test
diff --git a/ballista/client/tests/object_store.rs
b/ballista/client/tests/object_store.rs
new file mode 100644
index 00000000..b58bcb90
--- /dev/null
+++ b/ballista/client/tests/object_store.rs
@@ -0,0 +1,201 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! # Object Store Support
+//!
+//! Tests demonstrate how to setup object stores with ballista.
+//!
+//! Test depend on Minio testcontainer acting as S3 object
+//! store.
+//!
+//! Tesctoncainers require docker to run.
+
+mod common;
+
+#[cfg(test)]
+#[cfg(feature = "standalone")]
+#[cfg(feature = "testcontainers")]
+mod standalone {
+
+ use ballista::extension::SessionContextExt;
+ use datafusion::{assert_batches_eq, prelude::SessionContext};
+ use datafusion::{
+ error::DataFusionError,
+ execution::{
+ runtime_env::{RuntimeConfig, RuntimeEnv},
+ SessionStateBuilder,
+ },
+ };
+ use std::sync::Arc;
+ use testcontainers_modules::testcontainers::runners::AsyncRunner;
+
+ #[tokio::test]
+ async fn should_execute_sql_write() -> datafusion::error::Result<()> {
+ let container = crate::common::create_minio_container();
+ let node = container.start().await.unwrap();
+
+ node.exec(crate::common::create_bucket_command())
+ .await
+ .unwrap();
+
+ let port = node.get_host_port_ipv4(9000).await.unwrap();
+
+ let object_store = crate::common::create_s3_store(port)
+ .map_err(|e| DataFusionError::External(e.into()))?;
+
+ let test_data = crate::common::example_test_data();
+ let config = RuntimeConfig::new();
+ let runtime_env = RuntimeEnv::new(config)?;
+
+ runtime_env.register_object_store(
+ &format!("s3://{}", crate::common::BUCKET)
+ .as_str()
+ .try_into()
+ .unwrap(),
+ Arc::new(object_store),
+ );
+ let state = SessionStateBuilder::new()
+ .with_runtime_env(runtime_env.into())
+ .build();
+
+ let ctx: SessionContext =
SessionContext::standalone_with_state(state).await?;
+ ctx.register_parquet(
+ "test",
+ &format!("{test_data}/alltypes_plain.parquet"),
+ Default::default(),
+ )
+ .await?;
+
+ let write_dir_path =
+ &format!("s3://{}/write_test.parquet", crate::common::BUCKET);
+
+ ctx.sql("select * from test")
+ .await?
+ .write_parquet(write_dir_path, Default::default(),
Default::default())
+ .await?;
+
+ ctx.register_parquet("written_table", write_dir_path,
Default::default())
+ .await?;
+
+ let result = ctx
+ .sql("select id, string_col, timestamp_col from written_table
where id > 4")
+ .await?
+ .collect()
+ .await?;
+ let expected = [
+ "+----+------------+---------------------+",
+ "| id | string_col | timestamp_col |",
+ "+----+------------+---------------------+",
+ "| 5 | 31 | 2009-03-01T00:01:00 |",
+ "| 6 | 30 | 2009-04-01T00:00:00 |",
+ "| 7 | 31 | 2009-04-01T00:01:00 |",
+ "+----+------------+---------------------+",
+ ];
+
+ assert_batches_eq!(expected, &result);
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+#[cfg(feature = "testcontainers")]
+mod remote {
+
+ use ballista::extension::SessionContextExt;
+ use datafusion::{assert_batches_eq, prelude::SessionContext};
+ use datafusion::{
+ error::DataFusionError,
+ execution::{
+ runtime_env::{RuntimeConfig, RuntimeEnv},
+ SessionStateBuilder,
+ },
+ };
+ use std::sync::Arc;
+ use testcontainers_modules::testcontainers::runners::AsyncRunner;
+
+ #[tokio::test]
+ async fn should_execute_sql_write() -> datafusion::error::Result<()> {
+ let test_data = crate::common::example_test_data();
+
+ let container = crate::common::create_minio_container();
+ let node = container.start().await.unwrap();
+
+ node.exec(crate::common::create_bucket_command())
+ .await
+ .unwrap();
+
+ let port = node.get_host_port_ipv4(9000).await.unwrap();
+
+ let object_store = crate::common::create_s3_store(port)
+ .map_err(|e| DataFusionError::External(e.into()))?;
+
+ let config = RuntimeConfig::new();
+ let runtime_env = RuntimeEnv::new(config)?;
+
+ runtime_env.register_object_store(
+ &format!("s3://{}", crate::common::BUCKET)
+ .as_str()
+ .try_into()
+ .unwrap(),
+ Arc::new(object_store),
+ );
+ let state = SessionStateBuilder::new()
+ .with_runtime_env(runtime_env.into())
+ .build();
+
+ let (host, port) =
+ crate::common::setup_test_cluster_with_state(state.clone()).await;
+ let url = format!("df://{host}:{port}");
+
+ let ctx: SessionContext = SessionContext::remote_with_state(&url,
state).await?;
+ ctx.register_parquet(
+ "test",
+ &format!("{test_data}/alltypes_plain.parquet"),
+ Default::default(),
+ )
+ .await?;
+
+ let write_dir_path =
+ &format!("s3://{}/write_test.parquet", crate::common::BUCKET);
+
+ ctx.sql("select * from test")
+ .await?
+ .write_parquet(write_dir_path, Default::default(),
Default::default())
+ .await?;
+
+ ctx.register_parquet("written_table", write_dir_path,
Default::default())
+ .await?;
+
+ let result = ctx
+ .sql("select id, string_col, timestamp_col from written_table
where id > 4")
+ .await?
+ .collect()
+ .await?;
+ let expected = [
+ "+----+------------+---------------------+",
+ "| id | string_col | timestamp_col |",
+ "+----+------------+---------------------+",
+ "| 5 | 31 | 2009-03-01T00:01:00 |",
+ "| 6 | 30 | 2009-04-01T00:00:00 |",
+ "| 7 | 31 | 2009-04-01T00:01:00 |",
+ "+----+------------+---------------------+",
+ ];
+
+ assert_batches_eq!(expected, &result);
+ Ok(())
+ }
+}
diff --git a/ballista/client/tests/remote.rs b/ballista/client/tests/remote.rs
index 619c4cd6..b0184b26 100644
--- a/ballista/client/tests/remote.rs
+++ b/ballista/client/tests/remote.rs
@@ -142,4 +142,44 @@ mod remote {
Ok(())
}
+
+ #[tokio::test]
+ async fn should_execute_sql_app_name_show() ->
datafusion::error::Result<()> {
+ let (host, port) = crate::common::setup_test_cluster().await;
+ let url = format!("df://{host}:{port}");
+
+ let test_data = crate::common::example_test_data();
+ let ctx: SessionContext = SessionContext::remote(&url).await?;
+
+ ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'")
+ .await?
+ .show()
+ .await?;
+
+ ctx.register_parquet(
+ "test",
+ &format!("{test_data}/alltypes_plain.parquet"),
+ Default::default(),
+ )
+ .await?;
+
+ let result = ctx
+ .sql("select string_col, timestamp_col from test where id > 4")
+ .await?
+ .collect()
+ .await?;
+ let expected = [
+ "+------------+---------------------+",
+ "| string_col | timestamp_col |",
+ "+------------+---------------------+",
+ "| 31 | 2009-03-01T00:01:00 |",
+ "| 30 | 2009-04-01T00:00:00 |",
+ "| 31 | 2009-04-01T00:01:00 |",
+ "+------------+---------------------+",
+ ];
+
+ assert_batches_eq!(expected, &result);
+
+ Ok(())
+ }
}
diff --git a/ballista/client/tests/setup.rs b/ballista/client/tests/setup.rs
index 30a6df84..10b48290 100644
--- a/ballista/client/tests/setup.rs
+++ b/ballista/client/tests/setup.rs
@@ -20,7 +20,7 @@ mod common;
#[cfg(test)]
mod remote {
use ballista::{
- extension::{BallistaSessionConfigExt, SessionContextExt},
+ extension::{SessionConfigExt, SessionContextExt},
prelude::BALLISTA_JOB_NAME,
};
use datafusion::{
@@ -109,12 +109,10 @@ mod standalone {
use std::sync::{atomic::AtomicBool, Arc};
use ballista::{
- extension::{BallistaSessionConfigExt, SessionContextExt},
+ extension::{SessionConfigExt, SessionContextExt},
prelude::BALLISTA_JOB_NAME,
};
- use ballista_core::{
- config::BALLISTA_PLANNER_OVERRIDE,
serde::BallistaPhysicalExtensionCodec,
- };
+ use ballista_core::serde::BallistaPhysicalExtensionCodec;
use datafusion::{
assert_batches_eq,
common::exec_err,
@@ -243,12 +241,11 @@ mod standalone {
async fn should_override_planner() -> datafusion::error::Result<()> {
let session_config = SessionConfig::new_with_ballista()
.with_information_schema(true)
- .set_str(BALLISTA_PLANNER_OVERRIDE, "false");
+ .with_ballista_query_planner(Arc::new(BadPlanner::default()));
let state = SessionStateBuilder::new()
.with_default_features()
.with_config(session_config)
- .with_query_planner(Arc::new(BadPlanner::default()))
.build();
let ctx: SessionContext =
SessionContext::standalone_with_state(state).await?;
@@ -257,14 +254,12 @@ mod standalone {
assert!(result.is_err());
- let session_config = SessionConfig::new_with_ballista()
- .with_information_schema(true)
- .set_str(BALLISTA_PLANNER_OVERRIDE, "true");
+ let session_config =
+ SessionConfig::new_with_ballista().with_information_schema(true);
let state = SessionStateBuilder::new()
.with_default_features()
.with_config(session_config)
- .with_query_planner(Arc::new(BadPlanner::default()))
.build();
let ctx: SessionContext =
SessionContext::standalone_with_state(state).await?;
diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs
index 782b8b9d..88cba1d9 100644
--- a/ballista/core/src/config.rs
+++ b/ballista/core/src/config.rs
@@ -43,11 +43,6 @@ pub const BALLISTA_REPARTITION_WINDOWS: &str =
"ballista.repartition.windows";
pub const BALLISTA_PARQUET_PRUNING: &str = "ballista.parquet.pruning";
pub const BALLISTA_COLLECT_STATISTICS: &str = "ballista.collect_statistics";
pub const BALLISTA_STANDALONE_PARALLELISM: &str =
"ballista.standalone.parallelism";
-/// If set to false, planner will not be overridden by ballista.
-/// This allows user to replace ballista planner
-// this is a bit of a hack, as we can't detect if there is a
-// custom planner provided
-pub const BALLISTA_PLANNER_OVERRIDE: &str = "ballista.planner.override";
pub const BALLISTA_WITH_INFORMATION_SCHEMA: &str =
"ballista.with_information_schema";
@@ -221,10 +216,6 @@ impl BallistaConfig {
"Configuration for max message size in gRPC
clients".to_string(),
DataType::UInt64,
Some((16 * 1024 * 1024).to_string())),
- ConfigEntry::new(BALLISTA_PLANNER_OVERRIDE.to_string(),
- "Disable overriding provided planner".to_string(),
- DataType::Boolean,
- Some((true).to_string())),
];
entries
.iter()
@@ -280,10 +271,6 @@ impl BallistaConfig {
self.get_bool_setting(BALLISTA_WITH_INFORMATION_SCHEMA)
}
- pub fn planner_override(&self) -> bool {
- self.get_bool_setting(BALLISTA_PLANNER_OVERRIDE)
- }
-
fn get_usize_setting(&self, key: &str) -> usize {
if let Some(v) = self.settings.get(key) {
// infallible because we validate all configs in the constructor
diff --git a/ballista/core/src/execution_plans/distributed_query.rs
b/ballista/core/src/execution_plans/distributed_query.rs
index 050ba877..dae4bb8e 100644
--- a/ballista/core/src/execution_plans/distributed_query.rs
+++ b/ballista/core/src/execution_plans/distributed_query.rs
@@ -194,7 +194,7 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for
DistributedQueryExec<T> {
fn execute(
&self,
partition: usize,
- _context: Arc<TaskContext>,
+ context: Arc<TaskContext>,
) -> Result<SendableRecordBatchStream> {
assert_eq!(0, partition);
@@ -210,17 +210,22 @@ impl<T: 'static + AsLogicalPlan> ExecutionPlan for
DistributedQueryExec<T> {
DataFusionError::Execution(format!("failed to encode logical plan:
{e:?}"))
})?;
+ let settings = context
+ .session_config()
+ .options()
+ .entries()
+ .iter()
+ .map(
+ |datafusion::config::ConfigEntry { key, value, .. }|
KeyValuePair {
+ key: key.to_owned(),
+ value: value.clone().unwrap_or_else(|| String::from("")),
+ },
+ )
+ .collect();
+
let query = ExecuteQueryParams {
query: Some(Query::LogicalPlan(buf)),
- settings: self
- .config
- .settings()
- .iter()
- .map(|(k, v)| KeyValuePair {
- key: k.to_owned(),
- value: v.to_owned(),
- })
- .collect::<Vec<_>>(),
+ settings,
optional_session_id: Some(OptionalSessionId::SessionId(
self.session_id.clone(),
)),
diff --git a/ballista/core/src/lib.rs b/ballista/core/src/lib.rs
index c52d2ef4..8ae5dfb5 100644
--- a/ballista/core/src/lib.rs
+++ b/ballista/core/src/lib.rs
@@ -16,6 +16,10 @@
// under the License.
#![doc = include_str!("../README.md")]
+
+use std::sync::Arc;
+
+use datafusion::{execution::runtime_env::RuntimeEnv, prelude::SessionConfig};
pub const BALLISTA_VERSION: &str = env!("CARGO_PKG_VERSION");
pub fn print_version() {
@@ -33,3 +37,23 @@ pub mod utils;
#[macro_use]
pub mod serde;
+
+///
+/// [RuntimeProducer] is a factory which creates runtime [RuntimeEnv]
+/// from [SessionConfig]. As [SessionConfig] will be propagated
+/// from client to executors, this provides possibility to
+/// create [RuntimeEnv] components and configure them according to
+/// [SessionConfig] or some of its config extension
+///
+/// It is intended to be used with executor configuration
+///
+pub type RuntimeProducer = Arc<
+ dyn Fn(&SessionConfig) -> datafusion::error::Result<Arc<RuntimeEnv>> +
Send + Sync,
+>;
+///
+/// [ConfigProducer] is a factory which can create [SessionConfig], with
+/// additional extension or configuration codecs
+///
+/// It is intended to be used with executor configuration
+///
+pub type ConfigProducer = Arc<dyn Fn() -> SessionConfig + Send + Sync>;
diff --git a/ballista/core/src/object_store_registry/mod.rs
b/ballista/core/src/object_store_registry/mod.rs
index aedccc5e..e7fbee21 100644
--- a/ballista/core/src/object_store_registry/mod.rs
+++ b/ballista/core/src/object_store_registry/mod.rs
@@ -31,6 +31,7 @@ use std::sync::Arc;
use url::Url;
/// Get a RuntimeConfig with specific ObjectStoreRegistry
+// TODO: #[deprecated] this method
pub fn with_object_store_registry(config: RuntimeConfig) -> RuntimeConfig {
let registry = Arc::new(BallistaObjectStoreRegistry::default());
config.with_object_store_registry(registry)
diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs
index 7464fe68..5400b00c 100644
--- a/ballista/core/src/serde/mod.rs
+++ b/ballista/core/src/serde/mod.rs
@@ -89,7 +89,7 @@ impl Default for BallistaCodec {
fn default() -> Self {
Self {
logical_extension_codec:
Arc::new(BallistaLogicalExtensionCodec::default()),
- physical_extension_codec: Arc::new(BallistaPhysicalExtensionCodec
{}),
+ physical_extension_codec:
Arc::new(BallistaPhysicalExtensionCodec::default()),
logical_plan_repr: PhantomData,
physical_plan_repr: PhantomData,
}
diff --git a/ballista/core/src/serde/scheduler/from_proto.rs
b/ballista/core/src/serde/scheduler/from_proto.rs
index 4821eab2..28a1e8a5 100644
--- a/ballista/core/src/serde/scheduler/from_proto.rs
+++ b/ballista/core/src/serde/scheduler/from_proto.rs
@@ -17,12 +17,13 @@
use chrono::{TimeZone, Utc};
use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode};
-use datafusion::execution::runtime_env::RuntimeEnv;
+
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
use datafusion::physical_plan::metrics::{
Count, Gauge, MetricValue, MetricsSet, Time, Timestamp,
};
use datafusion::physical_plan::{ExecutionPlan, Metric};
+use datafusion::prelude::SessionConfig;
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
use std::collections::HashMap;
@@ -32,11 +33,13 @@ use std::time::Duration;
use crate::error::BallistaError;
use crate::serde::scheduler::{
- Action, ExecutorData, ExecutorMetadata, ExecutorSpecification, PartitionId,
- PartitionLocation, PartitionStats, SimpleFunctionRegistry, TaskDefinition,
+ Action, BallistaFunctionRegistry, ExecutorData, ExecutorMetadata,
+ ExecutorSpecification, PartitionId, PartitionLocation, PartitionStats,
+ TaskDefinition,
};
use crate::serde::{protobuf, BallistaCodec};
+use crate::RuntimeProducer;
use protobuf::{operator_metric, NamedCount, NamedGauge, NamedTime};
impl TryInto<Action> for protobuf::Action {
@@ -281,17 +284,17 @@ impl Into<ExecutorData> for protobuf::ExecutorData {
pub fn get_task_definition<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan>(
task: protobuf::TaskDefinition,
- runtime: Arc<RuntimeEnv>,
+ produce_runtime: RuntimeProducer,
+ session_config: SessionConfig,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
window_functions: HashMap<String, Arc<WindowUDF>>,
codec: BallistaCodec<T, U>,
) -> Result<TaskDefinition, BallistaError> {
- let mut props = HashMap::new();
+ let mut session_config = session_config;
for kv_pair in task.props {
- props.insert(kv_pair.key, kv_pair.value);
+ session_config = session_config.set_str(&kv_pair.key, &kv_pair.value);
}
- let props = Arc::new(props);
let mut task_scalar_functions = HashMap::new();
let mut task_aggregate_functions = HashMap::new();
@@ -306,12 +309,12 @@ pub fn get_task_definition<T: 'static + AsLogicalPlan, U:
'static + AsExecutionP
for agg_func in window_functions {
task_window_functions.insert(agg_func.0, agg_func.1);
}
- let function_registry = Arc::new(SimpleFunctionRegistry {
+ let function_registry = Arc::new(BallistaFunctionRegistry {
scalar_functions: task_scalar_functions,
aggregate_functions: task_aggregate_functions,
window_functions: task_window_functions,
});
-
+ let runtime = produce_runtime(&session_config)?;
let encoded_plan = task.plan.as_slice();
let plan: Arc<dyn ExecutionPlan> =
U::try_decode(encoded_plan).and_then(|proto| {
proto.try_into_physical_plan(
@@ -340,7 +343,7 @@ pub fn get_task_definition<T: 'static + AsLogicalPlan, U:
'static + AsExecutionP
plan,
launch_time,
session_id,
- props,
+ session_config,
function_registry,
})
}
@@ -350,17 +353,17 @@ pub fn get_task_definition_vec<
U: 'static + AsExecutionPlan,
>(
multi_task: protobuf::MultiTaskDefinition,
- runtime: Arc<RuntimeEnv>,
+ runtime_producer: RuntimeProducer,
+ session_config: SessionConfig,
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
window_functions: HashMap<String, Arc<WindowUDF>>,
codec: BallistaCodec<T, U>,
) -> Result<Vec<TaskDefinition>, BallistaError> {
- let mut props = HashMap::new();
+ let mut session_config = session_config;
for kv_pair in multi_task.props {
- props.insert(kv_pair.key, kv_pair.value);
+ session_config = session_config.set_str(&kv_pair.key, &kv_pair.value);
}
- let props = Arc::new(props);
let mut task_scalar_functions = HashMap::new();
let mut task_aggregate_functions = HashMap::new();
@@ -375,12 +378,14 @@ pub fn get_task_definition_vec<
for agg_func in window_functions {
task_window_functions.insert(agg_func.0, agg_func.1);
}
- let function_registry = Arc::new(SimpleFunctionRegistry {
+ let function_registry = Arc::new(BallistaFunctionRegistry {
scalar_functions: task_scalar_functions,
aggregate_functions: task_aggregate_functions,
window_functions: task_window_functions,
});
+ let runtime = runtime_producer(&session_config)?;
+
let encoded_plan = multi_task.plan.as_slice();
let plan: Arc<dyn ExecutionPlan> =
U::try_decode(encoded_plan).and_then(|proto| {
proto.try_into_physical_plan(
@@ -410,7 +415,7 @@ pub fn get_task_definition_vec<
plan: reset_metrics_for_execution_plan(plan.clone())?,
launch_time,
session_id: session_id.clone(),
- props: props.clone(),
+ session_config: session_config.clone(),
function_registry: function_registry.clone(),
})
})
diff --git a/ballista/core/src/serde/scheduler/mod.rs
b/ballista/core/src/serde/scheduler/mod.rs
index 23c9c425..2905455e 100644
--- a/ballista/core/src/serde/scheduler/mod.rs
+++ b/ballista/core/src/serde/scheduler/mod.rs
@@ -24,11 +24,15 @@ use datafusion::arrow::array::{
};
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::common::DataFusionError;
-use datafusion::execution::FunctionRegistry;
+use datafusion::execution::{FunctionRegistry, SessionState};
+use datafusion::functions::all_default_functions;
+use datafusion::functions_aggregate::all_default_aggregate_functions;
+use datafusion::functions_window::all_default_window_functions;
use datafusion::logical_expr::planner::ExprPlanner;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
use datafusion::physical_plan::ExecutionPlan;
use datafusion::physical_plan::Partitioning;
+use datafusion::prelude::SessionConfig;
use serde::Serialize;
use crate::error::BallistaError;
@@ -288,18 +292,43 @@ pub struct TaskDefinition {
pub plan: Arc<dyn ExecutionPlan>,
pub launch_time: u64,
pub session_id: String,
- pub props: Arc<HashMap<String, String>>,
- pub function_registry: Arc<SimpleFunctionRegistry>,
+ pub session_config: SessionConfig,
+ pub function_registry: Arc<BallistaFunctionRegistry>,
}
#[derive(Debug)]
-pub struct SimpleFunctionRegistry {
+pub struct BallistaFunctionRegistry {
pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
pub window_functions: HashMap<String, Arc<WindowUDF>>,
}
-impl FunctionRegistry for SimpleFunctionRegistry {
+impl Default for BallistaFunctionRegistry {
+ fn default() -> Self {
+ let scalar_functions = all_default_functions()
+ .into_iter()
+ .map(|f| (f.name().to_string(), f))
+ .collect();
+
+ let aggregate_functions = all_default_aggregate_functions()
+ .into_iter()
+ .map(|f| (f.name().to_string(), f))
+ .collect();
+
+ let window_functions = all_default_window_functions()
+ .into_iter()
+ .map(|f| (f.name().to_string(), f))
+ .collect();
+
+ Self {
+ scalar_functions,
+ aggregate_functions,
+ window_functions,
+ }
+ }
+}
+
+impl FunctionRegistry for BallistaFunctionRegistry {
fn expr_planners(&self) -> Vec<Arc<dyn ExprPlanner>> {
vec![]
}
@@ -338,3 +367,17 @@ impl FunctionRegistry for SimpleFunctionRegistry {
})
}
}
+
+impl From<&SessionState> for BallistaFunctionRegistry {
+ fn from(state: &SessionState) -> Self {
+ let scalar_functions = state.scalar_functions().clone();
+ let aggregate_functions = state.aggregate_functions().clone();
+ let window_functions = state.window_functions().clone();
+
+ Self {
+ scalar_functions,
+ aggregate_functions,
+ window_functions,
+ }
+ }
+}
diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs
index 8be32c40..3f8f6bfe 100644
--- a/ballista/core/src/utils.rs
+++ b/ballista/core/src/utils.rs
@@ -277,7 +277,7 @@ pub fn create_df_ctx_with_ballista_query_planner<T: 'static
+ AsLogicalPlan>(
SessionContext::new_with_state(session_state)
}
-pub trait BallistaSessionStateExt {
+pub trait SessionStateExt {
fn new_ballista_state(
scheduler_url: String,
session_id: String,
@@ -291,7 +291,7 @@ pub trait BallistaSessionStateExt {
fn ballista_config(&self) -> BallistaConfig;
}
-impl BallistaSessionStateExt for SessionState {
+impl SessionStateExt for SessionState {
fn ballista_config(&self) -> BallistaConfig {
self.config()
.options()
@@ -313,7 +313,9 @@ impl BallistaSessionStateExt for SessionState {
let session_config = SessionConfig::new()
.with_information_schema(true)
- .with_option_extension(config.clone());
+ .with_option_extension(config.clone())
+ // Ballista disables this option
+ .with_round_robin_repartition(false);
let runtime_config = RuntimeConfig::default();
let runtime_env = RuntimeEnv::new(runtime_config)?;
@@ -334,6 +336,7 @@ impl BallistaSessionStateExt for SessionState {
session_id: String,
) -> datafusion::error::Result<SessionState> {
let codec_logical = self.config().ballista_logical_extension_codec();
+ let planner_override = self.config().ballista_query_planner();
let new_config = self
.config()
@@ -346,39 +349,31 @@ impl BallistaSessionStateExt for SessionState {
let session_config = self
.config()
.clone()
- .with_option_extension(new_config.clone());
-
- // at the moment we don't have a way to detect if
- // user set planner so we provide a configuration to
- // user to disable planner override
- let planner_override = self
- .config()
- .options()
- .extensions
- .get::<BallistaConfig>()
- .map(|config| config.planner_override())
- .unwrap_or(true);
+ .with_option_extension(new_config.clone())
+ // Ballista disables this option
+ .with_round_robin_repartition(false);
let builder = SessionStateBuilder::new_from_existing(self)
.with_config(session_config)
.with_session_id(session_id);
- let builder = if planner_override {
- let query_planner =
BallistaQueryPlanner::<LogicalPlanNode>::with_extension(
- scheduler_url,
- new_config,
- codec_logical,
- );
- builder.with_query_planner(Arc::new(query_planner))
- } else {
- builder
+ let builder = match planner_override {
+ Some(planner) => builder.with_query_planner(planner),
+ None => {
+ let planner =
BallistaQueryPlanner::<LogicalPlanNode>::with_extension(
+ scheduler_url,
+ new_config,
+ codec_logical,
+ );
+ builder.with_query_planner(Arc::new(planner))
+ }
};
Ok(builder.build())
}
}
-pub trait BallistaSessionConfigExt {
+pub trait SessionConfigExt {
/// Creates session config which has
/// ballista configuration initialized
fn new_with_ballista() -> SessionConfig;
@@ -402,9 +397,20 @@ pub trait BallistaSessionConfigExt {
/// returns [PhysicalExtensionCodec] if set
/// or default ballista codec if not
fn ballista_physical_extension_codec(&self) -> Arc<dyn
PhysicalExtensionCodec>;
+
+ /// Overrides ballista's [QueryPlanner]
+ fn with_ballista_query_planner(
+ self,
+ planner: Arc<dyn QueryPlanner + Send + Sync + 'static>,
+ ) -> SessionConfig;
+
+ /// Returns ballista's [QueryPlanner] if overriden
+ fn ballista_query_planner(
+ &self,
+ ) -> Option<Arc<dyn QueryPlanner + Send + Sync + 'static>>;
}
-impl BallistaSessionConfigExt for SessionConfig {
+impl SessionConfigExt for SessionConfig {
fn new_with_ballista() -> SessionConfig {
SessionConfig::new().with_option_extension(BallistaConfig::new().unwrap())
}
@@ -433,6 +439,21 @@ impl BallistaSessionConfigExt for SessionConfig {
.map(|c| c.codec())
.unwrap_or_else(||
Arc::new(BallistaPhysicalExtensionCodec::default()))
}
+
+ fn with_ballista_query_planner(
+ self,
+ planner: Arc<dyn QueryPlanner + Send + Sync + 'static>,
+ ) -> SessionConfig {
+ let extension = BallistaQueryPlannerExtension::new(planner);
+ self.with_extension(Arc::new(extension))
+ }
+
+ fn ballista_query_planner(
+ &self,
+ ) -> Option<Arc<dyn QueryPlanner + Send + Sync + 'static>> {
+ self.get_extension::<BallistaQueryPlannerExtension>()
+ .map(|c| c.planner())
+ }
}
/// Wrapper for [SessionConfig] extension
@@ -465,6 +486,21 @@ impl BallistaConfigExtensionPhysicalCodec {
}
}
+/// Wrapper for [SessionConfig] extension
+/// holding overridden [QueryPlanner]
+struct BallistaQueryPlannerExtension {
+ planner: Arc<dyn QueryPlanner + Send + Sync + 'static>,
+}
+
+impl BallistaQueryPlannerExtension {
+ fn new(planner: Arc<dyn QueryPlanner + Send + Sync + 'static>) -> Self {
+ Self { planner }
+ }
+ fn planner(&self) -> Arc<dyn QueryPlanner + Send + Sync + 'static> {
+ self.planner.clone()
+ }
+}
+
pub struct BallistaQueryPlanner<T: AsLogicalPlan> {
scheduler_url: String,
config: BallistaConfig,
@@ -656,12 +692,12 @@ mod test {
error::Result,
execution::{
runtime_env::{RuntimeConfig, RuntimeEnv},
- SessionStateBuilder,
+ SessionState, SessionStateBuilder,
},
prelude::{SessionConfig, SessionContext},
};
- use crate::utils::LocalRun;
+ use crate::utils::{LocalRun, SessionStateExt};
fn context() -> SessionContext {
let runtime_environment =
RuntimeEnv::new(RuntimeConfig::new()).unwrap();
@@ -738,4 +774,25 @@ mod test {
Ok(())
}
+
+ // Ballista disables round robin repatriations
+ #[tokio::test]
+ async fn should_disable_round_robin_repartition() {
+ let state = SessionState::new_ballista_state(
+ "scheduler_url".to_string(),
+ "session_id".to_string(),
+ )
+ .unwrap();
+
+ assert!(!state.config().round_robin_repartition());
+
+ let state = SessionStateBuilder::new().build();
+
+ assert!(state.config().round_robin_repartition());
+ let state = state
+ .upgrade_for_ballista("scheduler_url".to_string(),
"session_id".to_string())
+ .unwrap();
+
+ assert!(!state.config().round_robin_repartition());
+ }
}
diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml
index ed7c4318..b04abd9d 100644
--- a/ballista/executor/Cargo.toml
+++ b/ballista/executor/Cargo.toml
@@ -41,7 +41,7 @@ anyhow = "1"
arrow = { workspace = true }
arrow-flight = { workspace = true }
async-trait = { workspace = true }
-ballista-core = { path = "../core", version = "0.12.0", features = ["s3"] }
+ballista-core = { path = "../core", version = "0.12.0" }
configure_me = { workspace = true }
dashmap = { workspace = true }
datafusion = { workspace = true }
diff --git a/ballista/executor/src/bin/main.rs
b/ballista/executor/src/bin/main.rs
index ba56b333..9f5ed12f 100644
--- a/ballista/executor/src/bin/main.rs
+++ b/ballista/executor/src/bin/main.rs
@@ -87,6 +87,11 @@ async fn main() -> Result<()> {
cache_capacity: opt.cache_capacity,
cache_io_concurrency: opt.cache_io_concurrency,
execution_engine: None,
+ function_registry: None,
+ config_producer: None,
+ runtime_producer: None,
+ logical_codec: None,
+ physical_codec: None,
};
start_executor_process(Arc::new(config)).await
diff --git a/ballista/executor/src/execution_loop.rs
b/ballista/executor/src/execution_loop.rs
index 591c5c45..8056d6c5 100644
--- a/ballista/executor/src/execution_loop.rs
+++ b/ballista/executor/src/execution_loop.rs
@@ -15,40 +15,30 @@
// specific language governing permissions and limitations
// under the License.
-use datafusion::config::ConfigOptions;
-use datafusion::physical_plan::ExecutionPlan;
-
-use ballista_core::serde::protobuf::{
- scheduler_grpc_client::SchedulerGrpcClient, PollWorkParams, PollWorkResult,
- TaskDefinition, TaskStatus,
-};
-use datafusion::prelude::SessionConfig;
-use tokio::sync::{OwnedSemaphorePermit, Semaphore};
-
use crate::cpu_bound_executor::DedicatedExecutor;
use crate::executor::Executor;
use crate::{as_task_status, TaskExecutionTimes};
use ballista_core::error::BallistaError;
+use ballista_core::serde::protobuf::{
+ scheduler_grpc_client::SchedulerGrpcClient, PollWorkParams, PollWorkResult,
+ TaskDefinition, TaskStatus,
+};
use ballista_core::serde::scheduler::{ExecutorSpecification, PartitionId};
use ballista_core::serde::BallistaCodec;
use datafusion::execution::context::TaskContext;
-use datafusion::functions::datetime::date_part;
-use datafusion::functions::unicode::substr;
-use datafusion::functions_aggregate::covariance::{covar_pop_udaf,
covar_samp_udaf};
-use datafusion::functions_aggregate::sum::sum_udaf;
-use datafusion::functions_aggregate::variance::var_samp_udaf;
+use datafusion::physical_plan::ExecutionPlan;
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
use futures::FutureExt;
use log::{debug, error, info, warn};
use std::any::Any;
-use std::collections::HashMap;
use std::convert::TryInto;
use std::error::Error;
use std::ops::Deref;
use std::sync::mpsc::{Receiver, Sender, TryRecvError};
use std::time::{SystemTime, UNIX_EPOCH};
use std::{sync::Arc, time::Duration};
+use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tonic::transport::Channel;
pub async fn poll_loop<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan>(
@@ -172,43 +162,20 @@ async fn run_received_task<T: 'static + AsLogicalPlan, U:
'static + AsExecutionP
let task_identity = format!(
"TID {task_id}
{job_id}/{stage_id}.{stage_attempt_num}/{partition_id}.{task_attempt_num}"
);
- info!("Received task {}", task_identity);
-
- let mut task_props = HashMap::new();
+ info!(
+ "Received task: {}, task_properties: {:?}",
+ task_identity, task.props
+ );
+ let mut session_config = executor.produce_config();
for kv_pair in task.props {
- task_props.insert(kv_pair.key, kv_pair.value);
+ session_config = session_config.set_str(&kv_pair.key, &kv_pair.value);
}
- let mut config = ConfigOptions::new();
- for (k, v) in task_props {
- config.set(&k, &v)?;
- }
- let session_config = SessionConfig::from(config);
- let mut task_scalar_functions = HashMap::new();
- let mut task_aggregate_functions = HashMap::new();
- let mut task_window_functions = HashMap::new();
- // TODO combine the functions from Executor's functions and
TaskDefintion's function resources
- for scalar_func in executor.scalar_functions.clone() {
- task_scalar_functions.insert(scalar_func.0.clone(), scalar_func.1);
- }
- for agg_func in executor.aggregate_functions.clone() {
- task_aggregate_functions.insert(agg_func.0, agg_func.1);
- }
- // since DataFusion 38 some internal functions were converted to UDAF, so
- // we have to register them manually
- task_aggregate_functions.insert("var".to_string(), var_samp_udaf());
- task_aggregate_functions.insert("covar_samp".to_string(),
covar_samp_udaf());
- task_aggregate_functions.insert("covar_pop".to_string(), covar_pop_udaf());
- task_aggregate_functions.insert("SUM".to_string(), sum_udaf());
+ let task_scalar_functions =
executor.function_registry.scalar_functions.clone();
+ let task_aggregate_functions =
executor.function_registry.aggregate_functions.clone();
+ let task_window_functions =
executor.function_registry.window_functions.clone();
- // TODO which other functions need adding here?
- task_scalar_functions.insert("date_part".to_string(), date_part());
- task_scalar_functions.insert("substr".to_string(), substr());
-
- for window_func in executor.window_functions.clone() {
- task_window_functions.insert(window_func.0, window_func.1);
- }
- let runtime = executor.get_runtime();
+ let runtime = executor.produce_runtime(&session_config)?;
let session_id = task.session_id.clone();
let task_context = Arc::new(TaskContext::new(
Some(task_identity.clone()),
diff --git a/ballista/executor/src/executor.rs
b/ballista/executor/src/executor.rs
index 8ae8e6aa..53a36855 100644
--- a/ballista/executor/src/executor.rs
+++ b/ballista/executor/src/executor.rs
@@ -21,18 +21,19 @@ use crate::execution_engine::DefaultExecutionEngine;
use crate::execution_engine::ExecutionEngine;
use crate::execution_engine::QueryStageExecutor;
use crate::metrics::ExecutorMetricsCollector;
+use crate::metrics::LoggingMetricsCollector;
use ballista_core::error::BallistaError;
use ballista_core::serde::protobuf;
use ballista_core::serde::protobuf::ExecutorRegistration;
+use ballista_core::serde::scheduler::BallistaFunctionRegistry;
use ballista_core::serde::scheduler::PartitionId;
+use ballista_core::ConfigProducer;
+use ballista_core::RuntimeProducer;
use dashmap::DashMap;
use datafusion::execution::context::TaskContext;
use datafusion::execution::runtime_env::RuntimeEnv;
-use datafusion::functions::all_default_functions;
-use datafusion::functions_aggregate::all_default_aggregate_functions;
-use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
+use datafusion::prelude::SessionConfig;
use futures::future::AbortHandle;
-use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
@@ -63,17 +64,14 @@ pub struct Executor {
/// Directory for storing partial results
pub work_dir: String,
- /// Scalar functions that are registered in the Executor
- pub scalar_functions: HashMap<String, Arc<ScalarUDF>>,
+ /// Function registry
+ pub function_registry: Arc<BallistaFunctionRegistry>,
- /// Aggregate functions registered in the Executor
- pub aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
+ /// Creates [RuntimeEnv] based on [SessionConfig]
+ pub runtime_producer: RuntimeProducer,
- /// Window functions registered in the Executor
- pub window_functions: HashMap<String, Arc<WindowUDF>>,
-
- /// Runtime environment for Executor
- runtime: Arc<RuntimeEnv>,
+ /// Creates default [SessionConfig]
+ pub config_producer: ConfigProducer,
/// Collector for runtime execution metrics
pub metrics_collector: Arc<dyn ExecutorMetricsCollector>,
@@ -90,33 +88,47 @@ pub struct Executor {
}
impl Executor {
- /// Create a new executor instance
+ /// Create a new executor instance with given [RuntimeEnv]
+ /// It will use default scalar, aggregate and window functions
+ pub fn new_basic(
+ metadata: ExecutorRegistration,
+ work_dir: &str,
+ runtime_producer: RuntimeProducer,
+ config_producer: ConfigProducer,
+ concurrent_tasks: usize,
+ ) -> Self {
+ Self::new(
+ metadata,
+ work_dir,
+ runtime_producer,
+ config_producer,
+ Arc::new(BallistaFunctionRegistry::default()),
+ Arc::new(LoggingMetricsCollector::default()),
+ concurrent_tasks,
+ None,
+ )
+ }
+
+ /// Create a new executor instance with given [RuntimeEnv],
+ /// [ScalarUDF], [AggregateUDF] and [WindowUDF]
+
+ #[allow(clippy::too_many_arguments)]
pub fn new(
metadata: ExecutorRegistration,
work_dir: &str,
- runtime: Arc<RuntimeEnv>,
+ runtime_producer: RuntimeProducer,
+ config_producer: ConfigProducer,
+ function_registry: Arc<BallistaFunctionRegistry>,
metrics_collector: Arc<dyn ExecutorMetricsCollector>,
concurrent_tasks: usize,
execution_engine: Option<Arc<dyn ExecutionEngine>>,
) -> Self {
- let scalar_functions = all_default_functions()
- .into_iter()
- .map(|f| (f.name().to_string(), f))
- .collect();
-
- let aggregate_functions = all_default_aggregate_functions()
- .into_iter()
- .map(|f| (f.name().to_string(), f))
- .collect();
-
Self {
metadata,
work_dir: work_dir.to_owned(),
- scalar_functions,
- aggregate_functions,
- // TODO: set to default window functions when they are moved to
udwf
- window_functions: HashMap::new(),
- runtime,
+ function_registry,
+ runtime_producer,
+ config_producer,
metrics_collector,
concurrent_tasks,
abort_handles: Default::default(),
@@ -127,8 +139,15 @@ impl Executor {
}
impl Executor {
- pub fn get_runtime(&self) -> Arc<RuntimeEnv> {
- self.runtime.clone()
+ pub fn produce_runtime(
+ &self,
+ config: &SessionConfig,
+ ) -> datafusion::error::Result<Arc<RuntimeEnv>> {
+ (self.runtime_producer)(config)
+ }
+
+ pub fn produce_config(&self) -> SessionConfig {
+ (self.config_producer)()
}
/// Execute one partition of a query stage and persist the result to disk
in IPC format. On
@@ -197,12 +216,13 @@ impl Executor {
mod test {
use crate::execution_engine::DefaultQueryStageExec;
use crate::executor::Executor;
- use crate::metrics::LoggingMetricsCollector;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
+ use ballista_core::config::BallistaConfig;
use ballista_core::execution_plans::ShuffleWriterExec;
use ballista_core::serde::protobuf::ExecutorRegistration;
use ballista_core::serde::scheduler::PartitionId;
+ use ballista_core::RuntimeProducer;
use datafusion::error::{DataFusionError, Result};
use datafusion::execution::context::TaskContext;
@@ -210,7 +230,7 @@ mod test {
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning,
PlanProperties,
RecordBatchStream, SendableRecordBatchStream, Statistics,
};
- use datafusion::prelude::SessionContext;
+ use datafusion::prelude::{SessionConfig, SessionContext};
use futures::Stream;
use std::any::Any;
use std::pin::Pin;
@@ -341,16 +361,20 @@ mod test {
specification: None,
optional_host: None,
};
-
+ let config_producer = Arc::new(|| {
+
SessionConfig::new().with_option_extension(BallistaConfig::new().unwrap())
+ });
let ctx = SessionContext::new();
+ let runtime_env = ctx.runtime_env().clone();
+ let runtime_producer: RuntimeProducer =
+ Arc::new(move |_| Ok(runtime_env.clone()));
- let executor = Executor::new(
+ let executor = Executor::new_basic(
executor_registration,
&work_dir,
- ctx.runtime_env(),
- Arc::new(LoggingMetricsCollector {}),
+ runtime_producer,
+ config_producer,
2,
- None,
);
let (sender, receiver) = tokio::sync::oneshot::channel();
diff --git a/ballista/executor/src/executor_process.rs
b/ballista/executor/src/executor_process.rs
index c19f0656..a15bfadb 100644
--- a/ballista/executor/src/executor_process.rs
+++ b/ballista/executor/src/executor_process.rs
@@ -25,6 +25,10 @@ use std::{env, io};
use anyhow::{Context, Result};
use arrow_flight::flight_service_server::FlightServiceServer;
+use ballista_core::serde::scheduler::BallistaFunctionRegistry;
+use datafusion::prelude::SessionConfig;
+use datafusion_proto::logical_plan::LogicalExtensionCodec;
+use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use futures::stream::FuturesUnordered;
use futures::StreamExt;
use log::{error, info, warn};
@@ -38,11 +42,11 @@ use tracing_subscriber::EnvFilter;
use uuid::Uuid;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
-use datafusion_proto::protobuf::{LogicalPlanNode, PhysicalPlanNode};
-use ballista_core::config::{DataCachePolicy, LogRotationPolicy,
TaskSchedulingPolicy};
+use ballista_core::config::{
+ BallistaConfig, DataCachePolicy, LogRotationPolicy, TaskSchedulingPolicy,
+};
use ballista_core::error::BallistaError;
-use ballista_core::object_store_registry::with_object_store_registry;
use ballista_core::serde::protobuf::executor_resource::Resource;
use ballista_core::serde::protobuf::executor_status::Status;
use ballista_core::serde::protobuf::{
@@ -50,11 +54,13 @@ use ballista_core::serde::protobuf::{
ExecutorRegistration, ExecutorResource, ExecutorSpecification,
ExecutorStatus,
ExecutorStoppedParams, HeartBeatParams,
};
-use ballista_core::serde::BallistaCodec;
+use ballista_core::serde::{
+ BallistaCodec, BallistaLogicalExtensionCodec,
BallistaPhysicalExtensionCodec,
+};
use ballista_core::utils::{
create_grpc_client_connection, create_grpc_server, get_time_before,
};
-use ballista_core::BALLISTA_VERSION;
+use ballista_core::{ConfigProducer, RuntimeProducer, BALLISTA_VERSION};
use crate::execution_engine::ExecutionEngine;
use crate::executor::{Executor, TasksDrainedFuture};
@@ -96,6 +102,16 @@ pub struct ExecutorProcessConfig {
/// Optional execution engine to use to execute physical plans, will
default to
/// DataFusion if none is provided.
pub execution_engine: Option<Arc<dyn ExecutionEngine>>,
+ /// Overrides default function registry
+ pub function_registry: Option<Arc<BallistaFunctionRegistry>>,
+ /// [RuntimeProducer] override option
+ pub runtime_producer: Option<RuntimeProducer>,
+ /// [ConfigProducer] override option
+ pub config_producer: Option<ConfigProducer>,
+ /// [PhysicalExtensionCodec] override option
+ pub logical_codec: Option<Arc<dyn LogicalExtensionCodec>>,
+ /// [PhysicalExtensionCodec] override option
+ pub physical_codec: Option<Arc<dyn PhysicalExtensionCodec>>,
}
pub async fn start_executor_process(opt: Arc<ExecutorProcessConfig>) ->
Result<()> {
@@ -181,20 +197,40 @@ pub async fn start_executor_process(opt:
Arc<ExecutorProcessConfig>) -> Result<(
}),
};
- let config = RuntimeConfig::new().with_temp_file_path(work_dir.clone());
- let runtime = {
- let config = with_object_store_registry(config.clone());
- Arc::new(RuntimeEnv::new(config).map_err(|_| {
- BallistaError::Internal("Failed to init Executor
RuntimeEnv".to_owned())
- })?)
- };
-
+ // put them to session config
let metrics_collector = Arc::new(LoggingMetricsCollector::default());
+ let config_producer = opt.config_producer.clone().unwrap_or_else(|| {
+ Arc::new(|| {
+
SessionConfig::new().with_option_extension(BallistaConfig::new().unwrap())
+ })
+ });
+ let wd = work_dir.clone();
+ let runtime_producer: RuntimeProducer = Arc::new(move |_| {
+ let config = RuntimeConfig::new().with_temp_file_path(wd.clone());
+ Ok(Arc::new(RuntimeEnv::new(config)?))
+ });
+
+ let logical = opt
+ .logical_codec
+ .clone()
+ .unwrap_or_else(|| Arc::new(BallistaLogicalExtensionCodec::default()));
+
+ let physical = opt
+ .physical_codec
+ .clone()
+ .unwrap_or_else(||
Arc::new(BallistaPhysicalExtensionCodec::default()));
+
+ let default_codec: BallistaCodec<
+ datafusion_proto::protobuf::LogicalPlanNode,
+ datafusion_proto::protobuf::PhysicalPlanNode,
+ > = BallistaCodec::new(logical, physical);
let executor = Arc::new(Executor::new(
executor_meta,
&work_dir,
- runtime,
+ runtime_producer,
+ config_producer,
+ opt.function_registry.clone().unwrap_or_default(),
metrics_collector,
concurrent_tasks,
opt.execution_engine.clone(),
@@ -244,9 +280,6 @@ pub async fn start_executor_process(opt:
Arc<ExecutorProcessConfig>) -> Result<(
.max_encoding_message_size(opt.grpc_max_encoding_message_size as usize)
.max_decoding_message_size(opt.grpc_max_decoding_message_size as
usize);
- let default_codec: BallistaCodec<LogicalPlanNode, PhysicalPlanNode> =
- BallistaCodec::default();
-
let scheduler_policy = opt.task_scheduling_policy;
let job_data_ttl_seconds = opt.job_data_ttl_seconds;
diff --git a/ballista/executor/src/executor_server.rs
b/ballista/executor/src/executor_server.rs
index 6e3d5589..cfbc2bd4 100644
--- a/ballista/executor/src/executor_server.rs
+++ b/ballista/executor/src/executor_server.rs
@@ -46,9 +46,7 @@ use ballista_core::serde::scheduler::TaskDefinition;
use ballista_core::serde::BallistaCodec;
use ballista_core::utils::{create_grpc_client_connection, create_grpc_server};
use dashmap::DashMap;
-use datafusion::config::ConfigOptions;
use datafusion::execution::TaskContext;
-use datafusion::prelude::SessionConfig;
use datafusion_proto::{logical_plan::AsLogicalPlan,
physical_plan::AsExecutionPlan};
use tokio::sync::mpsc::error::TryRecvError;
use tokio::task::JoinHandle;
@@ -342,22 +340,13 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> ExecutorServer<T,
.unwrap();
let task_context = {
- let task_props = task.props;
- let mut config = ConfigOptions::new();
- for (k, v) in task_props.iter() {
- if let Err(e) = config.set(k, v) {
- debug!("Fail to set session config for ({},{}): {:?}", k,
v, e);
- }
- }
- let session_config = SessionConfig::from(config);
-
let function_registry = task.function_registry;
- let runtime = self.executor.get_runtime();
+ let runtime =
self.executor.produce_runtime(&task.session_config).unwrap();
Arc::new(TaskContext::new(
Some(task_identity.clone()),
task.session_id,
- session_config,
+ task.session_config,
function_registry.scalar_functions.clone(),
function_registry.aggregate_functions.clone(),
function_registry.window_functions.clone(),
@@ -641,10 +630,11 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> ExecutorGrpc
scheduler_id: scheduler_id.clone(),
task: get_task_definition(
task,
- self.executor.get_runtime(),
- self.executor.scalar_functions.clone(),
- self.executor.aggregate_functions.clone(),
- self.executor.window_functions.clone(),
+ self.executor.runtime_producer.clone(),
+ self.executor.produce_config(),
+
self.executor.function_registry.scalar_functions.clone(),
+
self.executor.function_registry.aggregate_functions.clone(),
+
self.executor.function_registry.window_functions.clone(),
self.codec.clone(),
)
.map_err(|e| Status::invalid_argument(format!("{e}")))?,
@@ -669,10 +659,11 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> ExecutorGrpc
for multi_task in multi_tasks {
let multi_task: Vec<TaskDefinition> = get_task_definition_vec(
multi_task,
- self.executor.get_runtime(),
- self.executor.scalar_functions.clone(),
- self.executor.aggregate_functions.clone(),
- self.executor.window_functions.clone(),
+ self.executor.runtime_producer.clone(),
+ self.executor.produce_config(),
+ self.executor.function_registry.scalar_functions.clone(),
+ self.executor.function_registry.aggregate_functions.clone(),
+ self.executor.function_registry.window_functions.clone(),
self.codec.clone(),
)
.map_err(|e| Status::invalid_argument(format!("{e}")))?;
diff --git a/ballista/executor/src/lib.rs b/ballista/executor/src/lib.rs
index beb9faac..b7219225 100644
--- a/ballista/executor/src/lib.rs
+++ b/ballista/executor/src/lib.rs
@@ -32,6 +32,7 @@ mod cpu_bound_executor;
mod standalone;
pub use standalone::new_standalone_executor;
+pub use standalone::new_standalone_executor_from_state;
use log::info;
diff --git a/ballista/executor/src/standalone.rs
b/ballista/executor/src/standalone.rs
index 38e27713..628de96f 100644
--- a/ballista/executor/src/standalone.rs
+++ b/ballista/executor/src/standalone.rs
@@ -18,6 +18,8 @@
use crate::metrics::LoggingMetricsCollector;
use crate::{execution_loop, executor::Executor,
flight_service::BallistaFlightService};
use arrow_flight::flight_service_server::FlightServiceServer;
+use ballista_core::config::BallistaConfig;
+use ballista_core::utils::SessionConfigExt;
use ballista_core::{
error::Result,
object_store_registry::with_object_store_registry,
@@ -28,7 +30,10 @@ use ballista_core::{
utils::create_grpc_server,
BALLISTA_VERSION,
};
+use ballista_core::{ConfigProducer, RuntimeProducer};
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
+use datafusion::execution::SessionState;
+use datafusion::prelude::SessionConfig;
use datafusion_proto::logical_plan::AsLogicalPlan;
use datafusion_proto::physical_plan::AsExecutionPlan;
use log::info;
@@ -38,14 +43,26 @@ use tokio::net::TcpListener;
use tonic::transport::Channel;
use uuid::Uuid;
-pub async fn new_standalone_executor<
+/// Creates new standalone executor based on
+/// session_state provided.
+///
+/// This provides flexible way of configuring underlying
+/// components.
+pub async fn new_standalone_executor_from_state<
T: 'static + AsLogicalPlan,
U: 'static + AsExecutionPlan,
>(
scheduler: SchedulerGrpcClient<Channel>,
concurrent_tasks: usize,
- codec: BallistaCodec<T, U>,
+ session_state: &SessionState,
) -> Result<()> {
+ let logical = session_state.config().ballista_logical_extension_codec();
+ let physical = session_state.config().ballista_physical_extension_codec();
+ let codec: BallistaCodec<
+ datafusion_proto::protobuf::LogicalPlanNode,
+ datafusion_proto::protobuf::PhysicalPlanNode,
+ > = BallistaCodec::new(logical, physical);
+
// Let the OS assign a random, free port
let listener = TcpListener::bind("localhost:0").await?;
let addr = listener.local_addr()?;
@@ -74,14 +91,21 @@ pub async fn new_standalone_executor<
.unwrap();
info!("work_dir: {}", work_dir);
- let config = with_object_store_registry(
- RuntimeConfig::new().with_temp_file_path(work_dir.clone()),
- );
+ let config = session_state
+ .config()
+ .clone()
+ .with_option_extension(BallistaConfig::new().unwrap());
+ let runtime = session_state.runtime_env().clone();
+
+ let config_producer: ConfigProducer = Arc::new(move || config.clone());
+ let runtime_producer: RuntimeProducer = Arc::new(move |_|
Ok(runtime.clone()));
let executor = Arc::new(Executor::new(
executor_meta,
&work_dir,
- Arc::new(RuntimeEnv::new(config).unwrap()),
+ runtime_producer,
+ config_producer,
+ Arc::new(session_state.into()),
Arc::new(LoggingMetricsCollector::default()),
concurrent_tasks,
None,
@@ -100,3 +124,74 @@ pub async fn new_standalone_executor<
tokio::spawn(execution_loop::poll_loop(scheduler, executor, codec));
Ok(())
}
+
+/// Creates standalone executor with most values
+/// set as default.
+pub async fn new_standalone_executor<
+ T: 'static + AsLogicalPlan,
+ U: 'static + AsExecutionPlan,
+>(
+ scheduler: SchedulerGrpcClient<Channel>,
+ concurrent_tasks: usize,
+ codec: BallistaCodec<T, U>,
+) -> Result<()> {
+ // Let the OS assign a random, free port
+ let listener = TcpListener::bind("localhost:0").await?;
+ let addr = listener.local_addr()?;
+ info!(
+ "Ballista v{} Rust Executor listening on {:?}",
+ BALLISTA_VERSION, addr
+ );
+
+ let executor_meta = ExecutorRegistration {
+ id: Uuid::new_v4().to_string(), // assign this executor a unique ID
+ optional_host: Some(OptionalHost::Host("localhost".to_string())),
+ port: addr.port() as u32,
+ // TODO Make it configurable
+ grpc_port: 50020,
+ specification: Some(
+ ExecutorSpecification {
+ task_slots: concurrent_tasks as u32,
+ }
+ .into(),
+ ),
+ };
+ let work_dir = TempDir::new()?
+ .into_path()
+ .into_os_string()
+ .into_string()
+ .unwrap();
+ info!("work_dir: {}", work_dir);
+
+ let config_producer = Arc::new(|| {
+
SessionConfig::new().with_option_extension(BallistaConfig::new().unwrap())
+ });
+ let wd = work_dir.clone();
+ let runtime_producer: RuntimeProducer = Arc::new(move |_: &SessionConfig| {
+ let config = with_object_store_registry(
+ RuntimeConfig::new().with_temp_file_path(wd.clone()),
+ );
+ Ok(Arc::new(RuntimeEnv::new(config)?))
+ });
+
+ let executor = Arc::new(Executor::new_basic(
+ executor_meta,
+ &work_dir,
+ runtime_producer,
+ config_producer,
+ concurrent_tasks,
+ ));
+
+ let service = BallistaFlightService::new();
+ let server = FlightServiceServer::new(service);
+ tokio::spawn(
+ create_grpc_server()
+ .add_service(server)
+
.serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(
+ listener,
+ )),
+ );
+
+ tokio::spawn(execution_loop::poll_loop(scheduler, executor, codec));
+ Ok(())
+}
diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml
index a1d59735..642e63d4 100644
--- a/ballista/scheduler/Cargo.toml
+++ b/ballista/scheduler/Cargo.toml
@@ -45,7 +45,7 @@ anyhow = "1"
arrow-flight = { workspace = true }
async-trait = { workspace = true }
axum = "0.7.7"
-ballista-core = { path = "../core", version = "0.12.0", features = ["s3"] }
+ballista-core = { path = "../core", version = "0.12.0" }
base64 = { version = "0.22" }
clap = { workspace = true }
configure_me = { workspace = true }
diff --git a/ballista/scheduler/src/cluster/memory.rs
b/ballista/scheduler/src/cluster/memory.rs
index f2fe589a..861b8657 100644
--- a/ballista/scheduler/src/cluster/memory.rs
+++ b/ballista/scheduler/src/cluster/memory.rs
@@ -401,7 +401,7 @@ impl JobState for InMemoryJobState {
&self,
config: &BallistaConfig,
) -> Result<Arc<SessionContext>> {
- let session = create_datafusion_context(config, self.session_builder);
+ let session = create_datafusion_context(config,
self.session_builder.clone());
self.sessions.insert(session.session_id(), session.clone());
Ok(session)
@@ -412,7 +412,7 @@ impl JobState for InMemoryJobState {
session_id: &str,
config: &BallistaConfig,
) -> Result<Arc<SessionContext>> {
- let session = create_datafusion_context(config, self.session_builder);
+ let session = create_datafusion_context(config,
self.session_builder.clone());
self.sessions
.insert(session_id.to_string(), session.clone());
@@ -486,6 +486,8 @@ impl JobState for InMemoryJobState {
#[cfg(test)]
mod test {
+ use std::sync::Arc;
+
use crate::cluster::memory::InMemoryJobState;
use crate::cluster::test_util::{test_job_lifecycle,
test_job_planning_failure};
use crate::test_utils::{
@@ -497,17 +499,17 @@ mod test {
#[tokio::test]
async fn test_in_memory_job_lifecycle() -> Result<()> {
test_job_lifecycle(
- InMemoryJobState::new("", default_session_builder),
+ InMemoryJobState::new("", Arc::new(default_session_builder)),
test_aggregation_plan(4).await,
)
.await?;
test_job_lifecycle(
- InMemoryJobState::new("", default_session_builder),
+ InMemoryJobState::new("", Arc::new(default_session_builder)),
test_two_aggregations_plan(4).await,
)
.await?;
test_job_lifecycle(
- InMemoryJobState::new("", default_session_builder),
+ InMemoryJobState::new("", Arc::new(default_session_builder)),
test_join_plan(4).await,
)
.await?;
@@ -518,17 +520,17 @@ mod test {
#[tokio::test]
async fn test_in_memory_job_planning_failure() -> Result<()> {
test_job_planning_failure(
- InMemoryJobState::new("", default_session_builder),
+ InMemoryJobState::new("", Arc::new(default_session_builder)),
test_aggregation_plan(4).await,
)
.await?;
test_job_planning_failure(
- InMemoryJobState::new("", default_session_builder),
+ InMemoryJobState::new("", Arc::new(default_session_builder)),
test_two_aggregations_plan(4).await,
)
.await?;
test_job_planning_failure(
- InMemoryJobState::new("", default_session_builder),
+ InMemoryJobState::new("", Arc::new(default_session_builder)),
test_join_plan(4).await,
)
.await?;
diff --git a/ballista/scheduler/src/cluster/mod.rs
b/ballista/scheduler/src/cluster/mod.rs
index 81432056..450c8018 100644
--- a/ballista/scheduler/src/cluster/mod.rs
+++ b/ballista/scheduler/src/cluster/mod.rs
@@ -109,7 +109,7 @@ impl BallistaCluster {
match &config.cluster_storage {
ClusterStorageConfig::Memory => Ok(BallistaCluster::new_memory(
scheduler,
- default_session_builder,
+ Arc::new(default_session_builder),
)),
}
}
diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs
b/ballista/scheduler/src/scheduler_server/grpc.rs
index 653bda83..e475e438 100644
--- a/ballista/scheduler/src/scheduler_server/grpc.rs
+++ b/ballista/scheduler/src/scheduler_server/grpc.rs
@@ -424,6 +424,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerGrpc
} = query_params
{
let mut query_settings = HashMap::new();
+ log::trace!("received query settings: {:?}", settings);
for kv_pair in settings {
query_settings.insert(kv_pair.key, kv_pair.value);
}
@@ -523,6 +524,7 @@ impl<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> SchedulerGrpc
.cloned()
.unwrap_or_else(|| "None".to_string());
+ log::trace!("setting job name: {}", job_name);
self.submit_job(&job_id, &job_name, session_ctx, &plan)
.await
.map_err(|e| {
diff --git a/ballista/scheduler/src/scheduler_server/mod.rs
b/ballista/scheduler/src/scheduler_server/mod.rs
index 3e2da13b..7ec0e63e 100644
--- a/ballista/scheduler/src/scheduler_server/mod.rs
+++ b/ballista/scheduler/src/scheduler_server/mod.rs
@@ -56,7 +56,7 @@ mod external_scaler;
mod grpc;
pub(crate) mod query_stage_scheduler;
-pub(crate) type SessionBuilder = fn(SessionConfig) -> SessionState;
+pub(crate) type SessionBuilder = Arc<dyn Fn(SessionConfig) -> SessionState +
Send + Sync>;
#[derive(Clone)]
pub struct SchedulerServer<T: 'static + AsLogicalPlan, U: 'static +
AsExecutionPlan> {
diff --git a/ballista/scheduler/src/standalone.rs
b/ballista/scheduler/src/standalone.rs
index bb6d7006..5ff4d611 100644
--- a/ballista/scheduler/src/standalone.rs
+++ b/ballista/scheduler/src/standalone.rs
@@ -20,11 +20,15 @@ use crate::config::SchedulerConfig;
use crate::metrics::default_metrics_collector;
use crate::scheduler_server::SchedulerServer;
use ballista_core::serde::BallistaCodec;
-use ballista_core::utils::{create_grpc_server, default_session_builder};
+use ballista_core::utils::{
+ create_grpc_server, default_session_builder, SessionConfigExt,
+};
use ballista_core::{
error::Result, serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer,
BALLISTA_VERSION,
};
+use datafusion::execution::{SessionState, SessionStateBuilder};
+use datafusion::prelude::SessionConfig;
use datafusion_proto::protobuf::LogicalPlanNode;
use datafusion_proto::protobuf::PhysicalPlanNode;
use log::info;
@@ -33,15 +37,39 @@ use std::sync::Arc;
use tokio::net::TcpListener;
pub async fn new_standalone_scheduler() -> Result<SocketAddr> {
- let metrics_collector = default_metrics_collector()?;
+ let codec = BallistaCodec::default();
+ new_standalone_scheduler_with_builder(Arc::new(default_session_builder),
codec).await
+}
+
+pub async fn new_standalone_scheduler_from_state(
+ session_state: &SessionState,
+) -> Result<SocketAddr> {
+ let logical = session_state.config().ballista_logical_extension_codec();
+ let physical = session_state.config().ballista_physical_extension_codec();
+ let codec = BallistaCodec::new(logical, physical);
- let cluster = BallistaCluster::new_memory("localhost:50050",
default_session_builder);
+ let session_state = session_state.clone();
+ let session_builder = Arc::new(move |c: SessionConfig| {
+ SessionStateBuilder::new_from_existing(session_state.clone())
+ .with_config(c)
+ .build()
+ });
+
+ new_standalone_scheduler_with_builder(session_builder, codec).await
+}
+
+async fn new_standalone_scheduler_with_builder(
+ session_builder: crate::scheduler_server::SessionBuilder,
+ codec: BallistaCodec,
+) -> Result<SocketAddr> {
+ let cluster = BallistaCluster::new_memory("localhost:50050",
session_builder);
+ let metrics_collector = default_metrics_collector()?;
let mut scheduler_server: SchedulerServer<LogicalPlanNode,
PhysicalPlanNode> =
SchedulerServer::new(
"localhost:50050".to_owned(),
cluster,
- BallistaCodec::default(),
+ codec,
Arc::new(SchedulerConfig::default()),
metrics_collector,
);
diff --git a/ballista/scheduler/src/test_utils.rs
b/ballista/scheduler/src/test_utils.rs
index 27bc0ec8..f9eae315 100644
--- a/ballista/scheduler/src/test_utils.rs
+++ b/ballista/scheduler/src/test_utils.rs
@@ -124,7 +124,7 @@ pub async fn await_condition<Fut: Future<Output =
Result<bool>>, F: Fn() -> Fut>
}
pub fn test_cluster_context() -> BallistaCluster {
- BallistaCluster::new_memory(TEST_SCHEDULER_NAME, default_session_builder)
+ BallistaCluster::new_memory(TEST_SCHEDULER_NAME,
Arc::new(default_session_builder))
}
pub async fn datafusion_test_context(path: &str) -> Result<SessionContext> {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]