This is an automated email from the ASF dual-hosted git repository.
amolina pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/main by this push:
new e0ab40d379 GH-41692: [Python] Improve substrait extended expressions
support (#41693)
e0ab40d379 is described below
commit e0ab40d3793b1f795ff8cdfa61b96c727bd88df0
Author: Alessandro Molina <[email protected]>
AuthorDate: Wed Oct 16 14:58:37 2024 +0200
GH-41692: [Python] Improve substrait extended expressions support (#41693)
Addresses some missing features and usability issues when using PyArrow
with Substrait ExtendedExpressions
* GitHub Issue: #41692
- [x] Allow passing `BoundExpressions` for `Scanner(columns=X)` instead
of a dict of expressions.
- [x] Allow passing `BoundExpressions` for `Scanner(filter=X)` so that
user doesn't have to distinguish between `Expression` and
`BoundExpressions` and can always just use
`pyarrow.substrait.deserialize_expressions`
- [x] Allow decoding `pyarrow.BoundExpressions` directly from
`protobuf.Message`, thus allowing to use substrait-python objects.
- [x] Return `memoryview` from methods encoding substrait, so that those
can be directly passed to substrait-python (or more in general other
python libraries) without a copy being involved.
- [x] Allow decoding messages from `memoryview` so that the output of
encoding functions can be sent back to dencoding functions.
- [x] Allow to encode and decode schemas from substrait
- [x] When encoding schemas return the extension types required for a
substrait consumer to decode the schema
- [x] Handle arrow extension types when decoding a schema
- [x] Update docstrings and documentation
---------
Co-authored-by: Raúl Cumplido <[email protected]>
---
cpp/src/arrow/engine/substrait/serde.cc | 2 +-
docs/source/python/api/substrait.rst | 3 +
docs/source/python/integration.rst | 1 +
docs/source/python/integration/substrait.rst | 249 +++++++++++++++++++++++++
python/pyarrow/_compute.pyx | 6 +-
python/pyarrow/_dataset.pyx | 19 +-
python/pyarrow/_substrait.pyx | 138 +++++++++++++-
python/pyarrow/includes/common.pxd | 9 +
python/pyarrow/includes/libarrow_substrait.pxd | 23 +++
python/pyarrow/substrait.py | 5 +-
python/pyarrow/tests/test_dataset.py | 34 ++++
python/pyarrow/tests/test_substrait.py | 43 ++++-
12 files changed, 521 insertions(+), 11 deletions(-)
diff --git a/cpp/src/arrow/engine/substrait/serde.cc
b/cpp/src/arrow/engine/substrait/serde.cc
index 16d2ace4ac..6b4c05a3b1 100644
--- a/cpp/src/arrow/engine/substrait/serde.cc
+++ b/cpp/src/arrow/engine/substrait/serde.cc
@@ -56,7 +56,7 @@ Status ParseFromBufferImpl(const Buffer& buf, const
std::string& full_name,
if (message->ParseFromZeroCopyStream(&buf_stream)) {
return Status::OK();
}
- return Status::IOError("ParseFromZeroCopyStream failed for ", full_name);
+ return Status::Invalid("ParseFromZeroCopyStream failed for ", full_name);
}
template <typename Message>
diff --git a/docs/source/python/api/substrait.rst
b/docs/source/python/api/substrait.rst
index 1556be9dbd..26c70216a8 100644
--- a/docs/source/python/api/substrait.rst
+++ b/docs/source/python/api/substrait.rst
@@ -43,6 +43,9 @@ compute expressions.
BoundExpressions
deserialize_expressions
serialize_expressions
+ serialize_schema
+ deserialize_schema
+ SubstraitSchema
Utility
-------
diff --git a/docs/source/python/integration.rst
b/docs/source/python/integration.rst
index 1cafc3dbde..95c912c187 100644
--- a/docs/source/python/integration.rst
+++ b/docs/source/python/integration.rst
@@ -34,6 +34,7 @@ This allows to easily integrate PyArrow with other languages
and technologies.
.. toctree::
:maxdepth: 2
+ integration/substrait
integration/python_r
integration/python_java
integration/extending
diff --git a/docs/source/python/integration/substrait.rst
b/docs/source/python/integration/substrait.rst
new file mode 100644
index 0000000000..eaa6151e4d
--- /dev/null
+++ b/docs/source/python/integration/substrait.rst
@@ -0,0 +1,249 @@
+.. Licensed to the Apache Software Foundation (ASF) under one
+.. or more contributor license agreements. See the NOTICE file
+.. distributed with this work for additional information
+.. regarding copyright ownership. The ASF licenses this file
+.. to you under the Apache License, Version 2.0 (the
+.. "License"); you may not use this file except in compliance
+.. with the License. You may obtain a copy of the License at
+
+.. http://www.apache.org/licenses/LICENSE-2.0
+
+.. Unless required by applicable law or agreed to in writing,
+.. software distributed under the License is distributed on an
+.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+.. KIND, either express or implied. See the License for the
+.. specific language governing permissions and limitations
+.. under the License.
+
+=========
+Substrait
+=========
+
+The ``arrow-substrait`` module implements support for the Substrait_ format,
+enabling conversion to and from Arrow objects.
+
+The ``arrow-dataset`` module can execute Substrait_ plans via the
+:doc:`Acero <../cpp/streaming_execution>` query engine.
+
+.. contents::
+
+Working with Schemas
+====================
+
+Arrow schemas can be encoded and decoded using the
:meth:`pyarrow.substrait.serialize_schema` and
+:meth:`pyarrow.substrait.deserialize_schema` functions.
+
+.. code-block:: python
+
+ import pyarrow as pa
+ import pyarrow.substrait as pa_substrait
+
+ arrow_schema = pa.schema([
+ pa.field("x", pa.int32()),
+ pa.field("y", pa.string())
+ ])
+ substrait_schema = pa_substrait.serialize_schema(arrow_schema)
+
+The schema marshalled as a Substrait ``NamedStruct`` is directly
+available as ``substrait_schema.schema``::
+
+ >>> print(substrait_schema.schema)
+ b'\n\x01x\n\x01y\x12\x0c\n\x04*\x02\x10\x01\n\x04b\x02\x10\x01'
+
+In case arrow custom types were used, the schema will require
+extensions for those types to be actually usable, for this reason
+the schema is also available as an `Extended Expression`_ including
+all the extensions types::
+
+ >>> print(substrait_schema.expression)
+
b'"\x14\n\x01x\n\x01y\x12\x0c\n\x04*\x02\x10\x01\n\x04b\x02\x10\x01:\x19\x10,*\x15Acero
17.0.0'
+
+If ``Substrait Python`` is installed, the schema can also be converted to
+a ``substrait-python`` object::
+
+ >>> print(substrait_schema.to_pysubstrait())
+ version {
+ minor_number: 44
+ producer: "Acero 17.0.0"
+ }
+ base_schema {
+ names: "x"
+ names: "y"
+ struct {
+ types {
+ i32 {
+ nullability: NULLABILITY_NULLABLE
+ }
+ }
+ types {
+ string {
+ nullability: NULLABILITY_NULLABLE
+ }
+ }
+ }
+ }
+
+Working with Expressions
+========================
+
+Arrow compute expressions can be encoded and decoded using the
+:meth:`pyarrow.substrait.serialize_expressions` and
+:meth:`pyarrow.substrait.deserialize_expressions` functions.
+
+.. code-block:: python
+
+ import pyarrow as pa
+ import pyarrow.compute as pa
+ import pyarrow.substrait as pa_substrait
+
+ arrow_schema = pa.schema([
+ pa.field("x", pa.int32()),
+ pa.field("y", pa.int32())
+ ])
+
+ substrait_expr = pa_substrait.serialize_expressions(
+ exprs=[pc.field("x") + pc.field("y")],
+ names=["total"],
+ schema=arrow_schema
+ )
+
+The result of encoding to substrait an expression will be the
+protobuf ``ExtendedExpression`` message data itself::
+
+ >>> print(bytes(substrait_expr))
+
b'\nZ\x12Xhttps://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml\x12\x07\x1a\x05\x1a\x03add\x1a>\n5\x1a3\x1a\x04*\x02\x10\x01"\n\x1a\x08\x12\x06\n\x02\x12\x00"\x00"\x0c\x1a\n\x12\x08\n\x04\x12\x02\x08\x01"\x00*\x11\n\x08overflow\x12\x05ERROR\x1a\x05total"\x14\n\x01x\n\x01y\x12\x0c\n\x04*\x02\x10\x01\n\x04*\x02\x10\x01:\x19\x10,*\x15Acero
17.0.0'
+
+So in case a ``Substrait Python`` object is required, the expression
+has to be decoded from ``substrait-python`` itself::
+
+ >>> import substrait
+ >>> pysubstrait_expr =
substrait.proto.ExtendedExpression.FromString(substrait_expr)
+ >>> print(pysubstrait_expr)
+ version {
+ minor_number: 44
+ producer: "Acero 17.0.0"
+ }
+ extension_uris {
+ uri:
"https://github.com/substrait-io/substrait/blob/main/extensions/functions_arithmetic.yaml"
+ }
+ extensions {
+ extension_function {
+ name: "add"
+ }
+ }
+ referred_expr {
+ expression {
+ scalar_function {
+ arguments {
+ value {
+ selection {
+ direct_reference {
+ struct_field {
+ }
+ }
+ root_reference {
+ }
+ }
+ }
+ }
+ arguments {
+ value {
+ selection {
+ direct_reference {
+ struct_field {
+ field: 1
+ }
+ }
+ root_reference {
+ }
+ }
+ }
+ }
+ options {
+ name: "overflow"
+ preference: "ERROR"
+ }
+ output_type {
+ i32 {
+ nullability: NULLABILITY_NULLABLE
+ }
+ }
+ }
+ }
+ output_names: "total"
+ }
+ base_schema {
+ names: "x"
+ names: "y"
+ struct {
+ types {
+ i32 {
+ nullability: NULLABILITY_NULLABLE
+ }
+ }
+ types {
+ i32 {
+ nullability: NULLABILITY_NULLABLE
+ }
+ }
+ }
+ }
+
+Executing Queries Using Substrait Extended Expressions
+======================================================
+
+Dataset supports executing queries using Substrait's `Extended Expression`_,
+the expressions can be passed to the dataset scanner in the form of
+:class:`pyarrow.substrait.BoundExpressions`
+
+.. code-block:: python
+
+ import pyarrow.dataset as ds
+ import pyarrow.substrait as pa_substrait
+
+ # Use substrait-python to create the queries
+ from substrait import proto
+
+ dataset = ds.dataset("./data/index-0.parquet")
+ substrait_schema =
pa_substrait.serialize_schema(dataset.schema).to_pysubstrait()
+
+ # SELECT project_name FROM dataset WHERE project_name = 'pyarrow'
+
+ projection = proto.ExtendedExpression(referred_expr=[
+ {"expression": {"selection": {"direct_reference": {"struct_field":
{"field": 0}}}},
+ "output_names": ["project_name"]}
+ ])
+ projection.MergeFrom(substrait_schema)
+
+ filtering = proto.ExtendedExpression(
+ extension_uris=[{"extension_uri_anchor": 99, "uri":
"/functions_comparison.yaml"}],
+ extensions=[{"extension_function": {"extension_uri_reference": 99,
"function_anchor": 199, "name": "equal:any1_any1"}}],
+ referred_expr=[
+ {"expression": {"scalar_function": {"function_reference": 199,
"arguments": [
+ {"value": {"selection": {"direct_reference":
{"struct_field": {"field": 0}}}}},
+ {"value": {"literal": {"string": "pyarrow"}}}
+ ], "output_type": {"bool": {"nullability": False}}}}}
+ ]
+ )
+ filtering.MergeFrom(substrait_schema)
+
+ results = dataset.scanner(
+ columns=pa.substrait.BoundExpressions.from_substrait(projection),
+ filter=pa.substrait.BoundExpressions.from_substrait(filtering)
+ ).head(5)
+
+
+.. code-block:: text
+
+ project_name
+ 0 pyarrow
+ 1 pyarrow
+ 2 pyarrow
+ 3 pyarrow
+ 4 pyarrow
+
+
+.. _`Substrait`: https://substrait.io/
+.. _`Substrait Python`: https://github.com/substrait-io/substrait-python
+.. _`Acero`: https://arrow.apache.org/docs/cpp/streaming_execution.html
+.. _`Extended Expression`:
https://github.com/substrait-io/substrait/blob/main/site/docs/expressions/extended_expression.md
diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx
index d39120934d..658f6b6cac 100644
--- a/python/pyarrow/_compute.pyx
+++ b/python/pyarrow/_compute.pyx
@@ -2441,7 +2441,7 @@ cdef class Expression(_Weakrefable):
)
@staticmethod
- def from_substrait(object buffer not None):
+ def from_substrait(object message not None):
"""
Deserialize an expression from Substrait
@@ -2453,7 +2453,7 @@ cdef class Expression(_Weakrefable):
Parameters
----------
- buffer : bytes or Buffer
+ message : bytes or Buffer or a protobuf Message
The Substrait message to deserialize
Returns
@@ -2461,7 +2461,7 @@ cdef class Expression(_Weakrefable):
Expression
The deserialized expression
"""
- expressions = _pas().deserialize_expressions(buffer).expressions
+ expressions =
_pas().BoundExpressions.from_substrait(message).expressions
if len(expressions) == 0:
raise ValueError("Substrait message did not contain any
expressions")
if len(expressions) > 1:
diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx
index 6b5259f499..39e3f4d665 100644
--- a/python/pyarrow/_dataset.pyx
+++ b/python/pyarrow/_dataset.pyx
@@ -39,6 +39,11 @@ from pyarrow.util import _is_iterable, _is_path_like,
_stringify_path
from pyarrow._json cimport ParseOptions as JsonParseOptions
from pyarrow._json cimport ReadOptions as JsonReadOptions
+try:
+ import pyarrow.substrait as pa_substrait
+except ImportError:
+ pa_substrait = None
+
_DEFAULT_BATCH_SIZE = 2**17
_DEFAULT_BATCH_READAHEAD = 16
@@ -272,6 +277,13 @@ cdef class Dataset(_Weakrefable):
# at the moment only support filter
requested_filter = options.get("filter")
+ if pa_substrait and isinstance(requested_filter,
pa_substrait.BoundExpressions):
+ expressions = list(requested_filter.expressions.values())
+ if len(expressions) != 1:
+ raise ValueError(
+ "Only one BoundExpressions with a single expression are
supported")
+ new_options["filter"] = requested_filter = expressions[0]
+
current_filter = self._scan_options.get("filter")
if requested_filter is not None and current_filter is not None:
new_options["filter"] = current_filter & requested_filter
@@ -282,7 +294,7 @@ cdef class Dataset(_Weakrefable):
def scanner(self,
object columns=None,
- Expression filter=None,
+ object filter=None,
int batch_size=_DEFAULT_BATCH_SIZE,
int batch_readahead=_DEFAULT_BATCH_READAHEAD,
int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD,
@@ -3447,6 +3459,9 @@ cdef void _populate_builder(const
shared_ptr[CScannerBuilder]& ptr,
filter, pyarrow_wrap_schema(builder.schema()))))
if columns is not None:
+ if pa_substrait and isinstance(columns, pa_substrait.BoundExpressions):
+ columns = columns.expressions
+
if isinstance(columns, dict):
for expr in columns.values():
if not isinstance(expr, Expression):
@@ -3527,7 +3542,7 @@ cdef class Scanner(_Weakrefable):
@staticmethod
def from_dataset(Dataset dataset not None, *,
object columns=None,
- Expression filter=None,
+ object filter=None,
int batch_size=_DEFAULT_BATCH_SIZE,
int batch_readahead=_DEFAULT_BATCH_READAHEAD,
int fragment_readahead=_DEFAULT_FRAGMENT_READAHEAD,
diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx
index 067cb5f916..d9359c8e77 100644
--- a/python/pyarrow/_substrait.pyx
+++ b/python/pyarrow/_substrait.pyx
@@ -26,6 +26,13 @@ from pyarrow.lib cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_substrait cimport *
+try:
+ import substrait as py_substrait
+except ImportError:
+ py_substrait = None
+else:
+ import substrait.proto # no-cython-lint
+
# TODO GH-37235: Fix exception handling
cdef CDeclaration _create_named_table_provider(
@@ -133,7 +140,7 @@ def run_query(plan, *, table_provider=None,
use_threads=True):
c_bool c_use_threads
c_use_threads = use_threads
- if isinstance(plan, bytes):
+ if isinstance(plan, (bytes, memoryview)):
c_buf_plan = pyarrow_unwrap_buffer(py_buffer(plan))
elif isinstance(plan, Buffer):
c_buf_plan = pyarrow_unwrap_buffer(plan)
@@ -187,6 +194,105 @@ def _parse_json_plan(plan):
return pyarrow_wrap_buffer(c_buf_plan)
+class SubstraitSchema:
+ """A Schema encoded for Substrait usage.
+
+ The SubstraitSchema contains a schema represented
+ both as a substrait ``NamedStruct`` and as an
+ ``ExtendedExpression``.
+
+ The ``ExtendedExpression`` is available for cases where types
+ used by the schema require extensions to decode them.
+ In such case the schema will be the ``base_schema`` of the
+ ``ExtendedExpression`` and all extensions will be provided.
+ """
+
+ def __init__(self, schema, expression):
+ self.schema = schema
+ self.expression = expression
+
+ def to_pysubstrait(self):
+ """Convert the schema to a substrait-python ExtendedExpression
object."""
+ if py_substrait is None:
+ raise ImportError("The 'substrait' package is required.")
+ return
py_substrait.proto.ExtendedExpression.FromString(self.expression)
+
+
+def serialize_schema(schema):
+ """
+ Serialize a schema into a SubstraitSchema object.
+
+ Parameters
+ ----------
+ schema : Schema
+ The schema to serialize
+
+ Returns
+ -------
+ SubstraitSchema
+ The schema stored in a SubstraitSchema object.
+ """
+ return SubstraitSchema(
+ schema=_serialize_namedstruct_schema(schema),
+ expression=serialize_expressions([], [], schema,
allow_arrow_extensions=True)
+ )
+
+
+def _serialize_namedstruct_schema(schema):
+ cdef:
+ CResult[shared_ptr[CBuffer]] c_res_buffer
+ shared_ptr[CBuffer] c_buffer
+ CConversionOptions c_conversion_options
+ CExtensionSet c_extensions
+
+ with nogil:
+ c_res_buffer = SerializeSchema(deref((<Schema> schema).sp_schema),
&c_extensions, c_conversion_options)
+ c_buffer = GetResultValue(c_res_buffer)
+
+ return memoryview(pyarrow_wrap_buffer(c_buffer))
+
+
+def deserialize_schema(buf):
+ """
+ Deserialize a ``NamedStruct`` Substrait message
+ or a SubstraitSchema object into an Arrow Schema object
+
+ Parameters
+ ----------
+ buf : Buffer or bytes or SubstraitSchema
+ The message to deserialize
+
+ Returns
+ -------
+ Schema
+ The deserialized schema
+ """
+ cdef:
+ shared_ptr[CBuffer] c_buffer
+ CResult[shared_ptr[CSchema]] c_res_schema
+ shared_ptr[CSchema] c_schema
+ CConversionOptions c_conversion_options
+ CExtensionSet c_extensions
+
+ if isinstance(buf, SubstraitSchema):
+ return deserialize_expressions(buf.expression).schema
+
+ if isinstance(buf, (bytes, memoryview)):
+ c_buffer = pyarrow_unwrap_buffer(py_buffer(buf))
+ elif isinstance(buf, Buffer):
+ c_buffer = pyarrow_unwrap_buffer(buf)
+ else:
+ raise TypeError(
+ f"Expected 'pyarrow.Buffer' or bytes, got '{type(buf)}'")
+
+ with nogil:
+ c_res_schema = DeserializeSchema(
+ deref(c_buffer), c_extensions, c_conversion_options)
+ c_schema = GetResultValue(c_res_schema)
+
+ return pyarrow_wrap_schema(c_schema)
+
+
def serialize_expressions(exprs, names, schema, *,
allow_arrow_extensions=False):
"""
Serialize a collection of expressions into Substrait
@@ -245,7 +351,7 @@ def serialize_expressions(exprs, names, schema, *,
allow_arrow_extensions=False)
with nogil:
c_res_buffer = SerializeExpressions(c_bound_exprs,
c_conversion_options)
c_buffer = GetResultValue(c_res_buffer)
- return pyarrow_wrap_buffer(c_buffer)
+ return memoryview(pyarrow_wrap_buffer(c_buffer))
cdef class BoundExpressions(_Weakrefable):
@@ -290,6 +396,32 @@ cdef class BoundExpressions(_Weakrefable):
self.init(bound_expressions)
return self
+ @classmethod
+ def from_substrait(cls, message):
+ """
+ Convert a Substrait message into a BoundExpressions object
+
+ Parameters
+ ----------
+ message : Buffer or bytes or protobuf Message
+ The message to convert to a BoundExpressions object
+
+ Returns
+ -------
+ BoundExpressions
+ The converted expressions, their names, and the bound schema
+ """
+ if isinstance(message, (bytes, memoryview)):
+ return deserialize_expressions(message)
+ elif isinstance(message, Buffer):
+ return deserialize_expressions(message)
+ else:
+ try:
+ return deserialize_expressions(message.SerializeToString())
+ except AttributeError:
+ raise TypeError(
+ f"Expected 'pyarrow.Buffer' or bytes or protobuf Message,
got '{type(message)}'")
+
def deserialize_expressions(buf):
"""
@@ -310,7 +442,7 @@ def deserialize_expressions(buf):
CResult[CBoundExpressions] c_res_bound_exprs
CBoundExpressions c_bound_exprs
- if isinstance(buf, bytes):
+ if isinstance(buf, (bytes, memoryview)):
c_buffer = pyarrow_unwrap_buffer(py_buffer(buf))
elif isinstance(buf, Buffer):
c_buffer = pyarrow_unwrap_buffer(buf)
diff --git a/python/pyarrow/includes/common.pxd
b/python/pyarrow/includes/common.pxd
index 044dd0333f..9297436c1c 100644
--- a/python/pyarrow/includes/common.pxd
+++ b/python/pyarrow/includes/common.pxd
@@ -173,3 +173,12 @@ cdef inline object PyObject_to_object(PyObject* o):
cdef object result = <object> o
cpython.Py_DECREF(result)
return result
+
+
+cdef extern from "<string_view>" namespace "std" nogil:
+ cdef cppclass cpp_string_view "std::string_view":
+ string_view()
+ string_view(const char*)
+ size_t size()
+ bint empty()
+ const char* data()
diff --git a/python/pyarrow/includes/libarrow_substrait.pxd
b/python/pyarrow/includes/libarrow_substrait.pxd
index c41f4c05d3..865568e2ba 100644
--- a/python/pyarrow/includes/libarrow_substrait.pxd
+++ b/python/pyarrow/includes/libarrow_substrait.pxd
@@ -45,6 +45,20 @@ cdef extern from "arrow/engine/substrait/options.h"
namespace "arrow::engine" no
cdef extern from "arrow/engine/substrait/extension_set.h" \
namespace "arrow::engine" nogil:
+ cdef struct CSubstraitId "arrow::engine::Id":
+ cpp_string_view uri
+ cpp_string_view name
+
+ cdef struct CExtensionSetTypeRecord
"arrow::engine::ExtensionSet::TypeRecord":
+ CSubstraitId id
+ shared_ptr[CDataType] type
+
+ cdef cppclass CExtensionSet "arrow::engine::ExtensionSet":
+ CExtensionSet()
+ unordered_map[uint32_t, cpp_string_view]& uris()
+ CResult[uint32_t] EncodeType(const CDataType&)
+ CResult[CExtensionSetTypeRecord] DecodeType(uint32_t)
+
cdef cppclass ExtensionIdRegistry:
std_vector[c_string] GetSupportedSubstraitFunctions()
@@ -68,6 +82,15 @@ cdef extern from "arrow/engine/substrait/serde.h" namespace
"arrow::engine" nogi
CResult[CBoundExpressions] DeserializeExpressions(
const CBuffer& serialized_expressions)
+ CResult[shared_ptr[CBuffer]] SerializeSchema(
+ const CSchema &schema, CExtensionSet* extension_set,
+ const CConversionOptions& conversion_options)
+
+ CResult[shared_ptr[CSchema]] DeserializeSchema(
+ const CBuffer& serialized_schema, const CExtensionSet& extension_set,
+ const CConversionOptions& conversion_options)
+
+
cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine"
nogil:
CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(
const CBuffer& substrait_buffer, const ExtensionIdRegistry* registry,
diff --git a/python/pyarrow/substrait.py b/python/pyarrow/substrait.py
index a2b217f493..db2c3a96a1 100644
--- a/python/pyarrow/substrait.py
+++ b/python/pyarrow/substrait.py
@@ -21,7 +21,10 @@ try:
get_supported_functions,
run_query,
deserialize_expressions,
- serialize_expressions
+ serialize_expressions,
+ deserialize_schema,
+ serialize_schema,
+ SubstraitSchema
)
except ImportError as exc:
raise ImportError(
diff --git a/python/pyarrow/tests/test_dataset.py
b/python/pyarrow/tests/test_dataset.py
index 0d3a0fbd3b..772670ad79 100644
--- a/python/pyarrow/tests/test_dataset.py
+++ b/python/pyarrow/tests/test_dataset.py
@@ -5730,3 +5730,37 @@ def test_make_write_options_error():
msg = "make_write_options\\(\\) takes exactly 0 positional arguments"
with pytest.raises(TypeError, match=msg):
pformat.make_write_options(43)
+
+
+def test_scanner_from_substrait(dataset):
+ try:
+ import pyarrow.substrait as ps
+ except ImportError:
+ pytest.skip("substrait NOT enabled")
+
+ # SELECT str WHERE i64 = 4
+ projection =
(b'\nS\x08\x0c\x12Ohttps://github.com/apache/arrow/blob/main/format'
+
b'/substrait/extension_types.yaml\x12\t\n\x07\x08\x0c\x1a\x03u64'
+
b'\x12\x0b\n\t\x08\x0c\x10\x01\x1a\x03u32\x1a\x0f\n\x08\x12\x06'
+
b'\n\x04\x12\x02\x08\x02\x1a\x03str"i\n\x03i64\n\x03f64\n\x03str'
+ b'\n\x05const\n\x06struct\n\x01a\n\x01b\n\x05group\n\x03key'
+ b'\x127\n\x04:\x02\x10\x01\n\x04Z\x02\x10\x01\n\x04b\x02\x10'
+
b'\x01\n\x04:\x02\x10\x01\n\x11\xca\x01\x0e\n\x04:\x02\x10\x01'
+
b'\n\x04b\x02\x10\x01\x18\x01\n\x04*\x02\x10\x01\n\x04b\x02\x10\x01')
+ filtering =
(b'\n\x1e\x08\x06\x12\x1a/functions_comparison.yaml\nS\x08\x0c\x12'
+ b'Ohttps://github.com/apache/arrow/blob/main/format'
+
b'/substrait/extension_types.yaml\x12\x18\x1a\x16\x08\x06\x10\xc5'
+
b'\x01\x1a\x0fequal:any1_any1\x12\t\n\x07\x08\x0c\x1a\x03u64\x12'
+
b'\x0b\n\t\x08\x0c\x10\x01\x1a\x03u32\x1a\x1f\n\x1d\x1a\x1b\x08'
+
b'\xc5\x01\x1a\x04\n\x02\x10\x02"\x08\x1a\x06\x12\x04\n\x02\x12\x00'
+
b'"\x06\x1a\x04\n\x02(\x04"i\n\x03i64\n\x03f64\n\x03str\n\x05const'
+
b'\n\x06struct\n\x01a\n\x01b\n\x05group\n\x03key\x127\n\x04:\x02'
+
b'\x10\x01\n\x04Z\x02\x10\x01\n\x04b\x02\x10\x01\n\x04:\x02\x10'
+
b'\x01\n\x11\xca\x01\x0e\n\x04:\x02\x10\x01\n\x04b\x02\x10\x01'
+ b'\x18\x01\n\x04*\x02\x10\x01\n\x04b\x02\x10\x01')
+
+ result = dataset.scanner(
+ columns=ps.BoundExpressions.from_substrait(projection),
+ filter=ps.BoundExpressions.from_substrait(filtering)
+ ).to_table()
+ assert result.to_pydict() == {'str': ['4', '4']}
diff --git a/python/pyarrow/tests/test_substrait.py
b/python/pyarrow/tests/test_substrait.py
index 01d468cd9e..fcd1c8d48c 100644
--- a/python/pyarrow/tests/test_substrait.py
+++ b/python/pyarrow/tests/test_substrait.py
@@ -105,7 +105,7 @@ def test_run_query_input_types(tmpdir, query):
# Otherwise error for invalid query
msg = "ParseFromZeroCopyStream failed for substrait.Plan"
- with pytest.raises(OSError, match=msg):
+ with pytest.raises(ArrowInvalid, match=msg):
substrait.run_query(query)
@@ -1077,3 +1077,44 @@ def test_serializing_udfs():
assert schema == returned.schema
assert len(returned.expressions) == 1
assert str(returned.expressions["expr"]) == str(exprs[0])
+
+
+def test_serializing_schema():
+ substrait_schema =
b'\n\x01x\n\x01y\x12\x0c\n\x04*\x02\x10\x01\n\x04b\x02\x10\x01'
+ expected_schema = pa.schema([
+ pa.field("x", pa.int32()),
+ pa.field("y", pa.string())
+ ])
+ returned = pa.substrait.deserialize_schema(substrait_schema)
+ assert expected_schema == returned
+
+ arrow_substrait_schema = pa.substrait.serialize_schema(returned)
+ assert arrow_substrait_schema.schema == substrait_schema
+
+ returned = pa.substrait.deserialize_schema(arrow_substrait_schema)
+ assert expected_schema == returned
+
+ returned = pa.substrait.deserialize_schema(arrow_substrait_schema.schema)
+ assert expected_schema == returned
+
+ returned =
pa.substrait.deserialize_expressions(arrow_substrait_schema.expression)
+ assert returned.schema == expected_schema
+
+
+def test_bound_expression_from_Message():
+ class FakeMessage:
+ def __init__(self, expr):
+ self.expr = expr
+
+ def SerializeToString(self):
+ return self.expr
+
+ # SELECT project_release, project_version
+ message =
(b'\x1a\x1b\n\x08\x12\x06\n\x04\x12\x02\x08\x01\x1a\x0fproject_release'
+ b'\x1a\x19\n\x06\x12\x04\n\x02\x12\x00\x1a\x0fproject_version'
+ b'"0\n\x0fproject_version\n\x0fproject_release'
+ b'\x12\x0c\n\x04:\x02\x10\x01\n\x04b\x02\x10\x01')
+ exprs = pa.substrait.BoundExpressions.from_substrait(FakeMessage(message))
+ assert len(exprs.expressions) == 2
+ assert 'project_release' in exprs.expressions
+ assert 'project_version' in exprs.expressions