Repository: spark
Updated Branches:
  refs/heads/branch-2.1 2c88e1dc3 -> 72bf51997


[SPARK-18637][SQL] Stateful UDF should be considered as nondeterministic

Make stateful udf as nondeterministic

Add new test cases with both Stateful and Stateless UDF.
Without the patch, the test cases will throw exception:

1 did not equal 10
ScalaTestFailureLocation: 
org.apache.spark.sql.hive.execution.HiveUDFSuite$$anonfun$21 at 
(HiveUDFSuite.scala:501)
org.scalatest.exceptions.TestFailedException: 1 did not equal 10
        at 
org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:500)
        at 
org.scalatest.FunSuite.newAssertionFailedException(FunSuite.scala:1555)
        ...

Author: Zhan Zhang <[email protected]>

Closes #16068 from zhzhan/state.

(cherry picked from commit 67587d961d5f94a8639c20cb80127c86bf79d5a8)
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/72bf5199
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/72bf5199
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/72bf5199

Branch: refs/heads/branch-2.1
Commit: 72bf5199738c7ab0361b2b55eb4f4299048a21fa
Parents: 2c88e1d
Author: Zhan Zhang <[email protected]>
Authored: Fri Dec 9 16:35:06 2016 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Fri Dec 9 16:36:35 2016 +0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/hive/hiveUDFs.scala    |  4 +-
 .../spark/sql/hive/execution/HiveUDFSuite.scala | 45 +++++++++++++++++++-
 2 files changed, 45 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/72bf5199/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index e30e0f9..37414ad 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -59,7 +59,7 @@ private[hive] case class HiveSimpleUDF(
   @transient
   private lazy val isUDFDeterministic = {
     val udfType = function.getClass().getAnnotation(classOf[HiveUDFType])
-    udfType != null && udfType.deterministic()
+    udfType != null && udfType.deterministic() && !udfType.stateful()
   }
 
   override def foldable: Boolean = isUDFDeterministic && 
children.forall(_.foldable)
@@ -142,7 +142,7 @@ private[hive] case class HiveGenericUDF(
   @transient
   private lazy val isUDFDeterministic = {
     val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
-    udfType != null && udfType.deterministic()
+    udfType != null && udfType.deterministic() && !udfType.stateful()
   }
 
   @transient

http://git-wip-us.apache.org/repos/asf/spark/blob/72bf5199/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index 48adc83..4098bb5 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -21,15 +21,17 @@ import java.io.{DataInput, DataOutput, File, PrintWriter}
 import java.util.{ArrayList, Arrays, Properties}
 
 import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.ql.udf.UDAFPercentile
+import org.apache.hadoop.hive.ql.exec.UDF
+import org.apache.hadoop.hive.ql.udf.{UDAFPercentile, UDFType}
 import org.apache.hadoop.hive.ql.udf.generic._
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
 import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats}
 import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, 
ObjectInspectorFactory}
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
-import org.apache.hadoop.io.Writable
+import org.apache.hadoop.io.{LongWritable, Writable}
 
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.functions.max
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.test.SQLTestUtils
 import org.apache.spark.util.Utils
@@ -487,6 +489,26 @@ class HiveUDFSuite extends QueryTest with 
TestHiveSingleton with SQLTestUtils {
     assert(count4 == 1)
     sql("DROP TABLE parquet_tmp")
   }
+
+  test("Hive Stateful UDF") {
+    withUserDefinedFunction("statefulUDF" -> true, "statelessUDF" -> true) {
+      sql(s"CREATE TEMPORARY FUNCTION statefulUDF AS 
'${classOf[StatefulUDF].getName}'")
+      sql(s"CREATE TEMPORARY FUNCTION statelessUDF AS 
'${classOf[StatelessUDF].getName}'")
+      val testData = spark.range(10).repartition(1)
+
+      // Expected Max(s) is 10 as statefulUDF returns the sequence number 
starting from 1.
+      checkAnswer(testData.selectExpr("statefulUDF() as s").agg(max($"s")), 
Row(10))
+
+      // Expected Max(s) is 5 as statefulUDF returns the sequence number 
starting from 1,
+      // and the data is evenly distributed into 2 partitions.
+      checkAnswer(testData.repartition(2)
+        .selectExpr("statefulUDF() as s").agg(max($"s")), Row(5))
+
+      // Expected Max(s) is 1, as stateless UDF is deterministic and foldable 
and replaced
+      // by constant 1 by ConstantFolding optimizer.
+      checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), 
Row(1))
+    }
+  }
 }
 
 class TestPair(x: Int, y: Int) extends Writable with Serializable {
@@ -551,3 +573,22 @@ class PairUDF extends GenericUDF {
 
   override def getDisplayString(p1: Array[String]): String = ""
 }
+
+@UDFType(stateful = true)
+class StatefulUDF extends UDF {
+  private val result = new LongWritable(0)
+
+  def evaluate(): LongWritable = {
+    result.set(result.get() + 1)
+    result
+  }
+}
+
+class StatelessUDF extends UDF {
+  private val result = new LongWritable(0)
+
+  def evaluate(): LongWritable = {
+    result.set(result.get() + 1)
+    result
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to