Github user viirya commented on a diff in the pull request:
https://github.com/apache/spark/pull/22944#discussion_r231001655
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
---
@@ -262,25 +262,39 @@ object AppendColumns {
def apply[T : Encoder, U : Encoder](
func: T => U,
child: LogicalPlan): AppendColumns = {
+ val outputEncoder = encoderFor[U]
+ val namedExpressions = if (!outputEncoder.isSerializedAsStruct) {
+ assert(outputEncoder.namedExpressions.length == 1)
+ outputEncoder.namedExpressions.map(Alias(_, "key")())
+ } else {
+ outputEncoder.namedExpressions
+ }
new AppendColumns(
func.asInstanceOf[Any => Any],
implicitly[Encoder[T]].clsTag.runtimeClass,
implicitly[Encoder[T]].schema,
UnresolvedDeserializer(encoderFor[T].deserializer),
- encoderFor[U].namedExpressions,
+ namedExpressions,
child)
}
def apply[T : Encoder, U : Encoder](
func: T => U,
inputAttributes: Seq[Attribute],
child: LogicalPlan): AppendColumns = {
+ val outputEncoder = encoderFor[U]
+ val namedExpressions = if (!outputEncoder.isSerializedAsStruct) {
+ assert(outputEncoder.namedExpressions.length == 1)
+ outputEncoder.namedExpressions.map(Alias(_, "key")())
+ } else {
+ outputEncoder.namedExpressions
--- End diff --
For primitive type and product type, looks like it works:
```scala
test("typed aggregation on primitive data") {
val ds = Seq(1, 2, 3).toDS()
val agg = ds.select(expr("value").as("data").as[Int])
.groupByKey(_ >= 2)
.agg(sum("data").as[Long], sum($"data" + 1).as[Long])
agg.show()
}
```
```
+-----+---------+---------------+
|value|sum(data)|sum((data + 1))|
+-----+---------+---------------+
|false| 1| 2|
| true| 5| 7|
+-----+---------+---------------+
```
```scala
test("typed aggregation on product data") {
val ds = Seq((1, 2), (2, 3), (3, 4)).toDS()
val agg = ds.select(expr("_1").as("a").as[Int],
expr("_2").as("b").as[Int])
.groupByKey(_._1).agg(sum("a").as[Int], sum($"b" + 1).as[Int])
agg.show
}
```
```
[info] - typed aggregation on primitive data (192 milliseconds)
+-----+------+------------+
|value|sum(a)|sum((b + 1))|
+-----+------+------------+
| 3| 3| 5|
| 1| 1| 3|
| 2| 2| 4|
+-----+------+------------+
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]