This is an automated email from the ASF dual-hosted git repository.

timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 9b6acec0 Support types other than String and Int for partition columns 
(#1154)
9b6acec0 is described below

commit 9b6acec075f49d551a2b90608b0c7114de84d718
Author: Michele Gregori <micl...@gmail.com>
AuthorDate: Thu Jun 19 19:58:22 2025 +0200

    Support types other than String and Int for partition columns (#1154)
    
    * impl
    
    impl
    
    * fix test
    
    * format rust
    
    * support for old logic
    
    dasdas
    
    * also on io
    
    * fix formatting
    
    ---------
    
    Co-authored-by: michele gregori <michelegregor...@gmail.com>
---
 python/datafusion/context.py | 66 +++++++++++++++++++++++++++-----
 python/datafusion/io.py      |  8 ++--
 python/tests/test_sql.py     | 26 +++++++------
 src/context.rs               | 89 ++++++++++++++++++++++++++++----------------
 4 files changed, 132 insertions(+), 57 deletions(-)

diff --git a/python/datafusion/context.py b/python/datafusion/context.py
index 4ed465c9..5b99b0d2 100644
--- a/python/datafusion/context.py
+++ b/python/datafusion/context.py
@@ -19,8 +19,11 @@
 
 from __future__ import annotations
 
+import warnings
 from typing import TYPE_CHECKING, Any, Protocol
 
+import pyarrow as pa
+
 try:
     from warnings import deprecated  # Python 3.13+
 except ImportError:
@@ -42,7 +45,6 @@ if TYPE_CHECKING:
 
     import pandas as pd
     import polars as pl
-    import pyarrow as pa
 
     from datafusion.plan import ExecutionPlan, LogicalPlan
 
@@ -539,7 +541,7 @@ class SessionContext:
         self,
         name: str,
         path: str | pathlib.Path,
-        table_partition_cols: list[tuple[str, str]] | None = None,
+        table_partition_cols: list[tuple[str, str | pa.DataType]] | None = 
None,
         file_extension: str = ".parquet",
         schema: pa.Schema | None = None,
         file_sort_order: list[list[Expr | SortExpr]] | None = None,
@@ -560,6 +562,7 @@ class SessionContext:
         """
         if table_partition_cols is None:
             table_partition_cols = []
+        table_partition_cols = 
self._convert_table_partition_cols(table_partition_cols)
         file_sort_order_raw = (
             [sort_list_to_raw_sort_list(f) for f in file_sort_order]
             if file_sort_order is not None
@@ -778,7 +781,7 @@ class SessionContext:
         self,
         name: str,
         path: str | pathlib.Path,
-        table_partition_cols: list[tuple[str, str]] | None = None,
+        table_partition_cols: list[tuple[str, str | pa.DataType]] | None = 
None,
         parquet_pruning: bool = True,
         file_extension: str = ".parquet",
         skip_metadata: bool = True,
@@ -806,6 +809,7 @@ class SessionContext:
         """
         if table_partition_cols is None:
             table_partition_cols = []
+        table_partition_cols = 
self._convert_table_partition_cols(table_partition_cols)
         self.ctx.register_parquet(
             name,
             str(path),
@@ -869,7 +873,7 @@ class SessionContext:
         schema: pa.Schema | None = None,
         schema_infer_max_records: int = 1000,
         file_extension: str = ".json",
-        table_partition_cols: list[tuple[str, str]] | None = None,
+        table_partition_cols: list[tuple[str, str | pa.DataType]] | None = 
None,
         file_compression_type: str | None = None,
     ) -> None:
         """Register a JSON file as a table.
@@ -890,6 +894,7 @@ class SessionContext:
         """
         if table_partition_cols is None:
             table_partition_cols = []
+        table_partition_cols = 
self._convert_table_partition_cols(table_partition_cols)
         self.ctx.register_json(
             name,
             str(path),
@@ -906,7 +911,7 @@ class SessionContext:
         path: str | pathlib.Path,
         schema: pa.Schema | None = None,
         file_extension: str = ".avro",
-        table_partition_cols: list[tuple[str, str]] | None = None,
+        table_partition_cols: list[tuple[str, str | pa.DataType]] | None = 
None,
     ) -> None:
         """Register an Avro file as a table.
 
@@ -922,6 +927,7 @@ class SessionContext:
         """
         if table_partition_cols is None:
             table_partition_cols = []
+        table_partition_cols = 
self._convert_table_partition_cols(table_partition_cols)
         self.ctx.register_avro(
             name, str(path), schema, file_extension, table_partition_cols
         )
@@ -981,7 +987,7 @@ class SessionContext:
         schema: pa.Schema | None = None,
         schema_infer_max_records: int = 1000,
         file_extension: str = ".json",
-        table_partition_cols: list[tuple[str, str]] | None = None,
+        table_partition_cols: list[tuple[str, str | pa.DataType]] | None = 
None,
         file_compression_type: str | None = None,
     ) -> DataFrame:
         """Read a line-delimited JSON data source.
@@ -1001,6 +1007,7 @@ class SessionContext:
         """
         if table_partition_cols is None:
             table_partition_cols = []
+        table_partition_cols = 
self._convert_table_partition_cols(table_partition_cols)
         return DataFrame(
             self.ctx.read_json(
                 str(path),
@@ -1020,7 +1027,7 @@ class SessionContext:
         delimiter: str = ",",
         schema_infer_max_records: int = 1000,
         file_extension: str = ".csv",
-        table_partition_cols: list[tuple[str, str]] | None = None,
+        table_partition_cols: list[tuple[str, str | pa.DataType]] | None = 
None,
         file_compression_type: str | None = None,
     ) -> DataFrame:
         """Read a CSV data source.
@@ -1045,6 +1052,7 @@ class SessionContext:
         """
         if table_partition_cols is None:
             table_partition_cols = []
+        table_partition_cols = 
self._convert_table_partition_cols(table_partition_cols)
 
         path = [str(p) for p in path] if isinstance(path, list) else str(path)
 
@@ -1064,7 +1072,7 @@ class SessionContext:
     def read_parquet(
         self,
         path: str | pathlib.Path,
-        table_partition_cols: list[tuple[str, str]] | None = None,
+        table_partition_cols: list[tuple[str, str | pa.DataType]] | None = 
None,
         parquet_pruning: bool = True,
         file_extension: str = ".parquet",
         skip_metadata: bool = True,
@@ -1093,6 +1101,7 @@ class SessionContext:
         """
         if table_partition_cols is None:
             table_partition_cols = []
+        table_partition_cols = 
self._convert_table_partition_cols(table_partition_cols)
         file_sort_order = (
             [sort_list_to_raw_sort_list(f) for f in file_sort_order]
             if file_sort_order is not None
@@ -1114,7 +1123,7 @@ class SessionContext:
         self,
         path: str | pathlib.Path,
         schema: pa.Schema | None = None,
-        file_partition_cols: list[tuple[str, str]] | None = None,
+        file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
         file_extension: str = ".avro",
     ) -> DataFrame:
         """Create a :py:class:`DataFrame` for reading Avro data source.
@@ -1130,6 +1139,7 @@ class SessionContext:
         """
         if file_partition_cols is None:
             file_partition_cols = []
+        file_partition_cols = 
self._convert_table_partition_cols(file_partition_cols)
         return DataFrame(
             self.ctx.read_avro(str(path), schema, file_partition_cols, 
file_extension)
         )
@@ -1146,3 +1156,41 @@ class SessionContext:
     def execute(self, plan: ExecutionPlan, partitions: int) -> 
RecordBatchStream:
         """Execute the ``plan`` and return the results."""
         return RecordBatchStream(self.ctx.execute(plan._raw_plan, partitions))
+
+    @staticmethod
+    def _convert_table_partition_cols(
+        table_partition_cols: list[tuple[str, str | pa.DataType]],
+    ) -> list[tuple[str, pa.DataType]]:
+        warn = False
+        converted_table_partition_cols = []
+
+        for col, data_type in table_partition_cols:
+            if isinstance(data_type, str):
+                warn = True
+                if data_type == "string":
+                    converted_data_type = pa.string()
+                elif data_type == "int":
+                    converted_data_type = pa.int32()
+                else:
+                    message = (
+                        f"Unsupported literal data type '{data_type}' for 
partition "
+                        "column. Supported types are 'string' and 'int'"
+                    )
+                    raise ValueError(message)
+            else:
+                converted_data_type = data_type
+
+            converted_table_partition_cols.append((col, converted_data_type))
+
+        if warn:
+            message = (
+                "using literals for table_partition_cols data types is 
deprecated,"
+                "use pyarrow types instead"
+            )
+            warnings.warn(
+                message,
+                category=DeprecationWarning,
+                stacklevel=2,
+            )
+
+        return converted_table_partition_cols
diff --git a/python/datafusion/io.py b/python/datafusion/io.py
index ef5ebf96..551e20a6 100644
--- a/python/datafusion/io.py
+++ b/python/datafusion/io.py
@@ -34,7 +34,7 @@ if TYPE_CHECKING:
 
 def read_parquet(
     path: str | pathlib.Path,
-    table_partition_cols: list[tuple[str, str]] | None = None,
+    table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
     parquet_pruning: bool = True,
     file_extension: str = ".parquet",
     skip_metadata: bool = True,
@@ -83,7 +83,7 @@ def read_json(
     schema: pa.Schema | None = None,
     schema_infer_max_records: int = 1000,
     file_extension: str = ".json",
-    table_partition_cols: list[tuple[str, str]] | None = None,
+    table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
     file_compression_type: str | None = None,
 ) -> DataFrame:
     """Read a line-delimited JSON data source.
@@ -124,7 +124,7 @@ def read_csv(
     delimiter: str = ",",
     schema_infer_max_records: int = 1000,
     file_extension: str = ".csv",
-    table_partition_cols: list[tuple[str, str]] | None = None,
+    table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
     file_compression_type: str | None = None,
 ) -> DataFrame:
     """Read a CSV data source.
@@ -171,7 +171,7 @@ def read_csv(
 def read_avro(
     path: str | pathlib.Path,
     schema: pa.Schema | None = None,
-    file_partition_cols: list[tuple[str, str]] | None = None,
+    file_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
     file_extension: str = ".avro",
 ) -> DataFrame:
     """Create a :py:class:`DataFrame` for reading Avro data source.
diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py
index b6348e3a..41cee4ef 100644
--- a/python/tests/test_sql.py
+++ b/python/tests/test_sql.py
@@ -157,8 +157,10 @@ def test_register_parquet(ctx, tmp_path):
     assert result.to_pydict() == {"cnt": [100]}
 
 
-@pytest.mark.parametrize("path_to_str", [True, False])
-def test_register_parquet_partitioned(ctx, tmp_path, path_to_str):
+@pytest.mark.parametrize(
+    ("path_to_str", "legacy_data_type"), [(True, False), (False, False), 
(False, True)]
+)
+def test_register_parquet_partitioned(ctx, tmp_path, path_to_str, 
legacy_data_type):
     dir_root = tmp_path / "dataset_parquet_partitioned"
     dir_root.mkdir(exist_ok=False)
     (dir_root / "grp=a").mkdir(exist_ok=False)
@@ -177,10 +179,12 @@ def test_register_parquet_partitioned(ctx, tmp_path, 
path_to_str):
 
     dir_root = str(dir_root) if path_to_str else dir_root
 
+    partition_data_type = "string" if legacy_data_type else pa.string()
+
     ctx.register_parquet(
         "datapp",
         dir_root,
-        table_partition_cols=[("grp", "string")],
+        table_partition_cols=[("grp", partition_data_type)],
         parquet_pruning=True,
         file_extension=".parquet",
     )
@@ -488,9 +492,9 @@ def test_register_listing_table(
 ):
     dir_root = tmp_path / "dataset_parquet_partitioned"
     dir_root.mkdir(exist_ok=False)
-    (dir_root / "grp=a/date_id=20201005").mkdir(exist_ok=False, parents=True)
-    (dir_root / "grp=a/date_id=20211005").mkdir(exist_ok=False, parents=True)
-    (dir_root / "grp=b/date_id=20201005").mkdir(exist_ok=False, parents=True)
+    (dir_root / "grp=a/date=2020-10-05").mkdir(exist_ok=False, parents=True)
+    (dir_root / "grp=a/date=2021-10-05").mkdir(exist_ok=False, parents=True)
+    (dir_root / "grp=b/date=2020-10-05").mkdir(exist_ok=False, parents=True)
 
     table = pa.Table.from_arrays(
         [
@@ -501,13 +505,13 @@ def test_register_listing_table(
         names=["int", "str", "float"],
     )
     pa.parquet.write_table(
-        table.slice(0, 3), dir_root / "grp=a/date_id=20201005/file.parquet"
+        table.slice(0, 3), dir_root / "grp=a/date=2020-10-05/file.parquet"
     )
     pa.parquet.write_table(
-        table.slice(3, 2), dir_root / "grp=a/date_id=20211005/file.parquet"
+        table.slice(3, 2), dir_root / "grp=a/date=2021-10-05/file.parquet"
     )
     pa.parquet.write_table(
-        table.slice(5, 10), dir_root / "grp=b/date_id=20201005/file.parquet"
+        table.slice(5, 10), dir_root / "grp=b/date=2020-10-05/file.parquet"
     )
 
     dir_root = f"file://{dir_root}/" if path_to_str else dir_root
@@ -515,7 +519,7 @@ def test_register_listing_table(
     ctx.register_listing_table(
         "my_table",
         dir_root,
-        table_partition_cols=[("grp", "string"), ("date_id", "int")],
+        table_partition_cols=[("grp", pa.string()), ("date", pa.date64())],
         file_extension=".parquet",
         schema=table.schema if pass_schema else None,
         file_sort_order=file_sort_order,
@@ -531,7 +535,7 @@ def test_register_listing_table(
     assert dict(zip(rd["grp"], rd["count"])) == {"a": 5, "b": 2}
 
     result = ctx.sql(
-        "SELECT grp, COUNT(*) AS count FROM my_table WHERE date_id=20201005 
GROUP BY grp"  # noqa: E501
+        "SELECT grp, COUNT(*) AS count FROM my_table WHERE date='2020-10-05' 
GROUP BY grp"  # noqa: E501
     ).collect()
     result = pa.Table.from_batches(result)
 
diff --git a/src/context.rs b/src/context.rs
index 55c92a8f..6ce1f12b 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -380,7 +380,7 @@ impl PySessionContext {
         &mut self,
         name: &str,
         path: &str,
-        table_partition_cols: Vec<(String, String)>,
+        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
         file_extension: &str,
         schema: Option<PyArrowType<Schema>>,
         file_sort_order: Option<Vec<Vec<PySortExpr>>>,
@@ -388,7 +388,12 @@ impl PySessionContext {
     ) -> PyDataFusionResult<()> {
         let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
             .with_file_extension(file_extension)
-            
.with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+            .with_table_partition_cols(
+                table_partition_cols
+                    .into_iter()
+                    .map(|(name, ty)| (name, ty.0))
+                    .collect::<Vec<(String, DataType)>>(),
+            )
             .with_file_sort_order(
                 file_sort_order
                     .unwrap_or_default()
@@ -656,7 +661,7 @@ impl PySessionContext {
         &mut self,
         name: &str,
         path: &str,
-        table_partition_cols: Vec<(String, String)>,
+        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
         parquet_pruning: bool,
         file_extension: &str,
         skip_metadata: bool,
@@ -665,7 +670,12 @@ impl PySessionContext {
         py: Python,
     ) -> PyDataFusionResult<()> {
         let mut options = ParquetReadOptions::default()
-            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+            .table_partition_cols(
+                table_partition_cols
+                    .into_iter()
+                    .map(|(name, ty)| (name, ty.0))
+                    .collect::<Vec<(String, DataType)>>(),
+            )
             .parquet_pruning(parquet_pruning)
             .skip_metadata(skip_metadata);
         options.file_extension = file_extension;
@@ -745,7 +755,7 @@ impl PySessionContext {
         schema: Option<PyArrowType<Schema>>,
         schema_infer_max_records: usize,
         file_extension: &str,
-        table_partition_cols: Vec<(String, String)>,
+        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
         file_compression_type: Option<String>,
         py: Python,
     ) -> PyDataFusionResult<()> {
@@ -755,7 +765,12 @@ impl PySessionContext {
 
         let mut options = NdJsonReadOptions::default()
             
.file_compression_type(parse_file_compression_type(file_compression_type)?)
-            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+            .table_partition_cols(
+                table_partition_cols
+                    .into_iter()
+                    .map(|(name, ty)| (name, ty.0))
+                    .collect::<Vec<(String, DataType)>>(),
+            );
         options.schema_infer_max_records = schema_infer_max_records;
         options.file_extension = file_extension;
         options.schema = schema.as_ref().map(|x| &x.0);
@@ -778,15 +793,19 @@ impl PySessionContext {
         path: PathBuf,
         schema: Option<PyArrowType<Schema>>,
         file_extension: &str,
-        table_partition_cols: Vec<(String, String)>,
+        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
         py: Python,
     ) -> PyDataFusionResult<()> {
         let path = path
             .to_str()
             .ok_or_else(|| PyValueError::new_err("Unable to convert path to a 
string"))?;
 
-        let mut options = AvroReadOptions::default()
-            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+        let mut options = AvroReadOptions::default().table_partition_cols(
+            table_partition_cols
+                .into_iter()
+                .map(|(name, ty)| (name, ty.0))
+                .collect::<Vec<(String, DataType)>>(),
+        );
         options.file_extension = file_extension;
         options.schema = schema.as_ref().map(|x| &x.0);
 
@@ -887,7 +906,7 @@ impl PySessionContext {
         schema: Option<PyArrowType<Schema>>,
         schema_infer_max_records: usize,
         file_extension: &str,
-        table_partition_cols: Vec<(String, String)>,
+        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
         file_compression_type: Option<String>,
         py: Python,
     ) -> PyDataFusionResult<PyDataFrame> {
@@ -895,7 +914,12 @@ impl PySessionContext {
             .to_str()
             .ok_or_else(|| PyValueError::new_err("Unable to convert path to a 
string"))?;
         let mut options = NdJsonReadOptions::default()
-            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+            .table_partition_cols(
+                table_partition_cols
+                    .into_iter()
+                    .map(|(name, ty)| (name, ty.0))
+                    .collect::<Vec<(String, DataType)>>(),
+            )
             
.file_compression_type(parse_file_compression_type(file_compression_type)?);
         options.schema_infer_max_records = schema_infer_max_records;
         options.file_extension = file_extension;
@@ -928,7 +952,7 @@ impl PySessionContext {
         delimiter: &str,
         schema_infer_max_records: usize,
         file_extension: &str,
-        table_partition_cols: Vec<(String, String)>,
+        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
         file_compression_type: Option<String>,
         py: Python,
     ) -> PyDataFusionResult<PyDataFrame> {
@@ -944,7 +968,12 @@ impl PySessionContext {
             .delimiter(delimiter[0])
             .schema_infer_max_records(schema_infer_max_records)
             .file_extension(file_extension)
-            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+            .table_partition_cols(
+                table_partition_cols
+                    .into_iter()
+                    .map(|(name, ty)| (name, ty.0))
+                    .collect::<Vec<(String, DataType)>>(),
+            )
             
.file_compression_type(parse_file_compression_type(file_compression_type)?);
         options.schema = schema.as_ref().map(|x| &x.0);
 
@@ -974,7 +1003,7 @@ impl PySessionContext {
     pub fn read_parquet(
         &self,
         path: &str,
-        table_partition_cols: Vec<(String, String)>,
+        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
         parquet_pruning: bool,
         file_extension: &str,
         skip_metadata: bool,
@@ -983,7 +1012,12 @@ impl PySessionContext {
         py: Python,
     ) -> PyDataFusionResult<PyDataFrame> {
         let mut options = ParquetReadOptions::default()
-            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
+            .table_partition_cols(
+                table_partition_cols
+                    .into_iter()
+                    .map(|(name, ty)| (name, ty.0))
+                    .collect::<Vec<(String, DataType)>>(),
+            )
             .parquet_pruning(parquet_pruning)
             .skip_metadata(skip_metadata);
         options.file_extension = file_extension;
@@ -1005,12 +1039,16 @@ impl PySessionContext {
         &self,
         path: &str,
         schema: Option<PyArrowType<Schema>>,
-        table_partition_cols: Vec<(String, String)>,
+        table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
         file_extension: &str,
         py: Python,
     ) -> PyDataFusionResult<PyDataFrame> {
-        let mut options = AvroReadOptions::default()
-            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+        let mut options = AvroReadOptions::default().table_partition_cols(
+            table_partition_cols
+                .into_iter()
+                .map(|(name, ty)| (name, ty.0))
+                .collect::<Vec<(String, DataType)>>(),
+        );
         options.file_extension = file_extension;
         let df = if let Some(schema) = schema {
             options.schema = Some(&schema.0);
@@ -1109,21 +1147,6 @@ impl PySessionContext {
     }
 }
 
-pub fn convert_table_partition_cols(
-    table_partition_cols: Vec<(String, String)>,
-) -> PyDataFusionResult<Vec<(String, DataType)>> {
-    table_partition_cols
-        .into_iter()
-        .map(|(name, ty)| match ty.as_str() {
-            "string" => Ok((name, DataType::Utf8)),
-            "int" => Ok((name, DataType::Int32)),
-            _ => Err(crate::errors::PyDataFusionError::Common(format!(
-                "Unsupported data type '{ty}' for partition column. Supported 
types are 'string' and 'int'"
-            ))),
-        })
-        .collect::<Result<Vec<_>, _>>()
-}
-
 pub fn parse_file_compression_type(
     file_compression_type: Option<String>,
 ) -> Result<FileCompressionType, PyErr> {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to