This is an automated email from the ASF dual-hosted git repository.

milenkovicm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new 771e28091 feat: add `Dataframe.cache()` factory (no planner handling) 
(#1420)
771e28091 is described below

commit 771e280916aa519d3a89f0ba1a8ec71b3df86ea1
Author: jgrim <[email protected]>
AuthorDate: Wed Feb 18 14:30:36 2026 +0100

    feat: add `Dataframe.cache()` factory (no planner handling) (#1420)
    
    * feat: disable DataFrame.cache() for ballista
    
    * add failing test for cache collect
    
    * feat: Add logical plan cache node extension
    
    * fix: update after review
    
    * fix: update after review (cleanup tests, update ballista cache noop)
    
    * fix: add missing default_codec usage
---
 ballista/client/tests/context_unsupported.rs  |  36 +++++++++
 ballista/core/proto/ballista.proto            |  14 ++++
 ballista/core/src/config.rs                   |  27 +++++--
 ballista/core/src/extension.rs                | 110 +++++++++++++++++++++++++-
 ballista/core/src/serde/generated/ballista.rs |  25 ++++++
 ballista/core/src/serde/mod.rs                |  51 +++++++++++-
 6 files changed, 250 insertions(+), 13 deletions(-)

diff --git a/ballista/client/tests/context_unsupported.rs 
b/ballista/client/tests/context_unsupported.rs
index 8cda528ee..3c6b25a98 100644
--- a/ballista/client/tests/context_unsupported.rs
+++ b/ballista/client/tests/context_unsupported.rs
@@ -59,6 +59,7 @@ mod unsupported {
             .await
             .unwrap();
     }
+
     #[rstest]
     #[case::standalone(standalone_context())]
     #[case::remote(remote_context())]
@@ -70,6 +71,12 @@ mod unsupported {
         ctx: SessionContext,
         test_data: String,
     ) {
+        ctx.sql("SET ballista.cache.noop = false")
+            .await
+            .unwrap()
+            .show()
+            .await
+            .unwrap();
         let df = ctx
             .read_parquet(
                 &format!("{test_data}/alltypes_plain.parquet"),
@@ -97,6 +104,35 @@ mod unsupported {
         assert_batches_eq!(expected, &result);
     }
 
+    #[rstest]
+    #[case::standalone(standalone_context())]
+    #[case::remote(remote_context())]
+    #[tokio::test]
+    async fn should_support_on_cache_collect(
+        #[future(awt)]
+        #[case]
+        ctx: SessionContext,
+    ) -> datafusion::error::Result<()> {
+        // opt out case, should fail
+        ctx.sql("SET ballista.cache.noop = false")
+            .await?
+            .show()
+            .await?;
+        let cached_df = ctx.sql("SELECT 1").await?.cache().await?;
+
+        // Collect fails because extension node is not handled for now by 
default query planner
+        let result = cached_df.collect().await;
+
+        assert!(result.is_err());
+        let err_msg = result.unwrap_err().to_string();
+        assert!(
+            err_msg.contains("No installed planner was able to convert the 
custom node to an execution plan: BallistaCacheNode"),
+            "Expected planner error, got: {err_msg}"
+        );
+
+        Ok(())
+    }
+
     #[rstest]
     #[case::standalone(standalone_context())]
     #[case::remote(remote_context())]
diff --git a/ballista/core/proto/ballista.proto 
b/ballista/core/proto/ballista.proto
index 8660b6c72..97812bb0d 100644
--- a/ballista/core/proto/ballista.proto
+++ b/ballista/core/proto/ballista.proto
@@ -27,6 +27,20 @@ option java_outer_classname = "BallistaProto";
 import "datafusion.proto";
 import "datafusion_common.proto";
 
+///////////////////////////////////////////////////////////////////////////////////////////////////
+// Ballista Logical Plan
+///////////////////////////////////////////////////////////////////////////////////////////////////
+message BallistaLogicalPlanNode {
+  oneof LogicalPlanType {
+    LogicalPlanCacheNode cache_node = 1;
+  }
+}
+
+message LogicalPlanCacheNode {
+  string cache_id = 1;
+  string session_id = 2;
+}
+
 
///////////////////////////////////////////////////////////////////////////////////////////////////
 // Ballista Physical Plan
 
///////////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs
index b9a78647f..688f613c3 100644
--- a/ballista/core/src/config.rs
+++ b/ballista/core/src/config.rs
@@ -31,9 +31,10 @@ use datafusion::{
 pub const BALLISTA_JOB_NAME: &str = "ballista.job.name";
 /// Configuration key for standalone processing parallelism.
 pub const BALLISTA_STANDALONE_PARALLELISM: &str = 
"ballista.standalone.parallelism";
-/// max message size for gRPC clients
-pub const BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE: &str =
-    "ballista.grpc_client_max_message_size";
+
+/// Configuration key for disabling default cache extension node.
+pub const BALLISTA_CACHE_NOOP: &str = "ballista.cache.noop";
+
 /// Configuration key for maximum concurrent shuffle read requests.
 pub const BALLISTA_SHUFFLE_READER_MAX_REQUESTS: &str =
     "ballista.shuffle.max_concurrent_read_requests";
@@ -44,6 +45,9 @@ pub const BALLISTA_SHUFFLE_READER_FORCE_REMOTE_READ: &str =
 pub const BALLISTA_SHUFFLE_READER_REMOTE_PREFER_FLIGHT: &str =
     "ballista.shuffle.remote_read_prefer_flight";
 
+/// max message size for gRPC clients
+pub const BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE: &str =
+    "ballista.grpc_client_max_message_size";
 /// Configuration key for gRPC client connection timeout in seconds.
 pub const BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS: &str =
     "ballista.grpc.client.connect_timeout_seconds";
@@ -85,10 +89,10 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String, 
ConfigEntry>> = LazyLock::new(||
         ConfigEntry::new(BALLISTA_STANDALONE_PARALLELISM.to_string(),
                          "Standalone processing parallelism ".to_string(),
                          DataType::UInt16, 
Some(std::thread::available_parallelism().map(|v| 
v.get()).unwrap_or(1).to_string())),
-        ConfigEntry::new(BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE.to_string(),
-                         "Configuration for max message size in gRPC 
clients".to_string(),
-                         DataType::UInt64,
-                         Some((16 * 1024 * 1024).to_string())),
+        ConfigEntry::new(BALLISTA_CACHE_NOOP.to_string(),
+                         "Disable default cache node extension".to_string(),
+                         DataType::Boolean,
+                         Some((true).to_string())),
         ConfigEntry::new(BALLISTA_SHUFFLE_READER_MAX_REQUESTS.to_string(),
                          "Maximum concurrent requests shuffle reader can 
process".to_string(),
                          DataType::UInt64,
@@ -101,6 +105,10 @@ static CONFIG_ENTRIES: LazyLock<HashMap<String, 
ConfigEntry>> = LazyLock::new(||
                          "Forces the shuffle reader to use flight reader 
instead of block reader for remote read. Block reader usually has better 
performance and resource utilization".to_string(),
                          DataType::Boolean,
                          Some((false).to_string())),
+        ConfigEntry::new(BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE.to_string(),
+                         "Configuration for max message size in gRPC 
clients".to_string(),
+                         DataType::UInt64,
+                         Some((16 * 1024 * 1024).to_string())),
         
ConfigEntry::new(BALLISTA_GRPC_CLIENT_CONNECT_TIMEOUT_SECONDS.to_string(),
                          "Connection timeout for gRPC client in 
seconds".to_string(),
                          DataType::UInt64,
@@ -283,6 +291,11 @@ impl BallistaConfig {
         
self.get_usize_setting(BALLISTA_GRPC_CLIENT_HTTP2_KEEPALIVE_INTERVAL_SECONDS)
     }
 
+    /// Returns whether the default cache node extension is disabled.
+    pub fn cache_noop(&self) -> bool {
+        self.get_bool_setting(BALLISTA_CACHE_NOOP)
+    }
+
     /// Forces the shuffle reader to always read partitions via the Arrow 
Flight client,
     /// even when partitions are local to the node.
     ///
diff --git a/ballista/core/src/extension.rs b/ballista/core/src/extension.rs
index 5341020d1..7b82a1a2c 100644
--- a/ballista/core/src/extension.rs
+++ b/ballista/core/src/extension.rs
@@ -24,14 +24,17 @@ use crate::config::{
 use crate::planner::BallistaQueryPlanner;
 use crate::serde::protobuf::KeyValuePair;
 use crate::serde::{BallistaLogicalExtensionCodec, 
BallistaPhysicalExtensionCodec};
+use datafusion::common::DFSchemaRef;
 use datafusion::execution::context::{QueryPlanner, SessionConfig, 
SessionState};
 use datafusion::execution::runtime_env::RuntimeEnvBuilder;
-use datafusion::execution::session_state::SessionStateBuilder;
+use datafusion::execution::session_state::{CacheFactory, SessionStateBuilder};
 use datafusion::functions::all_default_functions;
 use datafusion::functions_aggregate::all_default_aggregate_functions;
 use datafusion::functions_nested::all_default_nested_functions;
 use datafusion::functions_window::all_default_window_functions;
 use datafusion::logical_expr::{AggregateUDF, ScalarUDF, WindowUDF};
+use datafusion::logical_expr::{Extension, LogicalPlan, 
UserDefinedLogicalNodeCore};
+use datafusion::prelude::Expr;
 use datafusion_proto::logical_plan::LogicalExtensionCodec;
 use datafusion_proto::physical_plan::PhysicalExtensionCodec;
 use datafusion_proto::protobuf::LogicalPlanNode;
@@ -43,6 +46,7 @@ use tonic::metadata::MetadataMap;
 use tonic::service::Interceptor;
 use tonic::transport::Endpoint;
 use tonic::{Request, Status};
+use uuid::Uuid;
 
 /// Type alias for the endpoint override function used in gRPC client 
configuration
 pub type EndpointOverrideFn =
@@ -253,6 +257,7 @@ impl SessionStateExt for SessionState {
         let session_state = SessionStateBuilder::new()
             .with_default_features()
             .with_config(session_config)
+            .with_cache_factory(Some(Arc::new(BallistaCacheFactory::new())))
             .with_runtime_env(Arc::new(runtime_env))
             .with_query_planner(Arc::new(planner))
             .with_scalar_functions(ballista_scalar_functions())
@@ -274,8 +279,9 @@ impl SessionStateExt for SessionState {
 
         let ballista_config = session_config.ballista_config();
 
-        let builder =
-            
SessionStateBuilder::new_from_existing(self).with_config(session_config);
+        let builder = SessionStateBuilder::new_from_existing(self)
+            .with_config(session_config)
+            .with_cache_factory(Some(Arc::new(BallistaCacheFactory::new())));
 
         let builder = match planner_override {
             Some(planner) => builder.with_query_planner(planner),
@@ -720,6 +726,104 @@ impl BallistaConfigGrpcEndpoint {
 #[derive(Clone, Copy)]
 pub struct BallistaUseTls(pub bool);
 
+#[derive(Debug)]
+struct BallistaCacheFactory;
+
+impl BallistaCacheFactory {
+    fn new() -> Self {
+        Self {}
+    }
+}
+
+impl CacheFactory for BallistaCacheFactory {
+    fn create(
+        &self,
+        plan: LogicalPlan,
+        session_state: &SessionState,
+    ) -> datafusion::error::Result<LogicalPlan> {
+        if session_state.config().ballista_config().cache_noop() {
+            Ok(plan)
+        } else {
+            Ok(LogicalPlan::Extension(Extension {
+                node: Arc::new(BallistaCacheNode::new(
+                    Uuid::new_v4().to_string(),
+                    session_state.session_id().to_string(),
+                    plan,
+                )),
+            }))
+        }
+    }
+}
+
+/// Ballista logical Extension for caching.
+#[derive(PartialEq, Eq, PartialOrd, Hash, Debug)]
+pub struct BallistaCacheNode {
+    cache_id: String,
+    session_id: String,
+    input: LogicalPlan,
+    exprs: Vec<Expr>,
+}
+
+impl BallistaCacheNode {
+    /// Create a new cache node from provided logical input plan and cache 
infos.
+    pub fn new(cache_id: String, session_id: String, input: LogicalPlan) -> 
Self {
+        Self {
+            cache_id,
+            session_id,
+            input,
+            exprs: vec![],
+        }
+    }
+
+    /// Returns cache id.
+    pub fn cache_id(&self) -> &str {
+        self.cache_id.as_str()
+    }
+
+    /// Returns session id.
+    pub fn session_id(&self) -> &str {
+        self.session_id.as_str()
+    }
+}
+
+impl UserDefinedLogicalNodeCore for BallistaCacheNode {
+    fn name(&self) -> &str {
+        "BallistaCacheNode"
+    }
+
+    fn inputs(&self) -> Vec<&LogicalPlan> {
+        vec![&self.input]
+    }
+
+    fn schema(&self) -> &DFSchemaRef {
+        self.input.schema()
+    }
+
+    fn expressions(&self) -> Vec<Expr> {
+        self.exprs.clone()
+    }
+
+    fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result 
{
+        write!(f, "{}", self.name())
+    }
+
+    fn with_exprs_and_inputs(
+        &self,
+        exprs: Vec<datafusion::prelude::Expr>,
+        inputs: Vec<LogicalPlan>,
+    ) -> datafusion::error::Result<Self> {
+        let [input] = <[LogicalPlan; 1]>::try_from(inputs).map_err(|_| {
+            datafusion::error::DataFusionError::Plan("input size must be 
one".to_string())
+        })?;
+
+        Ok(Self {
+            cache_id: self.cache_id.clone(),
+            session_id: self.session_id.clone(),
+            input,
+            exprs,
+        })
+    }
+}
 #[cfg(test)]
 mod test {
     use datafusion::{
diff --git a/ballista/core/src/serde/generated/ballista.rs 
b/ballista/core/src/serde/generated/ballista.rs
index d0608d024..adeeaa8a5 100644
--- a/ballista/core/src/serde/generated/ballista.rs
+++ b/ballista/core/src/serde/generated/ballista.rs
@@ -1,5 +1,30 @@
 // This file is @generated by prost-build.
 /// 
/////////////////////////////////////////////////////////////////////////////////////////////////
+/// Ballista Logical Plan
+/// 
/////////////////////////////////////////////////////////////////////////////////////////////////
+#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
+pub struct BallistaLogicalPlanNode {
+    #[prost(oneof = "ballista_logical_plan_node::LogicalPlanType", tags = "1")]
+    pub logical_plan_type: ::core::option::Option<
+        ballista_logical_plan_node::LogicalPlanType,
+    >,
+}
+/// Nested message and enum types in `BallistaLogicalPlanNode`.
+pub mod ballista_logical_plan_node {
+    #[derive(Clone, PartialEq, Eq, Hash, ::prost::Oneof)]
+    pub enum LogicalPlanType {
+        #[prost(message, tag = "1")]
+        CacheNode(super::LogicalPlanCacheNode),
+    }
+}
+#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)]
+pub struct LogicalPlanCacheNode {
+    #[prost(string, tag = "1")]
+    pub cache_id: ::prost::alloc::string::String,
+    #[prost(string, tag = "2")]
+    pub session_id: ::prost::alloc::string::String,
+}
+/// 
/////////////////////////////////////////////////////////////////////////////////////////////////
 /// Ballista Physical Plan
 /// 
/////////////////////////////////////////////////////////////////////////////////////////////////
 #[derive(Clone, PartialEq, ::prost::Message)]
diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs
index 7eb4ecefd..1a4433e9e 100644
--- a/ballista/core/src/serde/mod.rs
+++ b/ballista/core/src/serde/mod.rs
@@ -18,12 +18,14 @@
 //! This crate contains code generated from the Ballista Protocol Buffer 
Definition as well
 //! as convenience code for interacting with the generated code.
 
+use crate::extension::BallistaCacheNode;
 use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction};
 
 use arrow_flight::sql::ProstMessageExt;
 use datafusion::arrow::datatypes::SchemaRef;
 use datafusion::common::{DataFusionError, Result};
 use datafusion::execution::TaskContext;
+use datafusion::logical_expr::Extension;
 use datafusion::physical_plan::{ExecutionPlan, Partitioning};
 use datafusion_proto::logical_plan::file_formats::{
     ArrowLogicalExtensionCodec, AvroLogicalExtensionCodec, 
CsvLogicalExtensionCodec,
@@ -51,7 +53,10 @@ use crate::execution_plans::sort_shuffle::SortShuffleConfig;
 use crate::execution_plans::{
     ShuffleReaderExec, ShuffleWriterExec, SortShuffleWriterExec, 
UnresolvedShuffleExec,
 };
-use crate::serde::protobuf::ballista_physical_plan_node::PhysicalPlanType;
+use crate::serde::protobuf::{
+    ballista_logical_plan_node::LogicalPlanType,
+    ballista_physical_plan_node::PhysicalPlanType,
+};
 use crate::serde::scheduler::PartitionLocation;
 pub use generated::ballista as protobuf;
 
@@ -188,7 +193,28 @@ impl LogicalExtensionCodec for 
BallistaLogicalExtensionCodec {
         inputs: &[datafusion::logical_expr::LogicalPlan],
         ctx: &TaskContext,
     ) -> Result<datafusion::logical_expr::Extension> {
-        self.default_codec.try_decode(buf, inputs, ctx)
+        let plan = protobuf::BallistaLogicalPlanNode::decode(buf)
+            .ok()
+            .and_then(|node| node.logical_plan_type);
+
+        let Some(plan) = plan else {
+            return self.default_codec.try_decode(buf, inputs, ctx);
+        };
+
+        match plan {
+            LogicalPlanType::CacheNode(plan_cache) => Ok(Extension {
+                node: Arc::new(BallistaCacheNode::new(
+                    plan_cache.cache_id,
+                    plan_cache.session_id,
+                    inputs
+                        .first()
+                        .ok_or(DataFusionError::Plan(
+                            "expected input size of 1".to_string(),
+                        ))?
+                        .clone(),
+                )),
+            }),
+        }
     }
 
     fn try_encode(
@@ -196,7 +222,26 @@ impl LogicalExtensionCodec for 
BallistaLogicalExtensionCodec {
         node: &datafusion::logical_expr::Extension,
         buf: &mut Vec<u8>,
     ) -> Result<()> {
-        self.default_codec.try_encode(node, buf)
+        if let Some(node) = 
node.node.as_any().downcast_ref::<BallistaCacheNode>() {
+            let proto = protobuf::BallistaLogicalPlanNode {
+                logical_plan_type: Some(LogicalPlanType::CacheNode(
+                    protobuf::LogicalPlanCacheNode {
+                        cache_id: node.cache_id().to_owned(),
+                        session_id: node.session_id().to_owned(),
+                    },
+                )),
+            };
+
+            proto.encode(buf).map_err(|e| {
+                DataFusionError::Internal(format!(
+                    "failed to encode cache node logical plan: {e:?}"
+                ))
+            })?;
+
+            Ok(())
+        } else {
+            self.default_codec.try_encode(node, buf)
+        }
     }
 
     fn try_decode_table_provider(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to