Repository: spark
Updated Branches:
  refs/heads/master e54581134 -> ff4bb836a


[SPARK-25817][SQL] Dataset encoder should support combination of map and 
product type

## What changes were proposed in this pull request?

After https://github.com/apache/spark/pull/22745 , Dataset encoder supports the 
combination of java bean and map type. This PR is to fix the Scala side.

The reason why it didn't work before is, `CatalystToExternalMap` tries to get 
the data type of the input map expression, while it can be unresolved and its 
data type is known. To fix it, we can follow `UnresolvedMapObjects`, to create 
a `UnresolvedCatalystToExternalMap`, and only create `CatalystToExternalMap` 
when the input map expression is resolved and the data type is known.

## How was this patch tested?

enable a old test case

Closes #22812 from cloud-fan/map.

Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ff4bb836
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ff4bb836
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ff4bb836

Branch: refs/heads/master
Commit: ff4bb836aa768082df9227628dfd5a837f8e4f4e
Parents: e545811
Author: Wenchen Fan <[email protected]>
Authored: Sun Oct 28 13:33:26 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Sun Oct 28 13:33:26 2018 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/ScalaReflection.scala    | 15 +++---
 .../spark/sql/catalyst/analysis/Analyzer.scala  | 13 ++++-
 .../catalyst/encoders/ExpressionEncoder.scala   |  8 ++-
 .../catalyst/expressions/objects/objects.scala  | 56 ++++++++++----------
 .../spark/sql/DatasetPrimitiveSuite.scala       |  2 +-
 .../org/apache/spark/sql/DatasetSuite.scala     |  9 ++++
 6 files changed, 59 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ff4bb836/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 40074b3..912744e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -143,8 +143,7 @@ object ScalaReflection extends ScalaReflection {
       walkedTypePath: Seq[String]): Expression = expected match {
     case _: StructType => expr
     case _: ArrayType => expr
-    // TODO: ideally we should also skip MapType, but nested StructType inside 
MapType is rare and
-    // it's not trivial to support by-name resolution for StructType inside 
MapType.
+    case _: MapType => expr
     case _ => UpCast(expr, expected, walkedTypePath)
   }
 
@@ -163,8 +162,8 @@ object ScalaReflection extends ScalaReflection {
     val Schema(dataType, nullable) = schemaFor(tpe)
 
     // Assumes we are deserializing the first column of a row.
-    val input = upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType,
-      walkedTypePath)
+    val input = upCastToExpectedType(
+      GetColumnByOrdinal(0, dataType), dataType, walkedTypePath)
 
     val expr = deserializerFor(tpe, input, walkedTypePath)
     if (nullable) {
@@ -350,10 +349,10 @@ object ScalaReflection extends ScalaReflection {
         // TODO: add walked type path for map
         val TypeRef(_, _, Seq(keyType, valueType)) = t
 
-        CatalystToExternalMap(
+        UnresolvedCatalystToExternalMap(
+          path,
           p => deserializerFor(keyType, p, walkedTypePath),
           p => deserializerFor(valueType, p, walkedTypePath),
-          path,
           mirror.runtimeClass(t.typeSymbol.asClass)
         )
 
@@ -431,8 +430,8 @@ object ScalaReflection extends ScalaReflection {
     val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
 
     // The input object to `ExpressionEncoder` is located at first column of 
an row.
-    val inputObject = BoundReference(0, dataTypeFor(tpe),
-      nullable = !tpe.typeSymbol.asClass.isPrimitive)
+    val isPrimitive = tpe.typeSymbol.asClass.isPrimitive
+    val inputObject = BoundReference(0, dataTypeFor(tpe), nullable = 
!isPrimitive)
 
     serializerFor(inputObject, tpe, walkedTypePath)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/ff4bb836/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 63a07e3..c2d22c5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2384,14 +2384,23 @@ class Analyzer(
             case UnresolvedMapObjects(func, inputData, cls) if 
inputData.resolved =>
               inputData.dataType match {
                 case ArrayType(et, cn) =>
-                  val expr = MapObjects(func, inputData, et, cn, cls) 
transformUp {
+                  MapObjects(func, inputData, et, cn, cls) transformUp {
                     case UnresolvedExtractValue(child, fieldName) if 
child.resolved =>
                       ExtractValue(child, fieldName, resolver)
                   }
-                  expr
                 case other =>
                   throw new AnalysisException("need an array field but got " + 
other.catalogString)
               }
+            case u: UnresolvedCatalystToExternalMap if u.child.resolved =>
+              u.child.dataType match {
+                case _: MapType =>
+                  CatalystToExternalMap(u) transformUp {
+                    case UnresolvedExtractValue(child, fieldName) if 
child.resolved =>
+                      ExtractValue(child, fieldName, resolver)
+                  }
+                case other =>
+                  throw new AnalysisException("need a map field but got " + 
other.catalogString)
+              }
           }
           validateNestedTupleFields(result)
           result

http://git-wip-us.apache.org/repos/asf/spark/blob/ff4bb836/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 29f6136..2c8e81e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -119,10 +119,9 @@ object ExpressionEncoder {
     }
 
     val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) 
=>
-      val getColumnsByOrdinals = enc.objDeserializer.collect { case c: 
GetColumnByOrdinal => c }
-        .distinct
-      assert(getColumnsByOrdinals.size == 1, "object deserializer should have 
only one " +
-        s"`GetColumnByOrdinal`, but there are ${getColumnsByOrdinals.size}")
+      val getColExprs = enc.objDeserializer.collect { case c: 
GetColumnByOrdinal => c }.distinct
+      assert(getColExprs.size == 1, "object deserializer should have only one 
" +
+        s"`GetColumnByOrdinal`, but there are ${getColExprs.size}")
 
       val input = GetStructField(GetColumnByOrdinal(0, schema), index)
       val newDeserializer = enc.objDeserializer.transformUp {
@@ -216,7 +215,6 @@ case class ExpressionEncoder[T](
       }
       nullSafeSerializer match {
         case If(_: IsNull, _, s: CreateNamedStruct) => s
-        case s: CreateNamedStruct => s
         case _ =>
           throw new RuntimeException(s"class $clsName has unexpected 
serializer: $objSerializer")
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/ff4bb836/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index b6f9b47..4fd36a4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -30,14 +30,13 @@ import org.apache.spark.serializer._
 import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, 
ScalaReflection}
 import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName
-import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, 
UnresolvedException}
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.expressions.codegen.Block._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
GenericArrayData, MapData}
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
 import org.apache.spark.util.Utils
 
 /**
@@ -963,25 +962,32 @@ case class MapObjects private(
   }
 }
 
+/**
+ * Similar to [[UnresolvedMapObjects]], this is a placeholder of 
[[CatalystToExternalMap]].
+ *
+ * @param child An expression that when evaluated returns a map object.
+ * @param keyFunction The function applied on the key collection elements.
+ * @param valueFunction The function applied on the value collection elements.
+ * @param collClass The type of the resulting collection.
+ */
+case class UnresolvedCatalystToExternalMap(
+    child: Expression,
+    @transient keyFunction: Expression => Expression,
+    @transient valueFunction: Expression => Expression,
+    collClass: Class[_]) extends UnaryExpression with Unevaluable {
+
+  override lazy val resolved = false
+
+  override def dataType: DataType = ObjectType(collClass)
+}
+
 object CatalystToExternalMap {
   private val curId = new java.util.concurrent.atomic.AtomicInteger()
 
-  /**
-   * Construct an instance of CatalystToExternalMap case class.
-   *
-   * @param keyFunction The function applied on the key collection elements.
-   * @param valueFunction The function applied on the value collection 
elements.
-   * @param inputData An expression that when evaluated returns a map object.
-   * @param collClass The type of the resulting collection.
-   */
-  def apply(
-      keyFunction: Expression => Expression,
-      valueFunction: Expression => Expression,
-      inputData: Expression,
-      collClass: Class[_]): CatalystToExternalMap = {
+  def apply(u: UnresolvedCatalystToExternalMap): CatalystToExternalMap = {
     val id = curId.getAndIncrement()
     val keyLoopValue = s"CatalystToExternalMap_keyLoopValue$id"
-    val mapType = inputData.dataType.asInstanceOf[MapType]
+    val mapType = u.child.dataType.asInstanceOf[MapType]
     val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, 
nullable = false)
     val valueLoopValue = s"CatalystToExternalMap_valueLoopValue$id"
     val valueLoopIsNull = if (mapType.valueContainsNull) {
@@ -991,9 +997,9 @@ object CatalystToExternalMap {
     }
     val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, 
mapType.valueType)
     CatalystToExternalMap(
-      keyLoopValue, keyFunction(keyLoopVar),
-      valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar),
-      inputData, collClass)
+      keyLoopValue, u.keyFunction(keyLoopVar),
+      valueLoopValue, valueLoopIsNull, u.valueFunction(valueLoopVar),
+      u.child, u.collClass)
   }
 }
 
@@ -1090,15 +1096,9 @@ case class CatalystToExternalMap private(
     val tupleLoopValue = ctx.freshName("tupleLoopValue")
     val builderValue = ctx.freshName("builderValue")
 
-    val getLength = s"${genInputData.value}.numElements()"
-
     val keyArray = ctx.freshName("keyArray")
     val valueArray = ctx.freshName("valueArray")
-    val getKeyArray =
-      s"${classOf[ArrayData].getName} $keyArray = 
${genInputData.value}.keyArray();"
     val getKeyLoopVar = CodeGenerator.getValue(keyArray, 
inputDataType(mapType.keyType), loopIndex)
-    val getValueArray =
-      s"${classOf[ArrayData].getName} $valueArray = 
${genInputData.value}.valueArray();"
     val getValueLoopVar = CodeGenerator.getValue(
       valueArray, inputDataType(mapType.valueType), loopIndex)
 
@@ -1147,10 +1147,10 @@ case class CatalystToExternalMap private(
       ${CodeGenerator.javaType(dataType)} ${ev.value} = 
${CodeGenerator.defaultValue(dataType)};
 
       if (!${genInputData.isNull}) {
-        int $dataLength = $getLength;
+        int $dataLength = ${genInputData.value}.numElements();
         $constructBuilder
-        $getKeyArray
-        $getValueArray
+        ArrayData $keyArray = ${genInputData.value}.keyArray();
+        ArrayData $valueArray = ${genInputData.value}.valueArray();
 
         int $loopIndex = 0;
         while ($loopIndex < $dataLength) {

http://git-wip-us.apache.org/repos/asf/spark/blob/ff4bb836/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
index edcdd77..96a6792 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala
@@ -295,7 +295,7 @@ class DatasetPrimitiveSuite extends QueryTest with 
SharedSQLContext {
     checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 
2.toLong))
   }
 
-  ignore("SPARK-19104: map and product combinations") {
+  test("SPARK-25817: map and product combinations") {
     // Case classes
     checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2)))
     checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> 
MapClass(Map(2 -> 3))))

http://git-wip-us.apache.org/repos/asf/spark/blob/ff4bb836/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 27b3b3d..82d3b22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -164,6 +164,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
       Seq(ClassData("a", 2))))
   }
 
+  test("as map of case class - reorder fields by name") {
+    val df = spark.range(3).select(map(lit(1), 
struct($"id".cast("int").as("b"), lit("a").as("a"))))
+    val ds = df.as[Map[Int, ClassData]]
+    assert(ds.collect() === Array(
+      Map(1 -> ClassData("a", 0)),
+      Map(1 -> ClassData("a", 1)),
+      Map(1 -> ClassData("a", 2))))
+  }
+
   test("map") {
     val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
     checkDataset(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to