jorisvandenbossche commented on a change in pull request #7026:
URL: https://github.com/apache/arrow/pull/7026#discussion_r416567863



##########
File path: python/pyarrow/_dataset.pyx
##########
@@ -41,6 +42,167 @@ def _forbid_instantiation(klass, subclasses_instead=True):
     raise TypeError(msg)
 
 
+cdef class Expression:
+
+    cdef:
+        shared_ptr[CExpression] wrapped
+        CExpression* expr
+
+    def __init__(self, Buffer buffer=None):
+        if buffer is not None:
+            c_buffer = pyarrow_unwrap_buffer(buffer)
+            expr = GetResultValue(CExpression.Deserialize(deref(c_buffer)))
+            self.init(expr)

Review comment:
       Maybe we can put this in a `Expression._deserialize` or so, to avoid 
putting this in the main constructor ? 
   (and keep this one with using `_forbid_instantiation`)

##########
File path: python/pyarrow/_dataset.pyx
##########
@@ -41,6 +42,167 @@ def _forbid_instantiation(klass, subclasses_instead=True):
     raise TypeError(msg)
 
 
+cdef class Expression:
+
+    cdef:
+        shared_ptr[CExpression] wrapped
+        CExpression* expr
+
+    def __init__(self, Buffer buffer=None):
+        if buffer is not None:
+            c_buffer = pyarrow_unwrap_buffer(buffer)
+            expr = GetResultValue(CExpression.Deserialize(deref(c_buffer)))
+            self.init(expr)

Review comment:
       Another good reason is that right now calling this without `buffer` 
results in an expression that segfaults in certain conditions (like printing 
the repr)

##########
File path: python/pyarrow/_dataset.pyx
##########
@@ -41,6 +42,167 @@ def _forbid_instantiation(klass, subclasses_instead=True):
     raise TypeError(msg)
 
 
+cdef class Expression:
+
+    cdef:
+        shared_ptr[CExpression] wrapped
+        CExpression* expr
+
+    def __init__(self, Buffer buffer=None):
+        if buffer is not None:
+            c_buffer = pyarrow_unwrap_buffer(buffer)
+            expr = GetResultValue(CExpression.Deserialize(deref(c_buffer)))
+            self.init(expr)
+
+    cdef void init(self, const shared_ptr[CExpression]& sp):
+        self.wrapped = sp
+        self.expr = sp.get()
+
+    @staticmethod
+    cdef wrap(const shared_ptr[CExpression]& sp):
+        self = Expression()
+        self.init(sp)
+        return self
+
+    cdef inline shared_ptr[CExpression] unwrap(self):
+        return self.wrapped
+
+    def equals(self, Expression other):
+        return self.expr.Equals(other.unwrap())
+
+    def __str__(self):
+        return frombytes(self.expr.ToString())
+
+    def __repr__(self):
+        return "<pyarrow.dataset.{0} {1}>".format(
+            self.__class__.__name__, str(self)
+        )
+
+    def __reduce__(self):
+        buffer = pyarrow_wrap_buffer(GetResultValue(self.expr.Serialize()))
+        return Expression, (buffer,)
+
+    def validate(self, Schema schema not None):
+        """Validate this expression for execution against a schema.
+
+        This will check that all reference fields are present (fields not in
+        the schema will be replaced with null) and all subexpressions are
+        executable. Returns the type to which this expression will evaluate.
+
+        Parameters
+        ----------
+        schema : Schema
+            Schema to execute the expression on.
+
+        Returns
+        -------
+        type : DataType
+        """
+        cdef:
+            shared_ptr[CSchema] sp_schema
+            CResult[shared_ptr[CDataType]] result
+        sp_schema = pyarrow_unwrap_schema(schema)
+        result = self.expr.Validate(deref(sp_schema))
+        return pyarrow_wrap_data_type(GetResultValue(result))
+
+    def assume(self, Expression given):
+        """Simplify to an equivalent Expression given assumed constraints."""
+        return Expression.wrap(self.expr.Assume(given.unwrap()))
+
+    def __invert__(self):
+        return Expression.wrap(CMakeNotExpression(self.unwrap()))
+
+    @staticmethod
+    cdef shared_ptr[CExpression] _expr_or_scalar(object expr) except *:
+        if isinstance(expr, Expression):
+            return (<Expression> expr).unwrap()
+        return (<Expression> Expression.scalar(expr)).unwrap()
+
+    @staticmethod
+    def wtf():
+        return Expression.wrap(Expression._expr_or_scalar([]))

Review comment:
       Leftover?

##########
File path: python/pyarrow/_dataset.pyx
##########
@@ -41,6 +42,167 @@ def _forbid_instantiation(klass, subclasses_instead=True):
     raise TypeError(msg)
 
 
+cdef class Expression:
+
+    cdef:
+        shared_ptr[CExpression] wrapped
+        CExpression* expr
+
+    def __init__(self, Buffer buffer=None):
+        if buffer is not None:
+            c_buffer = pyarrow_unwrap_buffer(buffer)
+            expr = GetResultValue(CExpression.Deserialize(deref(c_buffer)))
+            self.init(expr)
+
+    cdef void init(self, const shared_ptr[CExpression]& sp):
+        self.wrapped = sp
+        self.expr = sp.get()
+
+    @staticmethod
+    cdef wrap(const shared_ptr[CExpression]& sp):
+        self = Expression()

Review comment:
       ```suggestion
           cdef Expression self = Expression.__new__(Expression)
   ```
   
   (in case the init gets forbidden)

##########
File path: python/pyarrow/_dataset.pyx
##########
@@ -118,21 +280,15 @@ cdef class Dataset:
         -------
         fragments : iterator of Fragment
         """
-        cdef:
-            CFragmentIterator iterator
-            shared_ptr[CFragment] fragment
+        cdef CFragmentIterator c_fragments
 
         if filter is None or filter.expr == nullptr:
-            iterator = self.dataset.GetFragments()
+            c_fragments = self.dataset.GetFragments()
         else:
-            iterator = self.dataset.GetFragments(filter.unwrap())
+            c_fragments = self.dataset.GetFragments(filter.unwrap())
 
-        while True:
-            fragment = GetResultValue(iterator.Next())
-            if fragment.get() == nullptr:
-                raise StopIteration()
-            else:
-                yield Fragment.wrap(fragment)
+        for maybe_fragment in c_fragments:

Review comment:
       Why "maybe" fragment? (all elements should be fragments, no?)

##########
File path: python/pyarrow/_dataset.pyx
##########
@@ -41,6 +42,167 @@ def _forbid_instantiation(klass, subclasses_instead=True):
     raise TypeError(msg)
 
 
+cdef class Expression:
+
+    cdef:
+        shared_ptr[CExpression] wrapped
+        CExpression* expr
+
+    def __init__(self, Buffer buffer=None):
+        if buffer is not None:
+            c_buffer = pyarrow_unwrap_buffer(buffer)
+            expr = GetResultValue(CExpression.Deserialize(deref(c_buffer)))
+            self.init(expr)
+
+    cdef void init(self, const shared_ptr[CExpression]& sp):
+        self.wrapped = sp
+        self.expr = sp.get()
+
+    @staticmethod
+    cdef wrap(const shared_ptr[CExpression]& sp):
+        self = Expression()
+        self.init(sp)
+        return self
+
+    cdef inline shared_ptr[CExpression] unwrap(self):
+        return self.wrapped
+
+    def equals(self, Expression other):
+        return self.expr.Equals(other.unwrap())
+
+    def __str__(self):
+        return frombytes(self.expr.ToString())
+
+    def __repr__(self):
+        return "<pyarrow.dataset.{0} {1}>".format(
+            self.__class__.__name__, str(self)
+        )
+
+    def __reduce__(self):
+        buffer = pyarrow_wrap_buffer(GetResultValue(self.expr.Serialize()))
+        return Expression, (buffer,)
+
+    def validate(self, Schema schema not None):
+        """Validate this expression for execution against a schema.
+
+        This will check that all reference fields are present (fields not in
+        the schema will be replaced with null) and all subexpressions are
+        executable. Returns the type to which this expression will evaluate.
+
+        Parameters
+        ----------
+        schema : Schema
+            Schema to execute the expression on.
+
+        Returns
+        -------
+        type : DataType
+        """
+        cdef:
+            shared_ptr[CSchema] sp_schema
+            CResult[shared_ptr[CDataType]] result
+        sp_schema = pyarrow_unwrap_schema(schema)
+        result = self.expr.Validate(deref(sp_schema))
+        return pyarrow_wrap_data_type(GetResultValue(result))
+
+    def assume(self, Expression given):
+        """Simplify to an equivalent Expression given assumed constraints."""
+        return Expression.wrap(self.expr.Assume(given.unwrap()))
+
+    def __invert__(self):
+        return Expression.wrap(CMakeNotExpression(self.unwrap()))
+
+    @staticmethod
+    cdef shared_ptr[CExpression] _expr_or_scalar(object expr) except *:
+        if isinstance(expr, Expression):
+            return (<Expression> expr).unwrap()
+        return (<Expression> Expression.scalar(expr)).unwrap()
+
+    @staticmethod
+    def wtf():
+        return Expression.wrap(Expression._expr_or_scalar([]))
+
+    def __richcmp__(self, other, int op):
+        cdef:
+            shared_ptr[CExpression] c_expr
+            shared_ptr[CExpression] c_left
+            shared_ptr[CExpression] c_right
+
+        c_left = self.unwrap()
+        c_right = Expression._expr_or_scalar(other)
+
+        if op == Py_EQ:
+            c_expr = CMakeEqualExpression(move(c_left), move(c_right))
+        elif op == Py_NE:
+            c_expr = CMakeNotEqualExpression(move(c_left), move(c_right))
+        elif op == Py_GT:
+            c_expr = CMakeGreaterExpression(move(c_left), move(c_right))
+        elif op == Py_GE:
+            c_expr = CMakeGreaterEqualExpression(move(c_left), move(c_right))
+        elif op == Py_LT:
+            c_expr = CMakeLessExpression(move(c_left), move(c_right))
+        elif op == Py_LE:
+            c_expr = CMakeLessEqualExpression(move(c_left), move(c_right))
+
+        return Expression.wrap(c_expr)
+
+    def __and__(Expression self, other):
+        c_other = Expression._expr_or_scalar(other)
+        return Expression.wrap(CMakeAndExpression(self.wrapped,
+                                                  move(c_other)))
+
+    def __or__(Expression self, other):
+        c_other = Expression._expr_or_scalar(other)
+        return Expression.wrap(CMakeOrExpression(self.wrapped,
+                                                 move(c_other)))
+
+    def is_valid(self):
+        """Checks whether the expression is not-null (valid)"""
+        return Expression.wrap(self.expr.IsValid().Copy())
+
+    def cast(self, type, bint safe=True):
+        """Explicitly change the expression's data type"""
+        cdef CastOptions options
+        options = CastOptions.safe() if safe else CastOptions.unsafe()
+        c_type = pyarrow_unwrap_data_type(ensure_type(type))
+        return Expression.wrap(self.expr.CastTo(c_type,
+                                                options.unwrap()).Copy())
+
+    def isin(self, values):
+        """Checks whether the expression is contained in values"""
+        if not isinstance(values, pa.Array):
+            values = pa.array(values)
+        c_values = pyarrow_unwrap_array(values)
+        return Expression.wrap(self.expr.In(c_values).Copy())
+
+    @staticmethod
+    def field(str name not None):
+        return Expression.wrap(CMakeFieldExpression(tobytes(name)))
+
+    @staticmethod
+    def scalar(value):

Review comment:
       ```suggestion
       def _field(str name not None):
           return Expression.wrap(CMakeFieldExpression(tobytes(name)))
   
       @staticmethod
       def _scalar(value):
   ```
   
   To avoid public API exposure (that we might want/need to preserve later 
one), I would make those private. For the end-user, there are already the 
`ds.field(..)` and `ds.scalar(..)` functions

##########
File path: python/pyarrow/parquet.py
##########
@@ -154,27 +154,19 @@ def convert_single_predicate(col, op, val):
                 '"{0}" is not a valid operator in predicates.'.format(
                     (col, op, val)))
 
-    or_exprs = []
+    disjunction_members = []
 
     for conjunction in filters:
-        and_exprs = []
-        for col, op, val in conjunction:
-            and_exprs.append(convert_single_predicate(col, op, val))
+        conjunction_members = [
+            convert_single_predicate(col, op, val)
+            for col, op, val in conjunction
+        ]
 
-        expr = and_exprs[0]
-        if len(and_exprs) > 1:
-            for and_expr in and_exprs[1:]:
-                expr = ds.AndExpression(expr, and_expr)
+        conjunction = reduce(lambda acc, one: acc & one, conjunction_members)

Review comment:
       Nice :)
   
   Further change could be to use `operator.and_` / `or_` instead of `lambda 
acc, one: acc & one`

##########
File path: python/pyarrow/_dataset.pyx
##########
@@ -27,6 +27,7 @@ from pyarrow.lib cimport *
 from pyarrow.includes.libarrow_dataset cimport *
 from pyarrow.compat import frombytes, tobytes
 from pyarrow._fs cimport FileSystem, FileInfo, FileSelector
+import functools

Review comment:
       I think this import is not used anymore in this file now?

##########
File path: python/pyarrow/tests/test_dataset.py
##########
@@ -218,33 +216,21 @@ def test_filesystem_dataset(mockfs):
 
     # validation of required arguments
     with pytest.raises(TypeError, match="incorrect type"):
-        ds.FileSystemDataset(paths, format=file_format, filesystem=mockfs)
+        ds.FileSystemDataset(paths, schema=None, format=file_format,
+                             filesystem=mockfs)

Review comment:
       Can you leave those as they were as well? The goal here was to test that 
you get a decent error message when leaving out one of the required arguments 
(maybe should have added a better comment about that), which was the reason I 
used the manual checking in FileSystemDataset constructor instead of using 
actual required positional arguments.

##########
File path: python/pyarrow/tests/test_dataset.py
##########
@@ -373,141 +357,70 @@ def test_partitioning():
     assert expr.equals(expected)
 
 
-def test_expression():
-    a = ds.ScalarExpression(1)
-    b = ds.ScalarExpression(1.1)
-    c = ds.ScalarExpression(True)
-    d = ds.ScalarExpression("string")
-    e = ds.ScalarExpression(None)
-
-    equal = ds.ComparisonExpression(ds.CompareOperator.Equal, a, b)
-    greater = a > b
-    assert equal.op == ds.CompareOperator.Equal
-
-    and_ = ds.AndExpression(a, b)
-    assert and_.left_operand.equals(a)
-    assert and_.right_operand.equals(b)
-    assert and_.equals(ds.AndExpression(a, b))
-    assert and_.equals(and_)
-
-    or_ = ds.OrExpression(a, b)
-    not_ = ds.NotExpression(ds.OrExpression(a, b))
-    is_valid = ds.IsValidExpression(a)
-    cast_safe = ds.CastExpression(a, pa.int32())
-    cast_unsafe = ds.CastExpression(a, pa.int32(), safe=False)
-    in_ = ds.InExpression(a, pa.array([1, 2, 3]))
-
-    assert is_valid.operand == a
-    assert in_.set_.equals(pa.array([1, 2, 3]))
-    assert cast_unsafe.to == pa.int32()
-    assert cast_unsafe.safe is False
-    assert cast_safe.safe is True
-
-    condition = ds.ComparisonExpression(
-        ds.CompareOperator.Greater,
-        ds.FieldExpression('i64'),
-        ds.ScalarExpression(5)
-    )
+def test_expression_serialization():
+    a = ds.Expression.scalar(1)
+    b = ds.Expression.scalar(1.1)
+    c = ds.Expression.scalar(True)
+    d = ds.Expression.scalar("string")
+    e = ds.Expression.scalar(None)

Review comment:
       ```suggestion
       a = ds.scalar(1)
       b = ds.scalar(1.1)
       c = ds.scalar(True)
       d = ds.scalar("string")
       e = ds.scalar(None)
   ```

##########
File path: python/pyarrow/tests/test_dataset.py
##########
@@ -198,11 +197,10 @@ def test_filesystem_dataset(mockfs):
     file_format = ds.ParquetFileFormat()
 
     paths = ['subdir/1/xxx/file0.parquet', 'subdir/2/yyy/file1.parquet']
-    partitions = [ds.ScalarExpression(True), ds.ScalarExpression(True)]
+    partitions = [ds.Expression.scalar(True), ds.Expression.scalar(True)]

Review comment:
       ```suggestion
       partitions = [ds.scalar(True), ds.scalar(True)]
   ```

##########
File path: python/pyarrow/tests/test_dataset.py
##########
@@ -373,141 +357,70 @@ def test_partitioning():
     assert expr.equals(expected)
 
 
-def test_expression():
-    a = ds.ScalarExpression(1)
-    b = ds.ScalarExpression(1.1)
-    c = ds.ScalarExpression(True)
-    d = ds.ScalarExpression("string")
-    e = ds.ScalarExpression(None)
-
-    equal = ds.ComparisonExpression(ds.CompareOperator.Equal, a, b)
-    greater = a > b
-    assert equal.op == ds.CompareOperator.Equal
-
-    and_ = ds.AndExpression(a, b)
-    assert and_.left_operand.equals(a)
-    assert and_.right_operand.equals(b)
-    assert and_.equals(ds.AndExpression(a, b))
-    assert and_.equals(and_)
-
-    or_ = ds.OrExpression(a, b)
-    not_ = ds.NotExpression(ds.OrExpression(a, b))
-    is_valid = ds.IsValidExpression(a)
-    cast_safe = ds.CastExpression(a, pa.int32())
-    cast_unsafe = ds.CastExpression(a, pa.int32(), safe=False)
-    in_ = ds.InExpression(a, pa.array([1, 2, 3]))
-
-    assert is_valid.operand == a
-    assert in_.set_.equals(pa.array([1, 2, 3]))
-    assert cast_unsafe.to == pa.int32()
-    assert cast_unsafe.safe is False
-    assert cast_safe.safe is True
-
-    condition = ds.ComparisonExpression(
-        ds.CompareOperator.Greater,
-        ds.FieldExpression('i64'),
-        ds.ScalarExpression(5)
-    )
+def test_expression_serialization():
+    a = ds.Expression.scalar(1)
+    b = ds.Expression.scalar(1.1)
+    c = ds.Expression.scalar(True)
+    d = ds.Expression.scalar("string")
+    e = ds.Expression.scalar(None)
+
+    condition = ds.field('i64') > 5
     schema = pa.schema([
         pa.field('i64', pa.int64()),
         pa.field('f64', pa.float64())
     ])
     assert condition.validate(schema) == pa.bool_()
 
-    i64_is_5 = ds.ComparisonExpression(
-        ds.CompareOperator.Equal,
-        ds.FieldExpression('i64'),
-        ds.ScalarExpression(5)
-    )
-    i64_is_7 = ds.ComparisonExpression(
-        ds.CompareOperator.Equal,
-        ds.FieldExpression('i64'),
-        ds.ScalarExpression(7)
-    )
-    assert condition.assume(i64_is_5).equals(ds.ScalarExpression(False))
-    assert condition.assume(i64_is_7).equals(ds.ScalarExpression(True))
-    assert str(condition) == "(i64 > 5:int64)"
-    assert "(i64 > 5:int64)" in repr(condition)
+    assert condition.assume(ds.field('i64') == 5).equals(
+        ds.Expression.scalar(False))
 
-    all_exprs = [a, b, c, d, e, equal, greater, and_, or_, not_, is_valid,
-                 cast_unsafe, cast_safe, in_, condition, i64_is_5, i64_is_7]
+    assert condition.assume(ds.field('i64') == 7).equals(
+        ds.Expression.scalar(True))

Review comment:
       ```suggestion
           ds.scalar(True))
   ```

##########
File path: python/pyarrow/tests/test_dataset.py
##########
@@ -218,33 +216,21 @@ def test_filesystem_dataset(mockfs):
 
     # validation of required arguments
     with pytest.raises(TypeError, match="incorrect type"):
-        ds.FileSystemDataset(paths, format=file_format, filesystem=mockfs)
+        ds.FileSystemDataset(paths, schema=None, format=file_format,
+                             filesystem=mockfs)
     with pytest.raises(TypeError, match="incorrect type"):
-        ds.FileSystemDataset(paths, schema=schema, filesystem=mockfs)
+        ds.FileSystemDataset(paths, schema=schema, format=None,
+                             filesystem=mockfs)
     with pytest.raises(TypeError, match="incorrect type"):
-        ds.FileSystemDataset(paths, schema=schema, format=file_format)
+        ds.FileSystemDataset(paths, schema=schema, format=file_format,
+                             filesystem=None)
     # validation of root_partition
     with pytest.raises(TypeError, match="incorrect type"):
         ds.FileSystemDataset(paths, schema=schema, format=file_format,
                              filesystem=mockfs, root_partition=1)
 
-    root_partition = ds.ComparisonExpression(
-        ds.CompareOperator.Equal,
-        ds.FieldExpression('level'),
-        ds.ScalarExpression(1337)
-    )
-    partitions = [
-        ds.ComparisonExpression(
-            ds.CompareOperator.Equal,
-            ds.FieldExpression('part'),
-            ds.ScalarExpression(1)
-        ),
-        ds.ComparisonExpression(
-            ds.CompareOperator.Equal,
-            ds.FieldExpression('part'),
-            ds.ScalarExpression(2)
-        )
-    ]
+    root_partition = ds.field('level') == ds.Expression.scalar(1337)

Review comment:
       ```suggestion
       root_partition = ds.field('level') == ds.scalar(1337)
   ```

##########
File path: python/pyarrow/tests/test_dataset.py
##########
@@ -373,141 +357,70 @@ def test_partitioning():
     assert expr.equals(expected)
 
 
-def test_expression():
-    a = ds.ScalarExpression(1)
-    b = ds.ScalarExpression(1.1)
-    c = ds.ScalarExpression(True)
-    d = ds.ScalarExpression("string")
-    e = ds.ScalarExpression(None)
-
-    equal = ds.ComparisonExpression(ds.CompareOperator.Equal, a, b)
-    greater = a > b
-    assert equal.op == ds.CompareOperator.Equal
-
-    and_ = ds.AndExpression(a, b)
-    assert and_.left_operand.equals(a)
-    assert and_.right_operand.equals(b)
-    assert and_.equals(ds.AndExpression(a, b))
-    assert and_.equals(and_)
-
-    or_ = ds.OrExpression(a, b)
-    not_ = ds.NotExpression(ds.OrExpression(a, b))
-    is_valid = ds.IsValidExpression(a)
-    cast_safe = ds.CastExpression(a, pa.int32())
-    cast_unsafe = ds.CastExpression(a, pa.int32(), safe=False)
-    in_ = ds.InExpression(a, pa.array([1, 2, 3]))
-
-    assert is_valid.operand == a
-    assert in_.set_.equals(pa.array([1, 2, 3]))
-    assert cast_unsafe.to == pa.int32()
-    assert cast_unsafe.safe is False
-    assert cast_safe.safe is True
-
-    condition = ds.ComparisonExpression(
-        ds.CompareOperator.Greater,
-        ds.FieldExpression('i64'),
-        ds.ScalarExpression(5)
-    )
+def test_expression_serialization():
+    a = ds.Expression.scalar(1)
+    b = ds.Expression.scalar(1.1)
+    c = ds.Expression.scalar(True)
+    d = ds.Expression.scalar("string")
+    e = ds.Expression.scalar(None)
+
+    condition = ds.field('i64') > 5
     schema = pa.schema([
         pa.field('i64', pa.int64()),
         pa.field('f64', pa.float64())
     ])
     assert condition.validate(schema) == pa.bool_()
 
-    i64_is_5 = ds.ComparisonExpression(
-        ds.CompareOperator.Equal,
-        ds.FieldExpression('i64'),
-        ds.ScalarExpression(5)
-    )
-    i64_is_7 = ds.ComparisonExpression(
-        ds.CompareOperator.Equal,
-        ds.FieldExpression('i64'),
-        ds.ScalarExpression(7)
-    )
-    assert condition.assume(i64_is_5).equals(ds.ScalarExpression(False))
-    assert condition.assume(i64_is_7).equals(ds.ScalarExpression(True))
-    assert str(condition) == "(i64 > 5:int64)"
-    assert "(i64 > 5:int64)" in repr(condition)
+    assert condition.assume(ds.field('i64') == 5).equals(
+        ds.Expression.scalar(False))

Review comment:
       ```suggestion
           ds.scalar(False))
   ```

##########
File path: python/pyarrow/tests/test_dataset.py
##########
@@ -373,141 +357,70 @@ def test_partitioning():
     assert expr.equals(expected)
 
 
-def test_expression():
-    a = ds.ScalarExpression(1)
-    b = ds.ScalarExpression(1.1)
-    c = ds.ScalarExpression(True)
-    d = ds.ScalarExpression("string")
-    e = ds.ScalarExpression(None)
-
-    equal = ds.ComparisonExpression(ds.CompareOperator.Equal, a, b)
-    greater = a > b
-    assert equal.op == ds.CompareOperator.Equal
-
-    and_ = ds.AndExpression(a, b)
-    assert and_.left_operand.equals(a)
-    assert and_.right_operand.equals(b)
-    assert and_.equals(ds.AndExpression(a, b))
-    assert and_.equals(and_)
-
-    or_ = ds.OrExpression(a, b)
-    not_ = ds.NotExpression(ds.OrExpression(a, b))
-    is_valid = ds.IsValidExpression(a)
-    cast_safe = ds.CastExpression(a, pa.int32())
-    cast_unsafe = ds.CastExpression(a, pa.int32(), safe=False)
-    in_ = ds.InExpression(a, pa.array([1, 2, 3]))
-
-    assert is_valid.operand == a
-    assert in_.set_.equals(pa.array([1, 2, 3]))
-    assert cast_unsafe.to == pa.int32()
-    assert cast_unsafe.safe is False
-    assert cast_safe.safe is True
-
-    condition = ds.ComparisonExpression(
-        ds.CompareOperator.Greater,
-        ds.FieldExpression('i64'),
-        ds.ScalarExpression(5)
-    )
+def test_expression_serialization():
+    a = ds.Expression.scalar(1)
+    b = ds.Expression.scalar(1.1)
+    c = ds.Expression.scalar(True)
+    d = ds.Expression.scalar("string")
+    e = ds.Expression.scalar(None)
+
+    condition = ds.field('i64') > 5
     schema = pa.schema([
         pa.field('i64', pa.int64()),
         pa.field('f64', pa.float64())
     ])
     assert condition.validate(schema) == pa.bool_()
 
-    i64_is_5 = ds.ComparisonExpression(
-        ds.CompareOperator.Equal,
-        ds.FieldExpression('i64'),
-        ds.ScalarExpression(5)
-    )
-    i64_is_7 = ds.ComparisonExpression(
-        ds.CompareOperator.Equal,
-        ds.FieldExpression('i64'),
-        ds.ScalarExpression(7)
-    )
-    assert condition.assume(i64_is_5).equals(ds.ScalarExpression(False))
-    assert condition.assume(i64_is_7).equals(ds.ScalarExpression(True))
-    assert str(condition) == "(i64 > 5:int64)"
-    assert "(i64 > 5:int64)" in repr(condition)
+    assert condition.assume(ds.field('i64') == 5).equals(
+        ds.Expression.scalar(False))
 
-    all_exprs = [a, b, c, d, e, equal, greater, and_, or_, not_, is_valid,
-                 cast_unsafe, cast_safe, in_, condition, i64_is_5, i64_is_7]
+    assert condition.assume(ds.field('i64') == 7).equals(
+        ds.Expression.scalar(True))
+
+    all_exprs = [a, b, c, d, e, a == b, a > b, a & b, a | b, ~c,
+                 d.is_valid(), a.cast(pa.int32(), safe=False),
+                 a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]),
+                 ds.field('i64') > 5, ds.field('i64') == 5,
+                 ds.field('i64') == 7]
     for expr in all_exprs:
+        print(str(expr))

Review comment:
       leftover

##########
File path: python/pyarrow/tests/test_dataset.py
##########
@@ -592,7 +505,7 @@ def test_filesystem_factory(mockfs, paths_or_selector):
     assert isinstance(factory.inspect_schemas(), list)
     assert isinstance(factory.finish(inspected_schema),
                       ds.FileSystemDataset)
-    assert factory.root_partition.equals(ds.ScalarExpression(True))
+    assert factory.root_partition.equals(ds.Expression.scalar(True))

Review comment:
       ```suggestion
       assert factory.root_partition.equals(ds.scalar(True))
   ```

##########
File path: python/pyarrow/tests/test_dataset.py
##########
@@ -373,141 +357,70 @@ def test_partitioning():
     assert expr.equals(expected)
 
 
-def test_expression():
-    a = ds.ScalarExpression(1)
-    b = ds.ScalarExpression(1.1)
-    c = ds.ScalarExpression(True)
-    d = ds.ScalarExpression("string")
-    e = ds.ScalarExpression(None)
-
-    equal = ds.ComparisonExpression(ds.CompareOperator.Equal, a, b)
-    greater = a > b
-    assert equal.op == ds.CompareOperator.Equal
-
-    and_ = ds.AndExpression(a, b)
-    assert and_.left_operand.equals(a)
-    assert and_.right_operand.equals(b)
-    assert and_.equals(ds.AndExpression(a, b))
-    assert and_.equals(and_)
-
-    or_ = ds.OrExpression(a, b)
-    not_ = ds.NotExpression(ds.OrExpression(a, b))
-    is_valid = ds.IsValidExpression(a)
-    cast_safe = ds.CastExpression(a, pa.int32())
-    cast_unsafe = ds.CastExpression(a, pa.int32(), safe=False)
-    in_ = ds.InExpression(a, pa.array([1, 2, 3]))
-
-    assert is_valid.operand == a
-    assert in_.set_.equals(pa.array([1, 2, 3]))
-    assert cast_unsafe.to == pa.int32()
-    assert cast_unsafe.safe is False
-    assert cast_safe.safe is True
-
-    condition = ds.ComparisonExpression(
-        ds.CompareOperator.Greater,
-        ds.FieldExpression('i64'),
-        ds.ScalarExpression(5)
-    )
+def test_expression_serialization():
+    a = ds.Expression.scalar(1)
+    b = ds.Expression.scalar(1.1)
+    c = ds.Expression.scalar(True)
+    d = ds.Expression.scalar("string")
+    e = ds.Expression.scalar(None)
+
+    condition = ds.field('i64') > 5
     schema = pa.schema([
         pa.field('i64', pa.int64()),
         pa.field('f64', pa.float64())
     ])
     assert condition.validate(schema) == pa.bool_()
 
-    i64_is_5 = ds.ComparisonExpression(
-        ds.CompareOperator.Equal,
-        ds.FieldExpression('i64'),
-        ds.ScalarExpression(5)
-    )
-    i64_is_7 = ds.ComparisonExpression(
-        ds.CompareOperator.Equal,
-        ds.FieldExpression('i64'),
-        ds.ScalarExpression(7)
-    )
-    assert condition.assume(i64_is_5).equals(ds.ScalarExpression(False))
-    assert condition.assume(i64_is_7).equals(ds.ScalarExpression(True))
-    assert str(condition) == "(i64 > 5:int64)"
-    assert "(i64 > 5:int64)" in repr(condition)
+    assert condition.assume(ds.field('i64') == 5).equals(
+        ds.Expression.scalar(False))
 
-    all_exprs = [a, b, c, d, e, equal, greater, and_, or_, not_, is_valid,
-                 cast_unsafe, cast_safe, in_, condition, i64_is_5, i64_is_7]
+    assert condition.assume(ds.field('i64') == 7).equals(
+        ds.Expression.scalar(True))
+
+    all_exprs = [a, b, c, d, e, a == b, a > b, a & b, a | b, ~c,
+                 d.is_valid(), a.cast(pa.int32(), safe=False),
+                 a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]),
+                 ds.field('i64') > 5, ds.field('i64') == 5,
+                 ds.field('i64') == 7]
     for expr in all_exprs:
+        print(str(expr))
+        assert isinstance(expr, ds.Expression)
         restored = pickle.loads(pickle.dumps(expr))
         assert expr.equals(restored)
 
 
-def test_expression_ergonomics():
+def test_expression_construction():
     zero = ds.scalar(0)
     one = ds.scalar(1)
     true = ds.scalar(True)
     false = ds.scalar(False)
     string = ds.scalar("string")
     field = ds.field("field")
 
-    assert one.equals(ds.ScalarExpression(1))
-    assert zero.equals(ds.ScalarExpression(0))
-    assert true.equals(ds.ScalarExpression(True))
-    assert false.equals(ds.ScalarExpression(False))
-    assert string.equals(ds.ScalarExpression("string"))
-    assert field.equals(ds.FieldExpression("field"))
-
-    expected = ds.AndExpression(ds.ScalarExpression(1), ds.ScalarExpression(0))
-    for expr in [one & zero, 1 & zero, one & 0]:
-        assert expr.equals(expected)
-
-    expected = ds.OrExpression(ds.ScalarExpression(1), ds.ScalarExpression(0))
-    for expr in [one | zero, 1 | zero, one | 0]:
-        assert expr.equals(expected)
-
-    comparison_ops = [
-        (operator.eq, ds.CompareOperator.Equal),
-        (operator.ne, ds.CompareOperator.NotEqual),
-        (operator.ge, ds.CompareOperator.GreaterEqual),
-        (operator.le, ds.CompareOperator.LessEqual),
-        (operator.lt, ds.CompareOperator.Less),
-        (operator.gt, ds.CompareOperator.Greater),
-    ]
-    for op, compare_op in comparison_ops:
-        expr = op(zero, one)
-        expected = ds.ComparisonExpression(compare_op, zero, one)
-        assert expr.equals(expected)
-
+    expr = zero | one == string
     expr = ~true == false
-    expected = ds.ComparisonExpression(
-        ds.CompareOperator.Equal,
-        ds.NotExpression(ds.ScalarExpression(True)),
-        ds.ScalarExpression(False)
-    )
-    assert expr.equals(expected)
-
     for typ in ("bool", pa.bool_()):
         expr = field.cast(typ) == true
-        expected = ds.ComparisonExpression(
-            ds.CompareOperator.Equal,
-            ds.CastExpression(ds.FieldExpression("field"), pa.bool_()),
-            ds.ScalarExpression(True)
-        )
-        assert expr.equals(expected)
 
     expr = field.isin([1, 2])
-    expected = ds.InExpression(ds.FieldExpression("field"), pa.array([1, 2]))
-    assert expr.equals(expected)
 
     with pytest.raises(TypeError):
-        field.isin(1)
+        expr = field.isin(1)
 
     # operations with non-scalar values
     with pytest.raises(TypeError):
-        field == [1]
+        expr = field == [1]
 
     with pytest.raises(TypeError):
-        field != {1}
+        expr = field != {1}
 
     with pytest.raises(TypeError):
-        field & [1]
+        expr = field & [1]
 
     with pytest.raises(TypeError):
-        field | [1]
+        expr = field | [1]
+
+    assert expr is not None  # silence flake8

Review comment:
       if you don't assign them to `expr` in the lines above (how it was 
before), you shouldn't need this to silence flake8. 
   Or did you get another linter error when not assigning them to a variable?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to