This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 6ea305ec4c feat: allow custom caching via logical node (#18688)
6ea305ec4c is described below

commit 6ea305ec4c15c3ad7516c123f4d567cf946177fc
Author: yqrz <[email protected]>
AuthorDate: Tue Nov 25 01:16:53 2025 -0800

    feat: allow custom caching via logical node (#18688)
    
    ## Which issue does this PR close?
    
    <!--
    We generally require a GitHub issue to be filed for all bug fixes and
    enhancements and this helps us generate change logs for our releases.
    You can link an issue to this PR using the GitHub syntax. For example
    `Closes #123` indicates that this PR will close issue #123.
    -->
    
    - Closes https://github.com/apache/datafusion/issues/17297.
    
    ## Rationale for this change
    
    <!--
    Why are you proposing this change? If this is already explained clearly
    in the issue then this section is not needed.
    Explaining clearly why changes are proposed helps reviewers understand
    your changes and offer better suggestions for fixes.
    -->
    
    See https://github.com/apache/datafusion/issues/17297.
    
    ## What changes are included in this PR?
    
    <!--
    There is no need to duplicate the description in the issue here but it
    is sometimes worth providing a summary of the individual changes in this
    PR.
    -->
    
    - Added methods in `SessionState` to enable registering a
    `CacheProducer` that creates a logical node for caching.
    - Added branching in `DataFrame::cache()` to apply logical node for
    caching on top of original dataframe plan if cache producer is supplied,
    otherwise use current implementation as is.
    
    ## Are these changes tested?
    
    <!--
    We typically require tests for all PRs in order to:
    1. Prevent the code from being accidentally broken by subsequent changes
    2. Serve as another way to document the expected behavior of the code
    
    If tests are not included in your PR, please explain why (for example,
    are they covered by existing tests)?
    -->
    
    Yes
    
    ## Are there any user-facing changes?
    
    <!--
    If there are user-facing changes then we may require documentation to be
    updated before approving the PR.
    -->
    
    <!--
    If there are any breaking changes to public APIs, please add the `api
    change` label.
    -->
    
    Change in `SessionState` is user-facing.
    
    ---------
    
    Signed-off-by: dependabot[bot] <[email protected]>
    Co-authored-by: Yongting You <[email protected]>
    Co-authored-by: dependabot[bot] 
<49699333+dependabot[bot]@users.noreply.github.com>
    Co-authored-by: DaniĆ«l Heres <[email protected]>
    Co-authored-by: Kumar Ujjawal <[email protected]>
    Co-authored-by: Aryan Bagade <[email protected]>
    Co-authored-by: Oleks V <[email protected]>
    Co-authored-by: Andrew Lamb <[email protected]>
    Co-authored-by: Jeffrey Vo <[email protected]>
    Co-authored-by: Vrishabh <[email protected]>
    Co-authored-by: Dhanush <[email protected]>
---
 datafusion/core/src/dataframe/mod.rs           | 26 ++++++---
 datafusion/core/src/execution/session_state.rs | 46 +++++++++++++++-
 datafusion/core/src/test_util/mod.rs           | 75 ++++++++++++++++++++++++--
 datafusion/core/tests/dataframe/mod.rs         | 25 ++++++++-
 4 files changed, 159 insertions(+), 13 deletions(-)

diff --git a/datafusion/core/src/dataframe/mod.rs 
b/datafusion/core/src/dataframe/mod.rs
index 08ca1a176e..8e52ca1bf1 100644
--- a/datafusion/core/src/dataframe/mod.rs
+++ b/datafusion/core/src/dataframe/mod.rs
@@ -2330,6 +2330,10 @@ impl DataFrame {
 
     /// Cache DataFrame as a memory table.
     ///
+    /// Default behavior could be changed using
+    /// a [`crate::execution::session_state::CacheFactory`]
+    /// configured via [`SessionState`].
+    ///
     /// ```
     /// # use datafusion::prelude::*;
     /// # use datafusion::error::Result;
@@ -2344,14 +2348,20 @@ impl DataFrame {
     /// # }
     /// ```
     pub async fn cache(self) -> Result<DataFrame> {
-        let context = 
SessionContext::new_with_state((*self.session_state).clone());
-        // The schema is consistent with the output
-        let plan = self.clone().create_physical_plan().await?;
-        let schema = plan.schema();
-        let task_ctx = Arc::new(self.task_ctx());
-        let partitions = collect_partitioned(plan, task_ctx).await?;
-        let mem_table = MemTable::try_new(schema, partitions)?;
-        context.read_table(Arc::new(mem_table))
+        if let Some(cache_factory) = self.session_state.cache_factory() {
+            let new_plan =
+                cache_factory.create(self.plan, self.session_state.as_ref())?;
+            Ok(Self::new(*self.session_state, new_plan))
+        } else {
+            let context = 
SessionContext::new_with_state((*self.session_state).clone());
+            // The schema is consistent with the output
+            let plan = self.clone().create_physical_plan().await?;
+            let schema = plan.schema();
+            let task_ctx = Arc::new(self.task_ctx());
+            let partitions = collect_partitioned(plan, task_ctx).await?;
+            let mem_table = MemTable::try_new(schema, partitions)?;
+            context.read_table(Arc::new(mem_table))
+        }
     }
 
     /// Apply an alias to the DataFrame.
diff --git a/datafusion/core/src/execution/session_state.rs 
b/datafusion/core/src/execution/session_state.rs
index d7a66db28a..8bc526fbfb 100644
--- a/datafusion/core/src/execution/session_state.rs
+++ b/datafusion/core/src/execution/session_state.rs
@@ -185,6 +185,7 @@ pub struct SessionState {
     /// It will be invoked on `CREATE FUNCTION` statements.
     /// thus, changing dialect o PostgreSql is required
     function_factory: Option<Arc<dyn FunctionFactory>>,
+    cache_factory: Option<Arc<dyn CacheFactory>>,
     /// Cache logical plans of prepared statements for later execution.
     /// Key is the prepared statement name.
     prepared_plans: HashMap<String, Arc<PreparedPlan>>,
@@ -206,6 +207,7 @@ impl Debug for SessionState {
             .field("table_options", &self.table_options)
             .field("table_factories", &self.table_factories)
             .field("function_factory", &self.function_factory)
+            .field("cache_factory", &self.cache_factory)
             .field("expr_planners", &self.expr_planners);
 
         #[cfg(feature = "sql")]
@@ -355,6 +357,16 @@ impl SessionState {
         self.function_factory.as_ref()
     }
 
+    /// Register a [`CacheFactory`] for custom caching strategy
+    pub fn set_cache_factory(&mut self, cache_factory: Arc<dyn CacheFactory>) {
+        self.cache_factory = Some(cache_factory);
+    }
+
+    /// Get the cache factory
+    pub fn cache_factory(&self) -> Option<&Arc<dyn CacheFactory>> {
+        self.cache_factory.as_ref()
+    }
+
     /// Get the table factories
     pub fn table_factories(&self) -> &HashMap<String, Arc<dyn 
TableProviderFactory>> {
         &self.table_factories
@@ -941,6 +953,7 @@ pub struct SessionStateBuilder {
     table_factories: Option<HashMap<String, Arc<dyn TableProviderFactory>>>,
     runtime_env: Option<Arc<RuntimeEnv>>,
     function_factory: Option<Arc<dyn FunctionFactory>>,
+    cache_factory: Option<Arc<dyn CacheFactory>>,
     // fields to support convenience functions
     analyzer_rules: Option<Vec<Arc<dyn AnalyzerRule + Send + Sync>>>,
     optimizer_rules: Option<Vec<Arc<dyn OptimizerRule + Send + Sync>>>,
@@ -978,6 +991,7 @@ impl SessionStateBuilder {
             table_factories: None,
             runtime_env: None,
             function_factory: None,
+            cache_factory: None,
             // fields to support convenience functions
             analyzer_rules: None,
             optimizer_rules: None,
@@ -1030,7 +1044,7 @@ impl SessionStateBuilder {
             table_factories: Some(existing.table_factories),
             runtime_env: Some(existing.runtime_env),
             function_factory: existing.function_factory,
-
+            cache_factory: existing.cache_factory,
             // fields to support convenience functions
             analyzer_rules: None,
             optimizer_rules: None,
@@ -1319,6 +1333,15 @@ impl SessionStateBuilder {
         self
     }
 
+    /// Set a [`CacheFactory`] for custom caching strategy
+    pub fn with_cache_factory(
+        mut self,
+        cache_factory: Option<Arc<dyn CacheFactory>>,
+    ) -> Self {
+        self.cache_factory = cache_factory;
+        self
+    }
+
     /// Register an `ObjectStore` to the [`RuntimeEnv`]. See 
[`RuntimeEnv::register_object_store`]
     /// for more details.
     ///
@@ -1382,6 +1405,7 @@ impl SessionStateBuilder {
             table_factories,
             runtime_env,
             function_factory,
+            cache_factory,
             analyzer_rules,
             optimizer_rules,
             physical_optimizer_rules,
@@ -1418,6 +1442,7 @@ impl SessionStateBuilder {
             table_factories: table_factories.unwrap_or_default(),
             runtime_env,
             function_factory,
+            cache_factory,
             prepared_plans: HashMap::new(),
         };
 
@@ -1621,6 +1646,11 @@ impl SessionStateBuilder {
         &mut self.function_factory
     }
 
+    /// Returns the cache factory
+    pub fn cache_factory(&mut self) -> &mut Option<Arc<dyn CacheFactory>> {
+        &mut self.cache_factory
+    }
+
     /// Returns the current analyzer_rules value
     pub fn analyzer_rules(
         &mut self,
@@ -1659,6 +1689,7 @@ impl Debug for SessionStateBuilder {
             .field("table_options", &self.table_options)
             .field("table_factories", &self.table_factories)
             .field("function_factory", &self.function_factory)
+            .field("cache_factory", &self.cache_factory)
             .field("expr_planners", &self.expr_planners);
         #[cfg(feature = "sql")]
         let ret = ret.field("type_planner", &self.type_planner);
@@ -2047,6 +2078,19 @@ pub(crate) struct PreparedPlan {
     pub(crate) plan: Arc<LogicalPlan>,
 }
 
+/// A [`CacheFactory`] can be registered via [`SessionState`]
+/// to create a custom logical plan for [`crate::dataframe::DataFrame::cache`].
+/// Additionally, a custom 
[`crate::physical_planner::ExtensionPlanner`]/[`QueryPlanner`]
+/// may need to be implemented to handle such plans.
+pub trait CacheFactory: Debug + Send + Sync {
+    /// Create a logical plan for caching
+    fn create(
+        &self,
+        plan: LogicalPlan,
+        session_state: &SessionState,
+    ) -> datafusion_common::Result<LogicalPlan>;
+}
+
 #[cfg(test)]
 mod tests {
     use super::{SessionContextProvider, SessionStateBuilder};
diff --git a/datafusion/core/src/test_util/mod.rs 
b/datafusion/core/src/test_util/mod.rs
index 7149c5b0bd..466ee38a42 100644
--- a/datafusion/core/src/test_util/mod.rs
+++ b/datafusion/core/src/test_util/mod.rs
@@ -25,6 +25,7 @@ pub mod csv;
 use futures::Stream;
 use std::any::Any;
 use std::collections::HashMap;
+use std::fmt::Formatter;
 use std::fs::File;
 use std::io::Write;
 use std::path::Path;
@@ -36,16 +37,20 @@ use crate::dataframe::DataFrame;
 use crate::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable};
 use crate::datasource::{empty::EmptyTable, provider_as_source};
 use crate::error::Result;
+use crate::execution::session_state::CacheFactory;
 use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE};
 use crate::physical_plan::ExecutionPlan;
 use crate::prelude::{CsvReadOptions, SessionContext};
 
-use crate::execution::SendableRecordBatchStream;
+use crate::execution::{SendableRecordBatchStream, SessionState, 
SessionStateBuilder};
 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
 use datafusion_catalog::Session;
-use datafusion_common::TableReference;
-use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType};
+use datafusion_common::{DFSchemaRef, TableReference};
+use datafusion_expr::{
+    CreateExternalTable, Expr, LogicalPlan, SortExpr, TableType,
+    UserDefinedLogicalNodeCore,
+};
 use std::pin::Pin;
 
 use async_trait::async_trait;
@@ -282,3 +287,67 @@ impl RecordBatchStream for BoundedStream {
         self.record_batch.schema()
     }
 }
+
+#[derive(Hash, Eq, PartialEq, PartialOrd, Debug)]
+struct CacheNode {
+    input: LogicalPlan,
+}
+
+impl UserDefinedLogicalNodeCore for CacheNode {
+    fn name(&self) -> &str {
+        "CacheNode"
+    }
+
+    fn inputs(&self) -> Vec<&LogicalPlan> {
+        vec![&self.input]
+    }
+
+    fn schema(&self) -> &DFSchemaRef {
+        self.input.schema()
+    }
+
+    fn expressions(&self) -> Vec<Expr> {
+        vec![]
+    }
+
+    fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result {
+        write!(f, "CacheNode")
+    }
+
+    fn with_exprs_and_inputs(
+        &self,
+        _exprs: Vec<Expr>,
+        inputs: Vec<LogicalPlan>,
+    ) -> Result<Self> {
+        assert_eq!(inputs.len(), 1, "input size inconsistent");
+        Ok(Self {
+            input: inputs[0].clone(),
+        })
+    }
+}
+
+#[derive(Debug)]
+struct TestCacheFactory {}
+
+impl CacheFactory for TestCacheFactory {
+    fn create(
+        &self,
+        plan: LogicalPlan,
+        _session_state: &SessionState,
+    ) -> Result<LogicalPlan> {
+        Ok(LogicalPlan::Extension(datafusion_expr::Extension {
+            node: Arc::new(CacheNode { input: plan }),
+        }))
+    }
+}
+
+/// Create a test table registered to a session context with an associated 
cache factory
+pub async fn test_table_with_cache_factory() -> Result<DataFrame> {
+    let session_state = SessionStateBuilder::new()
+        .with_cache_factory(Some(Arc::new(TestCacheFactory {})))
+        .build();
+    let ctx = SessionContext::new_with_state(session_state);
+    let name = "aggregate_test_100";
+    register_aggregate_csv(&ctx, name).await?;
+    ctx.table(name).await
+}
diff --git a/datafusion/core/tests/dataframe/mod.rs 
b/datafusion/core/tests/dataframe/mod.rs
index b856a776c8..fb6dc3bcba 100644
--- a/datafusion/core/tests/dataframe/mod.rs
+++ b/datafusion/core/tests/dataframe/mod.rs
@@ -61,7 +61,7 @@ use datafusion::prelude::{
 };
 use datafusion::test_util::{
     parquet_test_data, populate_csv_partitions, register_aggregate_csv, 
test_table,
-    test_table_with_name,
+    test_table_with_cache_factory, test_table_with_name,
 };
 use datafusion_catalog::TableProvider;
 use datafusion_common::test_util::{batches_to_sort_string, batches_to_string};
@@ -2335,6 +2335,29 @@ async fn cache_test() -> Result<()> {
     Ok(())
 }
 
+#[tokio::test]
+async fn cache_producer_test() -> Result<()> {
+    let df = test_table_with_cache_factory()
+        .await?
+        .select_columns(&["c2", "c3"])?
+        .limit(0, Some(1))?
+        .with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;
+
+    let cached_df = df.clone().cache().await?;
+
+    assert_snapshot!(
+        cached_df.clone().into_optimized_plan().unwrap(),
+        @r###"
+    CacheNode
+      Projection: aggregate_test_100.c2, aggregate_test_100.c3, 
CAST(CAST(aggregate_test_100.c2 AS Int64) + CAST(aggregate_test_100.c3 AS 
Int64) AS Int64) AS sum
+        Projection: aggregate_test_100.c2, aggregate_test_100.c3
+          Limit: skip=0, fetch=1
+            TableScan: aggregate_test_100, fetch=1
+    "###
+    );
+    Ok(())
+}
+
 #[tokio::test]
 async fn partition_aware_union() -> Result<()> {
     let left = test_table().await?.select_columns(&["c1", "c2"])?;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to