This is an automated email from the ASF dual-hosted git repository.
timsaucer pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-python.git
The following commit(s) were added to refs/heads/main by this push:
new 8b51390 feat: remove DataFusion pyarrow feat (#1000)
8b51390 is described below
commit 8b513906315a0749b9f5cd6f34bf259ab4dd1add
Author: Tim Saucer <[email protected]>
AuthorDate: Sat Feb 1 08:29:48 2025 -0500
feat: remove DataFusion pyarrow feat (#1000)
* Add developer instructions to speed up build processes
* Remove pyarrow dep from datafusion. Add in PyScalarValue wrapper and
rename DataFusionError to PyDataFusionError to be less confusing
* Removed unnecessary cloning of scalar value when going from rust to
python. Also removed the rust unit tests copied over from upstream repo that
were failing due to #941 in pyo3
* Change return types to PyDataFusionError to simplify code
* Update exception handling to fix build errors in recent rust toolchains
---
Cargo.lock | 145 +++++++++++++++----------
Cargo.toml | 2 +-
docs/source/contributor-guide/introduction.rst | 53 +++++++++
python/tests/test_indexing.py | 3 +-
src/catalog.rs | 8 +-
src/common/data_type.rs | 14 +++
src/config.rs | 11 +-
src/context.rs | 136 +++++++++++------------
src/dataframe.rs | 119 ++++++++++----------
src/dataset_exec.rs | 6 +-
src/errors.rs | 42 +++----
src/expr.rs | 38 +++----
src/expr/conditional_expr.rs | 6 +-
src/expr/literal.rs | 4 +-
src/expr/window.rs | 13 ++-
src/functions.rs | 55 ++++++----
src/lib.rs | 1 +
src/physical_plan.rs | 13 ++-
src/pyarrow_filter_expression.rs | 24 ++--
src/pyarrow_util.rs | 61 +++++++++++
src/record_batch.rs | 3 +-
src/sql/exceptions.rs | 16 +--
src/sql/logical.rs | 14 ++-
src/substrait.rs | 54 +++++----
src/udaf.rs | 21 +++-
src/udwf.rs | 4 +-
src/utils.rs | 6 +-
27 files changed, 524 insertions(+), 348 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 5a74a48..c6590fd 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -79,7 +79,7 @@ checksum =
"e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011"
dependencies = [
"cfg-if",
"const-random",
- "getrandom",
+ "getrandom 0.2.15",
"once_cell",
"version_check",
"zerocopy",
@@ -449,9 +449,9 @@ dependencies = [
[[package]]
name = "async-trait"
-version = "0.1.85"
+version = "0.1.86"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056"
+checksum = "644dd749086bf3771a2fbc5f256fdb982d53f011c7d5d560304eafeecebce79d"
dependencies = [
"proc-macro2",
"quote",
@@ -576,9 +576,9 @@ dependencies = [
[[package]]
name = "brotli-decompressor"
-version = "4.0.1"
+version = "4.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9a45bd2e4095a8b518033b128020dd4a55aab1c0a381ba4404a472630f4bc362"
+checksum = "74fa05ad7d803d413eb8380983b092cbbaf9a85f151b871360e7b00cd7060b37"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
@@ -586,9 +586,9 @@ dependencies = [
[[package]]
name = "bumpalo"
-version = "3.16.0"
+version = "3.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c"
+checksum = "1628fb46dfa0b37568d12e5edd512553eccf6a22a78e8bde00bb4aed84d5bdbf"
[[package]]
name = "byteorder"
@@ -635,9 +635,9 @@ dependencies = [
[[package]]
name = "cc"
-version = "1.2.10"
+version = "1.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "13208fcbb66eaeffe09b99fffbe1af420f00a7b35aa99ad683dfc1aa76145229"
+checksum = "e4730490333d58093109dc02c23174c3f4d490998c3fed3cc8e82d57afedb9cf"
dependencies = [
"jobserver",
"libc",
@@ -692,9 +692,9 @@ dependencies = [
[[package]]
name = "cmake"
-version = "0.1.52"
+version = "0.1.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "c682c223677e0e5b6b7f63a64b9351844c3f1b1678a68b7ee617e30fb082620e"
+checksum = "e24a03c8b52922d68a1589ad61032f2c1aa5a8158d2aa0d93c6e9534944bbad6"
dependencies = [
"cc",
]
@@ -725,7 +725,7 @@ version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e"
dependencies = [
- "getrandom",
+ "getrandom 0.2.15",
"once_cell",
"tiny-keccak",
]
@@ -784,9 +784,9 @@ checksum =
"69f3b219d28b6e3b4ac87bc1fc522e0803ab22e055da177bff0068c4150c61a6"
[[package]]
name = "cpufeatures"
-version = "0.2.16"
+version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "16b80225097f2e5ae4e7179dd2266824648f3e2f49d9134d584b76389d31c4c3"
+checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280"
dependencies = [
"libc",
]
@@ -817,9 +817,9 @@ checksum =
"d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28"
[[package]]
name = "crunchy"
-version = "0.2.2"
+version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
+checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929"
[[package]]
name = "crypto-common"
@@ -961,7 +961,6 @@ dependencies = [
"object_store",
"parquet",
"paste",
- "pyo3",
"recursive",
"sqlparser",
"tokio",
@@ -1411,9 +1410,9 @@ dependencies = [
[[package]]
name = "dyn-clone"
-version = "1.0.17"
+version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125"
+checksum = "feeef44e73baff3a26d371801df019877a9866a8c493d315ab00177843314f35"
[[package]]
name = "either"
@@ -1607,10 +1606,22 @@ dependencies = [
"cfg-if",
"js-sys",
"libc",
- "wasi",
+ "wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
+[[package]]
+name = "getrandom"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "43a49c392881ce6d5c3b8cb70f98717b7c07aabbdff06687b9030dbfbe2725f8"
+dependencies = [
+ "cfg-if",
+ "libc",
+ "wasi 0.13.3+wasi-0.2.2",
+ "windows-targets",
+]
+
[[package]]
name = "gimli"
version = "0.31.1"
@@ -1722,9 +1733,9 @@ dependencies = [
[[package]]
name = "httparse"
-version = "1.9.5"
+version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946"
+checksum = "f2d708df4e7140240a16cd6ab0ab65c972d7433ab77819ea693fde9c43811e2a"
[[package]]
name = "humantime"
@@ -1734,9 +1745,9 @@ checksum =
"9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "hyper"
-version = "1.5.2"
+version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0"
+checksum = "cc2b571658e38e0c01b1fdca3bbbe93c00d3d71693ff2770043f8c29bc7d6f80"
dependencies = [
"bytes",
"futures-channel",
@@ -1953,9 +1964,9 @@ dependencies = [
[[package]]
name = "indexmap"
-version = "2.7.0"
+version = "2.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f"
+checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652"
dependencies = [
"equivalent",
"hashbrown 0.15.2",
@@ -1975,9 +1986,9 @@ checksum =
"8bb03732005da905c88227371639bf1ad885cc712789c011c31c5fb3ab3ccf02"
[[package]]
name = "ipnet"
-version = "2.10.1"
+version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708"
+checksum = "469fb0b9cefa57e3ef31275ee7cacb78f2fdca44e4765491884a2b119d4eb130"
[[package]]
name = "itertools"
@@ -2243,7 +2254,7 @@ source =
"registry+https://github.com/rust-lang/crates.io-index"
checksum = "2886843bf800fba2e3377cff24abf6379b4c4d5c6681eaf9ea5b0d15090450bd"
dependencies = [
"libc",
- "wasi",
+ "wasi 0.11.0+wasi-snapshot-preview1",
"windows-sys 0.52.0",
]
@@ -2377,9 +2388,9 @@ checksum =
"1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775"
[[package]]
name = "openssl-probe"
-version = "0.1.5"
+version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
+checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e"
[[package]]
name = "ordered-float"
@@ -2661,9 +2672,9 @@ dependencies = [
[[package]]
name = "protobuf-src"
-version = "2.1.0+27.1"
+version = "2.1.1+27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a7edafa3bcc668fa93efafcbdf58d7821bbda0f4b458ac7fae3d57ec0fec8167"
+checksum = "6217c3504da19b85a3a4b2e9a5183d635822d83507ba0986624b5c05b83bfc40"
dependencies = [
"cmake",
]
@@ -2794,7 +2805,7 @@ source =
"registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2fe5ef3495d7d2e377ff17b1a8ce2ee2ec2a18cde8b6ad6619d65d0701c135d"
dependencies = [
"bytes",
- "getrandom",
+ "getrandom 0.2.15",
"rand",
"ring",
"rustc-hash",
@@ -2857,7 +2868,7 @@ version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
dependencies = [
- "getrandom",
+ "getrandom 0.2.15",
]
[[package]]
@@ -2926,9 +2937,9 @@ checksum =
"2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "regress"
-version = "0.10.2"
+version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4f56e622c2378013c6c61e2bd776604c46dc1087b2dc5293275a0c20a44f0771"
+checksum = "78ef7fa9ed0256d64a688a3747d0fef7a88851c18a5e1d57f115f38ec2e09366"
dependencies = [
"hashbrown 0.15.2",
"memchr",
@@ -2997,7 +3008,7 @@ checksum =
"c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
dependencies = [
"cc",
"cfg-if",
- "getrandom",
+ "getrandom 0.2.15",
"libc",
"spin",
"untrusted",
@@ -3033,9 +3044,9 @@ dependencies = [
[[package]]
name = "rustix"
-version = "0.38.43"
+version = "0.38.44"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6"
+checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154"
dependencies = [
"bitflags 2.8.0",
"errno",
@@ -3046,9 +3057,9 @@ dependencies = [
[[package]]
name = "rustls"
-version = "0.23.21"
+version = "0.23.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8f287924602bf649d949c63dc8ac8b235fa5387d394020705b80c4eb597ce5b8"
+checksum = "9fb9263ab4eb695e42321db096e3b8fbd715a59b154d5c88d82db2175b681ba7"
dependencies = [
"once_cell",
"ring",
@@ -3081,9 +3092,9 @@ dependencies = [
[[package]]
name = "rustls-pki-types"
-version = "1.10.1"
+version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d2bf47e6ff922db3825eb750c4e2ff784c6ff8fb9e13046ef6a1d1c5401b0b37"
+checksum = "917ce264624a4b4db1c364dcc35bfca9ded014d0a958cd47ad3e960e988ea51c"
dependencies = [
"web-time",
]
@@ -3107,9 +3118,9 @@ checksum =
"f7c45b9784283f1b2e7fb61b42047c2fd678ef0960d4f6f1eba131594cc369d4"
[[package]]
name = "ryu"
-version = "1.0.18"
+version = "1.0.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
+checksum = "6ea1a2d0a644769cc99faa24c3ad26b379b786fe7c36fd3c546254801650e6dd"
[[package]]
name = "same-file"
@@ -3184,9 +3195,9 @@ dependencies = [
[[package]]
name = "semver"
-version = "1.0.24"
+version = "1.0.25"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba"
+checksum = "f79dfe2d285b0488816f30e700a7438c5a73d816b5b7d3ac72fbc48b0d185e03"
dependencies = [
"serde",
]
@@ -3239,9 +3250,9 @@ dependencies = [
[[package]]
name = "serde_json"
-version = "1.0.136"
+version = "1.0.138"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "336a0c23cf42a38d9eaa7cd22c7040d04e1228a19a933890805ffd00a16437d2"
+checksum = "d434192e7da787e94a6ea7e9670b26a036d0ca41e0b7efb2676dd32bae872949"
dependencies = [
"itoa",
"memchr",
@@ -3514,13 +3525,13 @@ checksum =
"61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1"
[[package]]
name = "tempfile"
-version = "3.15.0"
+version = "3.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704"
+checksum = "38c246215d7d24f48ae091a2902398798e05d978b24315d6efbc00ede9a8bb91"
dependencies = [
"cfg-if",
"fastrand",
- "getrandom",
+ "getrandom 0.3.1",
"once_cell",
"rustix",
"windows-sys 0.59.0",
@@ -3831,9 +3842,9 @@ dependencies = [
[[package]]
name = "unicode-ident"
-version = "1.0.14"
+version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83"
+checksum = "a210d160f08b701c8721ba1c726c11662f877ea6b7094007e1ca9a1041945034"
[[package]]
name = "unicode-segmentation"
@@ -3890,11 +3901,11 @@ checksum =
"b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "uuid"
-version = "1.12.0"
+version = "1.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "744018581f9a3454a9e15beb8a33b017183f1e7c0cd170232a2d1453b23a51c4"
+checksum = "b3758f5e68192bb96cc8f9b7e2c2cfdabb435499a28499a42f8f984092adad4b"
dependencies = [
- "getrandom",
+ "getrandom 0.2.15",
"serde",
]
@@ -3929,6 +3940,15 @@ version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
+[[package]]
+name = "wasi"
+version = "0.13.3+wasi-0.2.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "26816d2e1a4a36a2940b96c5296ce403917633dff8f3440e9b236ed6f6bacad2"
+dependencies = [
+ "wit-bindgen-rt",
+]
+
[[package]]
name = "wasm-bindgen"
version = "0.2.100"
@@ -4185,6 +4205,15 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
+[[package]]
+name = "wit-bindgen-rt"
+version = "0.33.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3268f3d866458b787f390cf61f4bbb563b922d091359f9608842999eaee3943c"
+dependencies = [
+ "bitflags 2.8.0",
+]
+
[[package]]
name = "write16"
version = "1.0.0"
diff --git a/Cargo.toml b/Cargo.toml
index 10cffcc..003ba36 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -38,7 +38,7 @@ tokio = { version = "1.42", features = ["macros", "rt",
"rt-multi-thread", "sync
pyo3 = { version = "0.22", features = ["extension-module", "abi3",
"abi3-py38"] }
pyo3-async-runtimes = { version = "0.22", features = ["tokio-runtime"]}
arrow = { version = "53", features = ["pyarrow"] }
-datafusion = { version = "44.0.0", features = ["pyarrow", "avro",
"unicode_expressions"] }
+datafusion = { version = "44.0.0", features = ["avro", "unicode_expressions"] }
datafusion-substrait = { version = "44.0.0", optional = true }
datafusion-proto = { version = "44.0.0" }
datafusion-ffi = { version = "44.0.0" }
diff --git a/docs/source/contributor-guide/introduction.rst
b/docs/source/contributor-guide/introduction.rst
index fb98cfd..25f2c21 100644
--- a/docs/source/contributor-guide/introduction.rst
+++ b/docs/source/contributor-guide/introduction.rst
@@ -95,3 +95,56 @@ To update dependencies, run
.. code-block:: shell
uv sync --dev --no-install-package datafusion
+
+Improving Build Speed
+---------------------
+
+The `pyo3 <https://github.com/PyO3/pyo3>`_ dependency of this project contains
a ``build.rs`` file which
+can cause it to rebuild frequently. You can prevent this from happening by
defining a ``PYO3_CONFIG_FILE``
+environment variable that points to a file with your build configuration.
Whenever your build configuration
+changes, such as during some major version updates, you will need to
regenerate this file. This variable
+should point to a fully resolved path on your build machine.
+
+To generate this file, use the following command:
+
+.. code-block:: shell
+
+ PYO3_PRINT_CONFIG=1 cargo build
+
+This will generate some output that looks like the following. You will want to
copy these contents intro
+a file. If you place this file in your project directory with filename
``.pyo3_build_config`` it will
+be ignored by ``git``.
+
+.. code-block::
+
+ implementation=CPython
+ version=3.8
+ shared=true
+ abi3=true
+ lib_name=python3.12
+
lib_dir=/opt/homebrew/opt/[email protected]/Frameworks/Python.framework/Versions/3.12/lib
+ executable=/Users/myusername/src/datafusion-python/.venv/bin/python
+ pointer_width=64
+ build_flags=
+ suppress_build_script_link_lines=false
+
+Add the environment variable to your system.
+
+.. code-block:: shell
+
+ export
PYO3_CONFIG_FILE="/Users//myusername/src/datafusion-python/.pyo3_build_config"
+
+If you are on a Mac and you use VS Code for your IDE, you will want to add
these variables
+to your settings. You can find the appropriate rust flags by looking in the
+``.cargo/config.toml`` file.
+
+.. code-block::
+
+ "rust-analyzer.cargo.extraEnv": {
+ "RUSTFLAGS": "-C link-arg=-undefined -C link-arg=dynamic_lookup",
+ "PYO3_CONFIG_FILE":
"/Users/myusername/src/datafusion-python/.pyo3_build_config"
+ },
+ "rust-analyzer.runnables.extraEnv": {
+ "RUSTFLAGS": "-C link-arg=-undefined -C link-arg=dynamic_lookup",
+ "PYO3_CONFIG_FILE":
"/Users/myusername/src/personal/datafusion-python/.pyo3_build_config"
+ }
diff --git a/python/tests/test_indexing.py b/python/tests/test_indexing.py
index 5b0d086..327decd 100644
--- a/python/tests/test_indexing.py
+++ b/python/tests/test_indexing.py
@@ -43,7 +43,8 @@ def test_err(df):
with pytest.raises(Exception) as e_info:
df["c"]
- assert "Schema error: No field named c." in e_info.value.args[0]
+ for e in ["SchemaError", "FieldNotFound", 'name: "c"']:
+ assert e in e_info.value.args[0]
with pytest.raises(Exception) as e_info:
df[1]
diff --git a/src/catalog.rs b/src/catalog.rs
index 1ce66a4..1e189a5 100644
--- a/src/catalog.rs
+++ b/src/catalog.rs
@@ -21,7 +21,7 @@ use std::sync::Arc;
use pyo3::exceptions::PyKeyError;
use pyo3::prelude::*;
-use crate::errors::DataFusionError;
+use crate::errors::{PyDataFusionError, PyDataFusionResult};
use crate::utils::wait_for_future;
use datafusion::{
arrow::pyarrow::ToPyArrow,
@@ -96,11 +96,13 @@ impl PyDatabase {
self.database.table_names().into_iter().collect()
}
- fn table(&self, name: &str, py: Python) -> PyResult<PyTable> {
+ fn table(&self, name: &str, py: Python) -> PyDataFusionResult<PyTable> {
if let Some(table) = wait_for_future(py, self.database.table(name))? {
Ok(PyTable::new(table))
} else {
- Err(DataFusionError::Common(format!("Table not found:
{name}")).into())
+ Err(PyDataFusionError::Common(format!(
+ "Table not found: {name}"
+ )))
}
}
diff --git a/src/common/data_type.rs b/src/common/data_type.rs
index 7f9c75b..f5f8a6b 100644
--- a/src/common/data_type.rs
+++ b/src/common/data_type.rs
@@ -23,6 +23,20 @@ use pyo3::{exceptions::PyValueError, prelude::*};
use crate::errors::py_datafusion_err;
+#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd)]
+pub struct PyScalarValue(pub ScalarValue);
+
+impl From<ScalarValue> for PyScalarValue {
+ fn from(value: ScalarValue) -> Self {
+ Self(value)
+ }
+}
+impl From<PyScalarValue> for ScalarValue {
+ fn from(value: PyScalarValue) -> Self {
+ value.0
+ }
+}
+
#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[pyclass(eq, eq_int, name = "RexType", module = "datafusion.common")]
pub enum RexType {
diff --git a/src/config.rs b/src/config.rs
index 3f2a055..cc725b9 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -21,6 +21,8 @@ use pyo3::types::*;
use datafusion::common::ScalarValue;
use datafusion::config::ConfigOptions;
+use crate::errors::PyDataFusionResult;
+
#[pyclass(name = "Config", module = "datafusion", subclass)]
#[derive(Clone)]
pub(crate) struct PyConfig {
@@ -38,7 +40,7 @@ impl PyConfig {
/// Get configurations from environment variables
#[staticmethod]
- pub fn from_env() -> PyResult<Self> {
+ pub fn from_env() -> PyDataFusionResult<Self> {
Ok(Self {
config: ConfigOptions::from_env()?,
})
@@ -56,11 +58,10 @@ impl PyConfig {
}
/// Set a configuration option
- pub fn set(&mut self, key: &str, value: PyObject, py: Python) ->
PyResult<()> {
+ pub fn set(&mut self, key: &str, value: PyObject, py: Python) ->
PyDataFusionResult<()> {
let scalar_value = py_obj_to_scalar_value(py, value);
- self.config
- .set(key, scalar_value.to_string().as_str())
- .map_err(|e| e.into())
+ self.config.set(key, scalar_value.to_string().as_str())?;
+ Ok(())
}
/// Get all configuration options
diff --git a/src/context.rs b/src/context.rs
index bab7fd4..f53b155 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -28,16 +28,17 @@ use object_store::ObjectStore;
use url::Url;
use uuid::Uuid;
-use pyo3::exceptions::{PyKeyError, PyNotImplementedError, PyTypeError,
PyValueError};
+use pyo3::exceptions::{PyKeyError, PyValueError};
use pyo3::prelude::*;
use crate::catalog::{PyCatalog, PyTable};
use crate::dataframe::PyDataFrame;
use crate::dataset::Dataset;
-use crate::errors::{py_datafusion_err, DataFusionError};
+use crate::errors::{py_datafusion_err, PyDataFusionResult};
use crate::expr::sort_expr::PySortExpr;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
+use crate::sql::exceptions::py_value_err;
use crate::sql::logical::PyLogicalPlan;
use crate::store::StorageContexts;
use crate::udaf::PyAggregateUDF;
@@ -277,7 +278,7 @@ impl PySessionContext {
pub fn new(
config: Option<PySessionConfig>,
runtime: Option<PyRuntimeEnvBuilder>,
- ) -> PyResult<Self> {
+ ) -> PyDataFusionResult<Self> {
let config = if let Some(c) = config {
c.config
} else {
@@ -348,7 +349,7 @@ impl PySessionContext {
schema: Option<PyArrowType<Schema>>,
file_sort_order: Option<Vec<Vec<PySortExpr>>>,
py: Python,
- ) -> PyResult<()> {
+ ) -> PyDataFusionResult<()> {
let options = ListingOptions::new(Arc::new(ParquetFormat::new()))
.with_file_extension(file_extension)
.with_table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
@@ -365,7 +366,7 @@ impl PySessionContext {
None => {
let state = self.ctx.state();
let schema = options.infer_schema(&state, &table_path);
- wait_for_future(py, schema).map_err(DataFusionError::from)?
+ wait_for_future(py, schema)?
}
};
let config = ListingTableConfig::new(table_path)
@@ -382,9 +383,9 @@ impl PySessionContext {
}
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
- pub fn sql(&mut self, query: &str, py: Python) -> PyResult<PyDataFrame> {
+ pub fn sql(&mut self, query: &str, py: Python) ->
PyDataFusionResult<PyDataFrame> {
let result = self.ctx.sql(query);
- let df = wait_for_future(py, result).map_err(DataFusionError::from)?;
+ let df = wait_for_future(py, result)?;
Ok(PyDataFrame::new(df))
}
@@ -394,14 +395,14 @@ impl PySessionContext {
query: &str,
options: Option<PySQLOptions>,
py: Python,
- ) -> PyResult<PyDataFrame> {
+ ) -> PyDataFusionResult<PyDataFrame> {
let options = if let Some(options) = options {
options.options
} else {
SQLOptions::new()
};
let result = self.ctx.sql_with_options(query, options);
- let df = wait_for_future(py, result).map_err(DataFusionError::from)?;
+ let df = wait_for_future(py, result)?;
Ok(PyDataFrame::new(df))
}
@@ -412,14 +413,14 @@ impl PySessionContext {
name: Option<&str>,
schema: Option<PyArrowType<Schema>>,
py: Python,
- ) -> PyResult<PyDataFrame> {
+ ) -> PyDataFusionResult<PyDataFrame> {
let schema = if let Some(schema) = schema {
SchemaRef::from(schema.0)
} else {
partitions.0[0][0].schema()
};
- let table = MemTable::try_new(schema,
partitions.0).map_err(DataFusionError::from)?;
+ let table = MemTable::try_new(schema, partitions.0)?;
// generate a random (unique) name for this table if none is provided
// table name cannot start with numeric digit
@@ -433,11 +434,9 @@ impl PySessionContext {
}
};
- self.ctx
- .register_table(&*table_name, Arc::new(table))
- .map_err(DataFusionError::from)?;
+ self.ctx.register_table(&*table_name, Arc::new(table))?;
- let table = wait_for_future(py,
self._table(&table_name)).map_err(DataFusionError::from)?;
+ let table = wait_for_future(py, self._table(&table_name))?;
let df = PyDataFrame::new(table);
Ok(df)
@@ -495,15 +494,14 @@ impl PySessionContext {
data: Bound<'_, PyAny>,
name: Option<&str>,
py: Python,
- ) -> PyResult<PyDataFrame> {
+ ) -> PyDataFusionResult<PyDataFrame> {
let (schema, batches) =
if let Ok(stream_reader) =
ArrowArrayStreamReader::from_pyarrow_bound(&data) {
// Works for any object that implements __arrow_c_stream__ in
pycapsule.
let schema = stream_reader.schema().as_ref().to_owned();
let batches = stream_reader
- .collect::<Result<Vec<RecordBatch>,
arrow::error::ArrowError>>()
- .map_err(DataFusionError::from)?;
+ .collect::<Result<Vec<RecordBatch>,
arrow::error::ArrowError>>()?;
(schema, batches)
} else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) {
@@ -512,8 +510,8 @@ impl PySessionContext {
(array.schema().as_ref().to_owned(), vec![array])
} else {
- return Err(PyTypeError::new_err(
- "Expected either a Arrow Array or Arrow Stream in
from_arrow().",
+ return Err(crate::errors::PyDataFusionError::Common(
+ "Expected either a Arrow Array or Arrow Stream in
from_arrow().".to_string(),
));
};
@@ -559,17 +557,13 @@ impl PySessionContext {
Ok(df)
}
- pub fn register_table(&mut self, name: &str, table: &PyTable) ->
PyResult<()> {
- self.ctx
- .register_table(name, table.table())
- .map_err(DataFusionError::from)?;
+ pub fn register_table(&mut self, name: &str, table: &PyTable) ->
PyDataFusionResult<()> {
+ self.ctx.register_table(name, table.table())?;
Ok(())
}
- pub fn deregister_table(&mut self, name: &str) -> PyResult<()> {
- self.ctx
- .deregister_table(name)
- .map_err(DataFusionError::from)?;
+ pub fn deregister_table(&mut self, name: &str) -> PyDataFusionResult<()> {
+ self.ctx.deregister_table(name)?;
Ok(())
}
@@ -578,10 +572,10 @@ impl PySessionContext {
&mut self,
name: &str,
provider: Bound<'_, PyAny>,
- ) -> PyResult<()> {
+ ) -> PyDataFusionResult<()> {
if provider.hasattr("__datafusion_table_provider__")? {
let capsule =
provider.getattr("__datafusion_table_provider__")?.call0()?;
- let capsule = capsule.downcast::<PyCapsule>()?;
+ let capsule =
capsule.downcast::<PyCapsule>().map_err(py_datafusion_err)?;
validate_pycapsule(capsule, "datafusion_table_provider")?;
let provider = unsafe { capsule.reference::<FFI_TableProvider>() };
@@ -591,8 +585,9 @@ impl PySessionContext {
Ok(())
} else {
- Err(PyNotImplementedError::new_err(
- "__datafusion_table_provider__ does not exist on Table
Provider object.",
+ Err(crate::errors::PyDataFusionError::Common(
+ "__datafusion_table_provider__ does not exist on Table
Provider object."
+ .to_string(),
))
}
}
@@ -601,12 +596,10 @@ impl PySessionContext {
&mut self,
name: &str,
partitions: PyArrowType<Vec<Vec<RecordBatch>>>,
- ) -> PyResult<()> {
+ ) -> PyDataFusionResult<()> {
let schema = partitions.0[0][0].schema();
let table = MemTable::try_new(schema, partitions.0)?;
- self.ctx
- .register_table(name, Arc::new(table))
- .map_err(DataFusionError::from)?;
+ self.ctx.register_table(name, Arc::new(table))?;
Ok(())
}
@@ -628,7 +621,7 @@ impl PySessionContext {
schema: Option<PyArrowType<Schema>>,
file_sort_order: Option<Vec<Vec<PySortExpr>>>,
py: Python,
- ) -> PyResult<()> {
+ ) -> PyDataFusionResult<()> {
let mut options = ParquetReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.parquet_pruning(parquet_pruning)
@@ -642,7 +635,7 @@ impl PySessionContext {
.collect();
let result = self.ctx.register_parquet(name, path, options);
- wait_for_future(py, result).map_err(DataFusionError::from)?;
+ wait_for_future(py, result)?;
Ok(())
}
@@ -666,12 +659,12 @@ impl PySessionContext {
file_extension: &str,
file_compression_type: Option<String>,
py: Python,
- ) -> PyResult<()> {
+ ) -> PyDataFusionResult<()> {
let delimiter = delimiter.as_bytes();
if delimiter.len() != 1 {
- return Err(PyValueError::new_err(
+ return
Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
"Delimiter must be a single character",
- ));
+ )));
}
let mut options = CsvReadOptions::new()
@@ -685,11 +678,11 @@ impl PySessionContext {
if path.is_instance_of::<PyList>() {
let paths = path.extract::<Vec<String>>()?;
let result = self.register_csv_from_multiple_paths(name, paths,
options);
- wait_for_future(py, result).map_err(DataFusionError::from)?;
+ wait_for_future(py, result)?;
} else {
let path = path.extract::<String>()?;
let result = self.ctx.register_csv(name, &path, options);
- wait_for_future(py, result).map_err(DataFusionError::from)?;
+ wait_for_future(py, result)?;
}
Ok(())
@@ -713,7 +706,7 @@ impl PySessionContext {
table_partition_cols: Vec<(String, String)>,
file_compression_type: Option<String>,
py: Python,
- ) -> PyResult<()> {
+ ) -> PyDataFusionResult<()> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a
string"))?;
@@ -726,7 +719,7 @@ impl PySessionContext {
options.schema = schema.as_ref().map(|x| &x.0);
let result = self.ctx.register_json(name, path, options);
- wait_for_future(py, result).map_err(DataFusionError::from)?;
+ wait_for_future(py, result)?;
Ok(())
}
@@ -745,7 +738,7 @@ impl PySessionContext {
file_extension: &str,
table_partition_cols: Vec<(String, String)>,
py: Python,
- ) -> PyResult<()> {
+ ) -> PyDataFusionResult<()> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a
string"))?;
@@ -756,7 +749,7 @@ impl PySessionContext {
options.schema = schema.as_ref().map(|x| &x.0);
let result = self.ctx.register_avro(name, path, options);
- wait_for_future(py, result).map_err(DataFusionError::from)?;
+ wait_for_future(py, result)?;
Ok(())
}
@@ -767,12 +760,10 @@ impl PySessionContext {
name: &str,
dataset: &Bound<'_, PyAny>,
py: Python,
- ) -> PyResult<()> {
+ ) -> PyDataFusionResult<()> {
let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset,
py)?);
- self.ctx
- .register_table(name, table)
- .map_err(DataFusionError::from)?;
+ self.ctx.register_table(name, table)?;
Ok(())
}
@@ -824,11 +815,11 @@ impl PySessionContext {
Ok(PyDataFrame::new(x))
}
- pub fn table_exist(&self, name: &str) -> PyResult<bool> {
+ pub fn table_exist(&self, name: &str) -> PyDataFusionResult<bool> {
Ok(self.ctx.table_exist(name)?)
}
- pub fn empty_table(&self) -> PyResult<PyDataFrame> {
+ pub fn empty_table(&self) -> PyDataFusionResult<PyDataFrame> {
Ok(PyDataFrame::new(self.ctx.read_empty()?))
}
@@ -847,7 +838,7 @@ impl PySessionContext {
table_partition_cols: Vec<(String, String)>,
file_compression_type: Option<String>,
py: Python,
- ) -> PyResult<PyDataFrame> {
+ ) -> PyDataFusionResult<PyDataFrame> {
let path = path
.to_str()
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a
string"))?;
@@ -859,10 +850,10 @@ impl PySessionContext {
let df = if let Some(schema) = schema {
options.schema = Some(&schema.0);
let result = self.ctx.read_json(path, options);
- wait_for_future(py, result).map_err(DataFusionError::from)?
+ wait_for_future(py, result)?
} else {
let result = self.ctx.read_json(path, options);
- wait_for_future(py, result).map_err(DataFusionError::from)?
+ wait_for_future(py, result)?
};
Ok(PyDataFrame::new(df))
}
@@ -888,12 +879,12 @@ impl PySessionContext {
table_partition_cols: Vec<(String, String)>,
file_compression_type: Option<String>,
py: Python,
- ) -> PyResult<PyDataFrame> {
+ ) -> PyDataFusionResult<PyDataFrame> {
let delimiter = delimiter.as_bytes();
if delimiter.len() != 1 {
- return Err(PyValueError::new_err(
+ return
Err(crate::errors::PyDataFusionError::PythonError(py_value_err(
"Delimiter must be a single character",
- ));
+ )));
};
let mut options = CsvReadOptions::new()
@@ -909,12 +900,12 @@ impl PySessionContext {
let paths = path.extract::<Vec<String>>()?;
let paths = paths.iter().map(|p| p as &str).collect::<Vec<&str>>();
let result = self.ctx.read_csv(paths, options);
- let df = PyDataFrame::new(wait_for_future(py,
result).map_err(DataFusionError::from)?);
+ let df = PyDataFrame::new(wait_for_future(py, result)?);
Ok(df)
} else {
let path = path.extract::<String>()?;
let result = self.ctx.read_csv(path, options);
- let df = PyDataFrame::new(wait_for_future(py,
result).map_err(DataFusionError::from)?);
+ let df = PyDataFrame::new(wait_for_future(py, result)?);
Ok(df)
}
}
@@ -938,7 +929,7 @@ impl PySessionContext {
schema: Option<PyArrowType<Schema>>,
file_sort_order: Option<Vec<Vec<PySortExpr>>>,
py: Python,
- ) -> PyResult<PyDataFrame> {
+ ) -> PyDataFusionResult<PyDataFrame> {
let mut options = ParquetReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?)
.parquet_pruning(parquet_pruning)
@@ -952,7 +943,7 @@ impl PySessionContext {
.collect();
let result = self.ctx.read_parquet(path, options);
- let df = PyDataFrame::new(wait_for_future(py,
result).map_err(DataFusionError::from)?);
+ let df = PyDataFrame::new(wait_for_future(py, result)?);
Ok(df)
}
@@ -965,26 +956,23 @@ impl PySessionContext {
table_partition_cols: Vec<(String, String)>,
file_extension: &str,
py: Python,
- ) -> PyResult<PyDataFrame> {
+ ) -> PyDataFusionResult<PyDataFrame> {
let mut options = AvroReadOptions::default()
.table_partition_cols(convert_table_partition_cols(table_partition_cols)?);
options.file_extension = file_extension;
let df = if let Some(schema) = schema {
options.schema = Some(&schema.0);
let read_future = self.ctx.read_avro(path, options);
- wait_for_future(py, read_future).map_err(DataFusionError::from)?
+ wait_for_future(py, read_future)?
} else {
let read_future = self.ctx.read_avro(path, options);
- wait_for_future(py, read_future).map_err(DataFusionError::from)?
+ wait_for_future(py, read_future)?
};
Ok(PyDataFrame::new(df))
}
- pub fn read_table(&self, table: &PyTable) -> PyResult<PyDataFrame> {
- let df = self
- .ctx
- .read_table(table.table())
- .map_err(DataFusionError::from)?;
+ pub fn read_table(&self, table: &PyTable) ->
PyDataFusionResult<PyDataFrame> {
+ let df = self.ctx.read_table(table.table())?;
Ok(PyDataFrame::new(df))
}
@@ -1011,7 +999,7 @@ impl PySessionContext {
plan: PyExecutionPlan,
part: usize,
py: Python,
- ) -> PyResult<PyRecordBatchStream> {
+ ) -> PyDataFusionResult<PyRecordBatchStream> {
let ctx: TaskContext = TaskContext::from(&self.ctx.state());
// create a Tokio runtime to run the async code
let rt = &get_tokio_runtime().0;
@@ -1071,13 +1059,13 @@ impl PySessionContext {
pub fn convert_table_partition_cols(
table_partition_cols: Vec<(String, String)>,
-) -> Result<Vec<(String, DataType)>, DataFusionError> {
+) -> PyDataFusionResult<Vec<(String, DataType)>> {
table_partition_cols
.into_iter()
.map(|(name, ty)| match ty.as_str() {
"string" => Ok((name, DataType::Utf8)),
"int" => Ok((name, DataType::Int32)),
- _ => Err(DataFusionError::Common(format!(
+ _ => Err(crate::errors::PyDataFusionError::Common(format!(
"Unsupported data type '{ty}' for partition column. Supported
types are 'string' and 'int'"
))),
})
diff --git a/src/dataframe.rs b/src/dataframe.rs
index b875480..6fb08ba 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -33,20 +33,20 @@ use datafusion::dataframe::{DataFrame,
DataFrameWriteOptions};
use datafusion::execution::SendableRecordBatchStream;
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel,
ZstdLevel};
use datafusion::prelude::*;
-use pyo3::exceptions::{PyTypeError, PyValueError};
+use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyCapsule, PyTuple, PyTupleMethods};
use tokio::task::JoinHandle;
-use crate::errors::py_datafusion_err;
+use crate::errors::{py_datafusion_err, PyDataFusionError};
use crate::expr::sort_expr::to_sort_expressions;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::PyRecordBatchStream;
use crate::sql::logical::PyLogicalPlan;
use crate::utils::{get_tokio_runtime, validate_pycapsule, wait_for_future};
use crate::{
- errors::DataFusionError,
+ errors::PyDataFusionResult,
expr::{sort_expr::PySortExpr, PyExpr},
};
@@ -69,7 +69,7 @@ impl PyDataFrame {
#[pymethods]
impl PyDataFrame {
/// Enable selection for `df[col]`, `df[col1, col2, col3]`, and `df[[col1,
col2, col3]]`
- fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyResult<Self> {
+ fn __getitem__(&self, key: Bound<'_, PyAny>) -> PyDataFusionResult<Self> {
if let Ok(key) = key.extract::<PyBackedStr>() {
// df[col]
self.select_columns(vec![key])
@@ -84,12 +84,12 @@ impl PyDataFrame {
// df[[col1, col2, col3]]
self.select_columns(keys)
} else {
- let message = "DataFrame can only be indexed by string index or
indices";
- Err(PyTypeError::new_err(message))
+ let message = "DataFrame can only be indexed by string index or
indices".to_string();
+ Err(PyDataFusionError::Common(message))
}
}
- fn __repr__(&self, py: Python) -> PyResult<String> {
+ fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
let df = self.df.as_ref().clone().limit(0, Some(10))?;
let batches = wait_for_future(py, df.collect())?;
let batches_as_string = pretty::pretty_format_batches(&batches);
@@ -99,7 +99,7 @@ impl PyDataFrame {
}
}
- fn _repr_html_(&self, py: Python) -> PyResult<String> {
+ fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
let mut html_str = "<table border='1'>\n".to_string();
let df = self.df.as_ref().clone().limit(0, Some(10))?;
@@ -145,7 +145,7 @@ impl PyDataFrame {
}
/// Calculate summary statistics for a DataFrame
- fn describe(&self, py: Python) -> PyResult<Self> {
+ fn describe(&self, py: Python) -> PyDataFusionResult<Self> {
let df = self.df.as_ref().clone();
let stat_df = wait_for_future(py, df.describe())?;
Ok(Self::new(stat_df))
@@ -157,37 +157,37 @@ impl PyDataFrame {
}
#[pyo3(signature = (*args))]
- fn select_columns(&self, args: Vec<PyBackedStr>) -> PyResult<Self> {
+ fn select_columns(&self, args: Vec<PyBackedStr>) ->
PyDataFusionResult<Self> {
let args = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
let df = self.df.as_ref().clone().select_columns(&args)?;
Ok(Self::new(df))
}
#[pyo3(signature = (*args))]
- fn select(&self, args: Vec<PyExpr>) -> PyResult<Self> {
+ fn select(&self, args: Vec<PyExpr>) -> PyDataFusionResult<Self> {
let expr = args.into_iter().map(|e| e.into()).collect();
let df = self.df.as_ref().clone().select(expr)?;
Ok(Self::new(df))
}
#[pyo3(signature = (*args))]
- fn drop(&self, args: Vec<PyBackedStr>) -> PyResult<Self> {
+ fn drop(&self, args: Vec<PyBackedStr>) -> PyDataFusionResult<Self> {
let cols = args.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
let df = self.df.as_ref().clone().drop_columns(&cols)?;
Ok(Self::new(df))
}
- fn filter(&self, predicate: PyExpr) -> PyResult<Self> {
+ fn filter(&self, predicate: PyExpr) -> PyDataFusionResult<Self> {
let df = self.df.as_ref().clone().filter(predicate.into())?;
Ok(Self::new(df))
}
- fn with_column(&self, name: &str, expr: PyExpr) -> PyResult<Self> {
+ fn with_column(&self, name: &str, expr: PyExpr) ->
PyDataFusionResult<Self> {
let df = self.df.as_ref().clone().with_column(name, expr.into())?;
Ok(Self::new(df))
}
- fn with_columns(&self, exprs: Vec<PyExpr>) -> PyResult<Self> {
+ fn with_columns(&self, exprs: Vec<PyExpr>) -> PyDataFusionResult<Self> {
let mut df = self.df.as_ref().clone();
for expr in exprs {
let expr: Expr = expr.into();
@@ -199,7 +199,7 @@ impl PyDataFrame {
/// Rename one column by applying a new projection. This is a no-op if the
column to be
/// renamed does not exist.
- fn with_column_renamed(&self, old_name: &str, new_name: &str) ->
PyResult<Self> {
+ fn with_column_renamed(&self, old_name: &str, new_name: &str) ->
PyDataFusionResult<Self> {
let df = self
.df
.as_ref()
@@ -208,7 +208,7 @@ impl PyDataFrame {
Ok(Self::new(df))
}
- fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) ->
PyResult<Self> {
+ fn aggregate(&self, group_by: Vec<PyExpr>, aggs: Vec<PyExpr>) ->
PyDataFusionResult<Self> {
let group_by = group_by.into_iter().map(|e| e.into()).collect();
let aggs = aggs.into_iter().map(|e| e.into()).collect();
let df = self.df.as_ref().clone().aggregate(group_by, aggs)?;
@@ -216,14 +216,14 @@ impl PyDataFrame {
}
#[pyo3(signature = (*exprs))]
- fn sort(&self, exprs: Vec<PySortExpr>) -> PyResult<Self> {
+ fn sort(&self, exprs: Vec<PySortExpr>) -> PyDataFusionResult<Self> {
let exprs = to_sort_expressions(exprs);
let df = self.df.as_ref().clone().sort(exprs)?;
Ok(Self::new(df))
}
#[pyo3(signature = (count, offset=0))]
- fn limit(&self, count: usize, offset: usize) -> PyResult<Self> {
+ fn limit(&self, count: usize, offset: usize) -> PyDataFusionResult<Self> {
let df = self.df.as_ref().clone().limit(offset, Some(count))?;
Ok(Self::new(df))
}
@@ -232,14 +232,15 @@ impl PyDataFrame {
/// Unless some order is specified in the plan, there is no
/// guarantee of the order of the result.
fn collect(&self, py: Python) -> PyResult<Vec<PyObject>> {
- let batches = wait_for_future(py, self.df.as_ref().clone().collect())?;
+ let batches = wait_for_future(py, self.df.as_ref().clone().collect())
+ .map_err(PyDataFusionError::from)?;
// cannot use PyResult<Vec<RecordBatch>> return type due to
// https://github.com/PyO3/pyo3/issues/1813
batches.into_iter().map(|rb| rb.to_pyarrow(py)).collect()
}
/// Cache DataFrame.
- fn cache(&self, py: Python) -> PyResult<Self> {
+ fn cache(&self, py: Python) -> PyDataFusionResult<Self> {
let df = wait_for_future(py, self.df.as_ref().clone().cache())?;
Ok(Self::new(df))
}
@@ -247,7 +248,8 @@ impl PyDataFrame {
/// Executes this DataFrame and collects all results into a vector of
vector of RecordBatch
/// maintaining the input partitioning.
fn collect_partitioned(&self, py: Python) -> PyResult<Vec<Vec<PyObject>>> {
- let batches = wait_for_future(py,
self.df.as_ref().clone().collect_partitioned())?;
+ let batches = wait_for_future(py,
self.df.as_ref().clone().collect_partitioned())
+ .map_err(PyDataFusionError::from)?;
batches
.into_iter()
@@ -257,13 +259,13 @@ impl PyDataFrame {
/// Print the result, 20 lines by default
#[pyo3(signature = (num=20))]
- fn show(&self, py: Python, num: usize) -> PyResult<()> {
+ fn show(&self, py: Python, num: usize) -> PyDataFusionResult<()> {
let df = self.df.as_ref().clone().limit(0, Some(num))?;
print_dataframe(py, df)
}
/// Filter out duplicate rows
- fn distinct(&self) -> PyResult<Self> {
+ fn distinct(&self) -> PyDataFusionResult<Self> {
let df = self.df.as_ref().clone().distinct()?;
Ok(Self::new(df))
}
@@ -274,7 +276,7 @@ impl PyDataFrame {
how: &str,
left_on: Vec<PyBackedStr>,
right_on: Vec<PyBackedStr>,
- ) -> PyResult<Self> {
+ ) -> PyDataFusionResult<Self> {
let join_type = match how {
"inner" => JoinType::Inner,
"left" => JoinType::Left,
@@ -283,10 +285,9 @@ impl PyDataFrame {
"semi" => JoinType::LeftSemi,
"anti" => JoinType::LeftAnti,
how => {
- return Err(DataFusionError::Common(format!(
+ return Err(PyDataFusionError::Common(format!(
"The join type {how} does not exist or is not implemented"
- ))
- .into());
+ )));
}
};
@@ -303,7 +304,12 @@ impl PyDataFrame {
Ok(Self::new(df))
}
- fn join_on(&self, right: PyDataFrame, on_exprs: Vec<PyExpr>, how: &str) ->
PyResult<Self> {
+ fn join_on(
+ &self,
+ right: PyDataFrame,
+ on_exprs: Vec<PyExpr>,
+ how: &str,
+ ) -> PyDataFusionResult<Self> {
let join_type = match how {
"inner" => JoinType::Inner,
"left" => JoinType::Left,
@@ -312,10 +318,9 @@ impl PyDataFrame {
"semi" => JoinType::LeftSemi,
"anti" => JoinType::LeftAnti,
how => {
- return Err(DataFusionError::Common(format!(
+ return Err(PyDataFusionError::Common(format!(
"The join type {how} does not exist or is not implemented"
- ))
- .into());
+ )));
}
};
let exprs: Vec<Expr> = on_exprs.into_iter().map(|e|
e.into()).collect();
@@ -330,7 +335,7 @@ impl PyDataFrame {
/// Print the query plan
#[pyo3(signature = (verbose=false, analyze=false))]
- fn explain(&self, py: Python, verbose: bool, analyze: bool) ->
PyResult<()> {
+ fn explain(&self, py: Python, verbose: bool, analyze: bool) ->
PyDataFusionResult<()> {
let df = self.df.as_ref().clone().explain(verbose, analyze)?;
print_dataframe(py, df)
}
@@ -341,18 +346,18 @@ impl PyDataFrame {
}
/// Get the optimized logical plan for this `DataFrame`
- fn optimized_logical_plan(&self) -> PyResult<PyLogicalPlan> {
+ fn optimized_logical_plan(&self) -> PyDataFusionResult<PyLogicalPlan> {
Ok(self.df.as_ref().clone().into_optimized_plan()?.into())
}
/// Get the execution plan for this `DataFrame`
- fn execution_plan(&self, py: Python) -> PyResult<PyExecutionPlan> {
+ fn execution_plan(&self, py: Python) ->
PyDataFusionResult<PyExecutionPlan> {
let plan = wait_for_future(py,
self.df.as_ref().clone().create_physical_plan())?;
Ok(plan.into())
}
/// Repartition a `DataFrame` based on a logical partitioning scheme.
- fn repartition(&self, num: usize) -> PyResult<Self> {
+ fn repartition(&self, num: usize) -> PyDataFusionResult<Self> {
let new_df = self
.df
.as_ref()
@@ -363,7 +368,7 @@ impl PyDataFrame {
/// Repartition a `DataFrame` based on a logical partitioning scheme.
#[pyo3(signature = (*args, num))]
- fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) ->
PyResult<Self> {
+ fn repartition_by_hash(&self, args: Vec<PyExpr>, num: usize) ->
PyDataFusionResult<Self> {
let expr = args.into_iter().map(|py_expr| py_expr.into()).collect();
let new_df = self
.df
@@ -376,7 +381,7 @@ impl PyDataFrame {
/// Calculate the union of two `DataFrame`s, preserving duplicate rows.The
/// two `DataFrame`s must have exactly the same schema
#[pyo3(signature = (py_df, distinct=false))]
- fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyResult<Self> {
+ fn union(&self, py_df: PyDataFrame, distinct: bool) ->
PyDataFusionResult<Self> {
let new_df = if distinct {
self.df
.as_ref()
@@ -391,7 +396,7 @@ impl PyDataFrame {
/// Calculate the distinct union of two `DataFrame`s. The
/// two `DataFrame`s must have exactly the same schema
- fn union_distinct(&self, py_df: PyDataFrame) -> PyResult<Self> {
+ fn union_distinct(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
let new_df = self
.df
.as_ref()
@@ -401,7 +406,7 @@ impl PyDataFrame {
}
#[pyo3(signature = (column, preserve_nulls=true))]
- fn unnest_column(&self, column: &str, preserve_nulls: bool) ->
PyResult<Self> {
+ fn unnest_column(&self, column: &str, preserve_nulls: bool) ->
PyDataFusionResult<Self> {
// TODO: expose RecursionUnnestOptions
// REF: https://github.com/apache/datafusion/pull/11577
let unnest_options =
UnnestOptions::default().with_preserve_nulls(preserve_nulls);
@@ -414,7 +419,11 @@ impl PyDataFrame {
}
#[pyo3(signature = (columns, preserve_nulls=true))]
- fn unnest_columns(&self, columns: Vec<String>, preserve_nulls: bool) ->
PyResult<Self> {
+ fn unnest_columns(
+ &self,
+ columns: Vec<String>,
+ preserve_nulls: bool,
+ ) -> PyDataFusionResult<Self> {
// TODO: expose RecursionUnnestOptions
// REF: https://github.com/apache/datafusion/pull/11577
let unnest_options =
UnnestOptions::default().with_preserve_nulls(preserve_nulls);
@@ -428,7 +437,7 @@ impl PyDataFrame {
}
/// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s
must have exactly the same schema
- fn intersect(&self, py_df: PyDataFrame) -> PyResult<Self> {
+ fn intersect(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
let new_df = self
.df
.as_ref()
@@ -438,13 +447,13 @@ impl PyDataFrame {
}
/// Calculate the exception of two `DataFrame`s. The two `DataFrame`s
must have exactly the same schema
- fn except_all(&self, py_df: PyDataFrame) -> PyResult<Self> {
+ fn except_all(&self, py_df: PyDataFrame) -> PyDataFusionResult<Self> {
let new_df =
self.df.as_ref().clone().except(py_df.df.as_ref().clone())?;
Ok(Self::new(new_df))
}
/// Write a `DataFrame` to a CSV file.
- fn write_csv(&self, path: &str, with_header: bool, py: Python) ->
PyResult<()> {
+ fn write_csv(&self, path: &str, with_header: bool, py: Python) ->
PyDataFusionResult<()> {
let csv_options = CsvOptions {
has_header: Some(with_header),
..Default::default()
@@ -472,7 +481,7 @@ impl PyDataFrame {
compression: &str,
compression_level: Option<u32>,
py: Python,
- ) -> PyResult<()> {
+ ) -> PyDataFusionResult<()> {
fn verify_compression_level(cl: Option<u32>) -> Result<u32, PyErr> {
cl.ok_or(PyValueError::new_err("compression_level is not defined"))
}
@@ -496,7 +505,7 @@ impl PyDataFrame {
"lz4_raw" => Compression::LZ4_RAW,
"uncompressed" => Compression::UNCOMPRESSED,
_ => {
- return Err(PyValueError::new_err(format!(
+ return Err(PyDataFusionError::Common(format!(
"Unrecognized compression type {compression}"
)));
}
@@ -522,7 +531,7 @@ impl PyDataFrame {
}
/// Executes a query and writes the results to a partitioned JSON file.
- fn write_json(&self, path: &str, py: Python) -> PyResult<()> {
+ fn write_json(&self, path: &str, py: Python) -> PyDataFusionResult<()> {
wait_for_future(
py,
self.df
@@ -551,7 +560,7 @@ impl PyDataFrame {
&'py mut self,
py: Python<'py>,
requested_schema: Option<Bound<'py, PyCapsule>>,
- ) -> PyResult<Bound<'py, PyCapsule>> {
+ ) -> PyDataFusionResult<Bound<'py, PyCapsule>> {
let mut batches = wait_for_future(py,
self.df.as_ref().clone().collect())?;
let mut schema: Schema = self.df.schema().to_owned().into();
@@ -559,15 +568,14 @@ impl PyDataFrame {
validate_pycapsule(&schema_capsule, "arrow_schema")?;
let schema_ptr = unsafe {
schema_capsule.reference::<FFI_ArrowSchema>() };
- let desired_schema =
Schema::try_from(schema_ptr).map_err(DataFusionError::from)?;
+ let desired_schema = Schema::try_from(schema_ptr)?;
- schema = project_schema(schema,
desired_schema).map_err(DataFusionError::ArrowError)?;
+ schema = project_schema(schema, desired_schema)?;
batches = batches
.into_iter()
.map(|record_batch| record_batch_into_schema(record_batch,
&schema))
- .collect::<Result<Vec<RecordBatch>, ArrowError>>()
- .map_err(DataFusionError::ArrowError)?;
+ .collect::<Result<Vec<RecordBatch>, ArrowError>>()?;
}
let batches_wrapped = batches.into_iter().map(Ok);
@@ -578,9 +586,10 @@ impl PyDataFrame {
let ffi_stream = FFI_ArrowArrayStream::new(reader);
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
PyCapsule::new_bound(py, ffi_stream, Some(stream_capsule_name))
+ .map_err(PyDataFusionError::from)
}
- fn execute_stream(&self, py: Python) -> PyResult<PyRecordBatchStream> {
+ fn execute_stream(&self, py: Python) ->
PyDataFusionResult<PyRecordBatchStream> {
// create a Tokio runtime to run the async code
let rt = &get_tokio_runtime().0;
let df = self.df.as_ref().clone();
@@ -647,13 +656,13 @@ impl PyDataFrame {
}
// Executes this DataFrame to get the total number of rows.
- fn count(&self, py: Python) -> PyResult<usize> {
+ fn count(&self, py: Python) -> PyDataFusionResult<usize> {
Ok(wait_for_future(py, self.df.as_ref().clone().count())?)
}
}
/// Print DataFrame
-fn print_dataframe(py: Python, df: DataFrame) -> PyResult<()> {
+fn print_dataframe(py: Python, df: DataFrame) -> PyDataFusionResult<()> {
// Get string representation of record batches
let batches = wait_for_future(py, df.collect())?;
let batches_as_string = pretty::pretty_format_batches(&batches);
diff --git a/src/dataset_exec.rs b/src/dataset_exec.rs
index 9d25594..ace4211 100644
--- a/src/dataset_exec.rs
+++ b/src/dataset_exec.rs
@@ -42,7 +42,7 @@ use datafusion::physical_plan::{
SendableRecordBatchStream, Statistics,
};
-use crate::errors::DataFusionError;
+use crate::errors::PyDataFusionResult;
use crate::pyarrow_filter_expression::PyArrowFilterExpression;
struct PyArrowBatchesAdapter {
@@ -83,8 +83,8 @@ impl DatasetExec {
dataset: &Bound<'_, PyAny>,
projection: Option<Vec<usize>>,
filters: &[Expr],
- ) -> Result<Self, DataFusionError> {
- let columns: Option<Result<Vec<String>, DataFusionError>> =
projection.map(|p| {
+ ) -> PyDataFusionResult<Self> {
+ let columns: Option<PyDataFusionResult<Vec<String>>> =
projection.map(|p| {
p.iter()
.map(|index| {
let name: String = dataset
diff --git a/src/errors.rs b/src/errors.rs
index d12b6ad..b02b754 100644
--- a/src/errors.rs
+++ b/src/errors.rs
@@ -24,10 +24,10 @@ use datafusion::error::DataFusionError as
InnerDataFusionError;
use prost::EncodeError;
use pyo3::{exceptions::PyException, PyErr};
-pub type Result<T> = std::result::Result<T, DataFusionError>;
+pub type PyDataFusionResult<T> = std::result::Result<T, PyDataFusionError>;
#[derive(Debug)]
-pub enum DataFusionError {
+pub enum PyDataFusionError {
ExecutionError(InnerDataFusionError),
ArrowError(ArrowError),
Common(String),
@@ -35,46 +35,46 @@ pub enum DataFusionError {
EncodeError(EncodeError),
}
-impl fmt::Display for DataFusionError {
+impl fmt::Display for PyDataFusionError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
- DataFusionError::ExecutionError(e) => write!(f, "DataFusion error:
{e:?}"),
- DataFusionError::ArrowError(e) => write!(f, "Arrow error: {e:?}"),
- DataFusionError::PythonError(e) => write!(f, "Python error {e:?}"),
- DataFusionError::Common(e) => write!(f, "{e}"),
- DataFusionError::EncodeError(e) => write!(f, "Failed to encode
substrait plan: {e}"),
+ PyDataFusionError::ExecutionError(e) => write!(f, "DataFusion
error: {e:?}"),
+ PyDataFusionError::ArrowError(e) => write!(f, "Arrow error:
{e:?}"),
+ PyDataFusionError::PythonError(e) => write!(f, "Python error
{e:?}"),
+ PyDataFusionError::Common(e) => write!(f, "{e}"),
+ PyDataFusionError::EncodeError(e) => write!(f, "Failed to encode
substrait plan: {e}"),
}
}
}
-impl From<ArrowError> for DataFusionError {
- fn from(err: ArrowError) -> DataFusionError {
- DataFusionError::ArrowError(err)
+impl From<ArrowError> for PyDataFusionError {
+ fn from(err: ArrowError) -> PyDataFusionError {
+ PyDataFusionError::ArrowError(err)
}
}
-impl From<InnerDataFusionError> for DataFusionError {
- fn from(err: InnerDataFusionError) -> DataFusionError {
- DataFusionError::ExecutionError(err)
+impl From<InnerDataFusionError> for PyDataFusionError {
+ fn from(err: InnerDataFusionError) -> PyDataFusionError {
+ PyDataFusionError::ExecutionError(err)
}
}
-impl From<PyErr> for DataFusionError {
- fn from(err: PyErr) -> DataFusionError {
- DataFusionError::PythonError(err)
+impl From<PyErr> for PyDataFusionError {
+ fn from(err: PyErr) -> PyDataFusionError {
+ PyDataFusionError::PythonError(err)
}
}
-impl From<DataFusionError> for PyErr {
- fn from(err: DataFusionError) -> PyErr {
+impl From<PyDataFusionError> for PyErr {
+ fn from(err: PyDataFusionError) -> PyErr {
match err {
- DataFusionError::PythonError(py_err) => py_err,
+ PyDataFusionError::PythonError(py_err) => py_err,
_ => PyException::new_err(err.to_string()),
}
}
}
-impl Error for DataFusionError {}
+impl Error for PyDataFusionError {}
pub fn py_type_err(e: impl Debug) -> PyErr {
PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!("{e:?}"))
diff --git a/src/expr.rs b/src/expr.rs
index bca0cd3..1e9983d 100644
--- a/src/expr.rs
+++ b/src/expr.rs
@@ -24,7 +24,6 @@ use std::convert::{From, Into};
use std::sync::Arc;
use window::PyWindowFrame;
-use arrow::pyarrow::ToPyArrow;
use datafusion::arrow::datatypes::{DataType, Field};
use datafusion::arrow::pyarrow::PyArrowType;
use datafusion::functions::core::expr_ext::FieldAccessor;
@@ -33,15 +32,17 @@ use datafusion::logical_expr::{
expr::{AggregateFunction, InList, InSubquery, ScalarFunction,
WindowFunction},
lit, Between, BinaryExpr, Case, Cast, Expr, Like, Operator, TryCast,
};
-use datafusion::scalar::ScalarValue;
-use crate::common::data_type::{DataTypeMap, NullTreatment, RexType};
-use crate::errors::{py_runtime_err, py_type_err, py_unsupported_variant_err,
DataFusionError};
+use crate::common::data_type::{DataTypeMap, NullTreatment, PyScalarValue,
RexType};
+use crate::errors::{
+ py_runtime_err, py_type_err, py_unsupported_variant_err,
PyDataFusionError, PyDataFusionResult,
+};
use crate::expr::aggregate_expr::PyAggregateFunction;
use crate::expr::binary_expr::PyBinaryExpr;
use crate::expr::column::PyColumn;
use crate::expr::literal::PyLiteral;
use crate::functions::add_builder_fns_to_window;
+use crate::pyarrow_util::scalar_to_pyarrow;
use crate::sql::logical::PyLogicalPlan;
use self::alias::PyAlias;
@@ -261,8 +262,8 @@ impl PyExpr {
}
#[staticmethod]
- pub fn literal(value: ScalarValue) -> PyExpr {
- lit(value).into()
+ pub fn literal(value: PyScalarValue) -> PyExpr {
+ lit(value.0).into()
}
#[staticmethod]
@@ -356,7 +357,7 @@ impl PyExpr {
/// Extracts the Expr value into a PyObject that can be shared with Python
pub fn python_value(&self, py: Python) -> PyResult<PyObject> {
match &self.expr {
- Expr::Literal(scalar_value) => Ok(scalar_value.to_pyarrow(py)?),
+ Expr::Literal(scalar_value) => scalar_to_pyarrow(scalar_value, py),
_ => Err(py_type_err(format!(
"Non Expr::Literal encountered in types: {:?}",
&self.expr
@@ -568,7 +569,7 @@ impl PyExpr {
window_frame: Option<PyWindowFrame>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>,
- ) -> PyResult<PyExpr> {
+ ) -> PyDataFusionResult<PyExpr> {
match &self.expr {
Expr::AggregateFunction(agg_fn) => {
let window_fn = Expr::WindowFunction(WindowFunction::new(
@@ -592,10 +593,9 @@ impl PyExpr {
null_treatment,
),
_ => Err(
-
DataFusionError::ExecutionError(datafusion::error::DataFusionError::Plan(
+
PyDataFusionError::ExecutionError(datafusion::error::DataFusionError::Plan(
format!("Using {} with `over` is not allowed. Must use an
aggregate or window function.", self.expr.variant_name()),
))
- .into(),
),
}
}
@@ -649,34 +649,26 @@ impl PyExprFuncBuilder {
.into()
}
- pub fn build(&self) -> PyResult<PyExpr> {
- self.builder
- .clone()
- .build()
- .map(|expr| expr.into())
- .map_err(|err| err.into())
+ pub fn build(&self) -> PyDataFusionResult<PyExpr> {
+ Ok(self.builder.clone().build().map(|expr| expr.into())?)
}
}
impl PyExpr {
- pub fn _column_name(&self, plan: &LogicalPlan) -> Result<String,
DataFusionError> {
+ pub fn _column_name(&self, plan: &LogicalPlan) ->
PyDataFusionResult<String> {
let field = Self::expr_to_field(&self.expr, plan)?;
Ok(field.name().to_owned())
}
/// Create a [Field] representing an [Expr], given an input [LogicalPlan]
to resolve against
- pub fn expr_to_field(
- expr: &Expr,
- input_plan: &LogicalPlan,
- ) -> Result<Arc<Field>, DataFusionError> {
+ pub fn expr_to_field(expr: &Expr, input_plan: &LogicalPlan) ->
PyDataFusionResult<Arc<Field>> {
match expr {
Expr::Wildcard { .. } => {
// Since * could be any of the valid column names just return
the first one
Ok(Arc::new(input_plan.schema().field(0).clone()))
}
_ => {
- let fields =
- exprlist_to_fields(&[expr.clone()],
input_plan).map_err(PyErr::from)?;
+ let fields = exprlist_to_fields(&[expr.clone()], input_plan)?;
Ok(fields[0].1.clone())
}
}
diff --git a/src/expr/conditional_expr.rs b/src/expr/conditional_expr.rs
index a8a885c..fe3af2e 100644
--- a/src/expr/conditional_expr.rs
+++ b/src/expr/conditional_expr.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use crate::expr::PyExpr;
+use crate::{errors::PyDataFusionResult, expr::PyExpr};
use datafusion::logical_expr::conditional_expressions::CaseBuilder;
use pyo3::prelude::*;
@@ -44,11 +44,11 @@ impl PyCaseBuilder {
}
}
- fn otherwise(&mut self, else_expr: PyExpr) -> PyResult<PyExpr> {
+ fn otherwise(&mut self, else_expr: PyExpr) -> PyDataFusionResult<PyExpr> {
Ok(self.case_builder.otherwise(else_expr.expr)?.clone().into())
}
- fn end(&mut self) -> PyResult<PyExpr> {
+ fn end(&mut self) -> PyDataFusionResult<PyExpr> {
Ok(self.case_builder.end()?.clone().into())
}
}
diff --git a/src/expr/literal.rs b/src/expr/literal.rs
index 43084ba..2cb2079 100644
--- a/src/expr/literal.rs
+++ b/src/expr/literal.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use crate::errors::DataFusionError;
+use crate::errors::PyDataFusionError;
use datafusion::common::ScalarValue;
use pyo3::prelude::*;
@@ -154,5 +154,5 @@ impl PyLiteral {
}
fn unexpected_literal_value(value: &ScalarValue) -> PyErr {
- DataFusionError::Common(format!("getValue<T>() - Unexpected value:
{value}")).into()
+ PyDataFusionError::Common(format!("getValue<T>() - Unexpected value:
{value}")).into()
}
diff --git a/src/expr/window.rs b/src/expr/window.rs
index 6486dbb..4dc6cb9 100644
--- a/src/expr/window.rs
+++ b/src/expr/window.rs
@@ -21,8 +21,9 @@ use datafusion::logical_expr::{Expr, Window, WindowFrame,
WindowFrameBound, Wind
use pyo3::prelude::*;
use std::fmt::{self, Display, Formatter};
+use crate::common::data_type::PyScalarValue;
use crate::common::df_schema::PyDFSchema;
-use crate::errors::py_type_err;
+use crate::errors::{py_type_err, PyDataFusionResult};
use crate::expr::logical_node::LogicalNode;
use crate::expr::sort_expr::{py_sort_expr_list, PySortExpr};
use crate::expr::PyExpr;
@@ -171,8 +172,8 @@ impl PyWindowFrame {
#[pyo3(signature=(unit, start_bound, end_bound))]
pub fn new(
unit: &str,
- start_bound: Option<ScalarValue>,
- end_bound: Option<ScalarValue>,
+ start_bound: Option<PyScalarValue>,
+ end_bound: Option<PyScalarValue>,
) -> PyResult<Self> {
let units = unit.to_ascii_lowercase();
let units = match units.as_str() {
@@ -187,7 +188,7 @@ impl PyWindowFrame {
}
};
let start_bound = match start_bound {
- Some(start_bound) => WindowFrameBound::Preceding(start_bound),
+ Some(start_bound) => WindowFrameBound::Preceding(start_bound.0),
None => match units {
WindowFrameUnits::Range =>
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
WindowFrameUnits::Rows =>
WindowFrameBound::Preceding(ScalarValue::UInt64(None)),
@@ -200,7 +201,7 @@ impl PyWindowFrame {
},
};
let end_bound = match end_bound {
- Some(end_bound) => WindowFrameBound::Following(end_bound),
+ Some(end_bound) => WindowFrameBound::Following(end_bound.0),
None => match units {
WindowFrameUnits::Rows =>
WindowFrameBound::Following(ScalarValue::UInt64(None)),
WindowFrameUnits::Range =>
WindowFrameBound::Following(ScalarValue::UInt64(None)),
@@ -253,7 +254,7 @@ impl PyWindowFrameBound {
matches!(self.frame_bound, WindowFrameBound::Following(_))
}
/// Returns the offset of the window frame
- pub fn get_offset(&self) -> PyResult<Option<u64>> {
+ pub fn get_offset(&self) -> PyDataFusionResult<Option<u64>> {
match &self.frame_bound {
WindowFrameBound::Preceding(val) |
WindowFrameBound::Following(val) => match val {
x if x.is_null() => Ok(None),
diff --git a/src/functions.rs b/src/functions.rs
index ae032d7..46c748c 100644
--- a/src/functions.rs
+++ b/src/functions.rs
@@ -22,8 +22,10 @@ use datafusion::logical_expr::WindowFrame;
use pyo3::{prelude::*, wrap_pyfunction};
use crate::common::data_type::NullTreatment;
+use crate::common::data_type::PyScalarValue;
use crate::context::PySessionContext;
-use crate::errors::DataFusionError;
+use crate::errors::PyDataFusionError;
+use crate::errors::PyDataFusionResult;
use crate::expr::conditional_expr::PyCaseBuilder;
use crate::expr::sort_expr::to_sort_expressions;
use crate::expr::sort_expr::PySortExpr;
@@ -44,7 +46,7 @@ fn add_builder_fns_to_aggregate(
filter: Option<PyExpr>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
// Since ExprFuncBuilder::new() is private, we can guarantee initializing
// a builder with an `null_treatment` with option None
let mut builder = agg_fn.null_treatment(None);
@@ -228,7 +230,10 @@ fn when(when: PyExpr, then: PyExpr) ->
PyResult<PyCaseBuilder> {
/// 1) If no function has been found, search default aggregate functions.
///
/// NOTE: we search the built-ins first because the `UDAF` versions currently
do not have the same behavior.
-fn find_window_fn(name: &str, ctx: Option<PySessionContext>) ->
PyResult<WindowFunctionDefinition> {
+fn find_window_fn(
+ name: &str,
+ ctx: Option<PySessionContext>,
+) -> PyDataFusionResult<WindowFunctionDefinition> {
if let Some(ctx) = ctx {
// search UDAFs
let udaf = ctx
@@ -284,7 +289,9 @@ fn find_window_fn(name: &str, ctx:
Option<PySessionContext>) -> PyResult<WindowF
return Ok(window_fn);
}
- Err(DataFusionError::Common(format!("window function `{name}` not
found")).into())
+ Err(PyDataFusionError::Common(format!(
+ "window function `{name}` not found"
+ )))
}
/// Creates a new Window function expression
@@ -341,7 +348,7 @@ macro_rules! aggregate_function {
filter: Option<PyExpr>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>
- ) -> PyResult<PyExpr> {
+ ) -> PyDataFusionResult<PyExpr> {
let agg_fn = functions_aggregate::expr_fn::$NAME($($arg.into()),*);
add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by,
null_treatment)
@@ -362,7 +369,7 @@ macro_rules! aggregate_function_vec_args {
filter: Option<PyExpr>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>
- ) -> PyResult<PyExpr> {
+ ) -> PyDataFusionResult<PyExpr> {
let agg_fn =
functions_aggregate::expr_fn::$NAME(vec![$($arg.into()),*]);
add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by,
null_treatment)
@@ -642,7 +649,7 @@ pub fn approx_percentile_cont(
percentile: f64,
num_centroids: Option<i64>, // enforces optional arguments at the end,
currently
filter: Option<PyExpr>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
let args = if let Some(num_centroids) = num_centroids {
vec![expression.expr, lit(percentile), lit(num_centroids)]
} else {
@@ -661,7 +668,7 @@ pub fn approx_percentile_cont_with_weight(
weight: PyExpr,
percentile: f64,
filter: Option<PyExpr>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
let agg_fn =
functions_aggregate::expr_fn::approx_percentile_cont_with_weight(
expression.expr,
weight.expr,
@@ -683,7 +690,7 @@ pub fn first_value(
filter: Option<PyExpr>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
// If we initialize the UDAF with order_by directly, then it gets
over-written by the builder
let agg_fn = functions_aggregate::expr_fn::first_value(expr.expr, None);
@@ -700,7 +707,7 @@ pub fn nth_value(
filter: Option<PyExpr>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
let agg_fn =
datafusion::functions_aggregate::nth_value::nth_value(expr.expr, n, vec![]);
add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by,
null_treatment)
}
@@ -715,7 +722,7 @@ pub fn string_agg(
filter: Option<PyExpr>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
let agg_fn =
datafusion::functions_aggregate::string_agg::string_agg(expr.expr,
lit(delimiter));
add_builder_fns_to_aggregate(agg_fn, distinct, filter, order_by,
null_treatment)
}
@@ -726,7 +733,7 @@ pub(crate) fn add_builder_fns_to_window(
window_frame: Option<PyWindowFrame>,
order_by: Option<Vec<PySortExpr>>,
null_treatment: Option<NullTreatment>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
let null_treatment = null_treatment.map(|n| n.into());
let mut builder = window_fn.null_treatment(null_treatment);
@@ -748,7 +755,7 @@ pub(crate) fn add_builder_fns_to_window(
builder = builder.window_frame(window_frame.into());
}
- builder.build().map(|e| e.into()).map_err(|err| err.into())
+ Ok(builder.build().map(|e| e.into())?)
}
#[pyfunction]
@@ -756,10 +763,11 @@ pub(crate) fn add_builder_fns_to_window(
pub fn lead(
arg: PyExpr,
shift_offset: i64,
- default_value: Option<ScalarValue>,
+ default_value: Option<PyScalarValue>,
partition_by: Option<Vec<PyExpr>>,
order_by: Option<Vec<PySortExpr>>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
+ let default_value = default_value.map(|v| v.into());
let window_fn = functions_window::expr_fn::lead(arg.expr,
Some(shift_offset), default_value);
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
@@ -770,10 +778,11 @@ pub fn lead(
pub fn lag(
arg: PyExpr,
shift_offset: i64,
- default_value: Option<ScalarValue>,
+ default_value: Option<PyScalarValue>,
partition_by: Option<Vec<PyExpr>>,
order_by: Option<Vec<PySortExpr>>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
+ let default_value = default_value.map(|v| v.into());
let window_fn = functions_window::expr_fn::lag(arg.expr,
Some(shift_offset), default_value);
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
@@ -784,7 +793,7 @@ pub fn lag(
pub fn row_number(
partition_by: Option<Vec<PyExpr>>,
order_by: Option<Vec<PySortExpr>>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
let window_fn = functions_window::expr_fn::row_number();
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
@@ -795,7 +804,7 @@ pub fn row_number(
pub fn rank(
partition_by: Option<Vec<PyExpr>>,
order_by: Option<Vec<PySortExpr>>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
let window_fn = functions_window::expr_fn::rank();
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
@@ -806,7 +815,7 @@ pub fn rank(
pub fn dense_rank(
partition_by: Option<Vec<PyExpr>>,
order_by: Option<Vec<PySortExpr>>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
let window_fn = functions_window::expr_fn::dense_rank();
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
@@ -817,7 +826,7 @@ pub fn dense_rank(
pub fn percent_rank(
partition_by: Option<Vec<PyExpr>>,
order_by: Option<Vec<PySortExpr>>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
let window_fn = functions_window::expr_fn::percent_rank();
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
@@ -828,7 +837,7 @@ pub fn percent_rank(
pub fn cume_dist(
partition_by: Option<Vec<PyExpr>>,
order_by: Option<Vec<PySortExpr>>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
let window_fn = functions_window::expr_fn::cume_dist();
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
@@ -840,7 +849,7 @@ pub fn ntile(
arg: PyExpr,
partition_by: Option<Vec<PyExpr>>,
order_by: Option<Vec<PySortExpr>>,
-) -> PyResult<PyExpr> {
+) -> PyDataFusionResult<PyExpr> {
let window_fn = functions_window::expr_fn::ntile(arg.into());
add_builder_fns_to_window(window_fn, partition_by, None, order_by, None)
diff --git a/src/lib.rs b/src/lib.rs
index 1111d5d..317c3a4 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -48,6 +48,7 @@ pub mod expr;
mod functions;
pub mod physical_plan;
mod pyarrow_filter_expression;
+pub mod pyarrow_util;
mod record_batch;
pub mod sql;
pub mod store;
diff --git a/src/physical_plan.rs b/src/physical_plan.rs
index 9ef2f0e..295908d 100644
--- a/src/physical_plan.rs
+++ b/src/physical_plan.rs
@@ -22,7 +22,7 @@ use std::sync::Arc;
use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyBytes};
-use crate::{context::PySessionContext, errors::DataFusionError};
+use crate::{context::PySessionContext, errors::PyDataFusionResult};
#[pyclass(name = "ExecutionPlan", module = "datafusion", subclass)]
#[derive(Debug, Clone)]
@@ -58,7 +58,7 @@ impl PyExecutionPlan {
format!("{}", d.indent(false))
}
- pub fn to_proto<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py,
PyBytes>> {
+ pub fn to_proto<'py>(&'py self, py: Python<'py>) ->
PyDataFusionResult<Bound<'py, PyBytes>> {
let codec = DefaultPhysicalExtensionCodec {};
let proto =
datafusion_proto::protobuf::PhysicalPlanNode::try_from_physical_plan(
self.plan.clone(),
@@ -70,7 +70,10 @@ impl PyExecutionPlan {
}
#[staticmethod]
- pub fn from_proto(ctx: PySessionContext, proto_msg: Bound<'_, PyBytes>) ->
PyResult<Self> {
+ pub fn from_proto(
+ ctx: PySessionContext,
+ proto_msg: Bound<'_, PyBytes>,
+ ) -> PyDataFusionResult<Self> {
let bytes: &[u8] = proto_msg.extract()?;
let proto_plan =
datafusion_proto::protobuf::PhysicalPlanNode::decode(bytes).map_err(|e| {
@@ -81,9 +84,7 @@ impl PyExecutionPlan {
})?;
let codec = DefaultPhysicalExtensionCodec {};
- let plan = proto_plan
- .try_into_physical_plan(&ctx.ctx, &ctx.ctx.runtime_env(), &codec)
- .map_err(DataFusionError::from)?;
+ let plan = proto_plan.try_into_physical_plan(&ctx.ctx,
&ctx.ctx.runtime_env(), &codec)?;
Ok(Self::new(plan))
}
diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs
index 0f97ea4..314eebf 100644
--- a/src/pyarrow_filter_expression.rs
+++ b/src/pyarrow_filter_expression.rs
@@ -21,11 +21,11 @@ use pyo3::prelude::*;
use std::convert::TryFrom;
use std::result::Result;
-use arrow::pyarrow::ToPyArrow;
use datafusion::common::{Column, ScalarValue};
use datafusion::logical_expr::{expr::InList, Between, BinaryExpr, Expr,
Operator};
-use crate::errors::DataFusionError;
+use crate::errors::{PyDataFusionError, PyDataFusionResult};
+use crate::pyarrow_util::scalar_to_pyarrow;
#[derive(Debug)]
#[repr(transparent)]
@@ -34,7 +34,7 @@ pub(crate) struct PyArrowFilterExpression(PyObject);
fn operator_to_py<'py>(
operator: &Operator,
op: &Bound<'py, PyModule>,
-) -> Result<Bound<'py, PyAny>, DataFusionError> {
+) -> PyDataFusionResult<Bound<'py, PyAny>> {
let py_op: Bound<'_, PyAny> = match operator {
Operator::Eq => op.getattr("eq")?,
Operator::NotEq => op.getattr("ne")?,
@@ -45,7 +45,7 @@ fn operator_to_py<'py>(
Operator::And => op.getattr("and_")?,
Operator::Or => op.getattr("or_")?,
_ => {
- return Err(DataFusionError::Common(format!(
+ return Err(PyDataFusionError::Common(format!(
"Unsupported operator {operator:?}"
)))
}
@@ -53,8 +53,8 @@ fn operator_to_py<'py>(
Ok(py_op)
}
-fn extract_scalar_list(exprs: &[Expr], py: Python) -> Result<Vec<PyObject>,
DataFusionError> {
- let ret: Result<Vec<PyObject>, DataFusionError> = exprs
+fn extract_scalar_list(exprs: &[Expr], py: Python) ->
PyDataFusionResult<Vec<PyObject>> {
+ let ret = exprs
.iter()
.map(|expr| match expr {
// TODO: should we also leverage `ScalarValue::to_pyarrow` here?
@@ -71,11 +71,11 @@ fn extract_scalar_list(exprs: &[Expr], py: Python) ->
Result<Vec<PyObject>, Data
ScalarValue::Float32(Some(f)) => Ok(f.into_py(py)),
ScalarValue::Float64(Some(f)) => Ok(f.into_py(py)),
ScalarValue::Utf8(Some(s)) => Ok(s.into_py(py)),
- _ => Err(DataFusionError::Common(format!(
+ _ => Err(PyDataFusionError::Common(format!(
"PyArrow can't handle ScalarValue: {v:?}"
))),
},
- _ => Err(DataFusionError::Common(format!(
+ _ => Err(PyDataFusionError::Common(format!(
"Only a list of Literals are supported got {expr:?}"
))),
})
@@ -90,7 +90,7 @@ impl PyArrowFilterExpression {
}
impl TryFrom<&Expr> for PyArrowFilterExpression {
- type Error = DataFusionError;
+ type Error = PyDataFusionError;
// Converts a Datafusion filter Expr into an expression string that can be
evaluated by Python
// Note that pyarrow.compute.{field,scalar} are put into Python globals()
when evaluated
@@ -100,9 +100,9 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
Python::with_gil(|py| {
let pc = Python::import_bound(py, "pyarrow.compute")?;
let op_module = Python::import_bound(py, "operator")?;
- let pc_expr: Result<Bound<'_, PyAny>, DataFusionError> = match
expr {
+ let pc_expr: PyDataFusionResult<Bound<'_, PyAny>> = match expr {
Expr::Column(Column { name, .. }) =>
Ok(pc.getattr("field")?.call1((name,))?),
- Expr::Literal(scalar) =>
Ok(scalar.to_pyarrow(py)?.into_bound(py)),
+ Expr::Literal(scalar) => Ok(scalar_to_pyarrow(scalar,
py)?.into_bound(py)),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let operator = operator_to_py(op, &op_module)?;
let left =
PyArrowFilterExpression::try_from(left.as_ref())?.0;
@@ -167,7 +167,7 @@ impl TryFrom<&Expr> for PyArrowFilterExpression {
Ok(if *negated { invert.call1((ret,))? } else { ret })
}
- _ => Err(DataFusionError::Common(format!(
+ _ => Err(PyDataFusionError::Common(format!(
"Unsupported Datafusion expression {expr:?}"
))),
};
diff --git a/src/pyarrow_util.rs b/src/pyarrow_util.rs
new file mode 100644
index 0000000..2b31467
--- /dev/null
+++ b/src/pyarrow_util.rs
@@ -0,0 +1,61 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Conversions between PyArrow and DataFusion types
+
+use arrow::array::{Array, ArrayData};
+use arrow::pyarrow::{FromPyArrow, ToPyArrow};
+use datafusion::scalar::ScalarValue;
+use pyo3::types::{PyAnyMethods, PyList};
+use pyo3::{Bound, FromPyObject, PyAny, PyObject, PyResult, Python};
+
+use crate::common::data_type::PyScalarValue;
+use crate::errors::PyDataFusionError;
+
+impl FromPyArrow for PyScalarValue {
+ fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult<Self> {
+ let py = value.py();
+ let typ = value.getattr("type")?;
+ let val = value.call_method0("as_py")?;
+
+ // construct pyarrow array from the python value and pyarrow type
+ let factory = py.import_bound("pyarrow")?.getattr("array")?;
+ let args = PyList::new_bound(py, [val]);
+ let array = factory.call1((args, typ))?;
+
+ // convert the pyarrow array to rust array using C data interface
+ let array =
arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?);
+ let scalar = ScalarValue::try_from_array(&array,
0).map_err(PyDataFusionError::from)?;
+
+ Ok(PyScalarValue(scalar))
+ }
+}
+
+impl<'source> FromPyObject<'source> for PyScalarValue {
+ fn extract_bound(value: &Bound<'source, PyAny>) -> PyResult<Self> {
+ Self::from_pyarrow_bound(value)
+ }
+}
+
+pub fn scalar_to_pyarrow(scalar: &ScalarValue, py: Python) ->
PyResult<PyObject> {
+ let array = scalar.to_array().map_err(PyDataFusionError::from)?;
+ // convert to pyarrow array using C data interface
+ let pyarray = array.to_data().to_pyarrow(py)?;
+ let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?;
+
+ Ok(pyscalar)
+}
diff --git a/src/record_batch.rs b/src/record_batch.rs
index eacdb58..ec61c26 100644
--- a/src/record_batch.rs
+++ b/src/record_batch.rs
@@ -17,6 +17,7 @@
use std::sync::Arc;
+use crate::errors::PyDataFusionError;
use crate::utils::wait_for_future;
use datafusion::arrow::pyarrow::ToPyArrow;
use datafusion::arrow::record_batch::RecordBatch;
@@ -90,7 +91,7 @@ async fn next_stream(
let mut stream = stream.lock().await;
match stream.next().await {
Some(Ok(batch)) => Ok(batch.into()),
- Some(Err(e)) => Err(e.into()),
+ Some(Err(e)) => Err(PyDataFusionError::from(e))?,
None => {
// Depending on whether the iteration is sync or not, we raise
either a
// StopIteration or a StopAsyncIteration
diff --git a/src/sql/exceptions.rs b/src/sql/exceptions.rs
index c458402..cfb0227 100644
--- a/src/sql/exceptions.rs
+++ b/src/sql/exceptions.rs
@@ -17,13 +17,7 @@
use std::fmt::{Debug, Display};
-use pyo3::{create_exception, PyErr};
-
-// Identifies exceptions that occur while attempting to generate a
`LogicalPlan` from a SQL string
-create_exception!(rust, ParsingException, pyo3::exceptions::PyException);
-
-// Identifies exceptions that occur during attempts to optimization an
existing `LogicalPlan`
-create_exception!(rust, OptimizationException, pyo3::exceptions::PyException);
+use pyo3::PyErr;
pub fn py_type_err(e: impl Debug + Display) -> PyErr {
PyErr::new::<pyo3::exceptions::PyTypeError, _>(format!("{e}"))
@@ -33,10 +27,6 @@ pub fn py_runtime_err(e: impl Debug + Display) -> PyErr {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("{e}"))
}
-pub fn py_parsing_exp(e: impl Debug + Display) -> PyErr {
- PyErr::new::<ParsingException, _>(format!("{e}"))
-}
-
-pub fn py_optimization_exp(e: impl Debug + Display) -> PyErr {
- PyErr::new::<OptimizationException, _>(format!("{e}"))
+pub fn py_value_err(e: impl Debug + Display) -> PyErr {
+ PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("{e}"))
}
diff --git a/src/sql/logical.rs b/src/sql/logical.rs
index a541889..1be33b7 100644
--- a/src/sql/logical.rs
+++ b/src/sql/logical.rs
@@ -17,6 +17,7 @@
use std::sync::Arc;
+use crate::errors::PyDataFusionResult;
use crate::expr::aggregate::PyAggregate;
use crate::expr::analyze::PyAnalyze;
use crate::expr::distinct::PyDistinct;
@@ -34,7 +35,7 @@ use crate::expr::table_scan::PyTableScan;
use crate::expr::unnest::PyUnnest;
use crate::expr::window::PyWindowExpr;
use crate::{context::PySessionContext, errors::py_unsupported_variant_err};
-use datafusion::{error::DataFusionError, logical_expr::LogicalPlan};
+use datafusion::logical_expr::LogicalPlan;
use datafusion_proto::logical_plan::{AsLogicalPlan,
DefaultLogicalExtensionCodec};
use prost::Message;
use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyBytes};
@@ -125,7 +126,7 @@ impl PyLogicalPlan {
format!("{}", self.plan.display_graphviz())
}
- pub fn to_proto<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py,
PyBytes>> {
+ pub fn to_proto<'py>(&'py self, py: Python<'py>) ->
PyDataFusionResult<Bound<'py, PyBytes>> {
let codec = DefaultLogicalExtensionCodec {};
let proto =
datafusion_proto::protobuf::LogicalPlanNode::try_from_logical_plan(&self.plan,
&codec)?;
@@ -135,7 +136,10 @@ impl PyLogicalPlan {
}
#[staticmethod]
- pub fn from_proto(ctx: PySessionContext, proto_msg: Bound<'_, PyBytes>) ->
PyResult<Self> {
+ pub fn from_proto(
+ ctx: PySessionContext,
+ proto_msg: Bound<'_, PyBytes>,
+ ) -> PyDataFusionResult<Self> {
let bytes: &[u8] = proto_msg.extract()?;
let proto_plan =
datafusion_proto::protobuf::LogicalPlanNode::decode(bytes).map_err(|e| {
@@ -146,9 +150,7 @@ impl PyLogicalPlan {
})?;
let codec = DefaultLogicalExtensionCodec {};
- let plan = proto_plan
- .try_into_logical_plan(&ctx.ctx, &codec)
- .map_err(DataFusionError::from)?;
+ let plan = proto_plan.try_into_logical_plan(&ctx.ctx, &codec)?;
Ok(Self::new(plan))
}
}
diff --git a/src/substrait.rs b/src/substrait.rs
index 16e8c95..8dcf3e8 100644
--- a/src/substrait.rs
+++ b/src/substrait.rs
@@ -18,7 +18,7 @@
use pyo3::{prelude::*, types::PyBytes};
use crate::context::PySessionContext;
-use crate::errors::{py_datafusion_err, DataFusionError};
+use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
use crate::sql::logical::PyLogicalPlan;
use crate::utils::wait_for_future;
@@ -39,7 +39,7 @@ impl PyPlan {
let mut proto_bytes = Vec::<u8>::new();
self.plan
.encode(&mut proto_bytes)
- .map_err(DataFusionError::EncodeError)?;
+ .map_err(PyDataFusionError::EncodeError)?;
Ok(PyBytes::new_bound(py, &proto_bytes).unbind().into())
}
}
@@ -66,41 +66,47 @@ pub struct PySubstraitSerializer;
#[pymethods]
impl PySubstraitSerializer {
#[staticmethod]
- pub fn serialize(sql: &str, ctx: PySessionContext, path: &str, py: Python)
-> PyResult<()> {
- wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path))
- .map_err(DataFusionError::from)?;
+ pub fn serialize(
+ sql: &str,
+ ctx: PySessionContext,
+ path: &str,
+ py: Python,
+ ) -> PyDataFusionResult<()> {
+ wait_for_future(py, serializer::serialize(sql, &ctx.ctx, path))?;
Ok(())
}
#[staticmethod]
- pub fn serialize_to_plan(sql: &str, ctx: PySessionContext, py: Python) ->
PyResult<PyPlan> {
- match PySubstraitSerializer::serialize_bytes(sql, ctx, py) {
- Ok(proto_bytes) => {
- let proto_bytes =
proto_bytes.bind(py).downcast::<PyBytes>().unwrap();
-
PySubstraitSerializer::deserialize_bytes(proto_bytes.as_bytes().to_vec(), py)
- }
- Err(e) => Err(py_datafusion_err(e)),
- }
+ pub fn serialize_to_plan(
+ sql: &str,
+ ctx: PySessionContext,
+ py: Python,
+ ) -> PyDataFusionResult<PyPlan> {
+ PySubstraitSerializer::serialize_bytes(sql, ctx,
py).and_then(|proto_bytes| {
+ let proto_bytes =
proto_bytes.bind(py).downcast::<PyBytes>().unwrap();
+
PySubstraitSerializer::deserialize_bytes(proto_bytes.as_bytes().to_vec(), py)
+ })
}
#[staticmethod]
- pub fn serialize_bytes(sql: &str, ctx: PySessionContext, py: Python) ->
PyResult<PyObject> {
- let proto_bytes: Vec<u8> = wait_for_future(py,
serializer::serialize_bytes(sql, &ctx.ctx))
- .map_err(DataFusionError::from)?;
+ pub fn serialize_bytes(
+ sql: &str,
+ ctx: PySessionContext,
+ py: Python,
+ ) -> PyDataFusionResult<PyObject> {
+ let proto_bytes: Vec<u8> = wait_for_future(py,
serializer::serialize_bytes(sql, &ctx.ctx))?;
Ok(PyBytes::new_bound(py, &proto_bytes).unbind().into())
}
#[staticmethod]
- pub fn deserialize(path: &str, py: Python) -> PyResult<PyPlan> {
- let plan =
- wait_for_future(py,
serializer::deserialize(path)).map_err(DataFusionError::from)?;
+ pub fn deserialize(path: &str, py: Python) -> PyDataFusionResult<PyPlan> {
+ let plan = wait_for_future(py, serializer::deserialize(path))?;
Ok(PyPlan { plan: *plan })
}
#[staticmethod]
- pub fn deserialize_bytes(proto_bytes: Vec<u8>, py: Python) ->
PyResult<PyPlan> {
- let plan = wait_for_future(py,
serializer::deserialize_bytes(proto_bytes))
- .map_err(DataFusionError::from)?;
+ pub fn deserialize_bytes(proto_bytes: Vec<u8>, py: Python) ->
PyDataFusionResult<PyPlan> {
+ let plan = wait_for_future(py,
serializer::deserialize_bytes(proto_bytes))?;
Ok(PyPlan { plan: *plan })
}
}
@@ -134,10 +140,10 @@ impl PySubstraitConsumer {
ctx: &mut PySessionContext,
plan: PyPlan,
py: Python,
- ) -> PyResult<PyLogicalPlan> {
+ ) -> PyDataFusionResult<PyLogicalPlan> {
let session_state = ctx.ctx.state();
let result = consumer::from_substrait_plan(&session_state, &plan.plan);
- let logical_plan = wait_for_future(py,
result).map_err(DataFusionError::from)?;
+ let logical_plan = wait_for_future(py, result)?;
Ok(PyLogicalPlan::new(logical_plan))
}
}
diff --git a/src/udaf.rs b/src/udaf.rs
index a6aa59a..5f21533 100644
--- a/src/udaf.rs
+++ b/src/udaf.rs
@@ -28,6 +28,7 @@ use datafusion::logical_expr::{
create_udaf, Accumulator, AccumulatorFactoryFunction, AggregateUDF,
};
+use crate::common::data_type::PyScalarValue;
use crate::expr::PyExpr;
use crate::utils::parse_volatility;
@@ -44,13 +45,25 @@ impl RustAccumulator {
impl Accumulator for RustAccumulator {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
- Python::with_gil(|py|
self.accum.bind(py).call_method0("state")?.extract())
- .map_err(|e| DataFusionError::Execution(format!("{e}")))
+ Python::with_gil(|py| {
+ self.accum
+ .bind(py)
+ .call_method0("state")?
+ .extract::<Vec<PyScalarValue>>()
+ })
+ .map(|v| v.into_iter().map(|x| x.0).collect())
+ .map_err(|e| DataFusionError::Execution(format!("{e}")))
}
fn evaluate(&mut self) -> Result<ScalarValue> {
- Python::with_gil(|py|
self.accum.bind(py).call_method0("evaluate")?.extract())
- .map_err(|e| DataFusionError::Execution(format!("{e}")))
+ Python::with_gil(|py| {
+ self.accum
+ .bind(py)
+ .call_method0("evaluate")?
+ .extract::<PyScalarValue>()
+ })
+ .map(|v| v.0)
+ .map_err(|e| DataFusionError::Execution(format!("{e}")))
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
diff --git a/src/udwf.rs b/src/udwf.rs
index 689eb79..04a4a16 100644
--- a/src/udwf.rs
+++ b/src/udwf.rs
@@ -26,6 +26,7 @@ use datafusion::scalar::ScalarValue;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
+use crate::common::data_type::PyScalarValue;
use crate::expr::PyExpr;
use crate::utils::parse_volatility;
use datafusion::arrow::datatypes::DataType;
@@ -133,7 +134,8 @@ impl PartitionEvaluator for RustPartitionEvaluator {
self.evaluator
.bind(py)
.call_method1("evaluate", py_args)
- .and_then(|v| v.extract())
+ .and_then(|v| v.extract::<PyScalarValue>())
+ .map(|v| v.0)
.map_err(|e| DataFusionError::Execution(format!("{e}")))
})
}
diff --git a/src/utils.rs b/src/utils.rs
index 7955897..ed224b3 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-use crate::errors::DataFusionError;
+use crate::errors::{PyDataFusionError, PyDataFusionResult};
use crate::TokioRuntime;
use datafusion::logical_expr::Volatility;
use pyo3::exceptions::PyValueError;
@@ -47,13 +47,13 @@ where
py.allow_threads(|| runtime.block_on(f))
}
-pub(crate) fn parse_volatility(value: &str) -> Result<Volatility,
DataFusionError> {
+pub(crate) fn parse_volatility(value: &str) -> PyDataFusionResult<Volatility> {
Ok(match value {
"immutable" => Volatility::Immutable,
"stable" => Volatility::Stable,
"volatile" => Volatility::Volatile,
value => {
- return Err(DataFusionError::Common(format!(
+ return Err(PyDataFusionError::Common(format!(
"Unsupportad volatility type: `{value}`, supported \
values are: immutable, stable and volatile."
)))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]