This is an automated email from the ASF dual-hosted git repository.

xudong963 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 666220075 Adding dataframe with_column function (#2849)
666220075 is described below

commit 666220075770c1f3e879ca903debc88ae2d28f03
Author: comphead <[email protected]>
AuthorDate: Thu Jul 7 21:15:34 2022 -0700

    Adding dataframe with_column function (#2849)
    
    * With Column impl
    
    * Added more tests
---
 datafusion/core/src/dataframe.rs | 123 +++++++++++++++++++++++++++++++++++++++
 1 file changed, 123 insertions(+)

diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 02b9116be..32c9299e9 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -620,6 +620,55 @@ impl DataFrame {
         let state = self.session_state.read().clone();
         plan_to_json(&state, plan, path).await
     }
+
+    /// Create a projection based on arbitrary expressions.
+    ///
+    /// ```
+    /// # use datafusion::prelude::*;
+    /// # use datafusion::error::Result;
+    /// # #[tokio::main]
+    /// # async fn main() -> Result<()> {
+    /// let ctx = SessionContext::new();
+    /// let df = ctx.read_csv("tests/example.csv", 
CsvReadOptions::new()).await?;
+    /// let df = df.with_column("ab_sum", col("a") + col("b"))?;
+    /// # Ok(())
+    /// # }
+    /// ```
+    pub fn with_column(&self, name: &str, expr: Expr) -> 
Result<Arc<DataFrame>> {
+        let window_func_exprs = find_window_exprs(&[expr.clone()]);
+        let plan = if window_func_exprs.is_empty() {
+            self.plan.clone()
+        } else {
+            LogicalPlanBuilder::window_plan(self.plan.clone(), 
window_func_exprs)?
+        };
+
+        let new_column = Expr::Alias(Box::new(expr), name.to_string());
+        let mut col_exists = false;
+        let mut fields: Vec<Expr> = plan
+            .schema()
+            .fields()
+            .iter()
+            .map(|f| {
+                if f.name() == name {
+                    col_exists = true;
+                    new_column.clone()
+                } else {
+                    col(f.name())
+                }
+            })
+            .collect();
+
+        if !col_exists {
+            fields.push(new_column);
+        }
+
+        let project_plan = 
LogicalPlanBuilder::from(plan).project(fields)?.build()?;
+
+        Ok(Arc::new(DataFrame::new(
+            self.session_state.clone(),
+            &project_plan,
+        )))
+    }
 }
 
 // TODO: This will introduce a ref cycle (#2659)
@@ -1007,4 +1056,78 @@ mod tests {
         .await?;
         Ok(())
     }
+
+    #[tokio::test]
+    async fn with_column() -> Result<()> {
+        let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
+        let ctx = SessionContext::new();
+        let df_impl = Arc::new(DataFrame::new(ctx.state.clone(), 
&df.plan.clone()));
+
+        let df = &df_impl
+            .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
+            .with_column("sum", col("c2") + col("c3"))?;
+
+        // check that new column added
+        let df_results = df.collect().await?;
+
+        assert_batches_sorted_eq!(
+            vec![
+                "+----+----+-----+-----+",
+                "| c1 | c2 | c3  | sum |",
+                "+----+----+-----+-----+",
+                "| a  | 3  | -12 | -9  |",
+                "| a  | 3  | -72 | -69 |",
+                "| a  | 3  | 13  | 16  |",
+                "| a  | 3  | 13  | 16  |",
+                "| a  | 3  | 14  | 17  |",
+                "| a  | 3  | 17  | 20  |",
+                "+----+----+-----+-----+",
+            ],
+            &df_results
+        );
+
+        // check that col with the same name ovwewritten
+        let df_results_overwrite = df
+            .with_column("c1", col("c2") + col("c3"))?
+            .collect()
+            .await?;
+
+        assert_batches_sorted_eq!(
+            vec![
+                "+-----+----+-----+-----+",
+                "| c1  | c2 | c3  | sum |",
+                "+-----+----+-----+-----+",
+                "| -69 | 3  | -72 | -69 |",
+                "| -9  | 3  | -12 | -9  |",
+                "| 16  | 3  | 13  | 16  |",
+                "| 16  | 3  | 13  | 16  |",
+                "| 17  | 3  | 14  | 17  |",
+                "| 20  | 3  | 17  | 20  |",
+                "+-----+----+-----+-----+",
+            ],
+            &df_results_overwrite
+        );
+
+        // check that col with the same name ovwewritten using same name as 
reference
+        let df_results_overwrite_self =
+            df.with_column("c2", col("c2") + lit(1))?.collect().await?;
+
+        assert_batches_sorted_eq!(
+            vec![
+                "+----+----+-----+-----+",
+                "| c1 | c2 | c3  | sum |",
+                "+----+----+-----+-----+",
+                "| a  | 4  | -12 | -9  |",
+                "| a  | 4  | -72 | -69 |",
+                "| a  | 4  | 13  | 16  |",
+                "| a  | 4  | 13  | 16  |",
+                "| a  | 4  | 14  | 17  |",
+                "| a  | 4  | 17  | 20  |",
+                "+----+----+-----+-----+",
+            ],
+            &df_results_overwrite_self
+        );
+
+        Ok(())
+    }
 }

Reply via email to