This is an automated email from the ASF dual-hosted git repository.

yamamuro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 93ff690  [SPARK-27288][SQL] Pruning nested field in complex map key 
from object serializers
93ff690 is described below

commit 93ff69003b228abcf08da4488593f552e3a61665
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Wed Mar 27 19:40:14 2019 +0900

    [SPARK-27288][SQL] Pruning nested field in complex map key from object 
serializers
    
    ## What changes were proposed in this pull request?
    
    In the original PR #24158, pruning nested field in complex map key was not 
supported, because some methods in schema pruning did't support it at that 
moment. This is a followup to add it.
    
    ## How was this patch tested?
    
    Added tests.
    
    Closes #24220 from viirya/SPARK-26847-followup.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: Takeshi Yamamuro <yamam...@apache.org>
---
 .../apache/spark/sql/catalyst/optimizer/objects.scala | 13 ++++++++++---
 .../optimizer/ObjectSerializerPruningSuite.scala      |  5 +++--
 .../apache/spark/sql/DatasetOptimizationSuite.scala   | 19 ++++++++++++++++++-
 3 files changed, 31 insertions(+), 6 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
index 8e92421..c48bd8f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala
@@ -131,8 +131,8 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
         fields.map(f => collectStructType(f.dataType, structs))
       case ArrayType(elementType, _) =>
         collectStructType(elementType, structs)
-      case MapType(_, valueType, _) =>
-        // Because we can't select a field from struct in key, so we skip key 
type.
+      case MapType(keyType, valueType, _) =>
+        collectStructType(keyType, structs)
         collectStructType(valueType, structs)
       // We don't use UserDefinedType in those serializers.
       case _: UserDefinedType[_] =>
@@ -179,13 +179,20 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {
 
     val transformedSerializer = serializer.transformDown {
       case m: ExternalMapToCatalyst =>
+        val prunedKeyConverter = m.keyConverter.transformDown {
+          case s: CreateNamedStruct if structTypeIndex < 
prunedStructTypes.size =>
+            val prunedType = prunedStructTypes(structTypeIndex)
+            structTypeIndex += 1
+            pruneNamedStruct(s, prunedType)
+        }
         val prunedValueConverter = m.valueConverter.transformDown {
           case s: CreateNamedStruct if structTypeIndex < 
prunedStructTypes.size =>
             val prunedType = prunedStructTypes(structTypeIndex)
             structTypeIndex += 1
             pruneNamedStruct(s, prunedType)
         }
-        m.copy(valueConverter = alignNullTypeInIf(prunedValueConverter))
+        m.copy(keyConverter = alignNullTypeInIf(prunedKeyConverter),
+          valueConverter = alignNullTypeInIf(prunedValueConverter))
       case s: CreateNamedStruct if structTypeIndex < prunedStructTypes.size =>
         val prunedType = prunedStructTypes(structTypeIndex)
         structTypeIndex += 1
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala
index fb0f3a3..0dd4d6a 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ObjectSerializerPruningSuite.scala
@@ -60,8 +60,9 @@ class ObjectSerializerPruningSuite extends PlanTest {
       Seq(StructType.fromDDL("a struct<a:int, b:int>, b int"),
         StructType.fromDDL("a int, b int")),
       Seq(StructType.fromDDL("a int, b int, c string")),
-      Seq.empty[StructType],
-      Seq(StructType.fromDDL("c long, d string"))
+      Seq(StructType.fromDDL("a struct<a:int, b:int>, b int"),
+        StructType.fromDDL("a int, b int")),
+      Seq(StructType.fromDDL("a int, b int"), StructType.fromDDL("c long, d 
string"))
     )
 
     dataTypes.zipWithIndex.foreach { case (dt, idx) =>
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala
index 69634f8..cfbb343 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetOptimizationSuite.scala
@@ -51,7 +51,9 @@ class DatasetOptimizationSuite extends QueryTest with 
SharedSQLContext {
       val structs = serializer.collect {
         case c: CreateNamedStruct => Seq(c)
         case m: ExternalMapToCatalyst =>
-          m.valueConverter.collect {
+          m.keyConverter.collect {
+            case c: CreateNamedStruct => c
+          } ++ m.valueConverter.collect {
             case c: CreateNamedStruct => c
           }
       }.flatten
@@ -123,6 +125,21 @@ class DatasetOptimizationSuite extends QueryTest with 
SharedSQLContext {
       val df2 = mapDs.select("_1.k._2")
       testSerializer(df2, Seq(Seq("_2")))
       checkAnswer(df2, Seq(Row(11), Row(22), Row(33)))
+
+      val df3 = mapDs.select(expr("map_values(_1)._2[0]"))
+      testSerializer(df3, Seq(Seq("_2")))
+      checkAnswer(df3, Seq(Row(11), Row(22), Row(33)))
+    }
+  }
+
+  test("Pruned nested serializers: map of complex key") {
+    withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> 
"true") {
+      val mapData = Seq((Map((("1", 1), "a_1")), 1), (Map((("2", 2), "b_1")), 
2),
+        (Map((("3", 3), "c_1")), 3))
+      val mapDs = mapData.toDS().map(t => (t._1, t._2 + 1))
+      val df1 = mapDs.select(expr("map_keys(_1)._1[0]"))
+      testSerializer(df1, Seq(Seq("_1")))
+      checkAnswer(df1, Seq(Row("1"), Row("2"), Row("3")))
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to