Repository: spark
Updated Branches:
  refs/heads/master c074c96dc -> 67587d961


[SPARK-18637][SQL] Stateful UDF should be considered as nondeterministic

## What changes were proposed in this pull request?

Make stateful udf as nondeterministic

## How was this patch tested?
Add new test cases with both Stateful and Stateless UDF.
Without the patch, the test cases will throw exception:

1 did not equal 10
ScalaTestFailureLocation: 
org.apache.spark.sql.hive.execution.HiveUDFSuite$$anonfun$21 at 
(HiveUDFSuite.scala:501)
org.scalatest.exceptions.TestFailedException: 1 did not equal 10
        at 
org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:500)
        at 
org.scalatest.FunSuite.newAssertionFailedException(FunSuite.scala:1555)
        ...

Author: Zhan Zhang <zhanzh...@fb.com>

Closes #16068 from zhzhan/state.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/67587d96
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/67587d96
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/67587d96

Branch: refs/heads/master
Commit: 67587d961d5f94a8639c20cb80127c86bf79d5a8
Parents: c074c96
Author: Zhan Zhang <zhanzh...@fb.com>
Authored: Fri Dec 9 16:35:06 2016 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Dec 9 16:35:06 2016 +0800

----------------------------------------------------------------------
 .../org/apache/spark/sql/hive/hiveUDFs.scala    |  4 +-
 .../spark/sql/hive/execution/HiveUDFSuite.scala | 45 +++++++++++++++++++-
 2 files changed, 45 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/67587d96/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 349faae..26dc372 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -61,7 +61,7 @@ private[hive] case class HiveSimpleUDF(
   @transient
   private lazy val isUDFDeterministic = {
     val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
-    udfType != null && udfType.deterministic()
+    udfType != null && udfType.deterministic() && !udfType.stateful()
   }
 
   override def foldable: Boolean = isUDFDeterministic && 
children.forall(_.foldable)
@@ -144,7 +144,7 @@ private[hive] case class HiveGenericUDF(
   @transient
   private lazy val isUDFDeterministic = {
     val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
-    udfType != null && udfType.deterministic()
+    udfType != null && udfType.deterministic() && !udfType.stateful()
   }
 
   @transient

http://git-wip-us.apache.org/repos/asf/spark/blob/67587d96/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index 48adc83..4098bb5 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -21,15 +21,17 @@ import java.io.{DataInput, DataOutput, File, PrintWriter}
 import java.util.{ArrayList, Arrays, Properties}
 
 import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.ql.udf.UDAFPercentile
+import org.apache.hadoop.hive.ql.exec.UDF
+import org.apache.hadoop.hive.ql.udf.{UDAFPercentile, UDFType}
 import org.apache.hadoop.hive.ql.udf.generic._
 import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
 import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats}
 import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, 
ObjectInspectorFactory}
 import 
org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
-import org.apache.hadoop.io.Writable
+import org.apache.hadoop.io.{LongWritable, Writable}
 
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.functions.max
 import org.apache.spark.sql.hive.test.TestHiveSingleton
 import org.apache.spark.sql.test.SQLTestUtils
 import org.apache.spark.util.Utils
@@ -487,6 +489,26 @@ class HiveUDFSuite extends QueryTest with 
TestHiveSingleton with SQLTestUtils {
     assert(count4 == 1)
     sql("DROP TABLE parquet_tmp")
   }
+
+  test("Hive Stateful UDF") {
+    withUserDefinedFunction("statefulUDF" -> true, "statelessUDF" -> true) {
+      sql(s"CREATE TEMPORARY FUNCTION statefulUDF AS 
'${classOf[StatefulUDF].getName}'")
+      sql(s"CREATE TEMPORARY FUNCTION statelessUDF AS 
'${classOf[StatelessUDF].getName}'")
+      val testData = spark.range(10).repartition(1)
+
+      // Expected Max(s) is 10 as statefulUDF returns the sequence number 
starting from 1.
+      checkAnswer(testData.selectExpr("statefulUDF() as s").agg(max($"s")), 
Row(10))
+
+      // Expected Max(s) is 5 as statefulUDF returns the sequence number 
starting from 1,
+      // and the data is evenly distributed into 2 partitions.
+      checkAnswer(testData.repartition(2)
+        .selectExpr("statefulUDF() as s").agg(max($"s")), Row(5))
+
+      // Expected Max(s) is 1, as stateless UDF is deterministic and foldable 
and replaced
+      // by constant 1 by ConstantFolding optimizer.
+      checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), 
Row(1))
+    }
+  }
 }
 
 class TestPair(x: Int, y: Int) extends Writable with Serializable {
@@ -551,3 +573,22 @@ class PairUDF extends GenericUDF {
 
   override def getDisplayString(p1: Array[String]): String = ""
 }
+
+@UDFType(stateful = true)
+class StatefulUDF extends UDF {
+  private val result = new LongWritable(0)
+
+  def evaluate(): LongWritable = {
+    result.set(result.get() + 1)
+    result
+  }
+}
+
+class StatelessUDF extends UDF {
+  private val result = new LongWritable(0)
+
+  def evaluate(): LongWritable = {
+    result.set(result.get() + 1)
+    result
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to