Repository: spark
Updated Branches:
refs/heads/master 94145786a -> 6339c8c2c
[SPARK-24762][SQL] Enable Option of Product encoders
## What changes were proposed in this pull request?
SparkSQL doesn't support to encode `Option[Product]` as a top-level row now,
because in SparkSQL entire top-level row can't be null.
However for use cases like Aggregator, it is reasonable to use
`Option[Product]` as buffer and output column types. Due to above limitation,
we don't do it for now.
This patch proposes to encode `Option[Product]` at top-level as single struct
column. So we can work around the issue that entire top-level row can't be null.
To summarize encoding of `Product` and `Option[Product]`.
For `Product`, 1. at root level, the schema is all fields are flatten it into
multiple columns. The `Product ` can't be null, otherwise it throws an
exception.
```scala
val df = Seq((1 -> "a"), (2 -> "b")).toDF()
df.printSchema()
root
|-- _1: integer (nullable = false)
|-- _2: string (nullable = true)
```
2. At non-root level, `Product` is a struct type column.
```scala
val df = Seq((1, (1 -> "a")), (2, (2 -> "b")), (3, null)).toDF()
df.printSchema()
root
|-- _1: integer (nullable = false)
|-- _2: struct (nullable = true)
| |-- _1: integer (nullable = false)
| |-- _2: string (nullable = true)
```
For `Option[Product]`, 1. it was not supported at root level. After this
change, it is a struct type column.
```scala
val df = Seq(Some(1 -> "a"), Some(2 -> "b"), None).toDF()
df.printSchema
root
|-- value: struct (nullable = true)
| |-- _1: integer (nullable = false)
| |-- _2: string (nullable = true)
```
2. At non-root level, it is also a struct type column.
```scala
val df = Seq((1, Some(1 -> "a")), (2, Some(2 -> "b")), (3, None)).toDF()
df.printSchema
root
|-- _1: integer (nullable = false)
|-- _2: struct (nullable = true)
| |-- _1: integer (nullable = false)
| |-- _2: string (nullable = true)
```
3. For use case like Aggregator, it was not supported too. After this change,
we support to use `Option[Product]` as buffer/output column type.
```scala
val df = Seq(
OptionBooleanIntData("bob", Some((true, 1))),
OptionBooleanIntData("bob", Some((false, 2))),
OptionBooleanIntData("bob", None)).toDF()
val group = df
.groupBy("name")
.agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood"))
group.printSchema
root
|-- name: string (nullable = true)
|-- isGood: struct (nullable = true)
| |-- _1: boolean (nullable = false)
| |-- _2: integer (nullable = false)
```
The buffer and output type of `OptionBooleanIntAggregator` is both
`Option[(Boolean, Int)`.
## How was this patch tested?
Added test.
Closes #21732 from viirya/SPARK-24762.
Authored-by: Liang-Chi Hsieh <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6339c8c2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6339c8c2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6339c8c2
Branch: refs/heads/master
Commit: 6339c8c2c6b80a85e4ad6a7fa7595cf567a1113e
Parents: 9414578
Author: Liang-Chi Hsieh <[email protected]>
Authored: Mon Nov 26 11:13:28 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Mon Nov 26 11:13:28 2018 +0800
----------------------------------------------------------------------
.../catalyst/encoders/ExpressionEncoder.scala | 32 +++++---
.../scala/org/apache/spark/sql/Dataset.scala | 10 +--
.../spark/sql/KeyValueGroupedDataset.scala | 2 +-
.../aggregate/TypedAggregateExpression.scala | 18 ++---
.../spark/sql/DatasetAggregatorSuite.scala | 64 +++++++++++++++-
.../org/apache/spark/sql/DatasetSuite.scala | 77 +++++++++++++++++---
6 files changed, 163 insertions(+), 40 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/6339c8c2/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 592520c..d019924 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -49,15 +49,6 @@ object ExpressionEncoder {
val mirror = ScalaReflection.mirror
val tpe = typeTag[T].in(mirror).tpe
- if (ScalaReflection.optionOfProductType(tpe)) {
- throw new UnsupportedOperationException(
- "Cannot create encoder for Option of Product type, because Product
type is represented " +
- "as a row, and the entire row can not be null in Spark SQL like
normal databases. " +
- "You can wrap your type with Tuple1 if you do want top level null
Product objects, " +
- "e.g. instead of creating `Dataset[Option[MyClass]]`, you can do
something like " +
- "`val ds: Dataset[Tuple1[MyClass]] = Seq(Tuple1(MyClass(...)),
Tuple1(null)).toDS`")
- }
-
val cls = mirror.runtimeClass(tpe)
val serializer = ScalaReflection.serializerForType(tpe)
val deserializer = ScalaReflection.deserializerForType(tpe)
@@ -198,7 +189,7 @@ case class ExpressionEncoder[T](
val serializer: Seq[NamedExpression] = {
val clsName = Utils.getSimpleName(clsTag.runtimeClass)
- if (isSerializedAsStruct) {
+ if (isSerializedAsStructForTopLevel) {
val nullSafeSerializer = objSerializer.transformUp {
case r: BoundReference =>
// For input object of Product type, we can't encode it to row if
it's null, as Spark SQL
@@ -213,6 +204,9 @@ case class ExpressionEncoder[T](
} else {
// For other input objects like primitive, array, map, etc., we
construct a struct to wrap
// the serializer which is a column of an row.
+ //
+ // Note: Because Spark SQL doesn't allow top-level row to be null, to
encode
+ // top-level Option[Product] type, we make it as a top-level struct
column.
CreateNamedStruct(Literal("value") :: objSerializer :: Nil)
}
}.flatten
@@ -226,7 +220,7 @@ case class ExpressionEncoder[T](
* `GetColumnByOrdinal` with corresponding ordinal.
*/
val deserializer: Expression = {
- if (isSerializedAsStruct) {
+ if (isSerializedAsStructForTopLevel) {
// We serialized this kind of objects to root-level row. The input of
general deserializer
// is a `GetColumnByOrdinal(0)` expression to extract first column of a
row. We need to
// transform attributes accessors.
@@ -253,10 +247,24 @@ case class ExpressionEncoder[T](
})
/**
- * Returns true if the type `T` is serialized as a struct.
+ * Returns true if the type `T` is serialized as a struct by `objSerializer`.
*/
def isSerializedAsStruct: Boolean =
objSerializer.dataType.isInstanceOf[StructType]
+ /**
+ * Returns true if the type `T` is an `Option` type.
+ */
+ def isOptionType: Boolean =
classOf[Option[_]].isAssignableFrom(clsTag.runtimeClass)
+
+ /**
+ * If the type `T` is serialized as a struct, when it is encoded to a Spark
SQL row, fields in
+ * the struct are naturally mapped to top-level columns in a row. In other
words, the serialized
+ * struct is flattened to row. But in case of the `T` is also an `Option`
type, it can't be
+ * flattened to top-level row, because in Spark SQL top-level row can't be
null. This method
+ * returns true if `T` is serialized as struct and is not `Option` type.
+ */
+ def isSerializedAsStructForTopLevel: Boolean = isSerializedAsStruct &&
!isOptionType
+
// serializer expressions are used to encode an object to a row, while the
object is usually an
// intermediate value produced inside an operator, not from the output of
the child operator. This
// is quite different from normal expressions, and `AttributeReference`
doesn't work here
http://git-wip-us.apache.org/repos/asf/spark/blob/6339c8c2/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index f361bde..b10d66d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1093,7 +1093,7 @@ class Dataset[T] private[sql](
// Note that we do this before joining them, to enable the join operator
to return null for one
// side, in cases like outer-join.
val left = {
- val combined = if (!this.exprEnc.isSerializedAsStruct) {
+ val combined = if (!this.exprEnc.isSerializedAsStructForTopLevel) {
assert(joined.left.output.length == 1)
Alias(joined.left.output.head, "_1")()
} else {
@@ -1103,7 +1103,7 @@ class Dataset[T] private[sql](
}
val right = {
- val combined = if (!other.exprEnc.isSerializedAsStruct) {
+ val combined = if (!other.exprEnc.isSerializedAsStructForTopLevel) {
assert(joined.right.output.length == 1)
Alias(joined.right.output.head, "_2")()
} else {
@@ -1116,14 +1116,14 @@ class Dataset[T] private[sql](
// combine the outputs of each join side.
val conditionExpr = joined.condition.get transformUp {
case a: Attribute if joined.left.outputSet.contains(a) =>
- if (!this.exprEnc.isSerializedAsStruct) {
+ if (!this.exprEnc.isSerializedAsStructForTopLevel) {
left.output.head
} else {
val index = joined.left.output.indexWhere(_.exprId == a.exprId)
GetStructField(left.output.head, index)
}
case a: Attribute if joined.right.outputSet.contains(a) =>
- if (!other.exprEnc.isSerializedAsStruct) {
+ if (!other.exprEnc.isSerializedAsStructForTopLevel) {
right.output.head
} else {
val index = joined.right.output.indexWhere(_.exprId == a.exprId)
@@ -1396,7 +1396,7 @@ class Dataset[T] private[sql](
implicit val encoder = c1.encoder
val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named
:: Nil, logicalPlan)
- if (!encoder.isSerializedAsStruct) {
+ if (!encoder.isSerializedAsStructForTopLevel) {
new Dataset[U1](sparkSession, project, encoder)
} else {
// Flattens inner fields of U1
http://git-wip-us.apache.org/repos/asf/spark/blob/6339c8c2/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 2d849c6..a3cbea9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -458,7 +458,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
val encoders = columns.map(_.encoder)
val namedColumns =
columns.map(_.withInputType(vExprEnc, dataAttributes).named)
- val keyColumn = if (!kExprEnc.isSerializedAsStruct) {
+ val keyColumn = if (!kExprEnc.isSerializedAsStructForTopLevel) {
assert(groupingAttributes.length == 1)
if (SQLConf.get.nameNonStructGroupingKeyAsValue) {
groupingAttributes.head
http://git-wip-us.apache.org/repos/asf/spark/blob/6339c8c2/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 39200ec..b757529 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -40,9 +40,9 @@ object TypedAggregateExpression {
val outputEncoder = encoderFor[OUT]
val outputType = outputEncoder.objSerializer.dataType
- // Checks if the buffer object is simple, i.e. the buffer encoder is flat
and the serializer
- // expression is an alias of `BoundReference`, which means the buffer
object doesn't need
- // serialization.
+ // Checks if the buffer object is simple, i.e. the `BUF` type is not
serialized as struct
+ // and the serializer expression is an alias of `BoundReference`, which
means the buffer
+ // object doesn't need serialization.
val isSimpleBuffer = {
bufferSerializer.head match {
case Alias(_: BoundReference, _) if
!bufferEncoder.isSerializedAsStruct => true
@@ -76,7 +76,7 @@ object TypedAggregateExpression {
None,
bufferSerializer,
bufferEncoder.resolveAndBind().deserializer,
- outputEncoder.serializer,
+ outputEncoder.objSerializer,
outputType,
outputEncoder.objSerializer.nullable)
}
@@ -213,7 +213,7 @@ case class ComplexTypedAggregateExpression(
inputSchema: Option[StructType],
bufferSerializer: Seq[NamedExpression],
bufferDeserializer: Expression,
- outputSerializer: Seq[Expression],
+ outputSerializer: Expression,
dataType: DataType,
nullable: Boolean,
mutableAggBufferOffset: Int = 0,
@@ -245,13 +245,7 @@ case class ComplexTypedAggregateExpression(
aggregator.merge(buffer, input)
}
- private lazy val resultObjToRow = dataType match {
- case _: StructType =>
- UnsafeProjection.create(CreateStruct(outputSerializer))
- case _ =>
- assert(outputSerializer.length == 1)
- UnsafeProjection.create(outputSerializer.head)
- }
+ private lazy val resultObjToRow = UnsafeProjection.create(outputSerializer)
override def eval(buffer: Any): Any = {
val resultObj = aggregator.finish(buffer)
http://git-wip-us.apache.org/repos/asf/spark/blob/6339c8c2/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 538ea3c..97c3f35 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -22,7 +22,7 @@ import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.StringType
+import org.apache.spark.sql.types.{BooleanType, IntegerType, StringType,
StructType}
object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long,
Long)] {
@@ -149,6 +149,7 @@ object VeryComplexResultAgg extends Aggregator[Row, String,
ComplexAggData] {
case class OptionBooleanData(name: String, isGood: Option[Boolean])
+case class OptionBooleanIntData(name: String, isGood: Option[(Boolean, Int)])
case class OptionBooleanAggregator(colName: String)
extends Aggregator[Row, Option[Boolean], Option[Boolean]] {
@@ -183,6 +184,43 @@ case class OptionBooleanAggregator(colName: String)
def OptionalBoolEncoder: Encoder[Option[Boolean]] = ExpressionEncoder()
}
+case class OptionBooleanIntAggregator(colName: String)
+ extends Aggregator[Row, Option[(Boolean, Int)], Option[(Boolean, Int)]] {
+
+ override def zero: Option[(Boolean, Int)] = None
+
+ override def reduce(buffer: Option[(Boolean, Int)], row: Row):
Option[(Boolean, Int)] = {
+ val index = row.fieldIndex(colName)
+ val value = if (row.isNullAt(index)) {
+ Option.empty[(Boolean, Int)]
+ } else {
+ val nestedRow = row.getStruct(index)
+ Some((nestedRow.getBoolean(0), nestedRow.getInt(1)))
+ }
+ merge(buffer, value)
+ }
+
+ override def merge(
+ b1: Option[(Boolean, Int)],
+ b2: Option[(Boolean, Int)]): Option[(Boolean, Int)] = {
+ if ((b1.isDefined && b1.get._1) || (b2.isDefined && b2.get._1)) {
+ val newInt = b1.map(_._2).getOrElse(0) + b2.map(_._2).getOrElse(0)
+ Some((true, newInt))
+ } else if (b1.isDefined) {
+ b1
+ } else {
+ b2
+ }
+ }
+
+ override def finish(reduction: Option[(Boolean, Int)]): Option[(Boolean,
Int)] = reduction
+
+ override def bufferEncoder: Encoder[Option[(Boolean, Int)]] =
OptionalBoolIntEncoder
+ override def outputEncoder: Encoder[Option[(Boolean, Int)]] =
OptionalBoolIntEncoder
+
+ def OptionalBoolIntEncoder: Encoder[Option[(Boolean, Int)]] =
ExpressionEncoder()
+}
+
class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@@ -393,4 +431,28 @@ class DatasetAggregatorSuite extends QueryTest with
SharedSQLContext {
assert(grouped.schema == df.schema)
checkDataset(grouped.as[OptionBooleanData], OptionBooleanData("bob",
Some(true)))
}
+
+ test("SPARK-24762: Aggregator should be able to use Option of Product
encoder") {
+ val df = Seq(
+ OptionBooleanIntData("bob", Some((true, 1))),
+ OptionBooleanIntData("bob", Some((false, 2))),
+ OptionBooleanIntData("bob", None)).toDF()
+
+ val group = df
+ .groupBy("name")
+ .agg(OptionBooleanIntAggregator("isGood").toColumn.alias("isGood"))
+
+ val expectedSchema = new StructType()
+ .add("name", StringType, nullable = true)
+ .add("isGood",
+ new StructType()
+ .add("_1", BooleanType, nullable = false)
+ .add("_2", IntegerType, nullable = false),
+ nullable = true)
+
+ assert(df.schema == expectedSchema)
+ assert(group.schema == expectedSchema)
+ checkAnswer(group, Row("bob", Row(true, 3)) :: Nil)
+ checkDataset(group.as[OptionBooleanIntData], OptionBooleanIntData("bob",
Some((true, 3))))
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/6339c8c2/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index baece2d..0f90083 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -1312,15 +1312,6 @@ class DatasetSuite extends QueryTest with
SharedSQLContext {
checkDataset(dsString, arrayString)
}
- test("SPARK-18251: the type of Dataset can't be Option of Product type") {
- checkDataset(Seq(Some(1), None).toDS(), Some(1), None)
-
- val e = intercept[UnsupportedOperationException] {
- Seq(Some(1 -> "a"), None).toDS()
- }
- assert(e.getMessage.contains("Cannot create encoder for Option of Product
type"))
- }
-
test ("SPARK-17460: the sizeInBytes in Statistics shouldn't overflow to a
negative number") {
// Since the sizeInBytes in Statistics could exceed the limit of an Int,
we should use BigInt
// instead of Int for avoiding possible overflow.
@@ -1558,6 +1549,74 @@ class DatasetSuite extends QueryTest with
SharedSQLContext {
Seq(Row("Amsterdam")))
}
+ test("SPARK-24762: Enable top-level Option of Product encoders") {
+ val data = Seq(Some((1, "a")), Some((2, "b")), None)
+ val ds = data.toDS()
+
+ checkDataset(
+ ds,
+ data: _*)
+
+ val schema = new StructType().add(
+ "value",
+ new StructType()
+ .add("_1", IntegerType, nullable = false)
+ .add("_2", StringType, nullable = true),
+ nullable = true)
+
+ assert(ds.schema == schema)
+
+ val nestedOptData = Seq(Some((Some((1, "a")), 2.0)), Some((Some((2, "b")),
3.0)))
+ val nestedDs = nestedOptData.toDS()
+
+ checkDataset(
+ nestedDs,
+ nestedOptData: _*)
+
+ val nestedSchema = StructType(Seq(
+ StructField("value", StructType(Seq(
+ StructField("_1", StructType(Seq(
+ StructField("_1", IntegerType, nullable = false),
+ StructField("_2", StringType, nullable = true)))),
+ StructField("_2", DoubleType, nullable = false)
+ )), nullable = true)
+ ))
+ assert(nestedDs.schema == nestedSchema)
+ }
+
+ test("SPARK-24762: Resolving Option[Product] field") {
+ val ds = Seq((1, ("a", 1.0)), (2, ("b", 2.0)), (3, null)).toDS()
+ .as[(Int, Option[(String, Double)])]
+ checkDataset(ds,
+ (1, Some(("a", 1.0))), (2, Some(("b", 2.0))), (3, None))
+ }
+
+ test("SPARK-24762: select Option[Product] field") {
+ val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
+ val ds1 = ds.select(expr("struct(_2, _2 + 1)").as[Option[(Int, Int)]])
+ checkDataset(ds1,
+ Some((1, 2)), Some((2, 3)), Some((3, 4)))
+
+ val ds2 = ds.select(expr("if(_2 > 2, struct(_2, _2 + 1),
null)").as[Option[(Int, Int)]])
+ checkDataset(ds2,
+ None, None, Some((3, 4)))
+ }
+
+ test("SPARK-24762: joinWith on Option[Product]") {
+ val ds1 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("a")
+ val ds2 = Seq(Some((1, 2)), Some((2, 3)), None).toDS().as("b")
+ val joined = ds1.joinWith(ds2, $"a.value._1" === $"b.value._2", "inner")
+ checkDataset(joined, (Some((2, 3)), Some((1, 2))))
+ }
+
+ test("SPARK-24762: typed agg on Option[Product] type") {
+ val ds = Seq(Some((1, 2)), Some((2, 3)), Some((1, 3))).toDS()
+ assert(ds.groupByKey(_.get._1).count().collect() === Seq((1, 2), (2, 1)))
+
+ assert(ds.groupByKey(x => x).count().collect() ===
+ Seq((Some((1, 2)), 1), (Some((2, 3)), 1), (Some((1, 3)), 1)))
+ }
+
test("SPARK-25942: typed aggregation on primitive type") {
val ds = Seq(1, 2, 3).toDS()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]