[
https://issues.apache.org/jira/browse/SPARK-26353?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel=16720997#comment-16720997
]
ASF GitHub Bot commented on SPARK-26353:
10110346 closed pull request #23304: [SPARK-26353][SQL]Add typed aggregate
functions: max&
URL: https://github.com/apache/spark/pull/23304
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
index b6550bf3e4aac..2d08ea3fce6fb 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/typedaggregators.scala
@@ -99,3 +99,71 @@ class TypedAverage[IN](val f: IN => Double) extends
Aggregator[IN, (Double, Long
toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
}
}
+
+class TypedMaxDouble[IN](val f: IN => Double) extends Aggregator[IN, Double,
Double] {
+ override def zero: Double = Double.MinValue
+ override def reduce(b: Double, a: IN): Double = if (b > f(a)) b else f(a)
+ override def merge(b1: Double, b2: Double): Double = if (b1 > b2) b1 else b2
+ override def finish(reduction: Double): Double = reduction
+
+ override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+ override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+
+ // Java api support
+ def this(f: MapFunction[IN, java.lang.Double]) = this((x: IN) =>
f.call(x).asInstanceOf[Double])
+
+ def toColumnJava: TypedColumn[IN, java.lang.Double] = {
+toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
+ }
+}
+
+class TypedMaxLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] {
+ override def zero: Long = Long.MinValue
+ override def reduce(b: Long, a: IN): Long = if (b > f(a)) b else f(a)
+ override def merge(b1: Long, b2: Long): Long = if (b1 > b2) b1 else b2
+ override def finish(reduction: Long): Long = reduction
+
+ override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+ override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+
+ // Java api support
+ def this(f: MapFunction[IN, java.lang.Long]) = this((x: IN) =>
f.call(x).asInstanceOf[Long])
+
+ def toColumnJava: TypedColumn[IN, java.lang.Long] = {
+toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
+ }
+}
+
+class TypedMinDouble[IN](val f: IN => Double) extends Aggregator[IN, Double,
Double] {
+ override def zero: Double = Double.MaxValue
+ override def reduce(b: Double, a: IN): Double = if (b < f(a)) b else f(a)
+ override def merge(b1: Double, b2: Double): Double = if (b1 < b2) b1 else b2
+ override def finish(reduction: Double): Double = reduction
+
+ override def bufferEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+ override def outputEncoder: Encoder[Double] = ExpressionEncoder[Double]()
+
+ // Java api support
+ def this(f: MapFunction[IN, java.lang.Double]) = this((x: IN) =>
f.call(x).asInstanceOf[Double])
+
+ def toColumnJava: TypedColumn[IN, java.lang.Double] = {
+toColumn.asInstanceOf[TypedColumn[IN, java.lang.Double]]
+ }
+}
+
+class TypedMinLong[IN](val f: IN => Long) extends Aggregator[IN, Long, Long] {
+ override def zero: Long = Long.MaxValue
+ override def reduce(b: Long, a: IN): Long = if (b < f(a)) b else f(a)
+ override def merge(b1: Long, b2: Long): Long = if (b1 < b2) b1 else b2
+ override def finish(reduction: Long): Long = reduction
+
+ override def bufferEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+ override def outputEncoder: Encoder[Long] = ExpressionEncoder[Long]()
+
+ // Java api support
+ def this(f: MapFunction[IN, java.lang.Long]) = this((x: IN) =>
f.call(x).asInstanceOf[Long])
+
+ def toColumnJava: TypedColumn[IN, java.lang.Long] = {
+toColumn.asInstanceOf[TypedColumn[IN, java.lang.Long]]
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala
index 1cb579c4faa76..6a8336e01d6f6 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/scalalang/typed.scala
@@ -77,14 +77,31 @@ object typed {
*/
def sumLong[IN](f: IN => Long): TypedColumn[IN, Long] = new
TypedSumLong[IN](f).toColumn
+ /**
+ * Max aggregate function for floating point (double) type.
+ */
+ def max[IN](f: IN => Double): TypedColumn[IN, Double] = new
TypedMaxDouble[IN](f).toColumn
+
+ /**
+ * Max aggregate function for