bkietz commented on a change in pull request #7026: URL: https://github.com/apache/arrow/pull/7026#discussion_r416761581
########## File path: python/pyarrow/tests/test_dataset.py ########## @@ -373,141 +357,70 @@ def test_partitioning(): assert expr.equals(expected) -def test_expression(): - a = ds.ScalarExpression(1) - b = ds.ScalarExpression(1.1) - c = ds.ScalarExpression(True) - d = ds.ScalarExpression("string") - e = ds.ScalarExpression(None) - - equal = ds.ComparisonExpression(ds.CompareOperator.Equal, a, b) - greater = a > b - assert equal.op == ds.CompareOperator.Equal - - and_ = ds.AndExpression(a, b) - assert and_.left_operand.equals(a) - assert and_.right_operand.equals(b) - assert and_.equals(ds.AndExpression(a, b)) - assert and_.equals(and_) - - or_ = ds.OrExpression(a, b) - not_ = ds.NotExpression(ds.OrExpression(a, b)) - is_valid = ds.IsValidExpression(a) - cast_safe = ds.CastExpression(a, pa.int32()) - cast_unsafe = ds.CastExpression(a, pa.int32(), safe=False) - in_ = ds.InExpression(a, pa.array([1, 2, 3])) - - assert is_valid.operand == a - assert in_.set_.equals(pa.array([1, 2, 3])) - assert cast_unsafe.to == pa.int32() - assert cast_unsafe.safe is False - assert cast_safe.safe is True - - condition = ds.ComparisonExpression( - ds.CompareOperator.Greater, - ds.FieldExpression('i64'), - ds.ScalarExpression(5) - ) +def test_expression_serialization(): + a = ds.Expression.scalar(1) + b = ds.Expression.scalar(1.1) + c = ds.Expression.scalar(True) + d = ds.Expression.scalar("string") + e = ds.Expression.scalar(None) + + condition = ds.field('i64') > 5 schema = pa.schema([ pa.field('i64', pa.int64()), pa.field('f64', pa.float64()) ]) assert condition.validate(schema) == pa.bool_() - i64_is_5 = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.FieldExpression('i64'), - ds.ScalarExpression(5) - ) - i64_is_7 = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.FieldExpression('i64'), - ds.ScalarExpression(7) - ) - assert condition.assume(i64_is_5).equals(ds.ScalarExpression(False)) - assert condition.assume(i64_is_7).equals(ds.ScalarExpression(True)) - assert str(condition) == "(i64 > 5:int64)" - assert "(i64 > 5:int64)" in repr(condition) + assert condition.assume(ds.field('i64') == 5).equals( + ds.Expression.scalar(False)) - all_exprs = [a, b, c, d, e, equal, greater, and_, or_, not_, is_valid, - cast_unsafe, cast_safe, in_, condition, i64_is_5, i64_is_7] + assert condition.assume(ds.field('i64') == 7).equals( + ds.Expression.scalar(True)) + + all_exprs = [a, b, c, d, e, a == b, a > b, a & b, a | b, ~c, + d.is_valid(), a.cast(pa.int32(), safe=False), + a.cast(pa.int32(), safe=False), a.isin([1, 2, 3]), + ds.field('i64') > 5, ds.field('i64') == 5, + ds.field('i64') == 7] for expr in all_exprs: + print(str(expr)) + assert isinstance(expr, ds.Expression) restored = pickle.loads(pickle.dumps(expr)) assert expr.equals(restored) -def test_expression_ergonomics(): +def test_expression_construction(): zero = ds.scalar(0) one = ds.scalar(1) true = ds.scalar(True) false = ds.scalar(False) string = ds.scalar("string") field = ds.field("field") - assert one.equals(ds.ScalarExpression(1)) - assert zero.equals(ds.ScalarExpression(0)) - assert true.equals(ds.ScalarExpression(True)) - assert false.equals(ds.ScalarExpression(False)) - assert string.equals(ds.ScalarExpression("string")) - assert field.equals(ds.FieldExpression("field")) - - expected = ds.AndExpression(ds.ScalarExpression(1), ds.ScalarExpression(0)) - for expr in [one & zero, 1 & zero, one & 0]: - assert expr.equals(expected) - - expected = ds.OrExpression(ds.ScalarExpression(1), ds.ScalarExpression(0)) - for expr in [one | zero, 1 | zero, one | 0]: - assert expr.equals(expected) - - comparison_ops = [ - (operator.eq, ds.CompareOperator.Equal), - (operator.ne, ds.CompareOperator.NotEqual), - (operator.ge, ds.CompareOperator.GreaterEqual), - (operator.le, ds.CompareOperator.LessEqual), - (operator.lt, ds.CompareOperator.Less), - (operator.gt, ds.CompareOperator.Greater), - ] - for op, compare_op in comparison_ops: - expr = op(zero, one) - expected = ds.ComparisonExpression(compare_op, zero, one) - assert expr.equals(expected) - + expr = zero | one == string expr = ~true == false - expected = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.NotExpression(ds.ScalarExpression(True)), - ds.ScalarExpression(False) - ) - assert expr.equals(expected) - for typ in ("bool", pa.bool_()): expr = field.cast(typ) == true - expected = ds.ComparisonExpression( - ds.CompareOperator.Equal, - ds.CastExpression(ds.FieldExpression("field"), pa.bool_()), - ds.ScalarExpression(True) - ) - assert expr.equals(expected) expr = field.isin([1, 2]) - expected = ds.InExpression(ds.FieldExpression("field"), pa.array([1, 2])) - assert expr.equals(expected) with pytest.raises(TypeError): - field.isin(1) + expr = field.isin(1) # operations with non-scalar values with pytest.raises(TypeError): - field == [1] + expr = field == [1] with pytest.raises(TypeError): - field != {1} + expr = field != {1} with pytest.raises(TypeError): - field & [1] + expr = field & [1] with pytest.raises(TypeError): - field | [1] + expr = field | [1] + + assert expr is not None # silence flake8 Review comment: I didn't try that ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org