This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/master by this push:
     new adbcae3  [DataFrame] - Add union and union_distinct bindings for 
DataFrame (#35)
adbcae3 is described below

commit adbcae3119b7cb84a2568a331f7ba352a6d09976
Author: Francis Du <[email protected]>
AuthorDate: Tue Sep 13 21:47:56 2022 +0800

    [DataFrame] - Add union and union_distinct bindings for DataFrame (#35)
    
    * fix: conflicting
    
    * fix: python linter
    
    * fix: flake8 W503 isssue
    
    * fix: test error
---
 datafusion/tests/test_dataframe.py | 58 ++++++++++++++++++++++++++++++++++++++
 src/dataframe.rs                   | 20 +++++++++++++
 2 files changed, 78 insertions(+)

diff --git a/datafusion/tests/test_dataframe.py 
b/datafusion/tests/test_dataframe.py
index 760c376..9880b6d 100644
--- a/datafusion/tests/test_dataframe.py
+++ b/datafusion/tests/test_dataframe.py
@@ -22,6 +22,11 @@ from datafusion import functions as f
 from datafusion import DataFrame, SessionContext, column, literal, udf
 
 
[email protected]
+def ctx():
+    return SessionContext()
+
+
 @pytest.fixture
 def df():
     ctx = SessionContext()
@@ -323,3 +328,56 @@ def test_collect_partitioned():
     )
 
     assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned()
+
+
+def test_union(ctx):
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+        names=["a", "b"],
+    )
+    df_a = ctx.create_dataframe([[batch]])
+
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([3, 4, 5]), pa.array([6, 7, 8])],
+        names=["a", "b"],
+    )
+    df_b = ctx.create_dataframe([[batch]])
+
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, 3, 3, 4, 5]), pa.array([4, 5, 6, 6, 7, 8])],
+        names=["a", "b"],
+    )
+    df_c = ctx.create_dataframe([[batch]]).sort(
+        column("a").sort(ascending=True)
+    )
+
+    df_a_u_b = df_a.union(df_b).sort(column("a").sort(ascending=True))
+
+    assert df_c.collect() == df_a_u_b.collect()
+
+
+def test_union_distinct(ctx):
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+        names=["a", "b"],
+    )
+    df_a = ctx.create_dataframe([[batch]])
+
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([3, 4, 5]), pa.array([6, 7, 8])],
+        names=["a", "b"],
+    )
+    df_b = ctx.create_dataframe([[batch]])
+
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, 3, 4, 5]), pa.array([4, 5, 6, 7, 8])],
+        names=["a", "b"],
+    )
+    df_c = ctx.create_dataframe([[batch]]).sort(
+        column("a").sort(ascending=True)
+    )
+
+    df_a_u_b = df_a.union(df_b, True).sort(column("a").sort(ascending=True))
+
+    assert df_c.collect() == df_a_u_b.collect()
+    assert df_c.collect() == df_a_u_b.collect()
diff --git a/src/dataframe.rs b/src/dataframe.rs
index 4ae0160..4d8c0a3 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -204,6 +204,26 @@ impl PyDataFrame {
         Ok(Self::new(new_df))
     }
 
+    /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
+    /// two `DataFrame`s must have exactly the same schema
+    #[args(distinct = false)]
+    fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyResult<Self> {
+        let new_df = if distinct {
+            self.df.union_distinct(py_df.df)?
+        } else {
+            self.df.union(py_df.df)?
+        };
+
+        Ok(Self::new(new_df))
+    }
+
+    /// Calculate the distinct union of two `DataFrame`s.  The
+    /// two `DataFrame`s must have exactly the same schema
+    fn union_distinct(&self, py_df: PyDataFrame) -> PyResult<Self> {
+        let new_df = self.df.union_distinct(py_df.df)?;
+        Ok(Self::new(new_df))
+    }
+
     /// Calculate the intersection of two `DataFrame`s.  The two `DataFrame`s 
must have exactly the same schema
     fn intersect(&self, py_df: PyDataFrame) -> PyResult<Self> {
         let new_df = self.df.intersect(py_df.df)?;

Reply via email to