This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 6e570e2  Bind SQLOptions and relative ctx method #567 (#588)
6e570e2 is described below

commit 6e570e211fff70ef69f9fa5a4a3ba2830d4f3e0c
Author: Giacomo Rebecchi <[email protected]>
AuthorDate: Sat Apr 13 23:37:55 2024 +0200

    Bind SQLOptions and relative ctx method #567 (#588)
---
 datafusion/__init__.py           |  2 ++
 datafusion/tests/test_context.py | 46 ++++++++++++++++++++++++++++----
 src/context.rs                   | 57 +++++++++++++++++++++++++++++++++++++++-
 src/lib.rs                       |  1 +
 4 files changed, 100 insertions(+), 6 deletions(-)

diff --git a/datafusion/__init__.py b/datafusion/__init__.py
index cfb6736..c50bf64 100644
--- a/datafusion/__init__.py
+++ b/datafusion/__init__.py
@@ -33,6 +33,7 @@ from ._internal import (
     SessionConfig,
     RuntimeConfig,
     ScalarUDF,
+    SQLOptions,
 )
 
 from .common import (
@@ -96,6 +97,7 @@ __all__ = [
     "DataFrame",
     "SessionContext",
     "SessionConfig",
+    "SQLOptions",
     "RuntimeConfig",
     "Expr",
     "AggregateUDF",
diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py
index d48bdd9..962a7ff 100644
--- a/datafusion/tests/test_context.py
+++ b/datafusion/tests/test_context.py
@@ -19,16 +19,17 @@ import os
 
 import pyarrow as pa
 import pyarrow.dataset as ds
+import pytest
 
 from datafusion import (
+    DataFrame,
+    RuntimeConfig,
+    SessionConfig,
+    SessionContext,
+    SQLOptions,
     column,
     literal,
-    SessionContext,
-    SessionConfig,
-    RuntimeConfig,
-    DataFrame,
 )
-import pytest
 
 
 def test_create_context_no_args():
@@ -389,3 +390,38 @@ def test_read_parquet(ctx):
 def test_read_avro(ctx):
     csv_df = ctx.read_avro(path="testing/data/avro/alltypes_plain.avro")
     csv_df.show()
+
+
+def test_create_sql_options():
+    SQLOptions()
+
+
+def test_sql_with_options_no_ddl(ctx):
+    sql = "CREATE TABLE IF NOT EXISTS valuetable AS 
VALUES(1,'HELLO'),(12,'DATAFUSION')"
+    ctx.sql(sql)
+    options = SQLOptions().with_allow_ddl(False)
+    with pytest.raises(Exception, match="DDL"):
+        ctx.sql_with_options(sql, options=options)
+
+
+def test_sql_with_options_no_dml(ctx):
+    table_name = "t"
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+        names=["a", "b"],
+    )
+    dataset = ds.dataset([batch])
+    ctx.register_dataset(table_name, dataset)
+    sql = f'INSERT INTO "{table_name}" VALUES (1, 2), (2, 3);'
+    ctx.sql(sql)
+    options = SQLOptions().with_allow_dml(False)
+    with pytest.raises(Exception, match="DML"):
+        ctx.sql_with_options(sql, options=options)
+
+
+def test_sql_with_options_no_statements(ctx):
+    sql = "SET time zone = 1;"
+    ctx.sql(sql)
+    options = SQLOptions().with_allow_statements(False)
+    with pytest.raises(Exception, match="SetVariable"):
+        ctx.sql_with_options(sql, options=options)
diff --git a/src/context.rs b/src/context.rs
index f34fbce..825d3e3 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -45,7 +45,9 @@ use datafusion::arrow::record_batch::RecordBatch;
 use 
datafusion::datasource::file_format::file_compression_type::FileCompressionType;
 use datafusion::datasource::MemTable;
 use datafusion::datasource::TableProvider;
-use datafusion::execution::context::{SessionConfig, SessionContext, 
SessionState, TaskContext};
+use datafusion::execution::context::{
+    SQLOptions, SessionConfig, SessionContext, SessionState, TaskContext,
+};
 use datafusion::execution::disk_manager::DiskManagerConfig;
 use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, 
UnboundedMemoryPool};
 use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
@@ -210,6 +212,43 @@ impl PyRuntimeConfig {
     }
 }
 
+/// `PySQLOptions` allows you to specify options to the sql execution.
+#[pyclass(name = "SQLOptions", module = "datafusion", subclass)]
+#[derive(Clone)]
+pub struct PySQLOptions {
+    pub options: SQLOptions,
+}
+
+impl From<SQLOptions> for PySQLOptions {
+    fn from(options: SQLOptions) -> Self {
+        Self { options }
+    }
+}
+
+#[pymethods]
+impl PySQLOptions {
+    #[new]
+    fn new() -> Self {
+        let options = SQLOptions::new();
+        Self { options }
+    }
+
+    /// Should DDL data modification commands  (e.g. `CREATE TABLE`) be run? 
Defaults to `true`.
+    fn with_allow_ddl(&self, allow: bool) -> Self {
+        Self::from(self.options.with_allow_ddl(allow))
+    }
+
+    /// Should DML data modification commands (e.g. `INSERT and COPY`) be run? 
Defaults to `true`
+    pub fn with_allow_dml(&self, allow: bool) -> Self {
+        Self::from(self.options.with_allow_dml(allow))
+    }
+
+    /// Should Statements such as (e.g. `SET VARIABLE and `BEGIN TRANSACTION` 
...`) be run?. Defaults to `true`
+    pub fn with_allow_statements(&self, allow: bool) -> Self {
+        Self::from(self.options.with_allow_statements(allow))
+    }
+}
+
 /// `PySessionContext` is able to plan and execute DataFusion plans.
 /// It has a powerful optimizer, a physical planner for local execution, and a
 /// multi-threaded execution engine to perform the execution.
@@ -285,6 +324,22 @@ impl PySessionContext {
         Ok(PyDataFrame::new(df))
     }
 
+    pub fn sql_with_options(
+        &mut self,
+        query: &str,
+        options: Option<PySQLOptions>,
+        py: Python,
+    ) -> PyResult<PyDataFrame> {
+        let options = if let Some(options) = options {
+            options.options
+        } else {
+            SQLOptions::new()
+        };
+        let result = self.ctx.sql_with_options(query, options);
+        let df = wait_for_future(py, result).map_err(DataFusionError::from)?;
+        Ok(PyDataFrame::new(df))
+    }
+
     pub fn create_dataframe(
         &mut self,
         partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
diff --git a/src/lib.rs b/src/lib.rs
index 49c325a..a696ebf 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -85,6 +85,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
     m.add_class::<context::PyRuntimeConfig>()?;
     m.add_class::<context::PySessionConfig>()?;
     m.add_class::<context::PySessionContext>()?;
+    m.add_class::<context::PySQLOptions>()?;
     m.add_class::<dataframe::PyDataFrame>()?;
     m.add_class::<udf::PyScalarUDF>()?;
     m.add_class::<udaf::PyAggregateUDF>()?;

Reply via email to