This is an automated email from the ASF dual-hosted git repository.
milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git
The following commit(s) were added to refs/heads/main by this push:
new 771e28091 feat: add `Dataframe.cache()` factory (no planner handling)
(#1420)
771e28091 is described below
commit 771e280916aa519d3a89f0ba1a8ec71b3df86ea1
Author: jgrim <[email protected]>
AuthorDate: Wed Feb 18 14:30:36 2026 +0100
feat: add `Dataframe.cache()` factory (no planner handling) (#1420)
* feat: disable DataFrame.cache() for ballista
* add failing test for cache collect
* feat: Add logical plan cache node extension
* fix: update after review
* fix: update after review (cleanup tests, update ballista cache noop)
* fix: add missing default_codec usage
---
ballista/client/tests/context_unsupported.rs | 36 +++++++++
ballista/core/proto/ballista.proto | 14 ++++
ballista/core/src/config.rs | 27 +++++--
ballista/core/src/extension.rs | 110 +++++++++++++++++++++++++-
ballista/core/src/serde/generated/ballista.rs | 25 ++++++
ballista/core/src/serde/mod.rs | 51 +++++++++++-
6 files changed, 250 insertions(+), 13 deletions(-)
diff --git a/ballista/client/tests/context_unsupported.rs
b/ballista/client/tests/context_unsupported.rs
index 8cda528ee..3c6b25a98 100644
--- a/ballista/client/tests/context_unsupported.rs
+++ b/ballista/client/tests/context_unsupported.rs
@@ -59,6 +59,7 @@ mod unsupported {
.await
.unwrap();
}
+
#[rstest]
#[case::standalone(standalone_context())]
#[case::remote(remote_context())]
@@ -70,6 +71,12 @@ mod unsupported {
ctx: SessionContext,
test_data: String,
) {
+ ctx.sql("SET ballista.cache.noop = false")
+ .await
+ .unwrap()
+ .show()
+ .await
+ .unwrap();
let df = ctx
.read_parquet(
&format!("{test_data}/alltypes_plain.parquet"),
@@ -97,6 +104,35 @@ mod unsupported {
assert_batches_eq!(expected, &result);
}
+ #[rstest]
+ #[case::standalone(standalone_context())]
+ #[case::remote(remote_context())]
+ #[tokio::test]
+ async fn should_support_on_cache_collect(
+ #[future(awt)]
+ #[case]
+ ctx: SessionContext,
+ ) -> datafusion::error::Result<()> {
+ // opt out case, should fail
+ ctx.sql("SET ballista.cache.noop = false")
+ .await?
+ .show()
+ .await?;
+ let cached_df = ctx.sql("SELECT 1").await?.cache().await?;
+
+ // Collect fails because extension node is not handled for now by
default query planner
+ let result = cached_df.collect().await;
+
+ assert!(result.is_err());
+ let err_msg = result.unwrap_err().to_string();
+ assert!(
+ err_msg.contains("No installed planner was able to convert the
custom node to an execution plan: BallistaCacheNode"),
+ "Expected planner error, got: {err_msg}"
+ );
+
+ Ok(())
+ }
+
#[rstest]
#[case::standalone(standalone_context())]
#[case::remote(remote_context())]
diff --git a/ballista/core/proto/ballista.proto
b/ballista/core/proto/ballista.proto
index 8660b6c72..97812bb0d 100644
--- a/ballista/core/proto/ballista.proto
+++ b/ballista/core/proto/ballista.proto
@@ -27,6 +27,20 @@ option java_outer_classname = "BallistaProto";
import "datafusion.proto";
import "datafusion_common.proto";
+///////////////////////////////////////////////////////////////////////////////////////////////////
+// Ballista Logical Plan
+///////////////////////////////////////////////////////////////////////////////////////////////////
+message BallistaLogicalPlanNode {
+ oneof LogicalPlanType {
+ LogicalPlanCacheNode cache_node = 1;
+ }
+}
+
+message LogicalPlanCacheNode {
+ string cache_id = 1;
+ string session_id = 2;
+}
+
///////////////////////////////////////////////////////////////////////////////////////////////////
// Ballista Physical Plan
///////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs
index b9a78647f..688f613c3 100644
--- a/ballista/core/src/config.rs
+++ b/ballista/core/src/config.rs
@@ -31,9 +31,10 @@ use datafusion::{
pub const BALLISTA_JOB_NAME: &str = "ballista.job.name";
/// Configuration key for standalone processing parallelism.
pub const BALLISTA_STANDALONE_PARALLELISM: &str =
"ballista.standalone.parallelism";
-/// max message size for gRPC clients
-pub const BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE: &str =
- "ballista.grpc_client_max_message_size";
+
+/// Configuration key for disabling default cache extension node.
+pub const BALLISTA_CACHE_NOOP: &str = "ballista.cache.noop";
+
/// Configuration key for maximum concurrent shuffle read requests.
pub const BALLISTA_SHUFFLE_READER_MAX_REQUESTS: &str =
"ballista.shuffle.max_concurrent_read_requests";
@@ -44,6 +45,9 @@ pub const BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ: &str =
pub const BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT: &str =
"ballista.shuffle.remote_read_prefer_flight";
+/// max message size for gRPC clients
+pub const BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE: &str =
+ "ballista.grpc_client_max_message_size";
/// Configuration key for gRPC client connection timeout in seconds.
pub const BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS: &str =
"ballista.grpc.client.connect_timeout_seconds";
@@ -85,10 +89,10 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String,
ConfigEntry>> = LazyLock::new(||
ConfigEntry::new(BALLISTA_STANDALONE_PARALLELISM.to_string(),
"Standalone processing parallelism ".to_string(),
DataType::UInt16,
Some(std::thread::available_parallelism().map(|v|
v.get()).unwrap_or(1).to_string())),
- ConfigEntry::new(BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE.to_string(),
- "Configuration for max message size in gRPC
clients".to_string(),
- DataType::UInt64,
- Some((16 * 1024 * 1024).to_string())),
+ ConfigEntry::new(BALLISTA_CACHE_NOOP.to_string(),
+ "Disable default cache node extension".to_string(),
+ DataType::Boolean,
+ Some((true).to_string())),
ConfigEntry::new(BALLISTA_SHUFFLE_READER_MAX_REQUESTS.to_string(),
"Maximum concurrent requests shuffle reader can
process".to_string(),
DataType::UInt64,
@@ -101,6 +105,10 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String,
ConfigEntry>> = LazyLock::new(||
"Forces the shuffle reader to use flight reader
instead of block reader for remote read. Block reader usually has better
performance and resource utilization".to_string(),
DataType::Boolean,
Some((false).to_string())),
+ ConfigEntry::new(BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE.to_string(),
+ "Configuration for max message size in gRPC
clients".to_string(),
+ DataType::UInt64,
+ Some((16 * 1024 * 1024).to_string())),
ConfigEntry::new(BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS.to_string(),
"Connection timeout for gRPC client in
seconds".to_string(),
DataType::UInt64,
@@ -283,6 +291,11 @@ impl BallistaConfig {
self.get_usize_setting(BALLISTA_GRPC_CLIENT_HTTP2_KEEPALIVE_INTERVAL_SECONDS)
}
+ /// Returns whether the default cache node extension is disabled.
+ pub fn cache_noop(&self) -> bool {
+ self.get_bool_setting(BALLISTA_CACHE_NOOP)
+ }
+
/// Forces the shuffle reader to always read partitions via the Arrow
Flight client,
/// even when partitions are local to the node.
///
diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs
index 5341020d1..7b82a1a2c 100644
--- a/ballista/core/src/extension.rs
+++ b/ballista/core/src/extension.rs
@@ -24,14 +24,17 @@ use crate::config::{
use crate::planner::BallistaQueryPlanner;
use crate::serde::protobuf::KeyValuePair;
use crate::serde::{BallistaLogicalExtensionCodec,
BallistaPhysicalExtensionCodec};
+use datafusion::common::DFSchemaRef;
use datafusion::execution::context::{QueryPlanner, SessionConfig,
SessionState};
use datafusion::execution::runtime_env::RuntimeEnvBuilder;
-use datafusion::execution::session_state::SessionStateBuilder;
+use datafusion::execution::session_state::{CacheFactory, SessionStateBuilder};
use datafusion::functions::all_default_functions;
use datafusion::functions_aggregate::all_default_aggregate_functions;
use datafusion::functions_nested::all_default_nested_functions;
use datafusion::functions_window::all_default_window_functions;
use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
+use datafusion::logical_expr::{Extension, LogicalPlan,
UserDefinedLogicalNodeCore};
+use datafusion::prelude::Expr;
use datafusion_proto::logical_plan::LogicalExtensionCodec;
use datafusion_proto::physical_plan::PhysicalExtensionCodec;
use datafusion_proto::protobuf::LogicalPlanNode;
@@ -43,6 +46,7 @@ use tonic::metadata::MetadataMap;
use tonic::service::Interceptor;
use tonic::transport::Endpoint;
use tonic::{Request, Status};
+use uuid::Uuid;
/// Type alias for the endpoint override function used in gRPC client
configuration
pub type EndpointOverrideFn =
@@ -253,6 +257,7 @@ impl SessionStateExt for SessionState {
let session_state = SessionStateBuilder::new()
.with_default_features()
.with_config(session_config)
+ .with_cache_factory(Some(Arc::new(BallistaCacheFactory::new())))
.with_runtime_env(Arc::new(runtime_env))
.with_query_planner(Arc::new(planner))
.with_scalar_functions(ballista_scalar_functions())
@@ -274,8 +279,9 @@ impl SessionStateExt for SessionState {
let ballista_config = session_config.ballista_config();
- let builder =
-
SessionStateBuilder::new_from_existing(self).with_config(session_config);
+ let builder = SessionStateBuilder::new_from_existing(self)
+ .with_config(session_config)
+ .with_cache_factory(Some(Arc::new(BallistaCacheFactory::new())));
let builder = match planner_override {
Some(planner) => builder.with_query_planner(planner),
@@ -720,6 +726,104 @@ impl BallistaConfigGrpcEndpoint {
#[derive(Clone, Copy)]
pub struct BallistaUseTls(pub bool);
+#[derive(Debug)]
+struct BallistaCacheFactory;
+
+impl BallistaCacheFactory {
+ fn new() -> Self {
+ Self {}
+ }
+}
+
+impl CacheFactory for BallistaCacheFactory {
+ fn create(
+ &self,
+ plan: LogicalPlan,
+ session_state: &SessionState,
+ ) -> datafusion::error::Result<LogicalPlan> {
+ if session_state.config().ballista_config().cache_noop() {
+ Ok(plan)
+ } else {
+ Ok(LogicalPlan::Extension(Extension {
+ node: Arc::new(BallistaCacheNode::new(
+ Uuid::new_v4().to_string(),
+ session_state.session_id().to_string(),
+ plan,
+ )),
+ }))
+ }
+ }
+}
+
+/// Ballista logical Extension for caching.
+#[derive(PartialEq, Eq, PartialOrd, Hash, Debug)]
+pub struct BallistaCacheNode {
+ cache_id: String,
+ session_id: String,
+ input: LogicalPlan,
+ exprs: Vec<Expr>,
+}
+
+impl BallistaCacheNode {
+ /// Create a new cache node from provided logical input plan and cache
infos.
+ pub fn new(cache_id: String, session_id: String, input: LogicalPlan) ->
Self {
+ Self {
+ cache_id,
+ session_id,
+ input,
+ exprs: vec![],
+ }
+ }
+
+ /// Returns cache id.
+ pub fn cache_id(&self) -> &str {
+ self.cache_id.as_str()
+ }
+
+ /// Returns session id.
+ pub fn session_id(&self) -> &str {
+ self.session_id.as_str()
+ }
+}
+
+impl UserDefinedLogicalNodeCore for BallistaCacheNode {
+ fn name(&self) -> &str {
+ "BallistaCacheNode"
+ }
+
+ fn inputs(&self) -> Vec<&LogicalPlan> {
+ vec![&self.input]
+ }
+
+ fn schema(&self) -> &DFSchemaRef {
+ self.input.schema()
+ }
+
+ fn expressions(&self) -> Vec<Expr> {
+ self.exprs.clone()
+ }
+
+ fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result
{
+ write!(f, "{}", self.name())
+ }
+
+ fn with_exprs_and_inputs(
+ &self,
+ exprs: Vec<datafusion::prelude::Expr>,
+ inputs: Vec<LogicalPlan>,
+ ) -> datafusion::error::Result<Self> {
+ let [input] = <[LogicalPlan; 1]>::try_from(inputs).map_err(|_| {
+ datafusion::error::DataFusionError::Plan("input size must be
one".to_string())
+ })?;
+
+ Ok(Self {
+ cache_id: self.cache_id.clone(),
+ session_id: self.session_id.clone(),
+ input,
+ exprs,
+ })
+ }
+}
#[cfg(test)]
mod test {
use datafusion::{
diff --git a/ballista/core/src/serde/generated/ballista.rs
b/ballista/core/src/serde/generated/ballista.rs
index d0608d024..adeeaa8a5 100644
--- a/ballista/core/src/serde/generated/ballista.rs
+++ b/ballista/core/src/serde/generated/ballista.rs
@@ -1,5 +1,30 @@
// This file is @generated by prost-build.
///
/////////////////////////////////////////////////////////////////////////////////////////////////
+/// Ballista Logical Plan
+///
/////////////////////////////////////////////////////////////////////////////////////////////////
+#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
+pub struct BallistaLogicalPlanNode {
+ #[prost(oneof = "ballista_logical_plan_node::LogicalPlanType", tags = "1")]
+ pub logical_plan_type: ::core::option::Option<
+ ballista_logical_plan_node::LogicalPlanType,
+ >,
+}
+/// Nested message and enum types in `BallistaLogicalPlanNode`.
+pub mod ballista_logical_plan_node {
+ #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)]
+ pub enum LogicalPlanType {
+ #[prost(message, tag = "1")]
+ CacheNode(super::LogicalPlanCacheNode),
+ }
+}
+#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
+pub struct LogicalPlanCacheNode {
+ #[prost(string, tag = "1")]
+ pub cache_id: ::prost::alloc::string::String,
+ #[prost(string, tag = "2")]
+ pub session_id: ::prost::alloc::string::String,
+}
+///
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Ballista Physical Plan
///
/////////////////////////////////////////////////////////////////////////////////////////////////
#[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs
index 7eb4ecefd..1a4433e9e 100644
--- a/ballista/core/src/serde/mod.rs
+++ b/ballista/core/src/serde/mod.rs
@@ -18,12 +18,14 @@
//! This crate contains code generated from the Ballista Protocol Buffer
Definition as well
//! as convenience code for interacting with the generated code.
+use crate::extension::BallistaCacheNode;
use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction};
use arrow_flight::sql::ProstMessageExt;
use datafusion::arrow::datatypes::SchemaRef;
use datafusion::common::{DataFusionError, Result};
use datafusion::execution::TaskContext;
+use datafusion::logical_expr::Extension;
use datafusion::physical_plan::{ExecutionPlan, Partitioning};
use datafusion_proto::logical_plan::file_formats::{
ArrowLogicalExtensionCodec, AvroLogicalExtensionCodec,
CsvLogicalExtensionCodec,
@@ -51,7 +53,10 @@ use crate::execution_plans::sort_shuffle::SortShuffleConfig;
use crate::execution_plans::{
ShuffleReaderExec, ShuffleWriterExec, SortShuffleWriterExec,
UnresolvedShuffleExec,
};
-use crate::serde::protobuf::ballista_physical_plan_node::PhysicalPlanType;
+use crate::serde::protobuf::{
+ ballista_logical_plan_node::LogicalPlanType,
+ ballista_physical_plan_node::PhysicalPlanType,
+};
use crate::serde::scheduler::PartitionLocation;
pub use generated::ballista as protobuf;
@@ -188,7 +193,28 @@ impl LogicalExtensionCodec for
BallistaLogicalExtensionCodec {
inputs: &[datafusion::logical_expr::LogicalPlan],
ctx: &TaskContext,
) -> Result<datafusion::logical_expr::Extension> {
- self.default_codec.try_decode(buf, inputs, ctx)
+ let plan = protobuf::BallistaLogicalPlanNode::decode(buf)
+ .ok()
+ .and_then(|node| node.logical_plan_type);
+
+ let Some(plan) = plan else {
+ return self.default_codec.try_decode(buf, inputs, ctx);
+ };
+
+ match plan {
+ LogicalPlanType::CacheNode(plan_cache) => Ok(Extension {
+ node: Arc::new(BallistaCacheNode::new(
+ plan_cache.cache_id,
+ plan_cache.session_id,
+ inputs
+ .first()
+ .ok_or(DataFusionError::Plan(
+ "expected input size of 1".to_string(),
+ ))?
+ .clone(),
+ )),
+ }),
+ }
}
fn try_encode(
@@ -196,7 +222,26 @@ impl LogicalExtensionCodec for
BallistaLogicalExtensionCodec {
node: &datafusion::logical_expr::Extension,
buf: &mut Vec<u8>,
) -> Result<()> {
- self.default_codec.try_encode(node, buf)
+ if let Some(node) =
node.node.as_any().downcast_ref::<BallistaCacheNode>() {
+ let proto = protobuf::BallistaLogicalPlanNode {
+ logical_plan_type: Some(LogicalPlanType::CacheNode(
+ protobuf::LogicalPlanCacheNode {
+ cache_id: node.cache_id().to_owned(),
+ session_id: node.session_id().to_owned(),
+ },
+ )),
+ };
+
+ proto.encode(buf).map_err(|e| {
+ DataFusionError::Internal(format!(
+ "failed to encode cache node logical plan: {e:?}"
+ ))
+ })?;
+
+ Ok(())
+ } else {
+ self.default_codec.try_encode(node, buf)
+ }
}
fn try_decode_table_provider(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]