This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/main by this push:
     new 217ede8  feat: add register_json (#458)
217ede8 is described below

commit 217ede86961d3ffc356556ad75cbe2b514373048
Author: Daniel Mesejo <[email protected]>
AuthorDate: Thu Aug 17 19:19:22 2023 +0200

    feat: add register_json (#458)
---
 datafusion/tests/test_sql.py | 53 +++++++++++++++++++++++++++++++++++++++++++-
 src/context.rs               | 36 ++++++++++++++++++++++++++++++
 2 files changed, 88 insertions(+), 1 deletion(-)

diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py
index 608bb19..9d42a1f 100644
--- a/datafusion/tests/test_sql.py
+++ b/datafusion/tests/test_sql.py
@@ -14,12 +14,13 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import gzip
+import os
 
 import numpy as np
 import pyarrow as pa
 import pyarrow.dataset as ds
 import pytest
-import gzip
 
 from datafusion import udf
 
@@ -154,6 +155,56 @@ def test_register_dataset(ctx, tmp_path):
     assert result.to_pydict() == {"cnt": [100]}
 
 
+def test_register_json(ctx, tmp_path):
+    path = os.path.dirname(os.path.abspath(__file__))
+    test_data_path = os.path.join(path, "data_test_context", "data.json")
+    gzip_path = tmp_path / "data.json.gz"
+
+    with open(test_data_path, "rb") as json_file:
+        with gzip.open(gzip_path, "wb") as gzipped_file:
+            gzipped_file.writelines(json_file)
+
+    ctx.register_json("json", test_data_path)
+    ctx.register_json("json1", str(test_data_path))
+    ctx.register_json(
+        "json2",
+        test_data_path,
+        schema_infer_max_records=10,
+    )
+    ctx.register_json(
+        "json_gzip",
+        gzip_path,
+        file_extension="gz",
+        file_compression_type="gzip",
+    )
+
+    alternative_schema = pa.schema(
+        [
+            ("some_int", pa.int16()),
+            ("some_bytes", pa.string()),
+            ("some_floats", pa.float32()),
+        ]
+    )
+    ctx.register_json("json3", path, schema=alternative_schema)
+
+    assert ctx.tables() == {"json", "json1", "json2", "json3", "json_gzip"}
+
+    for table in ["json", "json1", "json2", "json_gzip"]:
+        result = ctx.sql(f'SELECT COUNT("B") AS cnt FROM {table}').collect()
+        result = pa.Table.from_batches(result)
+        assert result.to_pydict() == {"cnt": [3]}
+
+    result = ctx.sql("SELECT * FROM json3").collect()
+    result = pa.Table.from_batches(result)
+    assert result.schema == alternative_schema
+
+    with pytest.raises(
+        ValueError,
+        match="file_compression_type must one of: gzip, bz2, xz, zstd",
+    ):
+        ctx.register_json("json4", gzip_path, file_compression_type="rar")
+
+
 def test_execute(ctx, tmp_path):
     data = [1, 1, 2, 2, 3, 11, 12]
 
diff --git a/src/context.rs b/src/context.rs
index 1dca8a7..317ab78 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -509,6 +509,42 @@ impl PySessionContext {
         Ok(())
     }
 
+    #[allow(clippy::too_many_arguments)]
+    #[pyo3(signature = (name,
+                        path,
+                        schema=None,
+                        schema_infer_max_records=1000,
+                        file_extension=".json",
+                        table_partition_cols=vec![],
+                        file_compression_type=None))]
+    fn register_json(
+        &mut self,
+        name: &str,
+        path: PathBuf,
+        schema: Option<PyArrowType<Schema>>,
+        schema_infer_max_records: usize,
+        file_extension: &str,
+        table_partition_cols: Vec<(String, String)>,
+        file_compression_type: Option<String>,
+        py: Python,
+    ) -> PyResult<()> {
+        let path = path
+            .to_str()
+            .ok_or_else(|| PyValueError::new_err("Unable to convert path to a 
string"))?;
+
+        let mut options = NdJsonReadOptions::default()
+            
.file_compression_type(parse_file_compression_type(file_compression_type)?)
+            
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+        options.schema_infer_max_records = schema_infer_max_records;
+        options.file_extension = file_extension;
+        options.schema = schema.as_ref().map(|x| &x.0);
+
+        let result = self.ctx.register_json(name, path, options);
+        wait_for_future(py, result).map_err(DataFusionError::from)?;
+
+        Ok(())
+    }
+
     // Registers a PyArrow.Dataset
     fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) -> 
PyResult<()> {
         let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset, 
py)?);

Reply via email to