Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/21732#discussion_r212520685
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
---
@@ -19,25 +19,85 @@ package org.apache.spark.sql.execution.aggregate
import scala.language.existentials
-import org.apache.spark.sql.Encoder
+import org.apache.spark.sql.{AnalysisException, Encoder}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
-import org.apache.spark.sql.catalyst.encoders.encoderFor
+import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal,
UnresolvedDeserializer}
+import org.apache.spark.sql.catalyst.encoders.{encoderFor,
ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions._
import
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction,
DeclarativeAggregate, TypedImperativeAggregate}
import
org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
-import org.apache.spark.sql.catalyst.expressions.objects.Invoke
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull,
Invoke, NewInstance, WrapOption}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
object TypedAggregateExpression {
+
+ // Checks if given encoder is for `Option[Product]`.
+ def isOptProductEncoder(encoder: ExpressionEncoder[_]): Boolean = {
+ // Only Option[Product] is non-flat.
+ encoder.clsTag.runtimeClass == classOf[Option[_]] && !encoder.flat
+ }
+
+ /**
+ * Flattens serializers and deserializer of given encoder. We only
flatten encoder
+ * of `Option[Product]` class.
+ */
+ def flattenOptProductEncoder(encoder: ExpressionEncoder[_]):
ExpressionEncoder[_] = {
--- End diff --
I will go to add some tests against this.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]