jorisvandenbossche commented on a change in pull request #7026: URL: https://github.com/apache/arrow/pull/7026#discussion_r416567863
########## File path: python/pyarrow/_dataset.pyx ########## @@ -41,6 +42,167 @@ def _forbid_instantiation(klass, subclasses_instead=True): raise TypeError(msg) +cdef class Expression: + + cdef: + shared_ptr[CExpression] wrapped + CExpression* expr + + def __init__(self, Buffer buffer=None): + if buffer is not None: + c_buffer = pyarrow_unwrap_buffer(buffer) + expr = GetResultValue(CExpression.Deserialize(deref(c_buffer))) + self.init(expr) Review comment: Maybe we can put this in a `Expression._deserialize` or so, to avoid putting this in the main constructor ? (and keep this one with using `_forbid_instantiation`) ########## File path: python/pyarrow/_dataset.pyx ########## @@ -41,6 +42,167 @@ def _forbid_instantiation(klass, subclasses_instead=True): raise TypeError(msg) +cdef class Expression: + + cdef: + shared_ptr[CExpression] wrapped + CExpression* expr + + def __init__(self, Buffer buffer=None): + if buffer is not None: + c_buffer = pyarrow_unwrap_buffer(buffer) + expr = GetResultValue(CExpression.Deserialize(deref(c_buffer))) + self.init(expr) Review comment: Another good reason is that right now calling this without `buffer` results in an expression that segfaults in certain conditions (like printing the repr) ########## File path: python/pyarrow/_dataset.pyx ########## @@ -41,6 +42,167 @@ def _forbid_instantiation(klass, subclasses_instead=True): raise TypeError(msg) +cdef class Expression: + + cdef: + shared_ptr[CExpression] wrapped + CExpression* expr + + def __init__(self, Buffer buffer=None): + if buffer is not None: + c_buffer = pyarrow_unwrap_buffer(buffer) + expr = GetResultValue(CExpression.Deserialize(deref(c_buffer))) + self.init(expr) + + cdef void init(self, const shared_ptr[CExpression]& sp): + self.wrapped = sp + self.expr = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CExpression]& sp): + self = Expression() + self.init(sp) + return self + + cdef inline shared_ptr[CExpression] unwrap(self): + return self.wrapped + + def equals(self, Expression other): + return self.expr.Equals(other.unwrap()) + + def __str__(self): + return frombytes(self.expr.ToString()) + + def __repr__(self): + return "<pyarrow.dataset.{0} {1}>".format( + self.__class__.__name__, str(self) + ) + + def __reduce__(self): + buffer = pyarrow_wrap_buffer(GetResultValue(self.expr.Serialize())) + return Expression, (buffer,) + + def validate(self, Schema schema not None): + """Validate this expression for execution against a schema. + + This will check that all reference fields are present (fields not in + the schema will be replaced with null) and all subexpressions are + executable. Returns the type to which this expression will evaluate. + + Parameters + ---------- + schema : Schema + Schema to execute the expression on. + + Returns + ------- + type : DataType + """ + cdef: + shared_ptr[CSchema] sp_schema + CResult[shared_ptr[CDataType]] result + sp_schema = pyarrow_unwrap_schema(schema) + result = self.expr.Validate(deref(sp_schema)) + return pyarrow_wrap_data_type(GetResultValue(result)) + + def assume(self, Expression given): + """Simplify to an equivalent Expression given assumed constraints.""" + return Expression.wrap(self.expr.Assume(given.unwrap())) + + def __invert__(self): + return Expression.wrap(CMakeNotExpression(self.unwrap())) + + @staticmethod + cdef shared_ptr[CExpression] _expr_or_scalar(object expr) except *: + if isinstance(expr, Expression): + return (<Expression> expr).unwrap() + return (<Expression> Expression.scalar(expr)).unwrap() + + @staticmethod + def wtf(): + return Expression.wrap(Expression._expr_or_scalar([])) Review comment: Leftover? ########## File path: python/pyarrow/_dataset.pyx ########## @@ -41,6 +42,167 @@ def _forbid_instantiation(klass, subclasses_instead=True): raise TypeError(msg) +cdef class Expression: + + cdef: + shared_ptr[CExpression] wrapped + CExpression* expr + + def __init__(self, Buffer buffer=None): + if buffer is not None: + c_buffer = pyarrow_unwrap_buffer(buffer) + expr = GetResultValue(CExpression.Deserialize(deref(c_buffer))) + self.init(expr) + + cdef void init(self, const shared_ptr[CExpression]& sp): + self.wrapped = sp + self.expr = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CExpression]& sp): + self = Expression() Review comment: ```suggestion cdef Expression self = Expression.__new__(Expression) ``` (in case the init gets forbidden) ########## File path: python/pyarrow/_dataset.pyx ########## @@ -118,21 +280,15 @@ cdef class Dataset: ------- fragments : iterator of Fragment """ - cdef: - CFragmentIterator iterator - shared_ptr[CFragment] fragment + cdef CFragmentIterator c_fragments if filter is None or filter.expr == nullptr: - iterator = self.dataset.GetFragments() + c_fragments = self.dataset.GetFragments() else: - iterator = self.dataset.GetFragments(filter.unwrap()) + c_fragments = self.dataset.GetFragments(filter.unwrap()) - while True: - fragment = GetResultValue(iterator.Next()) - if fragment.get() == nullptr: - raise StopIteration() - else: - yield Fragment.wrap(fragment) + for maybe_fragment in c_fragments: Review comment: Why "maybe" fragment? (all elements should be fragments, no?) ########## File path: python/pyarrow/_dataset.pyx ########## @@ -41,6 +42,167 @@ def _forbid_instantiation(klass, subclasses_instead=True): raise TypeError(msg) +cdef class Expression: + + cdef: + shared_ptr[CExpression] wrapped + CExpression* expr + + def __init__(self, Buffer buffer=None): + if buffer is not None: + c_buffer = pyarrow_unwrap_buffer(buffer) + expr = GetResultValue(CExpression.Deserialize(deref(c_buffer))) + self.init(expr) + + cdef void init(self, const shared_ptr[CExpression]& sp): + self.wrapped = sp + self.expr = sp.get() + + @staticmethod + cdef wrap(const shared_ptr[CExpression]& sp): + self = Expression() + self.init(sp) + return self + + cdef inline shared_ptr[CExpression] unwrap(self): + return self.wrapped + + def equals(self, Expression other): + return self.expr.Equals(other.unwrap()) + + def __str__(self): + return frombytes(self.expr.ToString()) + + def __repr__(self): + return "<pyarrow.dataset.{0} {1}>".format( + self.__class__.__name__, str(self) + ) + + def __reduce__(self): + buffer = pyarrow_wrap_buffer(GetResultValue(self.expr.Serialize())) + return Expression, (buffer,) + + def validate(self, Schema schema not None): + """Validate this expression for execution against a schema. + + This will check that all reference fields are present (fields not in + the schema will be replaced with null) and all subexpressions are + executable. Returns the type to which this expression will evaluate. + + Parameters + ---------- + schema : Schema + Schema to execute the expression on. + + Returns + ------- + type : DataType + """ + cdef: + shared_ptr[CSchema] sp_schema + CResult[shared_ptr[CDataType]] result + sp_schema = pyarrow_unwrap_schema(schema) + result = self.expr.Validate(deref(sp_schema)) + return pyarrow_wrap_data_type(GetResultValue(result)) + + def assume(self, Expression given): + """Simplify to an equivalent Expression given assumed constraints.""" + return Expression.wrap(self.expr.Assume(given.unwrap())) + + def __invert__(self): + return Expression.wrap(CMakeNotExpression(self.unwrap())) + + @staticmethod + cdef shared_ptr[CExpression] _expr_or_scalar(object expr) except *: + if isinstance(expr, Expression): + return (<Expression> expr).unwrap() + return (<Expression> Expression.scalar(expr)).unwrap() + + @staticmethod + def wtf(): + return Expression.wrap(Expression._expr_or_scalar([])) + + def __richcmp__(self, other, int op): + cdef: + shared_ptr[CExpression] c_expr + shared_ptr[CExpression] c_left + shared_ptr[CExpression] c_right + + c_left = self.unwrap() + c_right = Expression._expr_or_scalar(other) + + if op == Py_EQ: + c_expr = CMakeEqualExpression(move(c_left), move(c_right)) + elif op == Py_NE: + c_expr = CMakeNotEqualExpression(move(c_left), move(c_right)) + elif op == Py_GT: + c_expr = CMakeGreaterExpression(move(c_left), move(c_right)) + elif op == Py_GE: + c_expr = CMakeGreaterEqualExpression(move(c_left), move(c_right)) + elif op == Py_LT: + c_expr = CMakeLessExpression(move(c_left), move(c_right)) + elif op == Py_LE: + c_expr = CMakeLessEqualExpression(move(c_left), move(c_right)) + + return Expression.wrap(c_expr) + + def __and__(Expression self, other): + c_other = Expression._expr_or_scalar(other) + return Expression.wrap(CMakeAndExpression(self.wrapped, + move(c_other))) + + def __or__(Expression self, other): + c_other = Expression._expr_or_scalar(other) + return Expression.wrap(CMakeOrExpression(self.wrapped, + move(c_other))) + + def is_valid(self): + """Checks whether the expression is not-null (valid)""" + return Expression.wrap(self.expr.IsValid().Copy()) + + def cast(self, type, bint safe=True): + """Explicitly change the expression's data type""" + cdef CastOptions options + options = CastOptions.safe() if safe else CastOptions.unsafe() + c_type = pyarrow_unwrap_data_type(ensure_type(type)) + return Expression.wrap(self.expr.CastTo(c_type, + options.unwrap()).Copy()) + + def isin(self, values): + """Checks whether the expression is contained in values""" + if not isinstance(values, pa.Array): + values = pa.array(values) + c_values = pyarrow_unwrap_array(values) + return Expression.wrap(self.expr.In(c_values).Copy()) + + @staticmethod + def field(str name not None): + return Expression.wrap(CMakeFieldExpression(tobytes(name))) + + @staticmethod + def scalar(value): Review comment: ```suggestion def _field(str name not None): return Expression.wrap(CMakeFieldExpression(tobytes(name))) @staticmethod def _scalar(value): ``` To avoid public API exposure (that we might want/need to preserve later one), I would make those private. For the end-user, there are already the `ds.field(..)` and `ds.scalar(..)` functions ########## File path: python/pyarrow/parquet.py ########## @@ -154,27 +154,19 @@ def convert_single_predicate(col, op, val): '"{0}" is not a valid operator in predicates.'.format( (col, op, val))) - or_exprs = [] + disjunction_members = [] for conjunction in filters: - and_exprs = [] - for col, op, val in conjunction: - and_exprs.append(convert_single_predicate(col, op, val)) + conjunction_members = [ + convert_single_predicate(col, op, val) + for col, op, val in conjunction + ] - expr = and_exprs[0] - if len(and_exprs) > 1: - for and_expr in and_exprs[1:]: - expr = ds.AndExpression(expr, and_expr) + conjunction = reduce(lambda acc, one: acc & one, conjunction_members) Review comment: Nice :) Further change could be to use `operator.and_` / `or_` instead of `lambda acc, one: acc & one` ########## File path: python/pyarrow/_dataset.pyx ########## @@ -27,6 +27,7 @@ from pyarrow.lib cimport * from pyarrow.includes.libarrow_dataset cimport * from pyarrow.compat import frombytes, tobytes from pyarrow._fs cimport FileSystem, FileInfo, FileSelector +import functools Review comment: I think this import is not used anymore in this file now? ########## File path: python/pyarrow/tests/test_dataset.py ########## @@ -218,33 +216,21 @@ def test_filesystem_dataset(mockfs): # validation of required arguments with pytest.raises(TypeError, match="incorrect type"): - ds.FileSystemDataset(paths, format=file_format, filesystem=mockfs) + ds.FileSystemDataset(paths, schema=None, format=file_format, + filesystem=mockfs) Review comment: Can you leave those as they were as well? The goal here was to test that you get a decent error message when leaving out one of the required arguments (maybe should have added a better comment about that), which was the reason I used the manual checking in FileSystemDataset constructor instead of using actual required positional arguments. ########## File path: python/pyarrow/tests/test_dataset.py ########## @@ -373,141 +357,70 @@ def test_partitioning(): assert expr.equals(expected) -def test_expression(): - a = ds.ScalarExpression(1) - b = ds.ScalarExpression(1.1) - c = ds.ScalarExpression(True) - d = ds.ScalarExpression("string") - e = ds.ScalarExpression(None) - - equal = ds.ComparisonExpression(ds.CompareOperator.Equal, a, b) - greater = a > b - assert equal.op == ds.CompareOperator.Equal - - and_ = ds.AndExpression(a, b) - assert and_.left_operand.equals(a) - assert and_.right_operand.equals(b) - assert and_.equals(ds.AndExpression(a, b)) - assert and_.equals(and_) - - or_ = ds.OrExpression(a, b) - not_ = ds.NotExpression(ds.OrExpression(a, b)) - is_valid = ds.IsValidExpression(a) - cast_safe = ds.CastExpression(a, pa.int32()) - cast_unsafe = ds.CastExpression(a, pa.int32(), safe=False) - in_ = ds.InExpression(a, pa.array([1, 2, 3])) - - assert is_valid.operand == a - assert in_.set_.equals(pa.array([1, 2, 3])) - assert cast_unsafe.to == pa.int32() - assert cast_unsafe.safe is False - assert cast_safe.safe is True - - condition = ds.ComparisonExpression( - ds.CompareOperator.Greater, - ds.FieldExpression('i64'), - ds.ScalarExpression(5) - ) +def test_expression_serialization(): + a = ds.Expression.scalar(1) + b = ds.Expression.scalar(1.1) + c = ds.Expression.scalar(True) + d = ds.Expression.scalar("string") + e = ds.Expression.scalar(None) Review comment: ```suggestion a = ds.scalar(1) b = ds.scalar(1.1) c = ds.scalar(True) d = ds.scalar("string") e = ds.scalar(None) ``` ########## File path: python/pyarrow/tests/test_dataset.py ########## @@ -198,11 +197,10 @@ def test_filesystem_dataset(mockfs): file_format = ds.ParquetFileFormat() paths = ['subdir/1/xxx/file0.parquet', 'subdir/2/yyy/file1.parquet'] - partitions = [ds.ScalarExpression(True), ds.ScalarExpression(True)] + partitions = [ds.Expression.scalar(True), ds.Expression.scalar(True)] Review comment: ```suggestion partitions = [ds.scalar(True), ds.scalar(True)] ``` ########## File path: python/pyarrow/tests/test_dataset.py ########## @@ -373,141 +357,70 @@ def test_partitioning(): assert expr.equals(expected) -def test_expression(): - a = ds.ScalarExpression(1) - b = ds.ScalarExpression(1.1) - c = ds.ScalarExpression(True) - d = ds.ScalarExpression("string") - e = ds.ScalarExpression(None) - - equal = ds.ComparisonExpression(ds.CompareOperator.Equal, a, b) - greater = a > b - assert equal.op == ds.CompareOperator.Equal - - and_ = ds.AndExpression(a, b) - assert and_.left_operand.equals(a) - assert and_.right_operand.equals(b) - assert and_.equals(ds.AndExpression(a, b)) - assert and_.equals(and_) - - or_ = ds.OrExpression(a, b) - not_ = ds.NotExpression(ds.OrExpression(a, b)) - is_valid = ds.IsValidExpression(a) - cast_safe = ds.CastExpression(a, pa.int32()) - cast_unsafe = ds.CastExpression(a, pa.int32(), safe=False) - in_ = ds.InExpression(a, pa.array([1, 2, 3])) - - assert is_valid.operand == a - assert in_.set_.equals(pa.array([1, 2, 3])) - assert cast_unsafe.to == pa.int32() - assert cast_unsafe.safe is False - assert cast_safe.safe is True - - condition = ds.ComparisonExpression( - ds.CompareOperator.Greater, - ds.FieldExpression('i64'), - ds.ScalarExpression(5) - ) +def test_expression_serialization(): + a = ds.Expression.scalar(1) + b = ds.Expression.scalar(1.1) + c = ds.Expression.scalar(True) + d = ds.Expression.scalar("string") + e = ds.Expression.scalar(None) + + condition = ds.field('i64') > 5 schema = pa.schema([ pa.field('i64', pa.int64()), pa.field('f64', pa.float64()) ]) assert condition.validate(schema) == pa.bool_() - i64_is_5 = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.FieldExpression('i64'), - ds.ScalarExpression(5) - ) - i64_is_7 = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.FieldExpression('i64'), - ds.ScalarExpression(7) - ) - assert condition.assume(i64_is_5).equals(ds.ScalarExpression(False)) - assert condition.assume(i64_is_7).equals(ds.ScalarExpression(True)) - assert str(condition) == "(i64 > 5:int64)" - assert "(i64 > 5:int64)" in repr(condition) + assert condition.assume(ds.field('i64') == 5).equals( + ds.Expression.scalar(False)) - all_exprs = [a, b, c, d, e, equal, greater, and_, or_, not_, is_valid, - cast_unsafe, cast_safe, in_, condition, i64_is_5, i64_is_7] + assert condition.assume(ds.field('i64') == 7).equals( + ds.Expression.scalar(True)) Review comment: ```suggestion ds.scalar(True)) ``` ########## File path: python/pyarrow/tests/test_dataset.py ########## @@ -218,33 +216,21 @@ def test_filesystem_dataset(mockfs): # validation of required arguments with pytest.raises(TypeError, match="incorrect type"): - ds.FileSystemDataset(paths, format=file_format, filesystem=mockfs) + ds.FileSystemDataset(paths, schema=None, format=file_format, + filesystem=mockfs) with pytest.raises(TypeError, match="incorrect type"): - ds.FileSystemDataset(paths, schema=schema, filesystem=mockfs) + ds.FileSystemDataset(paths, schema=schema, format=None, + filesystem=mockfs) with pytest.raises(TypeError, match="incorrect type"): - ds.FileSystemDataset(paths, schema=schema, format=file_format) + ds.FileSystemDataset(paths, schema=schema, format=file_format, + filesystem=None) # validation of root_partition with pytest.raises(TypeError, match="incorrect type"): ds.FileSystemDataset(paths, schema=schema, format=file_format, filesystem=mockfs, root_partition=1) - root_partition = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.FieldExpression('level'), - ds.ScalarExpression(1337) - ) - partitions = [ - ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.FieldExpression('part'), - ds.ScalarExpression(1) - ), - ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.FieldExpression('part'), - ds.ScalarExpression(2) - ) - ] + root_partition = ds.field('level') == ds.Expression.scalar(1337) Review comment: ```suggestion root_partition = ds.field('level') == ds.scalar(1337) ``` ########## File path: python/pyarrow/tests/test_dataset.py ########## @@ -373,141 +357,70 @@ def test_partitioning(): assert expr.equals(expected) -def test_expression(): - a = ds.ScalarExpression(1) - b = ds.ScalarExpression(1.1) - c = ds.ScalarExpression(True) - d = ds.ScalarExpression("string") - e = ds.ScalarExpression(None) - - equal = ds.ComparisonExpression(ds.CompareOperator.Equal, a, b) - greater = a > b - assert equal.op == ds.CompareOperator.Equal - - and_ = ds.AndExpression(a, b) - assert and_.left_operand.equals(a) - assert and_.right_operand.equals(b) - assert and_.equals(ds.AndExpression(a, b)) - assert and_.equals(and_) - - or_ = ds.OrExpression(a, b) - not_ = ds.NotExpression(ds.OrExpression(a, b)) - is_valid = ds.IsValidExpression(a) - cast_safe = ds.CastExpression(a, pa.int32()) - cast_unsafe = ds.CastExpression(a, pa.int32(), safe=False) - in_ = ds.InExpression(a, pa.array([1, 2, 3])) - - assert is_valid.operand == a - assert in_.set_.equals(pa.array([1, 2, 3])) - assert cast_unsafe.to == pa.int32() - assert cast_unsafe.safe is False - assert cast_safe.safe is True - - condition = ds.ComparisonExpression( - ds.CompareOperator.Greater, - ds.FieldExpression('i64'), - ds.ScalarExpression(5) - ) +def test_expression_serialization(): + a = ds.Expression.scalar(1) + b = ds.Expression.scalar(1.1) + c = ds.Expression.scalar(True) + d = ds.Expression.scalar("string") + e = ds.Expression.scalar(None) + + condition = ds.field('i64') > 5 schema = pa.schema([ pa.field('i64', pa.int64()), pa.field('f64', pa.float64()) ]) assert condition.validate(schema) == pa.bool_() - i64_is_5 = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.FieldExpression('i64'), - ds.ScalarExpression(5) - ) - i64_is_7 = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.FieldExpression('i64'), - ds.ScalarExpression(7) - ) - assert condition.assume(i64_is_5).equals(ds.ScalarExpression(False)) - assert condition.assume(i64_is_7).equals(ds.ScalarExpression(True)) - assert str(condition) == "(i64 > 5:int64)" - assert "(i64 > 5:int64)" in repr(condition) + assert condition.assume(ds.field('i64') == 5).equals( + ds.Expression.scalar(False)) Review comment: ```suggestion ds.scalar(False)) ``` ########## File path: python/pyarrow/tests/test_dataset.py ########## @@ -373,141 +357,70 @@ def test_partitioning(): assert expr.equals(expected) -def test_expression(): - a = ds.ScalarExpression(1) - b = ds.ScalarExpression(1.1) - c = ds.ScalarExpression(True) - d = ds.ScalarExpression("string") - e = ds.ScalarExpression(None) - - equal = ds.ComparisonExpression(ds.CompareOperator.Equal, a, b) - greater = a > b - assert equal.op == ds.CompareOperator.Equal - - and_ = ds.AndExpression(a, b) - assert and_.left_operand.equals(a) - assert and_.right_operand.equals(b) - assert and_.equals(ds.AndExpression(a, b)) - assert and_.equals(and_) - - or_ = ds.OrExpression(a, b) - not_ = ds.NotExpression(ds.OrExpression(a, b)) - is_valid = ds.IsValidExpression(a) - cast_safe = ds.CastExpression(a, pa.int32()) - cast_unsafe = ds.CastExpression(a, pa.int32(), safe=False) - in_ = ds.InExpression(a, pa.array([1, 2, 3])) - - assert is_valid.operand == a - assert in_.set_.equals(pa.array([1, 2, 3])) - assert cast_unsafe.to == pa.int32() - assert cast_unsafe.safe is False - assert cast_safe.safe is True - - condition = ds.ComparisonExpression( - ds.CompareOperator.Greater, - ds.FieldExpression('i64'), - ds.ScalarExpression(5) - ) +def test_expression_serialization(): + a = ds.Expression.scalar(1) + b = ds.Expression.scalar(1.1) + c = ds.Expression.scalar(True) + d = ds.Expression.scalar("string") + e = ds.Expression.scalar(None) + + condition = ds.field('i64') > 5 schema = pa.schema([ pa.field('i64', pa.int64()), pa.field('f64', pa.float64()) ]) assert condition.validate(schema) == pa.bool_() - i64_is_5 = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.FieldExpression('i64'), - ds.ScalarExpression(5) - ) - i64_is_7 = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.FieldExpression('i64'), - ds.ScalarExpression(7) - ) - assert condition.assume(i64_is_5).equals(ds.ScalarExpression(False)) - assert condition.assume(i64_is_7).equals(ds.ScalarExpression(True)) - assert str(condition) == "(i64 > 5:int64)" - assert "(i64 > 5:int64)" in repr(condition) + assert condition.assume(ds.field('i64') == 5).equals( + ds.Expression.scalar(False)) - all_exprs = [a, b, c, d, e, equal, greater, and_, or_, not_, is_valid, - cast_unsafe, cast_safe, in_, condition, i64_is_5, i64_is_7] + assert condition.assume(ds.field('i64') == 7).equals( + ds.Expression.scalar(True)) + + all_exprs = [a, b, c, d, e, a == b, a > b, a & b, a | b, ~c, + d.is_valid(), a.cast(pa.int32(), safe=False), + a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]), + ds.field('i64') > 5, ds.field('i64') == 5, + ds.field('i64') == 7] for expr in all_exprs: + print(str(expr)) Review comment: leftover ########## File path: python/pyarrow/tests/test_dataset.py ########## @@ -592,7 +505,7 @@ def test_filesystem_factory(mockfs, paths_or_selector): assert isinstance(factory.inspect_schemas(), list) assert isinstance(factory.finish(inspected_schema), ds.FileSystemDataset) - assert factory.root_partition.equals(ds.ScalarExpression(True)) + assert factory.root_partition.equals(ds.Expression.scalar(True)) Review comment: ```suggestion assert factory.root_partition.equals(ds.scalar(True)) ``` ########## File path: python/pyarrow/tests/test_dataset.py ########## @@ -373,141 +357,70 @@ def test_partitioning(): assert expr.equals(expected) -def test_expression(): - a = ds.ScalarExpression(1) - b = ds.ScalarExpression(1.1) - c = ds.ScalarExpression(True) - d = ds.ScalarExpression("string") - e = ds.ScalarExpression(None) - - equal = ds.ComparisonExpression(ds.CompareOperator.Equal, a, b) - greater = a > b - assert equal.op == ds.CompareOperator.Equal - - and_ = ds.AndExpression(a, b) - assert and_.left_operand.equals(a) - assert and_.right_operand.equals(b) - assert and_.equals(ds.AndExpression(a, b)) - assert and_.equals(and_) - - or_ = ds.OrExpression(a, b) - not_ = ds.NotExpression(ds.OrExpression(a, b)) - is_valid = ds.IsValidExpression(a) - cast_safe = ds.CastExpression(a, pa.int32()) - cast_unsafe = ds.CastExpression(a, pa.int32(), safe=False) - in_ = ds.InExpression(a, pa.array([1, 2, 3])) - - assert is_valid.operand == a - assert in_.set_.equals(pa.array([1, 2, 3])) - assert cast_unsafe.to == pa.int32() - assert cast_unsafe.safe is False - assert cast_safe.safe is True - - condition = ds.ComparisonExpression( - ds.CompareOperator.Greater, - ds.FieldExpression('i64'), - ds.ScalarExpression(5) - ) +def test_expression_serialization(): + a = ds.Expression.scalar(1) + b = ds.Expression.scalar(1.1) + c = ds.Expression.scalar(True) + d = ds.Expression.scalar("string") + e = ds.Expression.scalar(None) + + condition = ds.field('i64') > 5 schema = pa.schema([ pa.field('i64', pa.int64()), pa.field('f64', pa.float64()) ]) assert condition.validate(schema) == pa.bool_() - i64_is_5 = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.FieldExpression('i64'), - ds.ScalarExpression(5) - ) - i64_is_7 = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.FieldExpression('i64'), - ds.ScalarExpression(7) - ) - assert condition.assume(i64_is_5).equals(ds.ScalarExpression(False)) - assert condition.assume(i64_is_7).equals(ds.ScalarExpression(True)) - assert str(condition) == "(i64 > 5:int64)" - assert "(i64 > 5:int64)" in repr(condition) + assert condition.assume(ds.field('i64') == 5).equals( + ds.Expression.scalar(False)) - all_exprs = [a, b, c, d, e, equal, greater, and_, or_, not_, is_valid, - cast_unsafe, cast_safe, in_, condition, i64_is_5, i64_is_7] + assert condition.assume(ds.field('i64') == 7).equals( + ds.Expression.scalar(True)) + + all_exprs = [a, b, c, d, e, a == b, a > b, a & b, a | b, ~c, + d.is_valid(), a.cast(pa.int32(), safe=False), + a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]), + ds.field('i64') > 5, ds.field('i64') == 5, + ds.field('i64') == 7] for expr in all_exprs: + print(str(expr)) + assert isinstance(expr, ds.Expression) restored = pickle.loads(pickle.dumps(expr)) assert expr.equals(restored) -def test_expression_ergonomics(): +def test_expression_construction(): zero = ds.scalar(0) one = ds.scalar(1) true = ds.scalar(True) false = ds.scalar(False) string = ds.scalar("string") field = ds.field("field") - assert one.equals(ds.ScalarExpression(1)) - assert zero.equals(ds.ScalarExpression(0)) - assert true.equals(ds.ScalarExpression(True)) - assert false.equals(ds.ScalarExpression(False)) - assert string.equals(ds.ScalarExpression("string")) - assert field.equals(ds.FieldExpression("field")) - - expected = ds.AndExpression(ds.ScalarExpression(1), ds.ScalarExpression(0)) - for expr in [one & zero, 1 & zero, one & 0]: - assert expr.equals(expected) - - expected = ds.OrExpression(ds.ScalarExpression(1), ds.ScalarExpression(0)) - for expr in [one | zero, 1 | zero, one | 0]: - assert expr.equals(expected) - - comparison_ops = [ - (operator.eq, ds.CompareOperator.Equal), - (operator.ne, ds.CompareOperator.NotEqual), - (operator.ge, ds.CompareOperator.GreaterEqual), - (operator.le, ds.CompareOperator.LessEqual), - (operator.lt, ds.CompareOperator.Less), - (operator.gt, ds.CompareOperator.Greater), - ] - for op, compare_op in comparison_ops: - expr = op(zero, one) - expected = ds.ComparisonExpression(compare_op, zero, one) - assert expr.equals(expected) - + expr = zero | one == string expr = ~true == false - expected = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.NotExpression(ds.ScalarExpression(True)), - ds.ScalarExpression(False) - ) - assert expr.equals(expected) - for typ in ("bool", pa.bool_()): expr = field.cast(typ) == true - expected = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.CastExpression(ds.FieldExpression("field"), pa.bool_()), - ds.ScalarExpression(True) - ) - assert expr.equals(expected) expr = field.isin([1, 2]) - expected = ds.InExpression(ds.FieldExpression("field"), pa.array([1, 2])) - assert expr.equals(expected) with pytest.raises(TypeError): - field.isin(1) + expr = field.isin(1) # operations with non-scalar values with pytest.raises(TypeError): - field == [1] + expr = field == [1] with pytest.raises(TypeError): - field != {1} + expr = field != {1} with pytest.raises(TypeError): - field & [1] + expr = field & [1] with pytest.raises(TypeError): - field | [1] + expr = field | [1] + + assert expr is not None # silence flake8 Review comment: if you don't assign them to `expr` in the lines above (how it was before), you shouldn't need this to silence flake8. Or did you get another linter error when not assigning them to a variable? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org