This is an automated email from the ASF dual-hosted git repository.
xudong963 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 666220075 Adding dataframe with_column function (#2849)
666220075 is described below
commit 666220075770c1f3e879ca903debc88ae2d28f03
Author: comphead <[email protected]>
AuthorDate: Thu Jul 7 21:15:34 2022 -0700
Adding dataframe with_column function (#2849)
* With Column impl
* Added more tests
---
datafusion/core/src/dataframe.rs | 123 +++++++++++++++++++++++++++++++++++++++
1 file changed, 123 insertions(+)
diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs
index 02b9116be..32c9299e9 100644
--- a/datafusion/core/src/dataframe.rs
+++ b/datafusion/core/src/dataframe.rs
@@ -620,6 +620,55 @@ impl DataFrame {
let state = self.session_state.read().clone();
plan_to_json(&state, plan, path).await
}
+
+ /// Create a projection based on arbitrary expressions.
+ ///
+ /// ```
+ /// # use datafusion::prelude::*;
+ /// # use datafusion::error::Result;
+ /// # #[tokio::main]
+ /// # async fn main() -> Result<()> {
+ /// let ctx = SessionContext::new();
+ /// let df = ctx.read_csv("tests/example.csv",
CsvReadOptions::new()).await?;
+ /// let df = df.with_column("ab_sum", col("a") + col("b"))?;
+ /// # Ok(())
+ /// # }
+ /// ```
+ pub fn with_column(&self, name: &str, expr: Expr) ->
Result<Arc<DataFrame>> {
+ let window_func_exprs = find_window_exprs(&[expr.clone()]);
+ let plan = if window_func_exprs.is_empty() {
+ self.plan.clone()
+ } else {
+ LogicalPlanBuilder::window_plan(self.plan.clone(),
window_func_exprs)?
+ };
+
+ let new_column = Expr::Alias(Box::new(expr), name.to_string());
+ let mut col_exists = false;
+ let mut fields: Vec<Expr> = plan
+ .schema()
+ .fields()
+ .iter()
+ .map(|f| {
+ if f.name() == name {
+ col_exists = true;
+ new_column.clone()
+ } else {
+ col(f.name())
+ }
+ })
+ .collect();
+
+ if !col_exists {
+ fields.push(new_column);
+ }
+
+ let project_plan =
LogicalPlanBuilder::from(plan).project(fields)?.build()?;
+
+ Ok(Arc::new(DataFrame::new(
+ self.session_state.clone(),
+ &project_plan,
+ )))
+ }
}
// TODO: This will introduce a ref cycle (#2659)
@@ -1007,4 +1056,78 @@ mod tests {
.await?;
Ok(())
}
+
+ #[tokio::test]
+ async fn with_column() -> Result<()> {
+ let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?;
+ let ctx = SessionContext::new();
+ let df_impl = Arc::new(DataFrame::new(ctx.state.clone(),
&df.plan.clone()));
+
+ let df = &df_impl
+ .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))?
+ .with_column("sum", col("c2") + col("c3"))?;
+
+ // check that new column added
+ let df_results = df.collect().await?;
+
+ assert_batches_sorted_eq!(
+ vec![
+ "+----+----+-----+-----+",
+ "| c1 | c2 | c3 | sum |",
+ "+----+----+-----+-----+",
+ "| a | 3 | -12 | -9 |",
+ "| a | 3 | -72 | -69 |",
+ "| a | 3 | 13 | 16 |",
+ "| a | 3 | 13 | 16 |",
+ "| a | 3 | 14 | 17 |",
+ "| a | 3 | 17 | 20 |",
+ "+----+----+-----+-----+",
+ ],
+ &df_results
+ );
+
+ // check that col with the same name ovwewritten
+ let df_results_overwrite = df
+ .with_column("c1", col("c2") + col("c3"))?
+ .collect()
+ .await?;
+
+ assert_batches_sorted_eq!(
+ vec![
+ "+-----+----+-----+-----+",
+ "| c1 | c2 | c3 | sum |",
+ "+-----+----+-----+-----+",
+ "| -69 | 3 | -72 | -69 |",
+ "| -9 | 3 | -12 | -9 |",
+ "| 16 | 3 | 13 | 16 |",
+ "| 16 | 3 | 13 | 16 |",
+ "| 17 | 3 | 14 | 17 |",
+ "| 20 | 3 | 17 | 20 |",
+ "+-----+----+-----+-----+",
+ ],
+ &df_results_overwrite
+ );
+
+ // check that col with the same name ovwewritten using same name as
reference
+ let df_results_overwrite_self =
+ df.with_column("c2", col("c2") + lit(1))?.collect().await?;
+
+ assert_batches_sorted_eq!(
+ vec![
+ "+----+----+-----+-----+",
+ "| c1 | c2 | c3 | sum |",
+ "+----+----+-----+-----+",
+ "| a | 4 | -12 | -9 |",
+ "| a | 4 | -72 | -69 |",
+ "| a | 4 | 13 | 16 |",
+ "| a | 4 | 13 | 16 |",
+ "| a | 4 | 14 | 17 |",
+ "| a | 4 | 17 | 20 |",
+ "+----+----+-----+-----+",
+ ],
+ &df_results_overwrite_self
+ );
+
+ Ok(())
+ }
}