Repository: spark
Updated Branches:
  refs/heads/master fd990a908 -> 14844a62c


[SPARK-23918][SQL] Add array_min function

## What changes were proposed in this pull request?

The PR adds the SQL function `array_min`. It takes an array as argument and 
returns the minimum value in it.

## How was this patch tested?

added UTs

Author: Marco Gaido <marcogaid...@gmail.com>

Closes #21025 from mgaido91/SPARK-23918.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/14844a62
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/14844a62
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/14844a62

Branch: refs/heads/master
Commit: 14844a62c025e7299029d7452b8c4003bc221ac8
Parents: fd990a9
Author: Marco Gaido <marcogaid...@gmail.com>
Authored: Tue Apr 17 17:55:35 2018 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Tue Apr 17 17:55:35 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 | 17 +++++-
 .../catalyst/analysis/FunctionRegistry.scala    |  1 +
 .../sql/catalyst/expressions/arithmetic.scala   |  6 +-
 .../expressions/codegen/CodeGenerator.scala     | 17 ++++++
 .../expressions/collectionOperations.scala      | 64 ++++++++++++++++++++
 .../CollectionExpressionsSuite.scala            | 10 +++
 .../scala/org/apache/spark/sql/functions.scala  |  8 +++
 .../spark/sql/DataFrameFunctionsSuite.scala     | 14 +++++
 8 files changed, 131 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/14844a62/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index f3492ae..6ca22b6 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2081,6 +2081,21 @@ def size(col):
 
 
 @since(2.4)
+def array_min(col):
+    """
+    Collection function: returns the minimum value of the array.
+
+    :param col: name of column or expression
+
+    >>> df = spark.createDataFrame([([2, 1, 3],), ([None, 10, -1],)], ['data'])
+    >>> df.select(array_min(df.data).alias('min')).collect()
+    [Row(min=1), Row(min=-1)]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.array_min(_to_java_column(col)))
+
+
+@since(2.4)
 def array_max(col):
     """
     Collection function: returns the maximum value of the array.
@@ -2108,7 +2123,7 @@ def sort_array(col, asc=True):
     [Row(r=[1, 2, 3]), Row(r=[1]), Row(r=[])]
     >>> df.select(sort_array(df.data, asc=False).alias('r')).collect()
     [Row(r=[3, 2, 1]), Row(r=[1]), Row(r=[])]
-     """
+    """
     sc = SparkContext._active_spark_context
     return Column(sc._jvm.functions.sort_array(_to_java_column(col), asc))
 

http://git-wip-us.apache.org/repos/asf/spark/blob/14844a62/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 05bfa2d..4dd1ca5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -409,6 +409,7 @@ object FunctionRegistry {
     expression[MapValues]("map_values"),
     expression[Size]("size"),
     expression[SortArray]("sort_array"),
+    expression[ArrayMin]("array_min"),
     expression[ArrayMax]("array_max"),
     CreateStruct.registryEntry,
 

http://git-wip-us.apache.org/repos/asf/spark/blob/14844a62/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 942dfd4..d4e322d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -595,11 +595,7 @@ case class Least(children: Seq[Expression]) extends 
Expression {
     val evals = evalChildren.map(eval =>
       s"""
          |${eval.code}
-         |if (!${eval.isNull} && (${ev.isNull} ||
-         |  ${ctx.genGreater(dataType, ev.value, eval.value)})) {
-         |  ${ev.isNull} = false;
-         |  ${ev.value} = ${eval.value};
-         |}
+         |${ctx.reassignIfSmaller(dataType, ev, eval)}
       """.stripMargin
     )
 

http://git-wip-us.apache.org/repos/asf/spark/blob/14844a62/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index c86c5be..d97611c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -700,6 +700,23 @@ class CodegenContext {
   }
 
   /**
+   * Generates code for updating `partialResult` if `item` is smaller than it.
+   *
+   * @param dataType data type of the expressions
+   * @param partialResult `ExprCode` representing the partial result which has 
to be updated
+   * @param item `ExprCode` representing the new expression to evaluate for 
the result
+   */
+  def reassignIfSmaller(dataType: DataType, partialResult: ExprCode, item: 
ExprCode): String = {
+    s"""
+       |if (!${item.isNull} && (${partialResult.isNull} ||
+       |  ${genGreater(dataType, partialResult.value, item.value)})) {
+       |  ${partialResult.isNull} = false;
+       |  ${partialResult.value} = ${item.value};
+       |}
+      """.stripMargin
+  }
+
+  /**
    * Generates code for updating `partialResult` if `item` is greater than it.
    *
    * @param dataType data type of the expressions

http://git-wip-us.apache.org/repos/asf/spark/blob/14844a62/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index e2614a1..7c87777 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -288,6 +288,70 @@ case class ArrayContains(left: Expression, right: 
Expression)
   override def prettyName: String = "array_contains"
 }
 
+/**
+ * Returns the minimum value in the array.
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(array) - Returns the minimum value in the array. NULL 
elements are skipped.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 20, null, 3));
+       1
+  """, since = "2.4.0")
+case class ArrayMin(child: Expression) extends UnaryExpression with 
ImplicitCastInputTypes {
+
+  override def nullable: Boolean = true
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
+
+  private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val typeCheckResult = super.checkInputDataTypes()
+    if (typeCheckResult.isSuccess) {
+      TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
+    } else {
+      typeCheckResult
+    }
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
+    val childGen = child.genCode(ctx)
+    val javaType = CodeGenerator.javaType(dataType)
+    val i = ctx.freshName("i")
+    val item = ExprCode("",
+      isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
+      value = JavaCode.expression(CodeGenerator.getValue(childGen.value, 
dataType, i), dataType))
+    ev.copy(code =
+      s"""
+         |${childGen.code}
+         |boolean ${ev.isNull} = true;
+         |$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
+         |if (!${childGen.isNull}) {
+         |  for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
+         |    ${ctx.reassignIfSmaller(dataType, ev, item)}
+         |  }
+         |}
+      """.stripMargin)
+  }
+
+  override protected def nullSafeEval(input: Any): Any = {
+    var min: Any = null
+    input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
+      if (item != null && (min == null || ordering.lt(item, min))) {
+        min = item
+      }
+    )
+    min
+  }
+
+  override def dataType: DataType = child.dataType match {
+    case ArrayType(dt, _) => dt
+    case _ => throw new IllegalStateException(s"$prettyName accepts only 
arrays.")
+  }
+
+  override def prettyName: String = "array_min"
+}
 
 /**
  * Returns the maximum value in the array.

http://git-wip-us.apache.org/repos/asf/spark/blob/14844a62/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index a238401..5a31e3a 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -106,6 +106,16 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
   }
 
+  test("Array Min") {
+    checkEvaluation(ArrayMin(Literal.create(Seq(-11, 10, 2), 
ArrayType(IntegerType))), -11)
+    checkEvaluation(
+      ArrayMin(Literal.create(Seq[String](null, "abc", ""), 
ArrayType(StringType))), "")
+    checkEvaluation(ArrayMin(Literal.create(Seq(null), ArrayType(LongType))), 
null)
+    checkEvaluation(ArrayMin(Literal.create(null, ArrayType(StringType))), 
null)
+    checkEvaluation(
+      ArrayMin(Literal.create(Seq(1.123, 0.1234, 1.121), 
ArrayType(DoubleType))), 0.1234)
+  }
+
   test("Array max") {
     checkEvaluation(ArrayMax(Literal.create(Seq(1, 10, 2), 
ArrayType(IntegerType))), 10)
     checkEvaluation(

http://git-wip-us.apache.org/repos/asf/spark/blob/14844a62/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index daf4079..642ac05 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3301,6 +3301,14 @@ object functions {
   def sort_array(e: Column, asc: Boolean): Column = withExpr { 
SortArray(e.expr, lit(asc).expr) }
 
   /**
+   * Returns the minimum value in the array.
+   *
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def array_min(e: Column): Column = withExpr { ArrayMin(e.expr) }
+
+  /**
    * Returns the maximum value in the array.
    *
    * @group collection_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/14844a62/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 5d5d92c..636e86b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -413,6 +413,20 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     )
   }
 
+  test("array_min function") {
+    val df = Seq(
+      Seq[Option[Int]](Some(1), Some(3), Some(2)),
+      Seq.empty[Option[Int]],
+      Seq[Option[Int]](None),
+      Seq[Option[Int]](None, Some(1), Some(-100))
+    ).toDF("a")
+
+    val answer = Seq(Row(1), Row(null), Row(null), Row(-100))
+
+    checkAnswer(df.select(array_min(df("a"))), answer)
+    checkAnswer(df.selectExpr("array_min(a)"), answer)
+  }
+
   test("array_max function") {
     val df = Seq(
       Seq[Option[Int]](Some(1), Some(3), Some(2)),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to