This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-ballista.git
The following commit(s) were added to refs/heads/main by this push:
new 3d241c47 [Python] Add `execute_logical_plan` to context (#972)
3d241c47 is described below
commit 3d241c47478271417950cc385fe64077aa491936
Author: Andy Grove <[email protected]>
AuthorDate: Sun Feb 11 17:13:36 2024 -0700
[Python] Add `execute_logical_plan` to context (#972)
---
ballista/client/src/context.rs | 12 ++++++++++++
python/pyballista/tests/test_context.py | 10 ++++++++++
python/src/context.rs | 15 +++++++++++++--
3 files changed, 35 insertions(+), 2 deletions(-)
diff --git a/ballista/client/src/context.rs b/ballista/client/src/context.rs
index 9b1e64e6..0b41512c 100644
--- a/ballista/client/src/context.rs
+++ b/ballista/client/src/context.rs
@@ -455,6 +455,18 @@ impl BallistaContext {
_ => ctx.execute_logical_plan(plan).await,
}
}
+
+ /// Execute the [`LogicalPlan`], return a [`DataFrame`]. This API
+ /// is not featured limited (so all SQL such as `CREATE TABLE` and
+ /// `COPY` will be run).
+ ///
+ /// If you wish to limit the type of plan that can be run from
+ /// SQL, see [`Self::sql_with_options`] and
+ /// [`SQLOptions::verify_plan`].
+ pub async fn execute_logical_plan(&self, plan: LogicalPlan) ->
Result<DataFrame> {
+ let ctx = self.context.clone();
+ ctx.execute_logical_plan(plan).await
+ }
}
#[cfg(test)]
diff --git a/python/pyballista/tests/test_context.py
b/python/pyballista/tests/test_context.py
index 0a395527..b440bb27 100644
--- a/python/pyballista/tests/test_context.py
+++ b/python/pyballista/tests/test_context.py
@@ -65,3 +65,13 @@ def test_read_dataframe_api():
batches = df.collect()
assert len(batches) == 1
assert len(batches[0]) == 1
+
+def test_execute_plan():
+ ctx = SessionContext("localhost", 50050)
+ df = ctx.read_csv("testdata/test.csv", has_header=True) \
+ .select_columns('a', 'b') \
+ .limit(1)
+ df = ctx.execute_logical_plan(df.logical_plan())
+ batches = df.collect()
+ assert len(batches) == 1
+ assert len(batches[0]) == 1
diff --git a/python/src/context.rs b/python/src/context.rs
index b85e0e10..0d0231c6 100644
--- a/python/src/context.rs
+++ b/python/src/context.rs
@@ -15,15 +15,15 @@
// specific language governing permissions and limitations
// under the License.
-use datafusion::prelude::*;
+use crate::utils::to_pyerr;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use std::path::PathBuf;
-use crate::utils::to_pyerr;
use ballista::prelude::*;
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::pyarrow::PyArrowType;
+use datafusion::prelude::*;
use datafusion_python::catalog::PyTable;
use datafusion_python::context::{
convert_table_partition_cols, parse_file_compression_type,
@@ -31,6 +31,7 @@ use datafusion_python::context::{
use datafusion_python::dataframe::PyDataFrame;
use datafusion_python::errors::DataFusionError;
use datafusion_python::expr::PyExpr;
+use datafusion_python::sql::logical::PyLogicalPlan;
use datafusion_python::utils::wait_for_future;
/// PyBallista session context. This is largely a duplicate of
@@ -324,4 +325,14 @@ impl PySessionContext {
.map_err(DataFusionError::from)?;
Ok(())
}
+
+ pub fn execute_logical_plan(
+ &mut self,
+ logical_plan: PyLogicalPlan,
+ py: Python,
+ ) -> PyResult<PyDataFrame> {
+ let result = self.ctx.execute_logical_plan(logical_plan.into());
+ let df = wait_for_future(py, result).unwrap();
+ Ok(PyDataFrame::new(df))
+ }
}