This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 217ede8 feat: add register_json (#458)
217ede8 is described below
commit 217ede86961d3ffc356556ad75cbe2b514373048
Author: Daniel Mesejo <[email protected]>
AuthorDate: Thu Aug 17 19:19:22 2023 +0200
feat: add register_json (#458)
---
datafusion/tests/test_sql.py | 53 +++++++++++++++++++++++++++++++++++++++++++-
src/context.rs | 36 ++++++++++++++++++++++++++++++
2 files changed, 88 insertions(+), 1 deletion(-)
diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py
index 608bb19..9d42a1f 100644
--- a/datafusion/tests/test_sql.py
+++ b/datafusion/tests/test_sql.py
@@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import gzip
+import os
import numpy as np
import pyarrow as pa
import pyarrow.dataset as ds
import pytest
-import gzip
from datafusion import udf
@@ -154,6 +155,56 @@ def test_register_dataset(ctx, tmp_path):
assert result.to_pydict() == {"cnt": [100]}
+def test_register_json(ctx, tmp_path):
+ path = os.path.dirname(os.path.abspath(__file__))
+ test_data_path = os.path.join(path, "data_test_context", "data.json")
+ gzip_path = tmp_path / "data.json.gz"
+
+ with open(test_data_path, "rb") as json_file:
+ with gzip.open(gzip_path, "wb") as gzipped_file:
+ gzipped_file.writelines(json_file)
+
+ ctx.register_json("json", test_data_path)
+ ctx.register_json("json1", str(test_data_path))
+ ctx.register_json(
+ "json2",
+ test_data_path,
+ schema_infer_max_records=10,
+ )
+ ctx.register_json(
+ "json_gzip",
+ gzip_path,
+ file_extension="gz",
+ file_compression_type="gzip",
+ )
+
+ alternative_schema = pa.schema(
+ [
+ ("some_int", pa.int16()),
+ ("some_bytes", pa.string()),
+ ("some_floats", pa.float32()),
+ ]
+ )
+ ctx.register_json("json3", path, schema=alternative_schema)
+
+ assert ctx.tables() == {"json", "json1", "json2", "json3", "json_gzip"}
+
+ for table in ["json", "json1", "json2", "json_gzip"]:
+ result = ctx.sql(f'SELECT COUNT("B") AS cnt FROM {table}').collect()
+ result = pa.Table.from_batches(result)
+ assert result.to_pydict() == {"cnt": [3]}
+
+ result = ctx.sql("SELECT * FROM json3").collect()
+ result = pa.Table.from_batches(result)
+ assert result.schema == alternative_schema
+
+ with pytest.raises(
+ ValueError,
+ match="file_compression_type must one of: gzip, bz2, xz, zstd",
+ ):
+ ctx.register_json("json4", gzip_path, file_compression_type="rar")
+
+
def test_execute(ctx, tmp_path):
data = [1, 1, 2, 2, 3, 11, 12]
diff --git a/src/context.rs b/src/context.rs
index 1dca8a7..317ab78 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -509,6 +509,42 @@ impl PySessionContext {
Ok(())
}
+ #[allow(clippy::too_many_arguments)]
+ #[pyo3(signature = (name,
+ path,
+ schema=None,
+ schema_infer_max_records=1000,
+ file_extension=".json",
+ table_partition_cols=vec![],
+ file_compression_type=None))]
+ fn register_json(
+ &mut self,
+ name: &str,
+ path: PathBuf,
+ schema: Option<PyArrowType<Schema>>,
+ schema_infer_max_records: usize,
+ file_extension: &str,
+ table_partition_cols: Vec<(String, String)>,
+ file_compression_type: Option<String>,
+ py: Python,
+ ) -> PyResult<()> {
+ let path = path
+ .to_str()
+ .ok_or_else(|| PyValueError::new_err("Unable to convert path to a
string"))?;
+
+ let mut options = NdJsonReadOptions::default()
+
.file_compression_type(parse_file_compression_type(file_compression_type)?)
+
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
+ options.schema_infer_max_records = schema_infer_max_records;
+ options.file_extension = file_extension;
+ options.schema = schema.as_ref().map(|x| &x.0);
+
+ let result = self.ctx.register_json(name, path, options);
+ wait_for_future(py, result).map_err(DataFusionError::from)?;
+
+ Ok(())
+ }
+
// Registers a PyArrow.Dataset
fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) ->
PyResult<()> {
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset,
py)?);