This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-ballista.git


The following commit(s) were added to refs/heads/main by this push:
     new b3cf8d13 feat: Expose Ballista Scheduler and Executor in Python  
(#1148)
b3cf8d13 is described below

commit b3cf8d1368108434cca3738a888f5d1368082208
Author: Marko Milenković <[email protected]>
AuthorDate: Mon Dec 9 16:43:12 2024 +0000

    feat: Expose Ballista Scheduler and Executor in Python  (#1148)
    
    * Add PyScheduler and PyExecutor
    
    * fix builder api
    
    * scheduler & executor support __str__ and __repr__
    
    * update readme and requirements
    
    * fix pyproject dependency
    
    * expose additional configuration option
    
    * cleanup examples
    
    * add ability to close/stop scheduler and executor
    
    * clippy cleanup
    
    * concurrent_tasks can be configured
---
 ballista/scheduler/src/config.rs                   |   5 +
 python/Cargo.toml                                  |   4 +-
 python/README.md                                   |  48 +++-
 python/ballista/__init__.py                        |   4 +-
 python/examples/{example.py => client_remote.py}   |  13 +-
 .../examples/{example.py => client_standalone.py}  |  17 +-
 .../{ballista/__init__.py => examples/executor.py} |  33 ++-
 python/examples/{example.py => readme_remote.py}   |  26 +-
 .../examples/{example.py => readme_standalone.py}  |  24 +-
 .../__init__.py => examples/scheduler.py}          |  31 +--
 python/pyproject.toml                              |   2 +-
 python/requirements.txt                            |   7 +-
 python/src/cluster.rs                              | 264 +++++++++++++++++++++
 python/src/codec.rs                                | 253 ++++++++++++++++++++
 python/src/lib.rs                                  |  57 ++---
 python/src/utils.rs                                |  40 +++-
 16 files changed, 711 insertions(+), 117 deletions(-)

diff --git a/ballista/scheduler/src/config.rs b/ballista/scheduler/src/config.rs
index 9ddb8b6e..10c6df1d 100644
--- a/ballista/scheduler/src/config.rs
+++ b/ballista/scheduler/src/config.rs
@@ -38,6 +38,8 @@ pub struct SchedulerConfig {
     pub namespace: String,
     /// The external hostname of the scheduler
     pub external_host: String,
+    /// The bind host for the scheduler's gRPC service
+    pub bind_host: String,
     /// The bind port for the scheduler's gRPC service
     pub bind_port: u16,
     /// The task scheduling policy for the scheduler
@@ -87,6 +89,7 @@ impl std::fmt::Debug for SchedulerConfig {
             .field("namespace", &self.namespace)
             .field("external_host", &self.external_host)
             .field("bind_port", &self.bind_port)
+            .field("bind_host", &self.bind_host)
             .field("scheduling_policy", &self.scheduling_policy)
             .field("event_loop_buffer_size", &self.event_loop_buffer_size)
             .field("task_distribution", &self.task_distribution)
@@ -137,6 +140,7 @@ impl Default for SchedulerConfig {
             namespace: String::default(),
             external_host: "localhost".into(),
             bind_port: 50050,
+            bind_host: "127.0.0.1".into(),
             scheduling_policy: Default::default(),
             event_loop_buffer_size: 10000,
             task_distribution: Default::default(),
@@ -326,6 +330,7 @@ impl TryFrom<Config> for SchedulerConfig {
             namespace: opt.namespace,
             external_host: opt.external_host,
             bind_port: opt.bind_port,
+            bind_host: opt.bind_host,
             scheduling_policy: opt.scheduler_policy,
             event_loop_buffer_size: opt.event_loop_buffer_size,
             task_distribution,
diff --git a/python/Cargo.toml b/python/Cargo.toml
index b03f1e99..747f330a 100644
--- a/python/Cargo.toml
+++ b/python/Cargo.toml
@@ -31,8 +31,10 @@ publish = false
 
 [dependencies]
 async-trait = "0.1.77"
-ballista = { path = "../ballista/client", version = "0.12.0", features = 
["standalone"] }
+ballista = { path = "../ballista/client", version = "0.12.0" }
 ballista-core = { path = "../ballista/core", version = "0.12.0" }
+ballista-executor = { path = "../ballista/executor", version = "0.12.0" }
+ballista-scheduler = { path = "../ballista/scheduler", version = "0.12.0" }
 datafusion = { version = "42", features = ["pyarrow", "avro"] }
 datafusion-proto = { version = "42" }
 datafusion-python = { version = "42" }
diff --git a/python/README.md b/python/README.md
index 01b0a7f9..d8ba03f3 100644
--- a/python/README.md
+++ b/python/README.md
@@ -26,6 +26,12 @@ part of the default Cargo workspace so that it doesn't cause 
overhead for mainta
 
 ## Creating a SessionContext
 
+> [!IMPORTANT]
+> Current approach is to support datafusion python API, there are know 
limitations of current approach,
+> with some cases producing errors.
+> We trying to come up with the best approach to support datafusion python 
interface.
+> More details could be found at 
[#1142](https://github.com/apache/datafusion-ballista/issues/1142)
+
 Creates a new context and connects to a Ballista scheduler process.
 
 ```python
@@ -33,22 +39,50 @@ from ballista import BallistaBuilder
 >>> ctx = BallistaBuilder().standalone()
 ```
 
-## Example SQL Usage
+### Example SQL Usage
 
 ```python
->>> ctx.sql("create external table t stored as parquet location 
'/mnt/bigdata/tpch/sf10-parquet/lineitem.parquet'")
+>>> ctx.sql("create external table t stored as parquet location 
'./testdata/test.parquet'")
 >>> df = ctx.sql("select * from t limit 5")
 >>> pyarrow_batches = df.collect()
 ```
 
-## Example DataFrame Usage
+### Example DataFrame Usage
 
 ```python
->>> df = 
ctx.read_parquet('/mnt/bigdata/tpch/sf10-parquet/lineitem.parquet').limit(5)
+>>> df = ctx.read_parquet('./testdata/test.parquet').limit(5)
 >>> pyarrow_batches = df.collect()
 ```
 
-## Creating Virtual Environment
+## Scheduler and Executor
+
+Scheduler and executors can be configured and started from python code.
+
+To start scheduler:
+
+```python
+from ballista import BallistaScheduler
+
+scheduler = BallistaScheduler()
+
+scheduler.start()
+scheduler.wait_for_termination()
+```
+
+For executor:
+
+```python
+from ballista import BallistaExecutor
+
+executor = BallistaExecutor()
+
+executor.start()
+executor.wait_for_termination()
+```
+
+## Development Process
+
+### Creating Virtual Environment
 
 ```shell
 python3 -m venv venv
@@ -56,7 +90,7 @@ source venv/bin/activate
 pip3 install -r requirements.txt
 ```
 
-## Building
+### Building
 
 ```shell
 maturin develop
@@ -64,7 +98,7 @@ maturin develop
 
 Note that you can also run `maturin develop --release` to get a release build 
locally.
 
-## Testing
+### Testing
 
 ```shell
 python3 -m pytest
diff --git a/python/ballista/__init__.py b/python/ballista/__init__.py
index a143f17e..4e80422b 100644
--- a/python/ballista/__init__.py
+++ b/python/ballista/__init__.py
@@ -26,11 +26,11 @@ except ImportError:
 import pyarrow as pa
 
 from .ballista_internal import (
-    BallistaBuilder,
+    BallistaBuilder, BallistaScheduler, BallistaExecutor
 )
 
 __version__ = importlib_metadata.version(__name__)
 
 __all__ = [
-    "BallistaBuilder",
+    "BallistaBuilder", "BallistaScheduler", "BallistaExecutor"
 ]
\ No newline at end of file
diff --git a/python/examples/example.py b/python/examples/client_remote.py
similarity index 73%
copy from python/examples/example.py
copy to python/examples/client_remote.py
index 61a9abbd..fd85858a 100644
--- a/python/examples/example.py
+++ b/python/examples/client_remote.py
@@ -15,18 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 
+# %% 
 from ballista import BallistaBuilder
 from datafusion.context import SessionContext
-
-# Ballista will initiate with an empty config
-# set config variables with `config`
-ctx: SessionContext = BallistaBuilder()\
-    .config("ballista.job.name", "example ballista")\
-    .config("ballista.shuffle.partitions", "16")\
-    .standalone()
     
-#ctx_remote: SessionContext = ballista.remote("remote_ip", 50050)
+ctx: SessionContext = BallistaBuilder().remote("df://127.0.0.1:50050")
 
 # Select 1 to verify its working
 ctx.sql("SELECT 1").show()
-#ctx_remote.sql("SELECT 2").show()
\ No newline at end of file
+
+# %%
diff --git a/python/examples/example.py b/python/examples/client_standalone.py
similarity index 79%
copy from python/examples/example.py
copy to python/examples/client_standalone.py
index 61a9abbd..dfe3c372 100644
--- a/python/examples/example.py
+++ b/python/examples/client_standalone.py
@@ -15,18 +15,23 @@
 # specific language governing permissions and limitations
 # under the License.
 
+# %% 
+
 from ballista import BallistaBuilder
 from datafusion.context import SessionContext
 
-# Ballista will initiate with an empty config
-# set config variables with `config`
 ctx: SessionContext = BallistaBuilder()\
+    .config("datafusion.catalog.information_schema","true")\
     .config("ballista.job.name", "example ballista")\
-    .config("ballista.shuffle.partitions", "16")\
     .standalone()
     
-#ctx_remote: SessionContext = ballista.remote("remote_ip", 50050)
 
-# Select 1 to verify its working
 ctx.sql("SELECT 1").show()
-#ctx_remote.sql("SELECT 2").show()
\ No newline at end of file
+
+# %%
+ctx.sql("SHOW TABLES").show()
+# %%
+ctx.sql("select name, value from information_schema.df_settings where name 
like 'ballista.job.name'").show()
+
+
+# %%
diff --git a/python/ballista/__init__.py b/python/examples/executor.py
similarity index 69%
copy from python/ballista/__init__.py
copy to python/examples/executor.py
index a143f17e..bb032f63 100644
--- a/python/ballista/__init__.py
+++ b/python/examples/executor.py
@@ -15,22 +15,17 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from abc import ABCMeta, abstractmethod
-from typing import List
-
-try:
-    import importlib.metadata as importlib_metadata
-except ImportError:
-    import importlib_metadata
-
-import pyarrow as pa
-
-from .ballista_internal import (
-    BallistaBuilder,
-)
-
-__version__ = importlib_metadata.version(__name__)
-
-__all__ = [
-    "BallistaBuilder",
-]
\ No newline at end of file
+# %%
+from ballista import BallistaExecutor
+# %%
+executor = BallistaExecutor()
+# %%
+executor.start()
+# %%
+executor
+# %%
+executor.wait_for_termination()
+# %%
+# %%
+executor.close()
+# %%
diff --git a/python/examples/example.py b/python/examples/readme_remote.py
similarity index 65%
copy from python/examples/example.py
copy to python/examples/readme_remote.py
index 61a9abbd..7e1c82d8 100644
--- a/python/examples/example.py
+++ b/python/examples/readme_remote.py
@@ -15,18 +15,24 @@
 # specific language governing permissions and limitations
 # under the License.
 
+# %%
+
 from ballista import BallistaBuilder
 from datafusion.context import SessionContext
 
-# Ballista will initiate with an empty config
-# set config variables with `config`
 ctx: SessionContext = BallistaBuilder()\
-    .config("ballista.job.name", "example ballista")\
-    .config("ballista.shuffle.partitions", "16")\
-    .standalone()
-    
-#ctx_remote: SessionContext = ballista.remote("remote_ip", 50050)
+    .config("ballista.job.name", "Readme Example Remote")\
+    .config("datafusion.execution.target_partitions", "4")\
+    .remote("df://127.0.0.1:50050")
+
+ctx.sql("create external table t stored as parquet location 
'../testdata/test.parquet'")
 
-# Select 1 to verify its working
-ctx.sql("SELECT 1").show()
-#ctx_remote.sql("SELECT 2").show()
\ No newline at end of file
+# %%
+df = ctx.sql("select * from t limit 5")
+pyarrow_batches = df.collect()
+pyarrow_batches[0].to_pandas()
+# %%
+df = ctx.read_parquet('../testdata/test.parquet').limit(5)
+pyarrow_batches = df.collect()
+pyarrow_batches[0].to_pandas()
+# %%
\ No newline at end of file
diff --git a/python/examples/example.py b/python/examples/readme_standalone.py
similarity index 67%
rename from python/examples/example.py
rename to python/examples/readme_standalone.py
index 61a9abbd..15404e02 100644
--- a/python/examples/example.py
+++ b/python/examples/readme_standalone.py
@@ -15,18 +15,24 @@
 # specific language governing permissions and limitations
 # under the License.
 
+# %%
+
 from ballista import BallistaBuilder
 from datafusion.context import SessionContext
 
-# Ballista will initiate with an empty config
-# set config variables with `config`
 ctx: SessionContext = BallistaBuilder()\
-    .config("ballista.job.name", "example ballista")\
-    .config("ballista.shuffle.partitions", "16")\
+    .config("ballista.job.name", "Readme Example")\
+    .config("datafusion.execution.target_partitions", "4")\
     .standalone()
-    
-#ctx_remote: SessionContext = ballista.remote("remote_ip", 50050)
 
-# Select 1 to verify its working
-ctx.sql("SELECT 1").show()
-#ctx_remote.sql("SELECT 2").show()
\ No newline at end of file
+ctx.sql("create external table t stored as parquet location 
'../testdata/test.parquet'")
+
+# %%
+df = ctx.sql("select * from t limit 5")
+pyarrow_batches = df.collect()
+pyarrow_batches[0].to_pandas()
+# %%
+df = ctx.read_parquet('../testdata/test.parquet').limit(5)
+pyarrow_batches = df.collect()
+pyarrow_batches[0].to_pandas()
+# %%
\ No newline at end of file
diff --git a/python/ballista/__init__.py b/python/examples/scheduler.py
similarity index 69%
copy from python/ballista/__init__.py
copy to python/examples/scheduler.py
index a143f17e..1c40ce1e 100644
--- a/python/ballista/__init__.py
+++ b/python/examples/scheduler.py
@@ -15,22 +15,15 @@
 # specific language governing permissions and limitations
 # under the License.
 
-from abc import ABCMeta, abstractmethod
-from typing import List
-
-try:
-    import importlib.metadata as importlib_metadata
-except ImportError:
-    import importlib_metadata
-
-import pyarrow as pa
-
-from .ballista_internal import (
-    BallistaBuilder,
-)
-
-__version__ = importlib_metadata.version(__name__)
-
-__all__ = [
-    "BallistaBuilder",
-]
\ No newline at end of file
+# %%
+from ballista import BallistaScheduler
+# %%
+scheduler = BallistaScheduler()
+# %%
+scheduler
+# %%
+scheduler.start()
+# %%
+scheduler.wait_for_termination()
+# %%
+scheduler.close()
\ No newline at end of file
diff --git a/python/pyproject.toml b/python/pyproject.toml
index cce88fd3..d9b6d2bd 100644
--- a/python/pyproject.toml
+++ b/python/pyproject.toml
@@ -43,7 +43,7 @@ classifier = [
     "Programming Language :: Rust",
 ]
 dependencies = [
-    "pyarrow>=11.0.0",
+    "pyarrow>=11.0.0", "cloudpickle"
 ]
 
 [project.urls]
diff --git a/python/requirements.txt b/python/requirements.txt
index a03a8f8d..bfc0e03c 100644
--- a/python/requirements.txt
+++ b/python/requirements.txt
@@ -1,3 +1,6 @@
-datafusion==35.0.0
+datafusion==42.0.0
 pyarrow
-pytest
\ No newline at end of file
+pytest
+maturin==1.5.1
+cloudpickle
+pandas
\ No newline at end of file
diff --git a/python/src/cluster.rs b/python/src/cluster.rs
new file mode 100644
index 00000000..aa4260ce
--- /dev/null
+++ b/python/src/cluster.rs
@@ -0,0 +1,264 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::future::IntoFuture;
+use std::sync::Arc;
+
+use crate::codec::{PyLogicalCodec, PyPhysicalCodec};
+use crate::utils::to_pyerr;
+use crate::utils::{spawn_feature, wait_for_future};
+use ballista_executor::executor_process::{
+    start_executor_process, ExecutorProcessConfig,
+};
+use ballista_scheduler::cluster::BallistaCluster;
+use ballista_scheduler::config::SchedulerConfig;
+use ballista_scheduler::scheduler_process::start_server;
+use pyo3::exceptions::PyException;
+use pyo3::{pyclass, pymethods, PyResult, Python};
+use tokio::task::JoinHandle;
+
+#[pyclass(name = "BallistaScheduler", module = "ballista", subclass)]
+pub struct PyScheduler {
+    config: SchedulerConfig,
+    handle: Option<JoinHandle<()>>,
+}
+
+#[pymethods]
+impl PyScheduler {
+    #[pyo3(signature = (bind_host=None, bind_port=None))]
+    #[new]
+    pub fn new(py: Python, bind_host: Option<String>, bind_port: Option<u16>) 
-> Self {
+        let mut config = SchedulerConfig::default();
+
+        if let Some(bind_port) = bind_port {
+            config.bind_port = bind_port;
+        }
+
+        if let Some(host) = bind_host {
+            config.bind_host = host;
+        }
+
+        config.override_logical_codec =
+            Some(Arc::new(PyLogicalCodec::try_new(py).unwrap()));
+        config.override_physical_codec =
+            Some(Arc::new(PyPhysicalCodec::try_new(py).unwrap()));
+
+        Self {
+            config,
+            handle: None,
+        }
+    }
+
+    pub fn start(&mut self, py: Python) -> PyResult<()> {
+        if self.handle.is_some() {
+            return Err(PyException::new_err("Scheduler already started"));
+        }
+        let cluster = wait_for_future(py, 
BallistaCluster::new_from_config(&self.config))
+            .map_err(to_pyerr)?;
+
+        let config = self.config.clone();
+        let address = format!("{}:{}", config.bind_host, config.bind_port);
+        let address = address.parse()?;
+        let handle = spawn_feature(py, async move {
+            start_server(cluster, address, Arc::new(config))
+                .await
+                .unwrap();
+        });
+        self.handle = Some(handle);
+
+        Ok(())
+    }
+
+    pub fn wait_for_termination(&mut self, py: Python) -> PyResult<()> {
+        if self.handle.is_none() {
+            return Err(PyException::new_err("Scheduler not started"));
+        }
+        let mut handle = None;
+        std::mem::swap(&mut self.handle, &mut handle);
+
+        match handle {
+            Some(handle) => wait_for_future(py, handle.into_future())
+                .map_err(|e| PyException::new_err(e.to_string())),
+            None => Ok(()),
+        }
+    }
+
+    pub fn close(&mut self) -> PyResult<()> {
+        let mut handle = None;
+        std::mem::swap(&mut self.handle, &mut handle);
+
+        if let Some(handle) = handle {
+            handle.abort()
+        }
+
+        Ok(())
+    }
+
+    #[classattr]
+    pub fn version() -> &'static str {
+        ballista_core::BALLISTA_VERSION
+    }
+
+    pub fn __str__(&self) -> String {
+        match self.handle {
+            Some(_) => format!(
+                "listening address={}:{}",
+                self.config.bind_host, self.config.bind_port,
+            ),
+            None => format!(
+                "configured address={}:{}",
+                self.config.bind_host, self.config.bind_port,
+            ),
+        }
+    }
+
+    pub fn __repr__(&self) -> String {
+        format!(
+            "BallistaScheduler(config={:?}, listening= {})",
+            self.config,
+            self.handle.is_some()
+        )
+    }
+}
+
+#[pyclass(name = "BallistaExecutor", module = "ballista", subclass)]
+pub struct PyExecutor {
+    config: Arc<ExecutorProcessConfig>,
+    handle: Option<JoinHandle<()>>,
+}
+
+#[pymethods]
+impl PyExecutor {
+    #[pyo3(signature = (bind_port=None, bind_host =None, scheduler_host = 
None, scheduler_port = None, concurrent_tasks = None))]
+    #[new]
+    pub fn new(
+        py: Python,
+        bind_port: Option<u16>,
+        bind_host: Option<String>,
+        scheduler_host: Option<String>,
+        scheduler_port: Option<u16>,
+        concurrent_tasks: Option<u16>,
+    ) -> PyResult<Self> {
+        let mut config = ExecutorProcessConfig::default();
+        if let Some(port) = bind_port {
+            config.port = port;
+        }
+
+        if let Some(host) = bind_host {
+            config.bind_host = host;
+        }
+
+        if let Some(port) = scheduler_port {
+            config.scheduler_port = port;
+        }
+
+        if let Some(host) = scheduler_host {
+            config.scheduler_host = host;
+        }
+
+        if let Some(concurrent_tasks) = concurrent_tasks {
+            config.concurrent_tasks = concurrent_tasks as usize
+        }
+
+        config.override_logical_codec = 
Some(Arc::new(PyLogicalCodec::try_new(py)?));
+        config.override_physical_codec = 
Some(Arc::new(PyPhysicalCodec::try_new(py)?));
+
+        let config = Arc::new(config);
+        Ok(Self {
+            config,
+            handle: None,
+        })
+    }
+
+    pub fn start(&mut self, py: Python) -> PyResult<()> {
+        if self.handle.is_some() {
+            return Err(PyException::new_err("Executor already started"));
+        }
+
+        let config = self.config.clone();
+
+        let handle =
+            spawn_feature(
+                py,
+                async move { start_executor_process(config).await.unwrap() },
+            );
+        self.handle = Some(handle);
+
+        Ok(())
+    }
+
+    pub fn wait_for_termination(&mut self, py: Python) -> PyResult<()> {
+        if self.handle.is_none() {
+            return Err(PyException::new_err("Executor not started"));
+        }
+        let mut handle = None;
+        std::mem::swap(&mut self.handle, &mut handle);
+
+        match handle {
+            Some(handle) => wait_for_future(py, handle.into_future())
+                .map_err(|e| PyException::new_err(e.to_string()))
+                .map(|_| ()),
+            None => Ok(()),
+        }
+    }
+
+    pub fn close(&mut self) -> PyResult<()> {
+        let mut handle = None;
+        std::mem::swap(&mut self.handle, &mut handle);
+
+        if let Some(handle) = handle {
+            handle.abort()
+        }
+
+        Ok(())
+    }
+
+    #[classattr]
+    pub fn version() -> &'static str {
+        ballista_core::BALLISTA_VERSION
+    }
+
+    pub fn __str__(&self) -> String {
+        match self.handle {
+            Some(_) => format!(
+                "listening address={}:{}, scheduler={}:{}",
+                self.config.bind_host,
+                self.config.port,
+                self.config.scheduler_host,
+                self.config.scheduler_port
+            ),
+            None => format!(
+                "configured address={}:{}, scheduler={}:{}",
+                self.config.bind_host,
+                self.config.port,
+                self.config.scheduler_host,
+                self.config.scheduler_port
+            ),
+        }
+    }
+
+    pub fn __repr__(&self) -> String {
+        format!(
+            "BallistaExecutor(address={}:{}, scheduler={}:{}, listening={})",
+            self.config.bind_host,
+            self.config.port,
+            self.config.scheduler_host,
+            self.config.scheduler_port,
+            self.handle.is_some()
+        )
+    }
+}
diff --git a/python/src/codec.rs b/python/src/codec.rs
new file mode 100644
index 00000000..c6b0b7e5
--- /dev/null
+++ b/python/src/codec.rs
@@ -0,0 +1,253 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use ballista_core::serde::{
+    BallistaLogicalExtensionCodec, BallistaPhysicalExtensionCodec,
+};
+use datafusion::logical_expr::ScalarUDF;
+use datafusion_proto::logical_plan::LogicalExtensionCodec;
+use datafusion_proto::physical_plan::PhysicalExtensionCodec;
+use pyo3::types::{PyAnyMethods, PyBytes, PyBytesMethods};
+use pyo3::{PyObject, PyResult, Python};
+use std::fmt::Debug;
+use std::sync::Arc;
+
+static MODULE: &str = "cloudpickle";
+static FUN_LOADS: &str = "loads";
+static FUN_DUMPS: &str = "dumps";
+
+/// Serde protocol for UD(a)F
+#[derive(Debug)]
+struct CloudPickle {
+    loads: PyObject,
+    dumps: PyObject,
+}
+
+impl CloudPickle {
+    pub fn try_new(py: Python<'_>) -> PyResult<Self> {
+        let module = py.import_bound(MODULE)?;
+        let loads = module.getattr(FUN_LOADS)?.unbind();
+        let dumps = module.getattr(FUN_DUMPS)?.unbind();
+
+        Ok(Self { loads, dumps })
+    }
+
+    pub fn pickle(&self, py: Python<'_>, py_any: &PyObject) -> 
PyResult<Vec<u8>> {
+        let b: PyObject = self.dumps.call1(py, (py_any,))?.extract(py)?;
+        let blob = b.downcast_bound::<PyBytes>(py)?.clone();
+
+        Ok(blob.as_bytes().to_owned())
+    }
+
+    pub fn unpickle(&self, py: Python<'_>, blob: &[u8]) -> PyResult<PyObject> {
+        let t: PyObject = self.loads.call1(py, (blob,))?.extract(py)?;
+
+        Ok(t)
+    }
+}
+
+pub struct PyLogicalCodec {
+    inner: BallistaLogicalExtensionCodec,
+    cloudpickle: CloudPickle,
+}
+
+impl PyLogicalCodec {
+    pub fn try_new(py: Python<'_>) -> PyResult<Self> {
+        Ok(Self {
+            inner: BallistaLogicalExtensionCodec::default(),
+            cloudpickle: CloudPickle::try_new(py)?,
+        })
+    }
+}
+
+impl Debug for PyLogicalCodec {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("PyLogicalCodec").finish()
+    }
+}
+
+impl LogicalExtensionCodec for PyLogicalCodec {
+    fn try_decode(
+        &self,
+        buf: &[u8],
+        inputs: &[datafusion::logical_expr::LogicalPlan],
+        ctx: &datafusion::prelude::SessionContext,
+    ) -> datafusion::error::Result<datafusion::logical_expr::Extension> {
+        self.inner.try_decode(buf, inputs, ctx)
+    }
+
+    fn try_encode(
+        &self,
+        node: &datafusion::logical_expr::Extension,
+        buf: &mut Vec<u8>,
+    ) -> datafusion::error::Result<()> {
+        self.inner.try_encode(node, buf)
+    }
+
+    fn try_decode_table_provider(
+        &self,
+        buf: &[u8],
+        table_ref: &datafusion::sql::TableReference,
+        schema: datafusion::arrow::datatypes::SchemaRef,
+        ctx: &datafusion::prelude::SessionContext,
+    ) -> datafusion::error::Result<std::sync::Arc<dyn 
datafusion::catalog::TableProvider>>
+    {
+        self.inner
+            .try_decode_table_provider(buf, table_ref, schema, ctx)
+    }
+
+    fn try_encode_table_provider(
+        &self,
+        table_ref: &datafusion::sql::TableReference,
+        node: std::sync::Arc<dyn datafusion::catalog::TableProvider>,
+        buf: &mut Vec<u8>,
+    ) -> datafusion::error::Result<()> {
+        self.inner.try_encode_table_provider(table_ref, node, buf)
+    }
+
+    fn try_decode_file_format(
+        &self,
+        buf: &[u8],
+        ctx: &datafusion::prelude::SessionContext,
+    ) -> datafusion::error::Result<
+        std::sync::Arc<dyn 
datafusion::datasource::file_format::FileFormatFactory>,
+    > {
+        self.inner.try_decode_file_format(buf, ctx)
+    }
+
+    fn try_encode_file_format(
+        &self,
+        buf: &mut Vec<u8>,
+        node: std::sync::Arc<dyn 
datafusion::datasource::file_format::FileFormatFactory>,
+    ) -> datafusion::error::Result<()> {
+        self.inner.try_encode_file_format(buf, node)
+    }
+
+    fn try_decode_udf(
+        &self,
+        name: &str,
+        buf: &[u8],
+    ) -> 
datafusion::error::Result<std::sync::Arc<datafusion::logical_expr::ScalarUDF>>
+    {
+        // use cloud pickle to decode udf
+        self.inner.try_decode_udf(name, buf)
+    }
+
+    fn try_encode_udf(
+        &self,
+        node: &datafusion::logical_expr::ScalarUDF,
+        buf: &mut Vec<u8>,
+    ) -> datafusion::error::Result<()> {
+        // use cloud pickle to decode udf
+        self.inner.try_encode_udf(node, buf)
+    }
+
+    fn try_decode_udaf(
+        &self,
+        name: &str,
+        buf: &[u8],
+    ) -> 
datafusion::error::Result<std::sync::Arc<datafusion::logical_expr::AggregateUDF>>
+    {
+        self.inner.try_decode_udaf(name, buf)
+    }
+
+    fn try_encode_udaf(
+        &self,
+        node: &datafusion::logical_expr::AggregateUDF,
+        buf: &mut Vec<u8>,
+    ) -> datafusion::error::Result<()> {
+        self.inner.try_encode_udaf(node, buf)
+    }
+
+    fn try_decode_udwf(
+        &self,
+        name: &str,
+        buf: &[u8],
+    ) -> 
datafusion::error::Result<std::sync::Arc<datafusion::logical_expr::WindowUDF>>
+    {
+        self.inner.try_decode_udwf(name, buf)
+    }
+
+    fn try_encode_udwf(
+        &self,
+        node: &datafusion::logical_expr::WindowUDF,
+        buf: &mut Vec<u8>,
+    ) -> datafusion::error::Result<()> {
+        self.inner.try_encode_udwf(node, buf)
+    }
+}
+
+pub struct PyPhysicalCodec {
+    inner: BallistaPhysicalExtensionCodec,
+    cloudpickle: CloudPickle,
+}
+
+impl Debug for PyPhysicalCodec {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("PyPhysicalCodec").finish()
+    }
+}
+
+impl PyPhysicalCodec {
+    pub fn try_new(py: Python<'_>) -> PyResult<Self> {
+        Ok(Self {
+            inner: BallistaPhysicalExtensionCodec::default(),
+            cloudpickle: CloudPickle::try_new(py)?,
+        })
+    }
+}
+
+impl PhysicalExtensionCodec for PyPhysicalCodec {
+    fn try_decode(
+        &self,
+        buf: &[u8],
+        inputs: &[std::sync::Arc<dyn 
datafusion::physical_plan::ExecutionPlan>],
+        registry: &dyn datafusion::execution::FunctionRegistry,
+    ) -> datafusion::error::Result<
+        std::sync::Arc<dyn datafusion::physical_plan::ExecutionPlan>,
+    > {
+        self.inner.try_decode(buf, inputs, registry)
+    }
+
+    fn try_encode(
+        &self,
+        node: std::sync::Arc<dyn datafusion::physical_plan::ExecutionPlan>,
+        buf: &mut Vec<u8>,
+    ) -> datafusion::error::Result<()> {
+        self.inner.try_encode(node, buf)
+    }
+
+    fn try_decode_udf(
+        &self,
+        name: &str,
+        _buf: &[u8],
+    ) -> datafusion::common::Result<Arc<ScalarUDF>> {
+        // use cloudpickle here
+        datafusion::common::not_impl_err!(
+            "PhysicalExtensionCodec is not provided for scalar function {name}"
+        )
+    }
+
+    fn try_encode_udf(
+        &self,
+        _node: &ScalarUDF,
+        _buf: &mut Vec<u8>,
+    ) -> datafusion::common::Result<()> {
+        // use cloudpickle here
+        Ok(())
+    }
+}
diff --git a/python/src/lib.rs b/python/src/lib.rs
index 41b4b6d3..13a6c38b 100644
--- a/python/src/lib.rs
+++ b/python/src/lib.rs
@@ -15,32 +15,36 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use crate::utils::wait_for_future;
 use ballista::prelude::*;
+use cluster::{PyExecutor, PyScheduler};
 use datafusion::execution::SessionStateBuilder;
 use datafusion::prelude::*;
 use datafusion_python::context::PySessionContext;
-use datafusion_python::utils::wait_for_future;
-
-use std::collections::HashMap;
-
 use pyo3::prelude::*;
+
+mod cluster;
+#[allow(dead_code)]
+mod codec;
 mod utils;
-use utils::to_pyerr;
+
+pub(crate) struct TokioRuntime(tokio::runtime::Runtime);
 
 #[pymodule]
 fn ballista_internal(_py: Python, m: Bound<'_, PyModule>) -> PyResult<()> {
     pyo3_log::init();
-    // BallistaBuilder struct
+
     m.add_class::<PyBallistaBuilder>()?;
-    // DataFusion struct
     m.add_class::<datafusion_python::dataframe::PyDataFrame>()?;
+    m.add_class::<PyScheduler>()?;
+    m.add_class::<PyExecutor>()?;
+
     Ok(())
 }
 
-// Ballista Builder will take a HasMap/Dict Cionfg
 #[pyclass(name = "BallistaBuilder", module = "ballista", subclass)]
 pub struct PyBallistaBuilder {
-    conf: HashMap<String, String>,
+    session_config: SessionConfig,
 }
 
 #[pymethods]
@@ -48,56 +52,47 @@ impl PyBallistaBuilder {
     #[new]
     pub fn new() -> Self {
         Self {
-            conf: HashMap::new(),
+            session_config: SessionConfig::new_with_ballista(),
         }
     }
 
     pub fn config(
         mut slf: PyRefMut<'_, Self>,
-        k: &str,
-        v: &str,
+        key: &str,
+        value: &str,
         py: Python,
     ) -> PyResult<PyObject> {
-        slf.conf.insert(k.into(), v.into());
+        let _ = slf.session_config.options_mut().set(key, value);
 
         Ok(slf.into_py(py))
     }
 
     /// Construct the standalone instance from the SessionContext
     pub fn standalone(&self, py: Python) -> PyResult<PySessionContext> {
-        // Build the config
-        let config: SessionConfig = 
SessionConfig::from_string_hash_map(&self.conf)?;
-        // Build the state
         let state = SessionStateBuilder::new()
-            .with_config(config)
+            .with_config(self.session_config.clone())
             .with_default_features()
             .build();
-        // Build the context
-        let standalone_session = SessionContext::standalone_with_state(state);
 
-        // SessionContext is an async function
-        let ctx = wait_for_future(py, standalone_session)?;
+        let ctx = wait_for_future(py, 
SessionContext::standalone_with_state(state))?;
 
-        // Convert the SessionContext into a Python SessionContext
         Ok(ctx.into())
     }
 
     /// Construct the remote instance from the SessionContext
     pub fn remote(&self, url: &str, py: Python) -> PyResult<PySessionContext> {
-        // Build the config
-        let config: SessionConfig = 
SessionConfig::from_string_hash_map(&self.conf)?;
-        // Build the state
         let state = SessionStateBuilder::new()
-            .with_config(config)
+            .with_config(self.session_config.clone())
             .with_default_features()
             .build();
-        // Build the context
-        let remote_session = SessionContext::remote_with_state(url, state);
 
-        // SessionContext is an async function
-        let ctx = wait_for_future(py, remote_session)?;
+        let ctx = wait_for_future(py, SessionContext::remote_with_state(url, 
state))?;
 
-        // Convert the SessionContext into a Python SessionContext
         Ok(ctx.into())
     }
+
+    #[classattr]
+    pub fn version() -> &'static str {
+        ballista_core::BALLISTA_VERSION
+    }
 }
diff --git a/python/src/utils.rs b/python/src/utils.rs
index 10278537..f069475e 100644
--- a/python/src/utils.rs
+++ b/python/src/utils.rs
@@ -15,10 +15,48 @@
 // specific language governing permissions and limitations
 // under the License.
 
+use std::future::Future;
+use std::sync::OnceLock;
+use tokio::task::JoinHandle;
+
 use ballista_core::error::BallistaError;
 use pyo3::exceptions::PyException;
-use pyo3::PyErr;
+use pyo3::{PyErr, Python};
+use tokio::runtime::Runtime;
+
+use crate::TokioRuntime;
 
 pub(crate) fn to_pyerr(err: BallistaError) -> PyErr {
     PyException::new_err(err.to_string())
 }
+
+#[inline]
+pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
+    // NOTE: Other pyo3 python libraries have had issues with using tokio
+    // behind a forking app-server like `gunicorn`
+    // If we run into that problem, in the future we can look to `delta-rs`
+    // which adds a check in that disallows calls from a forked process
+    // 
https://github.com/delta-io/delta-rs/blob/87010461cfe01563d91a4b9cd6fa468e2ad5f283/python/src/utils.rs#L10-L31
+    static RUNTIME: OnceLock<TokioRuntime> = OnceLock::new();
+    RUNTIME.get_or_init(|| 
TokioRuntime(tokio::runtime::Runtime::new().unwrap()))
+}
+
+/// Utility to collect rust futures with GIL released
+pub(crate) fn wait_for_future<F>(py: Python, f: F) -> F::Output
+where
+    F: Future + Send,
+    F::Output: Send,
+{
+    let runtime: &Runtime = &get_tokio_runtime().0;
+    py.allow_threads(|| runtime.block_on(f))
+}
+
+pub(crate) fn spawn_feature<F>(py: Python, f: F) -> JoinHandle<F::Output>
+where
+    F: Future + Send + 'static,
+    F::Output: Send,
+{
+    let runtime: &Runtime = &get_tokio_runtime().0;
+    // do we need py.allow_threads ?
+    py.allow_threads(|| runtime.spawn(f))
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to