jorisvandenbossche commented on code in PR #34834:
URL: https://github.com/apache/arrow/pull/34834#discussion_r1253079809


##########
python/pyarrow/_compute.pyx:
##########
@@ -2350,6 +2367,58 @@ cdef class Expression(_Weakrefable):
             self.__class__.__name__, str(self)
         )
 
+    @staticmethod
+    def from_substrait(object buffer not None):
+        """
+        Deserialize an expression from Substrait
+
+        The serialized message must be an ExtendedExpression message that has
+        only a single expression.  The name of the expression and the schema
+        the expression was bound to will be ignored.  Use
+        pyarrow.substrait.deserialize_expressions if this information is needed
+        or if the message might contain multiple expressions.
+
+        Parameters
+        ----------
+        buffer : bytes or Buffer
+            The Substrait message to deserialize
+
+        Returns
+        -------
+        Expression
+            The deserialized expression
+        """
+        expressions = _pas().deserialize_expressions(buffer).expressions
+        if len(expressions) == 0:
+            raise ValueError("Substrait message did not contain any 
expressions")
+        if len(expressions) > 1:
+            raise ValueError(
+                "Substrait message contained multiple expressions.  Use 
pyarrow.substrait.deserialize_expressions instead")
+        return next(iter(expressions.values()))
+
+    def to_substrait(self, Schema schema not None, c_bool allow_udfs=False):
+        """
+        Serialize the expression using Substrait
+
+        The expression will be serialized as an ExtendedExpression message 
that has a
+        single expression named "expression"
+
+        Parameters
+        ----------
+        schema : Schema
+            The input schema the expression will be bound to
+        allow_udfs : bool, default False
+            If False then only functions that are part of the core Substrait 
function
+            definitions will be allowed.  Set this to True to allow 
pyarrow-specific functions
+            but the result may not be accepted by other compute libraries.
+
+        Returns
+        -------
+        Buffer
+            A buffer containing the serialized Protobuf plan.
+        """
+        return _pas().serialize_expressions([self], "expression", schema, 
allow_udfs=allow_udfs)

Review Comment:
   Allow the user to override this default name of "expression"? (although if 
you want that, you can always use `serialize_expressions` function instead)



##########
python/pyarrow/_substrait.pyx:
##########
@@ -185,6 +187,143 @@ def _parse_json_plan(plan):
     return pyarrow_wrap_buffer(c_buf_plan)
 
 
+def serialize_expressions(exprs, names, schema, *, allow_udfs=False):
+    """
+    Serialize a collection of expressions into Substrait
+
+    Substrait expressions must be bound to a schema.  For example,
+    the Substrait expression ``a_i32 + b_i32`` is different from the
+    Substrait expression ``a_i64 + b_i64``.  Pyarrow expressions are

Review Comment:
   ```suggestion
       the Substrait expression ``a:i32 + b:i32`` is different from the
       Substrait expression ``a:i64 + b:i64``.  Pyarrow expressions are
   ```
   
   ? (that might be clearer that the actual field names are still "a" and "b" 
in both cases)



##########
python/pyarrow/_compute.pyx:
##########
@@ -2350,6 +2367,58 @@ cdef class Expression(_Weakrefable):
             self.__class__.__name__, str(self)
         )
 
+    @staticmethod
+    def from_substrait(object buffer not None):
+        """
+        Deserialize an expression from Substrait
+
+        The serialized message must be an ExtendedExpression message that has
+        only a single expression.  The name of the expression and the schema
+        the expression was bound to will be ignored.  Use
+        pyarrow.substrait.deserialize_expressions if this information is needed
+        or if the message might contain multiple expressions.
+
+        Parameters
+        ----------
+        buffer : bytes or Buffer
+            The Substrait message to deserialize
+
+        Returns
+        -------
+        Expression
+            The deserialized expression
+        """
+        expressions = _pas().deserialize_expressions(buffer).expressions
+        if len(expressions) == 0:
+            raise ValueError("Substrait message did not contain any 
expressions")
+        if len(expressions) > 1:
+            raise ValueError(
+                "Substrait message contained multiple expressions.  Use 
pyarrow.substrait.deserialize_expressions instead")
+        return next(iter(expressions.values()))
+
+    def to_substrait(self, Schema schema not None, c_bool allow_udfs=False):
+        """
+        Serialize the expression using Substrait
+
+        The expression will be serialized as an ExtendedExpression message 
that has a
+        single expression named "expression"
+
+        Parameters
+        ----------
+        schema : Schema
+            The input schema the expression will be bound to
+        allow_udfs : bool, default False
+            If False then only functions that are part of the core Substrait 
function
+            definitions will be allowed.  Set this to True to allow 
pyarrow-specific functions
+            but the result may not be accepted by other compute libraries.
+
+        Returns
+        -------
+        Buffer
+            A buffer containing the serialized Protobuf plan.
+        """
+        return _pas().serialize_expressions([self], "expression", schema, 
allow_udfs=allow_udfs)

Review Comment:
   ```suggestion
           return _pas().serialize_expressions([self], ["expression"], schema, 
allow_udfs=allow_udfs)
   ```
   
   This currently causes the bug that a deserialized form of this has "e" as 
name:
   
   ```
   In [16]: expr = pc.field("a") == 1
   
   In [17]: buf = expr.to_substrait(pa.schema([('a', 'int32')]))
   
   In [18]: pyarrow.substrait.deserialize_expressions(buf).expressions
   Out[18]: {'e': <pyarrow.compute.Expression (FieldPath(0) == 1)>}
   ```
   
   (so might be good to add a test for this)



##########
python/pyarrow/_substrait.pyx:
##########
@@ -185,6 +187,143 @@ def _parse_json_plan(plan):
     return pyarrow_wrap_buffer(c_buf_plan)
 
 
+def serialize_expressions(exprs, names, schema, *, allow_udfs=False):
+    """
+    Serialize a collection of expressions into Substrait
+
+    Substrait expressions must be bound to a schema.  For example,
+    the Substrait expression ``a_i32 + b_i32`` is different from the
+    Substrait expression ``a_i64 + b_i64``.  Pyarrow expressions are
+    typically unbound.  For example, both of the above expressions
+    would be represented as ``a + b`` in pyarrow.
+
+    This means a schema must be provided when serializing an expression.
+    It also means that the serialization may fail if a matching function
+    call cannot be found for the expression.
+
+    Parameters
+    ----------
+    exprs : list of Expression
+        The expressions to serialize
+    names : list of str
+        Names for the expressions
+    schema : Schema
+        The schema the expressions will be bound to
+    allow_udfs : bool, default False

Review Comment:
   The "udf" in the keyword name might be a bit confusing, as I think users of 
pyarrow will think in the form of actual UDFs defined by them, and not 
functions defined by arrow (but not part of substrait), as for the user, those 
are "built-in" functions, not UDFs.
   
   I see that in the C++ code you are using `allow_arrow_extensions` as 
keyword. We can use that here as well? (or is there a specific reason you went 
for a different name?)



##########
python/pyarrow/_substrait.pyx:
##########
@@ -145,6 +146,7 @@ def run_query(plan, *, table_provider=None, 
use_threads=True):
         }
         c_conversion_options.named_table_provider = 
BindFunction[CNamedTableProvider](
             &_create_named_table_provider, named_table_args)
+        c_conversion_options.allow_arrow_extensions = False

Review Comment:
   This is the default, right? (i.e. this is not changing actual behaviour, 
just more explicit)



##########
python/pyarrow/_substrait.pyx:
##########
@@ -185,6 +187,143 @@ def _parse_json_plan(plan):
     return pyarrow_wrap_buffer(c_buf_plan)
 
 
+def serialize_expressions(exprs, names, schema, *, allow_udfs=False):
+    """
+    Serialize a collection of expressions into Substrait
+
+    Substrait expressions must be bound to a schema.  For example,
+    the Substrait expression ``a_i32 + b_i32`` is different from the
+    Substrait expression ``a_i64 + b_i64``.  Pyarrow expressions are
+    typically unbound.  For example, both of the above expressions
+    would be represented as ``a + b`` in pyarrow.
+
+    This means a schema must be provided when serializing an expression.
+    It also means that the serialization may fail if a matching function
+    call cannot be found for the expression.
+
+    Parameters
+    ----------
+    exprs : list of Expression
+        The expressions to serialize
+    names : list of str
+        Names for the expressions
+    schema : Schema
+        The schema the expressions will be bound to
+    allow_udfs : bool, default False
+        If False then only functions that are part of the core Substrait 
function
+        definitions will be allowed.  Set this to True to allow 
pyarrow-specific functions
+        but the result may not be accepted by other compute libraries.
+
+    Returns
+    -------
+    Buffer
+        An ExtendedExpression message containing the serialized expressions
+    """
+    cdef:
+        CResult[shared_ptr[CBuffer]] c_res_buffer
+        shared_ptr[CBuffer] c_buffer
+        CNamedExpression c_named_expr
+        CBoundExpressions c_bound_exprs
+        CConversionOptions c_conversion_options
+
+    for i in range(len(exprs)):
+        if not isinstance(exprs[i], Expression):
+            raise TypeError(f"Expected Expression, got '{type(exprs[i])}'")
+        c_named_expr.expression = (<Expression> exprs[i]).unwrap()
+        if i < len(names):
+            if not isinstance(names[i], str):
+                raise TypeError(f"Expected str, got '{type(names[i])}'")
+            c_named_expr.name = tobytes(<str> names[i])
+        else:
+            c_named_expr.name = tobytes("autoname")

Review Comment:
   Is this useful? (it can give multiple expressions with the same name?) Or 
could also raise an error instead?
   
   I would maybe rather validate that `len(exprs) == len(names)`
   
   (and in that case you can also do `for expr, name in zip(exprs, names): ..` 
to simplify the code a bit)



##########
python/pyarrow/_substrait.pyx:
##########
@@ -185,6 +187,143 @@ def _parse_json_plan(plan):
     return pyarrow_wrap_buffer(c_buf_plan)
 
 
+def serialize_expressions(exprs, names, schema, *, allow_udfs=False):
+    """
+    Serialize a collection of expressions into Substrait
+
+    Substrait expressions must be bound to a schema.  For example,
+    the Substrait expression ``a_i32 + b_i32`` is different from the
+    Substrait expression ``a_i64 + b_i64``.  Pyarrow expressions are
+    typically unbound.  For example, both of the above expressions
+    would be represented as ``a + b`` in pyarrow.
+
+    This means a schema must be provided when serializing an expression.
+    It also means that the serialization may fail if a matching function
+    call cannot be found for the expression.
+
+    Parameters
+    ----------
+    exprs : list of Expression
+        The expressions to serialize
+    names : list of str
+        Names for the expressions
+    schema : Schema
+        The schema the expressions will be bound to
+    allow_udfs : bool, default False
+        If False then only functions that are part of the core Substrait 
function
+        definitions will be allowed.  Set this to True to allow 
pyarrow-specific functions
+        but the result may not be accepted by other compute libraries.
+
+    Returns
+    -------
+    Buffer
+        An ExtendedExpression message containing the serialized expressions
+    """
+    cdef:
+        CResult[shared_ptr[CBuffer]] c_res_buffer
+        shared_ptr[CBuffer] c_buffer
+        CNamedExpression c_named_expr
+        CBoundExpressions c_bound_exprs
+        CConversionOptions c_conversion_options
+
+    for i in range(len(exprs)):
+        if not isinstance(exprs[i], Expression):
+            raise TypeError(f"Expected Expression, got '{type(exprs[i])}'")
+        c_named_expr.expression = (<Expression> exprs[i]).unwrap()
+        if i < len(names):
+            if not isinstance(names[i], str):
+                raise TypeError(f"Expected str, got '{type(names[i])}'")
+            c_named_expr.name = tobytes(<str> names[i])
+        else:
+            c_named_expr.name = tobytes("autoname")
+        c_bound_exprs.named_expressions.push_back(c_named_expr)
+    c_bound_exprs.schema = (<Schema> schema).sp_schema
+
+    c_conversion_options.allow_arrow_extensions = allow_udfs
+
+    c_res_buffer = SerializeExpressions(c_bound_exprs, c_conversion_options)
+    with nogil:
+        c_buffer = GetResultValue(c_res_buffer)

Review Comment:
   Can the `SerializeExpressions` be moved within the nogil block? (I assume 
this is the potentially costly part, the `GetResultValue` should always be 
cheap?)



##########
python/pyarrow/tests/test_substrait.py:
##########
@@ -923,3 +925,58 @@ def table_provider(names, _):
 
     # Ordering of k is deterministic because this is running with serial 
execution
     assert res_tb == expected_tb
+
+
[email protected]("expr", [
+    pc.equal(ds.field("x"), 7),
+    pc.equal(ds.field("x"), ds.field("y")),
+    ds.field("x") > 50
+])
+def test_serializing_expressions(expr):
+    schema = pa.schema([
+        pa.field("x", pa.int32()),
+        pa.field("y", pa.int32())
+    ])
+
+    buf = pa.substrait.serialize_expressions([expr], ["test_expr"], schema)
+    returned = pa.substrait.deserialize_expressions(buf)
+    assert schema == returned.schema
+    assert len(returned.expressions) == 1
+    assert "test_expr" in returned.expressions
+
+
+def test_serializing_multiple_expressions():
+    schema = pa.schema([
+        pa.field("x", pa.int32()),
+        pa.field("y", pa.int32())
+    ])
+    exprs = [pc.equal(ds.field("x"), 7), pc.equal(ds.field("x"), 
ds.field("y"))]
+    buf = pa.substrait.serialize_expressions(exprs, ["first", "second"], 
schema)
+    returned = pa.substrait.deserialize_expressions(buf)
+    assert schema == returned.schema
+    assert len(returned.expressions) == 2
+
+    norm_exprs = [pc.equal(ds.field(0), 7), pc.equal(ds.field(0), ds.field(1))]
+    assert str(returned.expressions["first"]) == str(norm_exprs[0])
+    assert str(returned.expressions["second"]) == str(norm_exprs[1])

Review Comment:
   We also have an `equals` method on the Expression if you want to avoid 
string repr comparison (but not sure what the corner cases for either option)



##########
python/pyarrow/_substrait.pyx:
##########
@@ -185,6 +187,143 @@ def _parse_json_plan(plan):
     return pyarrow_wrap_buffer(c_buf_plan)
 
 
+def serialize_expressions(exprs, names, schema, *, allow_udfs=False):
+    """
+    Serialize a collection of expressions into Substrait
+
+    Substrait expressions must be bound to a schema.  For example,
+    the Substrait expression ``a_i32 + b_i32`` is different from the
+    Substrait expression ``a_i64 + b_i64``.  Pyarrow expressions are
+    typically unbound.  For example, both of the above expressions
+    would be represented as ``a + b`` in pyarrow.
+
+    This means a schema must be provided when serializing an expression.
+    It also means that the serialization may fail if a matching function
+    call cannot be found for the expression.
+
+    Parameters
+    ----------
+    exprs : list of Expression
+        The expressions to serialize
+    names : list of str
+        Names for the expressions
+    schema : Schema
+        The schema the expressions will be bound to
+    allow_udfs : bool, default False
+        If False then only functions that are part of the core Substrait 
function
+        definitions will be allowed.  Set this to True to allow 
pyarrow-specific functions
+        but the result may not be accepted by other compute libraries.
+
+    Returns
+    -------
+    Buffer
+        An ExtendedExpression message containing the serialized expressions
+    """
+    cdef:
+        CResult[shared_ptr[CBuffer]] c_res_buffer
+        shared_ptr[CBuffer] c_buffer
+        CNamedExpression c_named_expr
+        CBoundExpressions c_bound_exprs
+        CConversionOptions c_conversion_options
+
+    for i in range(len(exprs)):
+        if not isinstance(exprs[i], Expression):
+            raise TypeError(f"Expected Expression, got '{type(exprs[i])}'")
+        c_named_expr.expression = (<Expression> exprs[i]).unwrap()
+        if i < len(names):
+            if not isinstance(names[i], str):
+                raise TypeError(f"Expected str, got '{type(names[i])}'")
+            c_named_expr.name = tobytes(<str> names[i])
+        else:
+            c_named_expr.name = tobytes("autoname")
+        c_bound_exprs.named_expressions.push_back(c_named_expr)
+    c_bound_exprs.schema = (<Schema> schema).sp_schema
+
+    c_conversion_options.allow_arrow_extensions = allow_udfs
+
+    c_res_buffer = SerializeExpressions(c_bound_exprs, c_conversion_options)
+    with nogil:
+        c_buffer = GetResultValue(c_res_buffer)
+    return pyarrow_wrap_buffer(c_buffer)
+
+
+cdef class BoundExpressions(_Weakrefable):
+    """
+    A collection of named expressions and the schema they are bound to
+
+    This is equivalent to the Substrait ExtendedExpression message
+    """
+
+    cdef:
+        CBoundExpressions c_bound_exprs
+
+    def __init__(self):
+        msg = 'BoundExpressions is an abstract class thus cannot be 
initialized.'
+        raise TypeError(msg)
+
+    cdef void init(self, CBoundExpressions bound_expressions):
+        self.c_bound_exprs = bound_expressions
+
+    @property
+    def schema(self):
+        """
+        The common schema that all expressions are bound to
+        """
+        return pyarrow_wrap_schema(self.c_bound_exprs.schema)
+
+    @property
+    def expressions(self):
+        """
+        A dict from expression name to expression
+        """
+        expr_dict = {}
+        for named_expr in self.c_bound_exprs.named_expressions:
+            name = frombytes(named_expr.name)
+            expr = Expression.wrap(named_expr.expression)
+            expr_dict[name] = expr
+        return expr_dict
+
+    @staticmethod
+    cdef wrap(const CBoundExpressions& bound_expressions):
+        cdef BoundExpressions self = BoundExpressions.__new__(BoundExpressions)
+        self.init(bound_expressions)
+        return self
+
+
+def deserialize_expressions(buf):
+    """
+    Deserialize an ExtendedExpression Substrait message into a 
BoundExpressions object
+
+    Parameters
+    ----------
+    buf : Buffer or bytes
+        The message to deserialize
+
+    Returns
+    -------
+    BoundExpressions
+        The deserialized expressions, their names, and the bound schema
+    """
+    cdef:
+        shared_ptr[CBuffer] c_buffer
+        CResult[CBoundExpressions] c_res_bound_exprs
+        CBoundExpressions c_bound_exprs
+
+    if isinstance(buf, bytes):
+        c_buffer = pyarrow_unwrap_buffer(py_buffer(buf))
+    elif isinstance(buf, Buffer):
+        c_buffer = pyarrow_unwrap_buffer(buf)
+    else:
+        raise TypeError(
+            f"Expected 'pyarrow.Buffer' or bytes, got '{type(buf)}'")
+
+    c_res_bound_exprs = DeserializeExpressions(deref(c_buffer))
+    with nogil:
+        c_bound_exprs = GetResultValue(c_res_bound_exprs)

Review Comment:
   Same question here about nogil



##########
python/pyarrow/tests/test_compute.py:
##########
@@ -3277,19 +3281,106 @@ def test_expression_serialization():
     f = pc.scalar({'a': 1})
     g = pc.scalar(pa.scalar(1))
     h = pc.scalar(np.int64(2))
+    j = pc.scalar(False)
+
+    literal_exprs = [a, b, c, d, e, g, h, j]
+
+    exprs_with_call = [a == b, a != b, a > b, c & j, c | j, ~c, d.is_valid(),
+                       a + b, a - b, a * b, a / b, pc.negate(a),
+                       pc.add(a, b), pc.subtract(a, b), pc.divide(a, b),
+                       pc.multiply(a, b), pc.power(a, a), pc.sqrt(a),
+                       pc.exp(b), pc.cos(b), pc.sin(b), pc.tan(b),
+                       pc.acos(b), pc.atan(b), pc.asin(b), pc.atan2(b, b),
+                       pc.abs(b), pc.sign(a), pc.bit_wise_not(a),
+                       pc.bit_wise_and(a, a), pc.bit_wise_or(a, a),
+                       pc.bit_wise_xor(a, a), pc.is_nan(b), pc.is_finite(b),
+                       pc.coalesce(a, b),
+                       a.cast(pa.int32(), safe=False)]
+
+    exprs_with_ref = [pc.field('i64') > 5, pc.field('i64') == 5,
+                      pc.field('i64') == 7,
+                      pc.field(('foo', 'bar')) == 'value',
+                      pc.field('foo', 'bar') == 'value']
+
+    special_cases = [
+        f,  # Struct literals lose their field names
+        a.isin([1, 2, 3]),  # isin converts to an or list
+        pc.field('i64').is_null()  # pyarrow always specifies a FunctionOptions
+        # for is_null which, being the default, is
+        # dropped on serialization
+    ]
 
-    all_exprs = [a, b, c, d, e, f, g, h, a == b, a > b, a & b, a | b, ~c,
-                 d.is_valid(), a.cast(pa.int32(), safe=False),
-                 a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]),
-                 pc.field('i64') > 5, pc.field('i64') == 5,
-                 pc.field('i64') == 7, pc.field('i64').is_null(),
-                 pc.field(('foo', 'bar')) == 'value',
-                 pc.field('foo', 'bar') == 'value']
+    all_exprs = literal_exprs + exprs_with_call + exprs_with_ref + 
special_cases
     for expr in all_exprs:
         assert isinstance(expr, pc.Expression)
         restored = pickle.loads(pickle.dumps(expr))
         assert expr.equals(restored)
 
+    if pas is not None:
+
+        test_schema = pa.schema([pa.field("i64", pa.int64()), pa.field(
+            "foo", pa.struct([pa.field("bar", pa.string())]))])
+
+        # Basic literals don't change on binding and so they will round
+        # trip without any change
+        for expr in literal_exprs:
+            serialized = expr.to_substrait(test_schema)
+            deserialized = pc.Expression.from_substrait(serialized)
+            assert expr.equals(deserialized)
+
+        # Expressions are bound when they get serialized.  Since bound
+        # expressions are not equal to their unbound variants we cannot
+        # compare the round tripped with the original
+        for expr in exprs_with_call:
+            serialized = expr.to_substrait(test_schema)
+            deserialized = pc.Expression.from_substrait(serialized)
+            # We can't compare the expressions themselves because of the bound
+            # unbound difference. But we can compare the string representation
+            assert str(deserialized) == str(expr)
+            serialized_again = deserialized.to_substrait(test_schema)
+            deserialized_again = pc.Expression.from_substrait(serialized_again)
+            assert deserialized.equals(deserialized_again)
+
+        # Expressions that have references will be normalized, on 
serialization,
+        # to numeric referneces

Review Comment:
   ```suggestion
           # to numeric references
   ```



##########
python/pyarrow/tests/test_compute.py:
##########
@@ -3277,19 +3281,106 @@ def test_expression_serialization():
     f = pc.scalar({'a': 1})
     g = pc.scalar(pa.scalar(1))
     h = pc.scalar(np.int64(2))
+    j = pc.scalar(False)
+
+    literal_exprs = [a, b, c, d, e, g, h, j]
+
+    exprs_with_call = [a == b, a != b, a > b, c & j, c | j, ~c, d.is_valid(),
+                       a + b, a - b, a * b, a / b, pc.negate(a),
+                       pc.add(a, b), pc.subtract(a, b), pc.divide(a, b),
+                       pc.multiply(a, b), pc.power(a, a), pc.sqrt(a),
+                       pc.exp(b), pc.cos(b), pc.sin(b), pc.tan(b),
+                       pc.acos(b), pc.atan(b), pc.asin(b), pc.atan2(b, b),
+                       pc.abs(b), pc.sign(a), pc.bit_wise_not(a),
+                       pc.bit_wise_and(a, a), pc.bit_wise_or(a, a),
+                       pc.bit_wise_xor(a, a), pc.is_nan(b), pc.is_finite(b),
+                       pc.coalesce(a, b),
+                       a.cast(pa.int32(), safe=False)]
+
+    exprs_with_ref = [pc.field('i64') > 5, pc.field('i64') == 5,
+                      pc.field('i64') == 7,
+                      pc.field(('foo', 'bar')) == 'value',
+                      pc.field('foo', 'bar') == 'value']
+
+    special_cases = [
+        f,  # Struct literals lose their field names
+        a.isin([1, 2, 3]),  # isin converts to an or list
+        pc.field('i64').is_null()  # pyarrow always specifies a FunctionOptions
+        # for is_null which, being the default, is
+        # dropped on serialization
+    ]
 
-    all_exprs = [a, b, c, d, e, f, g, h, a == b, a > b, a & b, a | b, ~c,
-                 d.is_valid(), a.cast(pa.int32(), safe=False),
-                 a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]),
-                 pc.field('i64') > 5, pc.field('i64') == 5,
-                 pc.field('i64') == 7, pc.field('i64').is_null(),
-                 pc.field(('foo', 'bar')) == 'value',
-                 pc.field('foo', 'bar') == 'value']
+    all_exprs = literal_exprs + exprs_with_call + exprs_with_ref + 
special_cases
     for expr in all_exprs:
         assert isinstance(expr, pc.Expression)
         restored = pickle.loads(pickle.dumps(expr))
         assert expr.equals(restored)
 
+    if pas is not None:
+
+        test_schema = pa.schema([pa.field("i64", pa.int64()), pa.field(
+            "foo", pa.struct([pa.field("bar", pa.string())]))])
+
+        # Basic literals don't change on binding and so they will round
+        # trip without any change
+        for expr in literal_exprs:
+            serialized = expr.to_substrait(test_schema)
+            deserialized = pc.Expression.from_substrait(serialized)
+            assert expr.equals(deserialized)
+
+        # Expressions are bound when they get serialized.  Since bound
+        # expressions are not equal to their unbound variants we cannot
+        # compare the round tripped with the original
+        for expr in exprs_with_call:
+            serialized = expr.to_substrait(test_schema)
+            deserialized = pc.Expression.from_substrait(serialized)
+            # We can't compare the expressions themselves because of the bound
+            # unbound difference. But we can compare the string representation

Review Comment:
   This one I don't fully understand: this are expressions with calls but 
without any field reference, only with scalars which already have a type. So 
why is bound/unbound relevant in this case? (I would have expected that only be 
relevant for references to fields in the schema)



##########
python/pyarrow/tests/test_compute.py:
##########
@@ -3277,19 +3281,106 @@ def test_expression_serialization():
     f = pc.scalar({'a': 1})
     g = pc.scalar(pa.scalar(1))
     h = pc.scalar(np.int64(2))
+    j = pc.scalar(False)
+
+    literal_exprs = [a, b, c, d, e, g, h, j]
+
+    exprs_with_call = [a == b, a != b, a > b, c & j, c | j, ~c, d.is_valid(),
+                       a + b, a - b, a * b, a / b, pc.negate(a),
+                       pc.add(a, b), pc.subtract(a, b), pc.divide(a, b),
+                       pc.multiply(a, b), pc.power(a, a), pc.sqrt(a),
+                       pc.exp(b), pc.cos(b), pc.sin(b), pc.tan(b),
+                       pc.acos(b), pc.atan(b), pc.asin(b), pc.atan2(b, b),
+                       pc.abs(b), pc.sign(a), pc.bit_wise_not(a),
+                       pc.bit_wise_and(a, a), pc.bit_wise_or(a, a),
+                       pc.bit_wise_xor(a, a), pc.is_nan(b), pc.is_finite(b),
+                       pc.coalesce(a, b),
+                       a.cast(pa.int32(), safe=False)]
+
+    exprs_with_ref = [pc.field('i64') > 5, pc.field('i64') == 5,
+                      pc.field('i64') == 7,
+                      pc.field(('foo', 'bar')) == 'value',
+                      pc.field('foo', 'bar') == 'value']
+
+    special_cases = [
+        f,  # Struct literals lose their field names
+        a.isin([1, 2, 3]),  # isin converts to an or list
+        pc.field('i64').is_null()  # pyarrow always specifies a FunctionOptions
+        # for is_null which, being the default, is
+        # dropped on serialization
+    ]
 
-    all_exprs = [a, b, c, d, e, f, g, h, a == b, a > b, a & b, a | b, ~c,
-                 d.is_valid(), a.cast(pa.int32(), safe=False),
-                 a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]),
-                 pc.field('i64') > 5, pc.field('i64') == 5,
-                 pc.field('i64') == 7, pc.field('i64').is_null(),
-                 pc.field(('foo', 'bar')) == 'value',
-                 pc.field('foo', 'bar') == 'value']
+    all_exprs = literal_exprs + exprs_with_call + exprs_with_ref + 
special_cases
     for expr in all_exprs:
         assert isinstance(expr, pc.Expression)
         restored = pickle.loads(pickle.dumps(expr))
         assert expr.equals(restored)
 
+    if pas is not None:

Review Comment:
   It _might_ be nicer to split out the substrait-based serialization to a 
separate test, but then we need to factor out the expression creation into a 
helper function? (it's certainly OK to just leave as is)



##########
python/pyarrow/tests/test_compute.py:
##########
@@ -3277,19 +3281,106 @@ def test_expression_serialization():
     f = pc.scalar({'a': 1})
     g = pc.scalar(pa.scalar(1))
     h = pc.scalar(np.int64(2))
+    j = pc.scalar(False)
+
+    literal_exprs = [a, b, c, d, e, g, h, j]
+
+    exprs_with_call = [a == b, a != b, a > b, c & j, c | j, ~c, d.is_valid(),
+                       a + b, a - b, a * b, a / b, pc.negate(a),
+                       pc.add(a, b), pc.subtract(a, b), pc.divide(a, b),
+                       pc.multiply(a, b), pc.power(a, a), pc.sqrt(a),
+                       pc.exp(b), pc.cos(b), pc.sin(b), pc.tan(b),
+                       pc.acos(b), pc.atan(b), pc.asin(b), pc.atan2(b, b),
+                       pc.abs(b), pc.sign(a), pc.bit_wise_not(a),
+                       pc.bit_wise_and(a, a), pc.bit_wise_or(a, a),
+                       pc.bit_wise_xor(a, a), pc.is_nan(b), pc.is_finite(b),
+                       pc.coalesce(a, b),
+                       a.cast(pa.int32(), safe=False)]
+
+    exprs_with_ref = [pc.field('i64') > 5, pc.field('i64') == 5,
+                      pc.field('i64') == 7,
+                      pc.field(('foo', 'bar')) == 'value',
+                      pc.field('foo', 'bar') == 'value']
+
+    special_cases = [
+        f,  # Struct literals lose their field names
+        a.isin([1, 2, 3]),  # isin converts to an or list
+        pc.field('i64').is_null()  # pyarrow always specifies a FunctionOptions
+        # for is_null which, being the default, is
+        # dropped on serialization
+    ]
 
-    all_exprs = [a, b, c, d, e, f, g, h, a == b, a > b, a & b, a | b, ~c,
-                 d.is_valid(), a.cast(pa.int32(), safe=False),
-                 a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]),
-                 pc.field('i64') > 5, pc.field('i64') == 5,
-                 pc.field('i64') == 7, pc.field('i64').is_null(),
-                 pc.field(('foo', 'bar')) == 'value',
-                 pc.field('foo', 'bar') == 'value']
+    all_exprs = literal_exprs + exprs_with_call + exprs_with_ref + 
special_cases
     for expr in all_exprs:
         assert isinstance(expr, pc.Expression)
         restored = pickle.loads(pickle.dumps(expr))
         assert expr.equals(restored)
 
+    if pas is not None:
+
+        test_schema = pa.schema([pa.field("i64", pa.int64()), pa.field(
+            "foo", pa.struct([pa.field("bar", pa.string())]))])
+
+        # Basic literals don't change on binding and so they will round
+        # trip without any change
+        for expr in literal_exprs:
+            serialized = expr.to_substrait(test_schema)
+            deserialized = pc.Expression.from_substrait(serialized)
+            assert expr.equals(deserialized)
+
+        # Expressions are bound when they get serialized.  Since bound
+        # expressions are not equal to their unbound variants we cannot
+        # compare the round tripped with the original
+        for expr in exprs_with_call:
+            serialized = expr.to_substrait(test_schema)
+            deserialized = pc.Expression.from_substrait(serialized)
+            # We can't compare the expressions themselves because of the bound
+            # unbound difference. But we can compare the string representation
+            assert str(deserialized) == str(expr)
+            serialized_again = deserialized.to_substrait(test_schema)
+            deserialized_again = pc.Expression.from_substrait(serialized_again)
+            assert deserialized.equals(deserialized_again)
+
+        # Expressions that have references will be normalized, on 
serialization,
+        # to numeric referneces
+        exprs_with_ref_norm = [pc.field(0) > 5, pc.field(0) == 5,
+                               pc.field(0) == 7,
+                               pc.field((1, 0)) == 'value',
+                               pc.field(1, 0) == 'value']
+
+        for expr, expr_norm in zip(exprs_with_ref, exprs_with_ref_norm):
+            serialized = expr.to_substrait(test_schema)
+            deserialized = pc.Expression.from_substrait(serialized)
+            assert str(deserialized) == str(expr_norm)
+            serialized_again = deserialized.to_substrait(test_schema)
+            deserialized_again = pc.Expression.from_substrait(serialized_again)
+            assert deserialized.equals(deserialized_again)
+
+        # For the special cases we get various wrinkles in serialization but we
+        # should always get the same thing from round tripping twice
+        for expr in special_cases:
+            serialized = expr.to_substrait(test_schema)
+            deserialized = pc.Expression.from_substrait(serialized)
+            serialized_again = deserialized.to_substrait(test_schema)
+            deserialized_again = pc.Expression.from_substrait(serialized_again)
+            assert deserialized.equals(deserialized_again)
+
+        # Special case, we lose the field names of struct literals
+        serialized = f.to_substrait(test_schema)
+        deserialized = pc.Expression.from_substrait(serialized)
+        assert deserialized.equals(pc.scalar({'': 1}))

Review Comment:
   Should we use something else than an empty string? (eg "field0", "field1", 
..) 
   (although this is probably enough of a corner case to not put too much 
effort in it)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to