This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 108e9db859 Proper resolution for old name in with_column_renamed
(#5992)
108e9db859 is described below
commit 108e9db859127cad3a01819a0ab35702e823c07b
Author: Jeffrey <[email protected]>
AuthorDate: Sat Apr 15 02:19:10 2023 +1000
Proper resolution for old name in with_column_renamed (#5992)
* Proper resolution for old name in with_column_renamed
* clippy
---
datafusion/core/src/dataframe.rs | 97 +++++++++++++++++++++++++++++-----------
1 file changed, 71 insertions(+), 26 deletions(-)
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index d7ce056a33..380f722a42 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -24,7 +24,7 @@ use arrow::array::{Array, ArrayRef, Int64Array, StringArray};
use arrow::compute::{cast, concat};
use arrow::datatypes::{DataType, Field};
use async_trait::async_trait;
-use datafusion_common::DataFusionError;
+use datafusion_common::{DataFusionError, SchemaError};
use parquet::file::properties::WriterProperties;
use datafusion_common::from_slice::FromSlice;
@@ -1007,28 +1007,36 @@ impl DataFrame {
/// ```
pub fn with_column_renamed(
self,
- old_name: &str,
+ old_name: impl Into<Column>,
new_name: &str,
) -> Result<DataFrame> {
- let mut projection = vec![];
- let mut rename_applied = false;
- for field in self.plan.schema().fields() {
- let field_name = field.qualified_name();
- if old_name == field_name {
- projection.push(col(&field_name).alias(new_name));
- rename_applied = true;
- } else {
- projection.push(col(&field_name));
+ let old_name: Column = old_name.into();
+
+ let field_to_rename = match
self.plan.schema().field_from_column(&old_name) {
+ Ok(field) => field,
+ // no-op if field not found
+ Err(DataFusionError::SchemaError(SchemaError::FieldNotFound { ..
})) => {
+ return Ok(self)
}
- }
- if rename_applied {
- let project_plan = LogicalPlanBuilder::from(self.plan)
- .project(projection)?
- .build()?;
- Ok(DataFrame::new(self.session_state, project_plan))
- } else {
- Ok(DataFrame::new(self.session_state, self.plan))
- }
+ Err(err) => return Err(err),
+ };
+ let projection = self
+ .plan
+ .schema()
+ .fields()
+ .iter()
+ .map(|f| {
+ if f == field_to_rename {
+ col(f.qualified_column()).alias(new_name)
+ } else {
+ col(f.qualified_column())
+ }
+ })
+ .collect::<Vec<_>>();
+ let project_plan = LogicalPlanBuilder::from(self.plan)
+ .project(projection)?
+ .build()?;
+ Ok(DataFrame::new(self.session_state, project_plan))
}
/// Convert a prepare logical plan into its inner logical plan with all
params replaced with their corresponding values
@@ -1681,15 +1689,24 @@ mod tests {
])?
.with_column("sum", col("c2") + col("c3"))?;
- let df_sum_renamed = df.with_column_renamed("sum",
"total")?.collect().await?;
+ let df_sum_renamed = df
+ .with_column_renamed("sum", "total")?
+ // table qualifier optional
+ .with_column_renamed("c1", "one")?
+ // accepts table qualifier
+ .with_column_renamed("aggregate_test_100.c2", "two")?
+ // no-op for missing column
+ .with_column_renamed("c4", "boom")?
+ .collect()
+ .await?;
assert_batches_sorted_eq!(
vec![
- "+----+----+----+-------+",
- "| c1 | c2 | c3 | total |",
- "+----+----+----+-------+",
- "| a | 3 | 13 | 16 |",
- "+----+----+----+-------+",
+ "+-----+-----+----+-------+",
+ "| one | two | c3 | total |",
+ "+-----+-----+----+-------+",
+ "| a | 3 | 13 | 16 |",
+ "+-----+-----+----+-------+",
],
&df_sum_renamed
);
@@ -1697,6 +1714,34 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn with_column_renamed_ambiguous() -> Result<()> {
+ let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
+ let ctx = SessionContext::new();
+
+ let table = df.into_view();
+ ctx.register_table("t1", table.clone())?;
+ ctx.register_table("t2", table)?;
+
+ let actual_err = ctx
+ .table("t1")
+ .await?
+ .join(
+ ctx.table("t2").await?,
+ JoinType::Inner,
+ &["c1"],
+ &["c1"],
+ None,
+ )?
+ // can be t1.c2 or t2.c2
+ .with_column_renamed("c2", "AAA")
+ .unwrap_err();
+ let expected_err = "Schema error: Ambiguous reference to unqualified
field c2";
+ assert_eq!(actual_err.to_string(), expected_err);
+
+ Ok(())
+ }
+
#[tokio::test]
async fn with_column_renamed_join() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;