This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 108e9db859 Proper resolution for old name in with_column_renamed 
(#5992)
108e9db859 is described below

commit 108e9db859127cad3a01819a0ab35702e823c07b
Author: Jeffrey <[email protected]>
AuthorDate: Sat Apr 15 02:19:10 2023 +1000

    Proper resolution for old name in with_column_renamed (#5992)
    
    * Proper resolution for old name in with_column_renamed
    
    * clippy
---
 datafusion/core/src/dataframe.rs | 97 +++++++++++++++++++++++++++++-----------
 1 file changed, 71 insertions(+), 26 deletions(-)

diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index d7ce056a33..380f722a42 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -24,7 +24,7 @@ use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
 use arrow::compute::{cast, concat};
 use arrow::datatypes::{DataType, Field};
 use async_trait::async_trait;
-use datafusion_common::DataFusionError;
+use datafusion_common::{DataFusionError, SchemaError};
 use parquet::file::properties::WriterProperties;
 
 use datafusion_common::from_slice::FromSlice;
@@ -1007,28 +1007,36 @@ impl DataFrame {
     /// ```
     pub fn with_column_renamed(
         self,
-        old_name: &str,
+        old_name: impl Into<Column>,
         new_name: &str,
     ) -> Result<DataFrame> {
-        let mut projection = vec![];
-        let mut rename_applied = false;
-        for field in self.plan.schema().fields() {
-            let field_name = field.qualified_name();
-            if old_name == field_name {
-                projection.push(col(&field_name).alias(new_name));
-                rename_applied = true;
-            } else {
-                projection.push(col(&field_name));
+        let old_name: Column = old_name.into();
+
+        let field_to_rename = match 
self.plan.schema().field_from_column(&old_name) {
+            Ok(field) => field,
+            // no-op if field not found
+            Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { .. 
})) => {
+                return Ok(self)
             }
-        }
-        if rename_applied {
-            let project_plan = LogicalPlanBuilder::from(self.plan)
-                .project(projection)?
-                .build()?;
-            Ok(DataFrame::new(self.session_state, project_plan))
-        } else {
-            Ok(DataFrame::new(self.session_state, self.plan))
-        }
+            Err(err) => return Err(err),
+        };
+        let projection = self
+            .plan
+            .schema()
+            .fields()
+            .iter()
+            .map(|f| {
+                if f == field_to_rename {
+                    col(f.qualified_column()).alias(new_name)
+                } else {
+                    col(f.qualified_column())
+                }
+            })
+            .collect::<Vec<_>>();
+        let project_plan = LogicalPlanBuilder::from(self.plan)
+            .project(projection)?
+            .build()?;
+        Ok(DataFrame::new(self.session_state, project_plan))
     }
 
     /// Convert a prepare logical plan into its inner logical plan with all 
params replaced with their corresponding values
@@ -1681,15 +1689,24 @@ mod tests {
             ])?
             .with_column("sum", col("c2") + col("c3"))?;
 
-        let df_sum_renamed = df.with_column_renamed("sum", 
"total")?.collect().await?;
+        let df_sum_renamed = df
+            .with_column_renamed("sum", "total")?
+            // table qualifier optional
+            .with_column_renamed("c1", "one")?
+            // accepts table qualifier
+            .with_column_renamed("aggregate_test_100.c2", "two")?
+            // no-op for missing column
+            .with_column_renamed("c4", "boom")?
+            .collect()
+            .await?;
 
         assert_batches_sorted_eq!(
             vec![
-                "+----+----+----+-------+",
-                "| c1 | c2 | c3 | total |",
-                "+----+----+----+-------+",
-                "| a  | 3  | 13 | 16    |",
-                "+----+----+----+-------+",
+                "+-----+-----+----+-------+",
+                "| one | two | c3 | total |",
+                "+-----+-----+----+-------+",
+                "| a   | 3   | 13 | 16    |",
+                "+-----+-----+----+-------+",
             ],
             &df_sum_renamed
         );
@@ -1697,6 +1714,34 @@ mod tests {
         Ok(())
     }
 
+    #[tokio::test]
+    async fn with_column_renamed_ambiguous() -> Result<()> {
+        let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
+        let ctx = SessionContext::new();
+
+        let table = df.into_view();
+        ctx.register_table("t1", table.clone())?;
+        ctx.register_table("t2", table)?;
+
+        let actual_err = ctx
+            .table("t1")
+            .await?
+            .join(
+                ctx.table("t2").await?,
+                JoinType::Inner,
+                &["c1"],
+                &["c1"],
+                None,
+            )?
+            // can be t1.c2 or t2.c2
+            .with_column_renamed("c2", "AAA")
+            .unwrap_err();
+        let expected_err = "Schema error: Ambiguous reference to unqualified 
field c2";
+        assert_eq!(actual_err.to_string(), expected_err);
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn with_column_renamed_join() -> Result<()> {
         let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;

Reply via email to