Repository: spark Updated Branches: refs/heads/master e54581134 -> ff4bb836a
[SPARK-25817][SQL] Dataset encoder should support combination of map and product type ## What changes were proposed in this pull request? After https://github.com/apache/spark/pull/22745 , Dataset encoder supports the combination of java bean and map type. This PR is to fix the Scala side. The reason why it didn't work before is, `CatalystToExternalMap` tries to get the data type of the input map expression, while it can be unresolved and its data type is known. To fix it, we can follow `UnresolvedMapObjects`, to create a `UnresolvedCatalystToExternalMap`, and only create `CatalystToExternalMap` when the input map expression is resolved and the data type is known. ## How was this patch tested? enable a old test case Closes #22812 from cloud-fan/map. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ff4bb836 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ff4bb836 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ff4bb836 Branch: refs/heads/master Commit: ff4bb836aa768082df9227628dfd5a837f8e4f4e Parents: e545811 Author: Wenchen Fan <[email protected]> Authored: Sun Oct 28 13:33:26 2018 +0800 Committer: Wenchen Fan <[email protected]> Committed: Sun Oct 28 13:33:26 2018 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/ScalaReflection.scala | 15 +++--- .../spark/sql/catalyst/analysis/Analyzer.scala | 13 ++++- .../catalyst/encoders/ExpressionEncoder.scala | 8 ++- .../catalyst/expressions/objects/objects.scala | 56 ++++++++++---------- .../spark/sql/DatasetPrimitiveSuite.scala | 2 +- .../org/apache/spark/sql/DatasetSuite.scala | 9 ++++ 6 files changed, 59 insertions(+), 44 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ff4bb836/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 40074b3..912744e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -143,8 +143,7 @@ object ScalaReflection extends ScalaReflection { walkedTypePath: Seq[String]): Expression = expected match { case _: StructType => expr case _: ArrayType => expr - // TODO: ideally we should also skip MapType, but nested StructType inside MapType is rare and - // it's not trivial to support by-name resolution for StructType inside MapType. + case _: MapType => expr case _ => UpCast(expr, expected, walkedTypePath) } @@ -163,8 +162,8 @@ object ScalaReflection extends ScalaReflection { val Schema(dataType, nullable) = schemaFor(tpe) // Assumes we are deserializing the first column of a row. - val input = upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, - walkedTypePath) + val input = upCastToExpectedType( + GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) val expr = deserializerFor(tpe, input, walkedTypePath) if (nullable) { @@ -350,10 +349,10 @@ object ScalaReflection extends ScalaReflection { // TODO: add walked type path for map val TypeRef(_, _, Seq(keyType, valueType)) = t - CatalystToExternalMap( + UnresolvedCatalystToExternalMap( + path, p => deserializerFor(keyType, p, walkedTypePath), p => deserializerFor(valueType, p, walkedTypePath), - path, mirror.runtimeClass(t.typeSymbol.asClass) ) @@ -431,8 +430,8 @@ object ScalaReflection extends ScalaReflection { val walkedTypePath = s"""- root class: "$clsName"""" :: Nil // The input object to `ExpressionEncoder` is located at first column of an row. - val inputObject = BoundReference(0, dataTypeFor(tpe), - nullable = !tpe.typeSymbol.asClass.isPrimitive) + val isPrimitive = tpe.typeSymbol.asClass.isPrimitive + val inputObject = BoundReference(0, dataTypeFor(tpe), nullable = !isPrimitive) serializerFor(inputObject, tpe, walkedTypePath) } http://git-wip-us.apache.org/repos/asf/spark/blob/ff4bb836/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 63a07e3..c2d22c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -2384,14 +2384,23 @@ class Analyzer( case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { case ArrayType(et, cn) => - val expr = MapObjects(func, inputData, et, cn, cls) transformUp { + MapObjects(func, inputData, et, cn, cls) transformUp { case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } - expr case other => throw new AnalysisException("need an array field but got " + other.catalogString) } + case u: UnresolvedCatalystToExternalMap if u.child.resolved => + u.child.dataType match { + case _: MapType => + CatalystToExternalMap(u) transformUp { + case UnresolvedExtractValue(child, fieldName) if child.resolved => + ExtractValue(child, fieldName, resolver) + } + case other => + throw new AnalysisException("need a map field but got " + other.catalogString) + } } validateNestedTupleFields(result) result http://git-wip-us.apache.org/repos/asf/spark/blob/ff4bb836/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 29f6136..2c8e81e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -119,10 +119,9 @@ object ExpressionEncoder { } val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => - val getColumnsByOrdinals = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c } - .distinct - assert(getColumnsByOrdinals.size == 1, "object deserializer should have only one " + - s"`GetColumnByOrdinal`, but there are ${getColumnsByOrdinals.size}") + val getColExprs = enc.objDeserializer.collect { case c: GetColumnByOrdinal => c }.distinct + assert(getColExprs.size == 1, "object deserializer should have only one " + + s"`GetColumnByOrdinal`, but there are ${getColExprs.size}") val input = GetStructField(GetColumnByOrdinal(0, schema), index) val newDeserializer = enc.objDeserializer.transformUp { @@ -216,7 +215,6 @@ case class ExpressionEncoder[T]( } nullSafeSerializer match { case If(_: IsNull, _, s: CreateNamedStruct) => s - case s: CreateNamedStruct => s case _ => throw new RuntimeException(s"class $clsName has unexpected serializer: $objSerializer") } http://git-wip-us.apache.org/repos/asf/spark/blob/ff4bb836/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index b6f9b47..4fd36a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -30,14 +30,13 @@ import org.apache.spark.serializer._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.ScalaReflection.universe.TermName -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedException} import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.Utils /** @@ -963,25 +962,32 @@ case class MapObjects private( } } +/** + * Similar to [[UnresolvedMapObjects]], this is a placeholder of [[CatalystToExternalMap]]. + * + * @param child An expression that when evaluated returns a map object. + * @param keyFunction The function applied on the key collection elements. + * @param valueFunction The function applied on the value collection elements. + * @param collClass The type of the resulting collection. + */ +case class UnresolvedCatalystToExternalMap( + child: Expression, + @transient keyFunction: Expression => Expression, + @transient valueFunction: Expression => Expression, + collClass: Class[_]) extends UnaryExpression with Unevaluable { + + override lazy val resolved = false + + override def dataType: DataType = ObjectType(collClass) +} + object CatalystToExternalMap { private val curId = new java.util.concurrent.atomic.AtomicInteger() - /** - * Construct an instance of CatalystToExternalMap case class. - * - * @param keyFunction The function applied on the key collection elements. - * @param valueFunction The function applied on the value collection elements. - * @param inputData An expression that when evaluated returns a map object. - * @param collClass The type of the resulting collection. - */ - def apply( - keyFunction: Expression => Expression, - valueFunction: Expression => Expression, - inputData: Expression, - collClass: Class[_]): CatalystToExternalMap = { + def apply(u: UnresolvedCatalystToExternalMap): CatalystToExternalMap = { val id = curId.getAndIncrement() val keyLoopValue = s"CatalystToExternalMap_keyLoopValue$id" - val mapType = inputData.dataType.asInstanceOf[MapType] + val mapType = u.child.dataType.asInstanceOf[MapType] val keyLoopVar = LambdaVariable(keyLoopValue, "", mapType.keyType, nullable = false) val valueLoopValue = s"CatalystToExternalMap_valueLoopValue$id" val valueLoopIsNull = if (mapType.valueContainsNull) { @@ -991,9 +997,9 @@ object CatalystToExternalMap { } val valueLoopVar = LambdaVariable(valueLoopValue, valueLoopIsNull, mapType.valueType) CatalystToExternalMap( - keyLoopValue, keyFunction(keyLoopVar), - valueLoopValue, valueLoopIsNull, valueFunction(valueLoopVar), - inputData, collClass) + keyLoopValue, u.keyFunction(keyLoopVar), + valueLoopValue, valueLoopIsNull, u.valueFunction(valueLoopVar), + u.child, u.collClass) } } @@ -1090,15 +1096,9 @@ case class CatalystToExternalMap private( val tupleLoopValue = ctx.freshName("tupleLoopValue") val builderValue = ctx.freshName("builderValue") - val getLength = s"${genInputData.value}.numElements()" - val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") - val getKeyArray = - s"${classOf[ArrayData].getName} $keyArray = ${genInputData.value}.keyArray();" val getKeyLoopVar = CodeGenerator.getValue(keyArray, inputDataType(mapType.keyType), loopIndex) - val getValueArray = - s"${classOf[ArrayData].getName} $valueArray = ${genInputData.value}.valueArray();" val getValueLoopVar = CodeGenerator.getValue( valueArray, inputDataType(mapType.valueType), loopIndex) @@ -1147,10 +1147,10 @@ case class CatalystToExternalMap private( ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${genInputData.isNull}) { - int $dataLength = $getLength; + int $dataLength = ${genInputData.value}.numElements(); $constructBuilder - $getKeyArray - $getValueArray + ArrayData $keyArray = ${genInputData.value}.keyArray(); + ArrayData $valueArray = ${genInputData.value}.valueArray(); int $loopIndex = 0; while ($loopIndex < $dataLength) { http://git-wip-us.apache.org/repos/asf/spark/blob/ff4bb836/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index edcdd77..96a6792 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -295,7 +295,7 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(LHMap("test" -> 2.toLong)).toDS(), LHMap("test" -> 2.toLong)) } - ignore("SPARK-19104: map and product combinations") { + test("SPARK-25817: map and product combinations") { // Case classes checkDataset(Seq(MapClass(Map(1 -> 2))).toDS(), MapClass(Map(1 -> 2))) checkDataset(Seq(Map(1 -> MapClass(Map(2 -> 3)))).toDS(), Map(1 -> MapClass(Map(2 -> 3)))) http://git-wip-us.apache.org/repos/asf/spark/blob/ff4bb836/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 27b3b3d..82d3b22 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -164,6 +164,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(ClassData("a", 2)))) } + test("as map of case class - reorder fields by name") { + val df = spark.range(3).select(map(lit(1), struct($"id".cast("int").as("b"), lit("a").as("a")))) + val ds = df.as[Map[Int, ClassData]] + assert(ds.collect() === Array( + Map(1 -> ClassData("a", 0)), + Map(1 -> ClassData("a", 1)), + Map(1 -> ClassData("a", 2)))) + } + test("map") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDataset( --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
