cloud-fan commented on a change in pull request #35768:
URL: https://github.com/apache/spark/pull/35768#discussion_r829092334
##########
File path:
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Expression.java
##########
@@ -26,8 +26,15 @@
*/
@Evolving
public interface Expression {
+ NamedReference[] EMPTY_REFERENCE = new NamedReference[0];
+
/**
* Format the expression as a human readable SQL-like string.
*/
default String describe() { return this.toString(); }
+
+ /**
+ * List of fields or columns that are referenced by this expression.
+ */
+ default NamedReference[] references() { return EMPTY_REFERENCE; }
Review comment:
hmm, the base `Expression` interface does not define `children`... Then
we can't have a reasonable default implementation.
##########
File path:
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
##########
@@ -190,6 +113,30 @@ public GeneralScalarExpression(String name, Expression[]
children) {
public String name() { return name; }
public Expression[] children() { return children; }
+ @Override
+ public NamedReference[] references() {
+ return Arrays.stream(children()).map(e -> {
+ if (e instanceof NamedReference) {
Review comment:
why do we need this check? Following the catalyst `Expression`, we just
need to combine the references from children and deduplicate them.
##########
File path:
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/GeneralScalarExpression.java
##########
@@ -190,6 +113,30 @@ public GeneralScalarExpression(String name, Expression[]
children) {
public String name() { return name; }
public Expression[] children() { return children; }
+ @Override
+ public NamedReference[] references() {
+ return Arrays.stream(children()).map(e -> {
+ if (e instanceof NamedReference) {
Review comment:
`Arrays.stream(children()).map(e -> e.references())`
##########
File path:
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/Transform.java
##########
@@ -34,11 +34,6 @@
*/
String name();
- /**
- * Returns all field references in the transform arguments.
- */
- NamedReference[] references();
Review comment:
We need to provide a default implementation here, which is exactly the
same with `GeneralScalarExpression`
##########
File path:
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
##########
@@ -58,6 +58,9 @@ public GeneralAggregateFunc(String name, boolean isDistinct,
NamedReference[] in
this.inputs = inputs;
}
+ @Override
+ public NamedReference[] references() { return inputs; }
Review comment:
not related to this pr: should `GeneralAggregateFunc.inputs` be
`Expression[]`?
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala
##########
@@ -64,6 +69,13 @@ sealed abstract class Filter {
private[sql] def containsNestedColumn: Boolean = {
this.v2references.exists(_.length > 1)
}
+
+ /**
+ * Converts V1 filter to V2 filter
+ *
+ * @since 3.3.0
Review comment:
we should remove `since` as it's not a public API
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
##########
@@ -609,20 +545,39 @@ private[sql] object DataSourceV2Strategy {
}
protected[sql] def rebuildExpressionFromFilter(
- filter: V2Filter,
- translatedFilterToExpr: mutable.HashMap[V2Filter, Expression]):
Expression = {
- filter match {
+ predicate: Predicate,
+ translatedFilterToExpr: mutable.HashMap[Predicate, Expression]):
Expression = {
+ predicate match {
case and: V2And =>
- expressions.And(rebuildExpressionFromFilter(and.left,
translatedFilterToExpr),
- rebuildExpressionFromFilter(and.right, translatedFilterToExpr))
+ expressions.And(
+ rebuildExpressionFromFilter(and.left(), translatedFilterToExpr),
+ rebuildExpressionFromFilter(and.right(), translatedFilterToExpr))
case or: V2Or =>
- expressions.Or(rebuildExpressionFromFilter(or.left,
translatedFilterToExpr),
- rebuildExpressionFromFilter(or.right, translatedFilterToExpr))
+ expressions.Or(
+ rebuildExpressionFromFilter(or.left(), translatedFilterToExpr),
+ rebuildExpressionFromFilter(or.right(), translatedFilterToExpr))
case not: V2Not =>
- expressions.Not(rebuildExpressionFromFilter(not.child,
translatedFilterToExpr))
- case other =>
- translatedFilterToExpr.getOrElse(other,
- throw new IllegalStateException("Failed to rebuild Expression for
filter: " + filter))
+ expressions.Not(rebuildExpressionFromFilter(not.child(),
translatedFilterToExpr))
+ case _ =>
+ translatedFilterToExpr.getOrElse(predicate,
+ throw new IllegalStateException("Failed to rebuild Expression for
filter: " + predicate))
}
}
}
+
+/**
+ * Get the expression of DS V2 to represent catalyst predicate that can be
pushed down.
+ */
+case class PushablePredicate(nestedPredicatePushdownEnabled: Boolean) {
+ private val pushableColumn: PushableColumnBase =
PushableColumn(nestedPredicatePushdownEnabled)
+
+ def unapply(e: Expression): Option[Predicate] = e match {
+ case col @ pushableColumn(name) if col.dataType.isInstanceOf[BooleanType]
=>
Review comment:
Are you sure we will hit this case? If yes, shall we also support
`AND(col_a, col_b)`?
##########
File path: sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
##########
@@ -220,12 +221,19 @@ abstract class JdbcDialect extends Serializable with
Logging{
}
class JDBCSQLBuilder extends V2ExpressionSQLBuilder {
- override def visitFieldReference(fieldRef: FieldReference): String = {
- if (fieldRef.fieldNames().length != 1) {
+ override def visitLiteral(literal: Literal[_]): String = {
+ val value =
+ compileValue(CatalystTypeConverters.convertToScala(literal.value(),
literal.dataType()))
+ s"$value"
Review comment:
why do we need 3 lines of code here?
```
compileValue(CatalystTypeConverters.convertToScala(literal.value(),
literal.dataType()))
```
##########
File path: sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala
##########
@@ -237,11 +244,7 @@ abstract class JdbcDialect extends Serializable with
Logging{
@Since("3.3.0")
def compileExpression(expr: Expression): Option[String] = {
val jdbcSQLBuilder = new JDBCSQLBuilder()
- try {
- Some(jdbcSQLBuilder.build(expr))
- } catch {
- case _: IllegalArgumentException => None
Review comment:
ping @beliefer
##########
File path:
sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2FiltersSuite.scala
##########
@@ -1,204 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources.v2
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.connector.expressions.{FieldReference, Literal,
LiteralValue}
-import org.apache.spark.sql.connector.expressions.filter._
-import org.apache.spark.sql.execution.datasources.v2.FiltersV2Suite.ref
-import org.apache.spark.sql.types.IntegerType
-import org.apache.spark.unsafe.types.UTF8String
-
-class FiltersV2Suite extends SparkFunSuite {
Review comment:
shall we rename it to `V2PredicateSuite`?
##########
File path:
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/aggregate/GeneralAggregateFunc.java
##########
@@ -46,18 +45,21 @@
public final class GeneralAggregateFunc implements AggregateFunc {
private final String name;
private final boolean isDistinct;
- private final NamedReference[] inputs;
+ private final Expression[] inputs;
public String name() { return name; }
public boolean isDistinct() { return isDistinct; }
- public NamedReference[] inputs() { return inputs; }
+ public Expression[] inputs() { return inputs; }
Review comment:
This API is newly added in 3.3.0. We can directly rename `inputs` to
`children`.
##########
File path:
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java
##########
@@ -17,34 +17,32 @@
package org.apache.spark.sql.connector.expressions.filter;
-import java.util.Objects;
-
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.Literal;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
/**
- * A filter that always evaluates to {@code false}.
+ * A predicate that always evaluates to {@code false}.
*
* @since 3.3.0
*/
@Evolving
-public final class AlwaysFalse extends Filter {
+public final class AlwaysFalse extends Predicate implements Literal<Boolean> {
+
+ private DataType dataType = DataTypes.BooleanType;
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- return true;
+ public AlwaysFalse() {
+ super("ALWAYS_FALSE", new Predicate[]{});
}
- @Override
- public int hashCode() {
- return Objects.hash();
+ public Boolean value() {
+ return false;
}
- @Override
- public String toString() { return "FALSE"; }
+ public DataType dataType() {
+ return dataType;
Review comment:
We ca remove `private DataType dataType = DataTypes.BooleanType;` and
simply do `return DataTypes.BooleanType;` here.
##########
File path:
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysFalse.java
##########
@@ -17,34 +17,32 @@
package org.apache.spark.sql.connector.expressions.filter;
-import java.util.Objects;
-
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.Literal;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
/**
- * A filter that always evaluates to {@code false}.
+ * A predicate that always evaluates to {@code false}.
*
* @since 3.3.0
*/
@Evolving
-public final class AlwaysFalse extends Filter {
+public final class AlwaysFalse extends Predicate implements Literal<Boolean> {
+
+ private DataType dataType = DataTypes.BooleanType;
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
- return true;
+ public AlwaysFalse() {
+ super("ALWAYS_FALSE", new Predicate[]{});
}
- @Override
- public int hashCode() {
- return Objects.hash();
+ public Boolean value() {
+ return false;
}
- @Override
- public String toString() { return "FALSE"; }
+ public DataType dataType() {
+ return dataType;
Review comment:
We can remove `private DataType dataType = DataTypes.BooleanType;` and
simply do `return DataTypes.BooleanType;` here.
##########
File path:
sql/catalyst/src/main/java/org/apache/spark/sql/connector/expressions/filter/AlwaysTrue.java
##########
@@ -17,34 +17,32 @@
package org.apache.spark.sql.connector.expressions.filter;
-import java.util.Objects;
-
import org.apache.spark.annotation.Evolving;
-import org.apache.spark.sql.connector.expressions.NamedReference;
+import org.apache.spark.sql.connector.expressions.Literal;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
/**
- * A filter that always evaluates to {@code true}.
+ * A predicate that always evaluates to {@code true}.
*
* @since 3.3.0
*/
@Evolving
-public final class AlwaysTrue extends Filter {
+public final class AlwaysTrue extends Predicate implements Literal<Boolean> {
+
+ private DataType dataType = DataTypes.BooleanType;
- @Override
- public boolean equals(Object o) {
- if (this == o) return true;
- if (o == null || getClass() != o.getClass()) return false;
+ public AlwaysTrue() {
+ super("ALWAYS_TRUE", new Predicate[]{});
+ }
+
+ public Boolean value() {
return true;
}
- @Override
- public int hashCode() {
- return Objects.hash();
+ public DataType dataType() {
+ return dataType;
Review comment:
ditto
##########
File path:
sql/catalyst/src/main/java/org/apache/spark/sql/connector/util/V2ExpressionSQLBuilder.java
##########
@@ -103,12 +123,53 @@ protected String visitIsNotNull(String v) {
return v + " IS NOT NULL";
}
- protected String visitBinaryComparison(String name, String l, String r) {
- return "(" + l + ") " + name + " (" + r + ")";
+ protected String visitStringPredicate(String name, GeneralScalarExpression
strPredicate) {
+ String l = build(strPredicate.children()[0]);
+ String r = build(strPredicate.children()[1]);
+ String value = r.replaceAll("'", "");
+ switch (name) {
+ case "STARTS_WITH":
+ return l + " LIKE '" + value + "%'";
+ case "ENDS_WITH":
+ return l + " LIKE '%" + value + "'";
+ case "CONTAINS":
+ return l + " LIKE '%" + value + "%'";
+ default:
+ return visitUnexpectedExpr(strPredicate);
+ }
}
- protected String visitBinaryArithmetic(String name, String l, String r) {
- return "(" + l + ") " + name + " (" + r + ")";
+ protected String visitBinaryComparison(
+ String name, String l, int lChildNum, String r, int rChildNum) {
+ switch (name) {
+ case "<=>":
+ return "(" + l + " = " + r + ") OR (" + l + " IS NULL AND " + r + " IS
NULL)";
+ default:
+ String left = l, right = r;
+ if (lChildNum > 1) {
+ left = "(" + l + ")";
+ }
+ if (rChildNum > 1) {
+ right = "(" + r + ")";
+ }
+ return left + " " + name + " " + right;
+ }
+ }
+
+ private int childNum(Expression expr) {
+ return expr.children().length;
+ }
+
+ protected String visitBinaryArithmetic(
+ String name, String l, int lChildNum, String r, int rChildNum) {
+ String left = l, right = r;
+ if (lChildNum > 1) {
Review comment:
Remember: we expect users to override `visitXXX` methods and we should
try out best to not push the work to the visit methods. We can add parentheses
at the caller side
```
private String inputToSQL(Expression input) {
if (input.children().length > 1) {
return "(" + build(input) + ")";
} else {
return build(input);
}
}
...
return visitBinaryArithmetic(name, inputToSQL(e.children()[0]),
inputToSQL(e.children()[1]));
```
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/connector/expressions/expressions.scala
##########
@@ -377,7 +377,7 @@ private[sql] final case class SortValue(
expression: Expression,
direction: SortDirection,
nullOrdering: NullOrdering) extends SortOrder {
-
+ override def references: Array[NamedReference] = Expression.EMPTY_REFERENCE
Review comment:
why?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
##########
@@ -39,48 +45,102 @@ class V2ExpressionBuilder(e: Expression) {
case _ => false
}
- private def generateExpression(expr: Expression): Option[V2Expression] =
expr match {
+ private def generateExpression(
+ expr: Expression, isPredicate: Boolean = false): Option[V2Expression] =
expr match {
Review comment:
This is a private method, let's not provide the default parameter value
to make the code clearer.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
##########
@@ -39,48 +45,102 @@ class V2ExpressionBuilder(e: Expression) {
case _ => false
}
- private def generateExpression(expr: Expression): Option[V2Expression] =
expr match {
+ private def generateExpression(
+ expr: Expression, isPredicate: Boolean = false): Option[V2Expression] =
expr match {
+ case Literal(true, BooleanType) => Some(new AlwaysTrue())
+ case Literal(false, BooleanType) => Some(new AlwaysFalse())
Review comment:
We only need to do so if `isPredicate = true`
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
##########
@@ -39,48 +45,102 @@ class V2ExpressionBuilder(e: Expression) {
case _ => false
}
- private def generateExpression(expr: Expression): Option[V2Expression] =
expr match {
+ private def generateExpression(
+ expr: Expression, isPredicate: Boolean = false): Option[V2Expression] =
expr match {
+ case Literal(true, BooleanType) => Some(new AlwaysTrue())
+ case Literal(false, BooleanType) => Some(new AlwaysFalse())
case Literal(value, dataType) => Some(LiteralValue(value, dataType))
- case attr: Attribute => Some(FieldReference.column(attr.name))
+ case col @ pushableColumn(name) if nestedPredicatePushdownEnabled =>
+ if (isPredicate && col.dataType.isInstanceOf[BooleanType]) {
+ Some(new V2Predicate("=", Array(FieldReference(name),
LiteralValue(true, BooleanType))))
+ } else {
+ Some(FieldReference(name))
+ }
+ case pushableColumn(name) if !nestedPredicatePushdownEnabled =>
+ Some(FieldReference.column(name))
+ case in @ InSet(child, hset) =>
+ generateExpression(child).map { v =>
+ val children =
+ (v +: hset.toSeq.map(elem => LiteralValue(elem,
in.dataType))).toArray[V2Expression]
+ new V2Predicate("IN", children)
+ }
+ // Because we only convert In to InSet in Optimizer when there are more
than certain
+ // items. So it is possible we still get an In expression here that needs
to be pushed
+ // down.
+ case In(value, list) =>
+ val v = generateExpression(value)
+ val listExpressions = list.flatMap(generateExpression(_))
+ if (v.isDefined && list.length == listExpressions.length) {
+ val children = (v.get +: listExpressions).toArray[V2Expression]
+ // The children looks like [expr, value1, ..., valueN]
+ Some(new V2Predicate("IN", children))
+ } else {
+ None
+ }
case IsNull(col) => generateExpression(col)
- .map(c => new GeneralScalarExpression("IS_NULL", Array[V2Expression](c)))
+ .map(c => new V2Predicate("IS_NULL", Array[V2Expression](c)))
case IsNotNull(col) => generateExpression(col)
- .map(c => new GeneralScalarExpression("IS_NOT_NULL",
Array[V2Expression](c)))
+ .map(c => new V2Predicate("IS_NOT_NULL", Array[V2Expression](c)))
+ case p: StringPredicate =>
+ val left = generateExpression(p.left)
+ val right = generateExpression(p.right)
+ if (left.isDefined && right.isDefined) {
+ val name = p match {
+ case _: StartsWith => "STARTS_WITH"
+ case _: EndsWith => "ENDS_WITH"
+ case _: Contains => "CONTAINS"
+ }
+ Some(new V2Predicate(name, Array[V2Expression](left.get, right.get)))
+ } else {
+ None
+ }
case b: BinaryOperator if canTranslate(b) =>
- val left = generateExpression(b.left)
- val right = generateExpression(b.right)
+ // AND/OR expect predicate
Review comment:
can we add a new `case` for `And` and `Or`?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
##########
@@ -39,48 +45,102 @@ class V2ExpressionBuilder(e: Expression) {
case _ => false
}
- private def generateExpression(expr: Expression): Option[V2Expression] =
expr match {
+ private def generateExpression(
+ expr: Expression, isPredicate: Boolean = false): Option[V2Expression] =
expr match {
+ case Literal(true, BooleanType) => Some(new AlwaysTrue())
+ case Literal(false, BooleanType) => Some(new AlwaysFalse())
case Literal(value, dataType) => Some(LiteralValue(value, dataType))
- case attr: Attribute => Some(FieldReference.column(attr.name))
+ case col @ pushableColumn(name) if nestedPredicatePushdownEnabled =>
+ if (isPredicate && col.dataType.isInstanceOf[BooleanType]) {
+ Some(new V2Predicate("=", Array(FieldReference(name),
LiteralValue(true, BooleanType))))
+ } else {
+ Some(FieldReference(name))
+ }
+ case pushableColumn(name) if !nestedPredicatePushdownEnabled =>
+ Some(FieldReference.column(name))
+ case in @ InSet(child, hset) =>
+ generateExpression(child).map { v =>
+ val children =
+ (v +: hset.toSeq.map(elem => LiteralValue(elem,
in.dataType))).toArray[V2Expression]
+ new V2Predicate("IN", children)
+ }
+ // Because we only convert In to InSet in Optimizer when there are more
than certain
+ // items. So it is possible we still get an In expression here that needs
to be pushed
+ // down.
+ case In(value, list) =>
+ val v = generateExpression(value)
+ val listExpressions = list.flatMap(generateExpression(_))
+ if (v.isDefined && list.length == listExpressions.length) {
+ val children = (v.get +: listExpressions).toArray[V2Expression]
+ // The children looks like [expr, value1, ..., valueN]
+ Some(new V2Predicate("IN", children))
+ } else {
+ None
+ }
case IsNull(col) => generateExpression(col)
- .map(c => new GeneralScalarExpression("IS_NULL", Array[V2Expression](c)))
+ .map(c => new V2Predicate("IS_NULL", Array[V2Expression](c)))
case IsNotNull(col) => generateExpression(col)
- .map(c => new GeneralScalarExpression("IS_NOT_NULL",
Array[V2Expression](c)))
+ .map(c => new V2Predicate("IS_NOT_NULL", Array[V2Expression](c)))
+ case p: StringPredicate =>
+ val left = generateExpression(p.left)
+ val right = generateExpression(p.right)
+ if (left.isDefined && right.isDefined) {
+ val name = p match {
+ case _: StartsWith => "STARTS_WITH"
+ case _: EndsWith => "ENDS_WITH"
+ case _: Contains => "CONTAINS"
+ }
+ Some(new V2Predicate(name, Array[V2Expression](left.get, right.get)))
+ } else {
+ None
+ }
case b: BinaryOperator if canTranslate(b) =>
- val left = generateExpression(b.left)
- val right = generateExpression(b.right)
+ // AND/OR expect predicate
+ val left = generateExpression(b.left, b.isInstanceOf[And] ||
b.isInstanceOf[Or])
+ val right = generateExpression(b.right, b.isInstanceOf[And] ||
b.isInstanceOf[Or])
if (left.isDefined && right.isDefined) {
- Some(new GeneralScalarExpression(b.sqlOperator,
Array[V2Expression](left.get, right.get)))
+ if (b.isInstanceOf[Predicate]) {
+ Some(new V2Predicate(b.sqlOperator, Array[V2Expression](left.get,
right.get)))
+ } else {
+ Some(new GeneralScalarExpression(b.sqlOperator,
Array[V2Expression](left.get, right.get)))
+ }
} else {
None
}
case Not(eq: EqualTo) =>
val left = generateExpression(eq.left)
val right = generateExpression(eq.right)
if (left.isDefined && right.isDefined) {
- Some(new GeneralScalarExpression("!=", Array[V2Expression](left.get,
right.get)))
+ Some(new V2Predicate("<>", Array[V2Expression](left.get, right.get)))
} else {
None
}
- case Not(child) => generateExpression(child)
- .map(v => new GeneralScalarExpression("NOT", Array[V2Expression](v)))
+ case Not(child) => generateExpression(child, true) // NOT expects predicate
+ .map(v => new V2Predicate("NOT", Array[V2Expression](v)))
case UnaryMinus(child, true) => generateExpression(child)
.map(v => new GeneralScalarExpression("-", Array[V2Expression](v)))
case BitwiseNot(child) => generateExpression(child)
.map(v => new GeneralScalarExpression("~", Array[V2Expression](v)))
case CaseWhen(branches, elseValue) =>
- val conditions = branches.map(_._1).flatMap(generateExpression)
- val values = branches.map(_._2).flatMap(generateExpression)
+ val conditions = branches.map(_._1).flatMap(generateExpression(_, true))
+ val values = branches.map(_._2).flatMap(generateExpression(_, true))
if (conditions.length == branches.length && values.length ==
branches.length) {
val branchExpressions = conditions.zip(values).flatMap { case (c, v) =>
Seq[V2Expression](c, v)
}
if (elseValue.isDefined) {
- elseValue.flatMap(generateExpression).map { v =>
+ elseValue.flatMap(generateExpression(_)).map { v =>
val children = (branchExpressions :+ v).toArray[V2Expression]
// The children looks like [condition1, value1, ..., conditionN,
valueN, elseValue]
- new GeneralScalarExpression("CASE_WHEN", children)
+ if (isPredicate) {
+ new V2Predicate("CASE_WHEN", children)
Review comment:
To simplify the code, we can always return `V2Predicate` for case when.
It extends `GeneralScalarExpression` anyway.
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
##########
@@ -609,20 +545,39 @@ private[sql] object DataSourceV2Strategy {
}
protected[sql] def rebuildExpressionFromFilter(
- filter: V2Filter,
- translatedFilterToExpr: mutable.HashMap[V2Filter, Expression]):
Expression = {
- filter match {
+ predicate: Predicate,
+ translatedFilterToExpr: mutable.HashMap[Predicate, Expression]):
Expression = {
+ predicate match {
case and: V2And =>
- expressions.And(rebuildExpressionFromFilter(and.left,
translatedFilterToExpr),
- rebuildExpressionFromFilter(and.right, translatedFilterToExpr))
+ expressions.And(
+ rebuildExpressionFromFilter(and.left(), translatedFilterToExpr),
+ rebuildExpressionFromFilter(and.right(), translatedFilterToExpr))
case or: V2Or =>
- expressions.Or(rebuildExpressionFromFilter(or.left,
translatedFilterToExpr),
- rebuildExpressionFromFilter(or.right, translatedFilterToExpr))
+ expressions.Or(
+ rebuildExpressionFromFilter(or.left(), translatedFilterToExpr),
+ rebuildExpressionFromFilter(or.right(), translatedFilterToExpr))
case not: V2Not =>
- expressions.Not(rebuildExpressionFromFilter(not.child,
translatedFilterToExpr))
- case other =>
- translatedFilterToExpr.getOrElse(other,
- throw new IllegalStateException("Failed to rebuild Expression for
filter: " + filter))
+ expressions.Not(rebuildExpressionFromFilter(not.child(),
translatedFilterToExpr))
+ case _ =>
+ translatedFilterToExpr.getOrElse(predicate,
+ throw new IllegalStateException("Failed to rebuild Expression for
filter: " + predicate))
}
}
}
+
+/**
+ * Get the expression of DS V2 to represent catalyst predicate that can be
pushed down.
+ */
+case class PushablePredicate(nestedPredicatePushdownEnabled: Boolean) {
+ private val pushableColumn: PushableColumnBase =
PushableColumn(nestedPredicatePushdownEnabled)
+
+ def unapply(e: Expression): Option[Predicate] = e match {
+ case col @ pushableColumn(name) if col.dataType.isInstanceOf[BooleanType]
=>
Review comment:
We can remove this now?
##########
File path:
sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
##########
@@ -17,15 +17,21 @@
package org.apache.spark.sql.catalyst.util
-import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute,
BinaryComparison, BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr,
BitwiseXor, CaseWhen, Divide, EqualTo, Expression, IsNotNull, IsNull, Literal,
Multiply, Not, Or, Remainder, Subtract, UnaryMinus}
+import org.apache.spark.sql.catalyst.expressions.{Add, And, BinaryComparison,
BinaryOperator, BitwiseAnd, BitwiseNot, BitwiseOr, BitwiseXor, CaseWhen,
Contains, Divide, EndsWith, EqualTo, Expression, In, InSet, IsNotNull, IsNull,
Literal, Multiply, Not, Or, Predicate, Remainder, StartsWith, StringPredicate,
Subtract, UnaryMinus}
import org.apache.spark.sql.connector.expressions.{Expression => V2Expression,
FieldReference, GeneralScalarExpression, LiteralValue}
+import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse,
AlwaysTrue, Predicate => V2Predicate}
+import org.apache.spark.sql.execution.datasources.PushableColumn
+import org.apache.spark.sql.types.BooleanType
/**
* The builder to generate V2 expressions from catalyst expressions.
*/
-class V2ExpressionBuilder(e: Expression) {
+class V2ExpressionBuilder(e: Expression, nestedPredicatePushdownEnabled:
Boolean = false) {
- def build(): Option[V2Expression] = generateExpression(e)
+ val pushableColumn = PushableColumn(nestedPredicatePushdownEnabled)
+
+ def build(isPredicate: Boolean = false): Option[V2Expression] =
Review comment:
should `isPredicate` be a class constructor parameter instead?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]