This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6b2492628c6 [SPARK-39857][SQL] V2ExpressionBuilder uses the wrong
LiteralValue data type for In predicate
6b2492628c6 is described below
commit 6b2492628c60fc1c4f70889c71cc3a9403a0adbc
Author: huaxingao <[email protected]>
AuthorDate: Mon Jul 25 08:11:19 2022 -0700
[SPARK-39857][SQL] V2ExpressionBuilder uses the wrong LiteralValue data
type for In predicate
### What changes were proposed in this pull request?
When building V2 `In` Predicate in `V2ExpressionBuilder`, `InSet.dataType`
(which is `BooleanType`) is used to build the `LiteralValue`,
`InSet.child.dataType` should be used instead.
### Why are the changes needed?
bug fix
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
new test
Closes #37271 from huaxingao/inset.
Authored-by: huaxingao <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../sql/catalyst/util/V2ExpressionBuilder.scala | 4 +-
.../datasources/v2/DataSourceV2StrategySuite.scala | 229 ++++++++++++++++++++-
2 files changed, 228 insertions(+), 5 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
index 70c85def45d..07d681a6616 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala
@@ -52,10 +52,10 @@ class V2ExpressionBuilder(e: Expression, isPredicate:
Boolean = false) {
} else {
Some(ref)
}
- case in @ InSet(child, hset) =>
+ case InSet(child, hset) =>
generateExpression(child).map { v =>
val children =
- (v +: hset.toSeq.map(elem => LiteralValue(elem,
in.dataType))).toArray[V2Expression]
+ (v +: hset.toSeq.map(elem => LiteralValue(elem,
child.dataType))).toArray[V2Expression]
new V2Predicate("IN", children)
}
// Because we only convert In to InSet in Optimizer when there are more
than certain
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
index c3f51bed269..5fefcadca3e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StrategySuite.scala
@@ -21,9 +21,10 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.connector.expressions.{FieldReference,
LiteralValue}
-import org.apache.spark.sql.connector.expressions.filter.Predicate
+import org.apache.spark.sql.connector.expressions.filter.{And => V2And, Not =>
V2Not, Or => V2Or, Predicate}
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType,
StructField, StructType}
+import org.apache.spark.unsafe.types.UTF8String
class DataSourceV2StrategySuite extends PlanTest with SharedSparkSession {
val attrInts = Seq(
@@ -55,8 +56,37 @@ class DataSourceV2StrategySuite extends PlanTest with
SharedSparkSession {
"a.b.cint" // three level nested field
))
- test("SPARK-39784: translate binary expression") { attrInts
- .foreach { case (attrInt, intColName) =>
+ val attrStrs = Seq(
+ $"cstr".string,
+ $"c.str".string,
+ GetStructField($"a".struct(StructType(
+ StructField("cint", IntegerType, nullable = true) ::
+ StructField("cstr", StringType, nullable = true) :: Nil)), 1, None),
+ GetStructField($"a".struct(StructType(
+ StructField("c.str", StringType, nullable = true) ::
+ StructField("cint", IntegerType, nullable = true) :: Nil)), 0, None),
+ GetStructField($"a.b".struct(StructType(
+ StructField("cint1", IntegerType, nullable = true) ::
+ StructField("cint2", IntegerType, nullable = true) ::
+ StructField("cstr", StringType, nullable = true) :: Nil)), 2, None),
+ GetStructField($"a.b".struct(StructType(
+ StructField("c.str", StringType, nullable = true) :: Nil)), 0, None),
+ GetStructField(GetStructField($"a".struct(StructType(
+ StructField("cint1", IntegerType, nullable = true) ::
+ StructField("b", StructType(StructField("cstr", StringType, nullable =
true) ::
+ StructField("cint2", IntegerType, nullable = true) :: Nil)) ::
Nil)), 1, None), 0, None)
+ ).zip(Seq(
+ "cstr",
+ "`c.str`", // single level field that contains `dot` in name
+ "a.cstr", // two level nested field
+ "a.`c.str`", // two level nested field, and nested level contains `dot`
+ "`a.b`.cstr", // two level nested field, and top level contains `dot`
+ "`a.b`.`c.str`", // two level nested field, and both levels contain `dot`
+ "a.b.cstr" // three level nested field
+ ))
+
+ test("translate simple expression") { attrInts.zip(attrStrs)
+ .foreach { case ((attrInt, intColName), (attrStr, strColName)) =>
testTranslateFilter(EqualTo(attrInt, 1),
Some(new Predicate("=", Array(FieldReference(intColName),
LiteralValue(1, IntegerType)))))
testTranslateFilter(EqualTo(1, attrInt),
@@ -86,6 +116,199 @@ class DataSourceV2StrategySuite extends PlanTest with
SharedSparkSession {
Some(new Predicate("<=", Array(FieldReference(intColName),
LiteralValue(1, IntegerType)))))
testTranslateFilter(LessThanOrEqual(1, attrInt),
Some(new Predicate(">=", Array(FieldReference(intColName),
LiteralValue(1, IntegerType)))))
+
+ testTranslateFilter(IsNull(attrInt),
+ Some(new Predicate("IS_NULL", Array(FieldReference(intColName)))))
+ testTranslateFilter(IsNotNull(attrInt),
+ Some(new Predicate("IS_NOT_NULL", Array(FieldReference(intColName)))))
+
+ testTranslateFilter(InSet(attrInt, Set(1, 2, 3)),
+ Some(new Predicate("IN", Array(FieldReference(intColName),
+ LiteralValue(1, IntegerType), LiteralValue(2, IntegerType),
+ LiteralValue(3, IntegerType)))))
+
+ testTranslateFilter(In(attrInt, Seq(1, 2, 3)),
+ Some(new Predicate("IN", Array(FieldReference(intColName),
+ LiteralValue(1, IntegerType), LiteralValue(2, IntegerType),
+ LiteralValue(3, IntegerType)))))
+
+ // cint > 1 AND cint < 10
+ testTranslateFilter(And(
+ GreaterThan(attrInt, 1),
+ LessThan(attrInt, 10)),
+ Some(new V2And(
+ new Predicate(">", Array(FieldReference(intColName), LiteralValue(1,
IntegerType))),
+ new Predicate("<", Array(FieldReference(intColName),
LiteralValue(10, IntegerType))))))
+
+ // cint >= 8 OR cint <= 2
+ testTranslateFilter(Or(
+ GreaterThanOrEqual(attrInt, 8),
+ LessThanOrEqual(attrInt, 2)),
+ Some(new V2Or(
+ new Predicate(">=", Array(FieldReference(intColName),
LiteralValue(8, IntegerType))),
+ new Predicate("<=", Array(FieldReference(intColName),
LiteralValue(2, IntegerType))))))
+
+ testTranslateFilter(Not(GreaterThanOrEqual(attrInt, 8)),
+ Some(new V2Not(new Predicate(">=", Array(FieldReference(intColName),
+ LiteralValue(8, IntegerType))))))
+
+ testTranslateFilter(StartsWith(attrStr, "a"),
+ Some(new Predicate("STARTS_WITH", Array(FieldReference(strColName),
+ LiteralValue(UTF8String.fromString("a"), StringType)))))
+
+ testTranslateFilter(EndsWith(attrStr, "a"),
+ Some(new Predicate("ENDS_WITH", Array(FieldReference(strColName),
+ LiteralValue(UTF8String.fromString("a"), StringType)))))
+
+ testTranslateFilter(Contains(attrStr, "a"),
+ Some(new Predicate("CONTAINS", Array(FieldReference(strColName),
+ LiteralValue(UTF8String.fromString("a"), StringType)))))
+ }
+ }
+
+ test("translate complex expression") {
+ attrInts.foreach { case (attrInt, intColName) =>
+
+ // ABS(cint) - 2 <= 1
+ testTranslateFilter(LessThanOrEqual(
+ // Expressions are not supported
+ // Functions such as 'Abs' are not supported
+ Subtract(Abs(attrInt), 2), 1), None)
+
+ // (cin1 > 1 AND cint < 10) OR (cint > 50 AND cint > 100)
+ testTranslateFilter(Or(
+ And(
+ GreaterThan(attrInt, 1),
+ LessThan(attrInt, 10)
+ ),
+ And(
+ GreaterThan(attrInt, 50),
+ LessThan(attrInt, 100))),
+ Some(new V2Or(
+ new V2And(
+ new Predicate(">", Array(FieldReference(intColName),
LiteralValue(1, IntegerType))),
+ new Predicate("<", Array(FieldReference(intColName),
LiteralValue(10, IntegerType)))),
+ new V2And(
+ new Predicate(">", Array(FieldReference(intColName),
LiteralValue(50, IntegerType))),
+ new Predicate("<", Array(FieldReference(intColName),
+ LiteralValue(100, IntegerType)))))
+ )
+ )
+
+ // (cint > 1 AND ABS(cint) < 10) OR (cint < 50 AND cint > 100)
+ testTranslateFilter(Or(
+ And(
+ GreaterThan(attrInt, 1),
+ // Functions such as 'Abs' are not supported
+ LessThan(Abs(attrInt), 10)
+ ),
+ And(
+ GreaterThan(attrInt, 50),
+ LessThan(attrInt, 100))), None)
+
+ // NOT ((cint <= 1 OR ABS(cint) >= 10) AND (cint <= 50 OR cint >= 100))
+ testTranslateFilter(Not(And(
+ Or(
+ LessThanOrEqual(attrInt, 1),
+ // Functions such as 'Abs' are not supported
+ GreaterThanOrEqual(Abs(attrInt), 10)
+ ),
+ Or(
+ LessThanOrEqual(attrInt, 50),
+ GreaterThanOrEqual(attrInt, 100)))), None)
+
+ // (cint = 1 OR cint = 10) OR (cint > 0 OR cint < -10)
+ testTranslateFilter(Or(
+ Or(
+ EqualTo(attrInt, 1),
+ EqualTo(attrInt, 10)
+ ),
+ Or(
+ GreaterThan(attrInt, 0),
+ LessThan(attrInt, -10))),
+ Some(new V2Or(
+ new V2Or(
+ new Predicate("=", Array(FieldReference(intColName),
LiteralValue(1, IntegerType))),
+ new Predicate("=", Array(FieldReference(intColName),
LiteralValue(10, IntegerType)))),
+ new V2Or(
+ new Predicate(">", Array(FieldReference(intColName),
LiteralValue(0, IntegerType))),
+ new Predicate("<", Array(FieldReference(intColName),
LiteralValue(-10, IntegerType)))))
+ )
+ )
+
+ // (cint = 1 OR ABS(cint) = 10) OR (cint > 0 OR cint < -10)
+ testTranslateFilter(Or(
+ Or(
+ EqualTo(attrInt, 1),
+ // Functions such as 'Abs' are not supported
+ EqualTo(Abs(attrInt), 10)
+ ),
+ Or(
+ GreaterThan(attrInt, 0),
+ LessThan(attrInt, -10))), None)
+
+ // In end-to-end testing, conjunctive predicate should has been split
+ // before reaching DataSourceStrategy.translateFilter.
+ // This is for UT purpose to test each [[case]].
+ // (cint > 1 AND cint < 10) AND (cint = 6 AND cint IS NOT NULL)
+ testTranslateFilter(And(
+ And(
+ GreaterThan(attrInt, 1),
+ LessThan(attrInt, 10)
+ ),
+ And(
+ EqualTo(attrInt, 6),
+ IsNotNull(attrInt))),
+ Some(new V2And(
+ new V2And(
+ new Predicate(">", Array(FieldReference(intColName),
LiteralValue(1, IntegerType))),
+ new Predicate("<", Array(FieldReference(intColName),
LiteralValue(10, IntegerType)))),
+ new V2And(
+ new Predicate("=", Array(FieldReference(intColName),
LiteralValue(6, IntegerType))),
+ new Predicate("IS_NOT_NULL", Array(FieldReference(intColName)))))
+ )
+ )
+
+ // (cint > 1 AND cint < 10) AND (ABS(cint) = 6 AND cint IS NOT NULL)
+ testTranslateFilter(And(
+ And(
+ GreaterThan(attrInt, 1),
+ LessThan(attrInt, 10)
+ ),
+ And(
+ // Functions such as 'Abs' are not supported
+ EqualTo(Abs(attrInt), 6),
+ IsNotNull(attrInt))), None)
+
+ // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
+ testTranslateFilter(And(
+ Or(
+ GreaterThan(attrInt, 1),
+ LessThan(attrInt, 10)
+ ),
+ Or(
+ EqualTo(attrInt, 6),
+ IsNotNull(attrInt))),
+ Some(new V2And(
+ new V2Or(
+ new Predicate(">", Array(FieldReference(intColName),
LiteralValue(1, IntegerType))),
+ new Predicate("<", Array(FieldReference(intColName),
LiteralValue(10, IntegerType)))),
+ new V2Or(
+ new Predicate("=", Array(FieldReference(intColName),
LiteralValue(6, IntegerType))),
+ new Predicate("IS_NOT_NULL", Array(FieldReference(intColName)))))
+ )
+ )
+
+ // (cint > 1 OR cint < 10) AND (cint = 6 OR cint IS NOT NULL)
+ testTranslateFilter(And(
+ Or(
+ GreaterThan(attrInt, 1),
+ LessThan(attrInt, 10)
+ ),
+ Or(
+ // Functions such as 'Abs' are not supported
+ EqualTo(Abs(attrInt), 6),
+ IsNotNull(attrInt))), None)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]