Github user liancheng commented on a diff in the pull request:
https://github.com/apache/spark/pull/13846#discussion_r68753763
--- Diff:
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
---
@@ -23,54 +23,111 @@ import
org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.{encoderFor,
ExpressionEncoder}
+import org.apache.spark.sql.catalyst.expressions.{BoundReference,
ReferenceToExpressions}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation,
LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.types.BooleanType
+import org.apache.spark.sql.types.{BooleanType, ObjectType}
class TypedFilterOptimizationSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("EliminateSerialization", FixedPoint(50),
EliminateSerialization) ::
- Batch("EmbedSerializerInFilter", FixedPoint(50),
- EmbedSerializerInFilter) :: Nil
+ Batch("CombineTypedFilters", FixedPoint(50),
+ CombineTypedFilters) :: Nil
}
implicit private def productEncoder[T <: Product : TypeTag] =
ExpressionEncoder[T]()
- test("back to back filter") {
+ test("filter after serialize") {
val input = LocalRelation('_1.int, '_2.int)
- val f1 = (i: (Int, Int)) => i._1 > 0
- val f2 = (i: (Int, Int)) => i._2 > 0
+ val f = (i: (Int, Int)) => i._1 > 0
- val query = input.filter(f1).filter(f2).analyze
+ val query = input
+ .deserialize[(Int, Int)]
+ .serialize[(Int, Int)]
+ .filter(f).analyze
val optimized = Optimize.execute(query)
- val expected = input.deserialize[(Int, Int)]
- .where(callFunction(f1, BooleanType, 'obj))
- .select('obj.as("obj"))
- .where(callFunction(f2, BooleanType, 'obj))
+ val expected = input
+ .deserialize[(Int, Int)]
+ .where(callFunction(f, BooleanType, 'obj))
.serialize[(Int, Int)].analyze
comparePlans(optimized, expected)
}
- // TODO: Remove this after we completely fix SPARK-15632 by adding
optimization rules
- // for typed filters.
- ignore("embed deserializer in typed filter condition if there is only
one filter") {
+ test("filter after serialize with object change") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val f = (i: OtherTuple) => i._1 > 0
+
+ val query = input
+ .deserialize[(Int, Int)]
+ .serialize[(Int, Int)]
+ .filter(f).analyze
+ val optimized = Optimize.execute(query)
+ comparePlans(optimized, query)
+ }
+
+ test("filter before deserialize") {
val input = LocalRelation('_1.int, '_2.int)
val f = (i: (Int, Int)) => i._1 > 0
- val query = input.filter(f).analyze
+ val query = input
+ .filter(f)
+ .deserialize[(Int, Int)]
+ .serialize[(Int, Int)].analyze
+
+ val optimized = Optimize.execute(query)
+
+ val expected = input
+ .deserialize[(Int, Int)]
+ .where(callFunction(f, BooleanType, 'obj))
+ .serialize[(Int, Int)].analyze
+
+ comparePlans(optimized, expected)
+ }
+
+ test("filter before deserialize with object change") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val f = (i: OtherTuple) => i._1 > 0
+
+ val query = input
+ .filter(f)
+ .deserialize[(Int, Int)]
+ .serialize[(Int, Int)].analyze
+ val optimized = Optimize.execute(query)
+ comparePlans(optimized, query)
+ }
+
+ test("back to back filter") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val f1 = (i: (Int, Int)) => i._1 > 0
+ val f2 = (i: (Int, Int)) => i._2 > 0
+
+ val query = input.filter(f1).filter(f2).analyze
val optimized = Optimize.execute(query)
val deserializer = UnresolvedDeserializer(encoderFor[(Int,
Int)].deserializer)
- val condition = callFunction(f, BooleanType, deserializer)
- val expected = input.where(condition).select('_1.as("_1"),
'_2.as("_2")).analyze
+ val boundReference = BoundReference(0, ObjectType(classOf[(Int,
Int)]), nullable = false)
+ val callFunc1 = callFunction(f1, BooleanType, boundReference)
+ val callFunc2 = callFunction(f2, BooleanType, boundReference)
+ val condition = ReferenceToExpressions(callFunc2 && callFunc1,
deserializer :: Nil)
+ val expected = input.where(condition).analyze
comparePlans(optimized, expected)
}
+
+ test("back to back filter with object change") {
--- End diff --
Nit: "back to back filters with different object types"
---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at [email protected] or file a JIRA ticket
with INFRA.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]