Repository: spark
Updated Branches:
refs/heads/master eb386be1e -> ff48b1b33
[SPARK-22901][PYTHON] Add deterministic flag to pyspark UDF
## What changes were proposed in this pull request?
In SPARK-20586 the flag `deterministic` was added to Scala UDF, but it is not
available for python UDF. This flag is useful for cases when the UDF's code can
return different result with the same input. Due to optimization, duplicate
invocations may be eliminated or the function may even be invoked more times
than it is present in the query. This can lead to unexpected behavior.
This PR adds the deterministic flag, via the `asNondeterministic` method, to
let the user mark the function as non-deterministic and therefore avoid the
optimizations which might lead to strange behaviors.
## How was this patch tested?
Manual tests:
```
>>> from pyspark.sql.functions import *
>>> from pyspark.sql.types import *
>>> df_br = spark.createDataFrame([{'name': 'hello'}])
>>> import random
>>> udf_random_col = udf(lambda: int(100*random.random()),
>>> IntegerType()).asNondeterministic()
>>> df_br = df_br.withColumn('RAND', udf_random_col())
>>> random.seed(1234)
>>> udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
>>> df_br.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).show()
+-----+----+-------------+
| name|RAND|RAND_PLUS_TEN|
+-----+----+-------------+
|hello| 3| 13|
+-----+----+-------------+
```
Author: Marco Gaido <[email protected]>
Author: Marco Gaido <[email protected]>
Closes #19929 from mgaido91/SPARK-22629.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ff48b1b3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ff48b1b3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ff48b1b3
Branch: refs/heads/master
Commit: ff48b1b338241039a7189e7a3c04333b1256fdb3
Parents: eb386be
Author: Marco Gaido <[email protected]>
Authored: Tue Dec 26 06:39:40 2017 -0800
Committer: gatorsmile <[email protected]>
Committed: Tue Dec 26 06:39:40 2017 -0800
----------------------------------------------------------------------
.../org/apache/spark/api/python/PythonRunner.scala | 7 +++++++
python/pyspark/sql/functions.py | 11 ++++++++---
python/pyspark/sql/tests.py | 9 +++++++++
python/pyspark/sql/udf.py | 13 ++++++++++++-
.../scala/org/apache/spark/sql/UDFRegistration.scala | 5 +++--
.../apache/spark/sql/execution/python/PythonUDF.scala | 5 ++++-
.../execution/python/UserDefinedPythonFunction.scala | 5 +++--
.../execution/python/BatchEvalPythonExecSuite.scala | 3 ++-
8 files changed, 48 insertions(+), 10 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/ff48b1b3/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 93d508c..1ec0e71 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -39,6 +39,13 @@ private[spark] object PythonEvalType {
val SQL_PANDAS_SCALAR_UDF = 200
val SQL_PANDAS_GROUP_MAP_UDF = 201
+
+ def toString(pythonEvalType: Int): String = pythonEvalType match {
+ case NON_UDF => "NON_UDF"
+ case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
+ case SQL_PANDAS_SCALAR_UDF => "SQL_PANDAS_SCALAR_UDF"
+ case SQL_PANDAS_GROUP_MAP_UDF => "SQL_PANDAS_GROUP_MAP_UDF"
+ }
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/ff48b1b3/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index ddd8df3..66ee033 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -2093,9 +2093,14 @@ class PandasUDFType(object):
def udf(f=None, returnType=StringType()):
"""Creates a user defined function (UDF).
- .. note:: The user-defined functions must be deterministic. Due to
optimization,
- duplicate invocations may be eliminated or the function may even be
invoked more times than
- it is present in the query.
+ .. note:: The user-defined functions are considered deterministic by
default. Due to
+ optimization, duplicate invocations may be eliminated or the function
may even be invoked
+ more times than it is present in the query. If your function is not
deterministic, call
+ `asNondeterministic` on the user defined function. E.g.:
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> import random
+ >>> random_udf = udf(lambda: int(random.random() * 100),
IntegerType()).asNondeterministic()
.. note:: The user-defined functions do not support conditional
expressions or short curcuiting
in boolean expressions and it ends up with being executed all
internally. If the functions
http://git-wip-us.apache.org/repos/asf/spark/blob/ff48b1b3/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index b811a0f..3ef1522 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -435,6 +435,15 @@ class SQLTests(ReusedSQLTestCase):
self.assertEqual(list(range(3)), l1)
self.assertEqual(1, l2)
+ def test_nondeterministic_udf(self):
+ from pyspark.sql.functions import udf
+ import random
+ udf_random_col = udf(lambda: int(100 * random.random()),
IntegerType()).asNondeterministic()
+ df =
self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND'))
+ udf_add_ten = udf(lambda rand: rand + 10, IntegerType())
+ [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect()
+ self.assertEqual(row[0] + 10, row[1])
+
def test_broadcast_in_udf(self):
bar = {"a": "aa", "b": "bb", "c": "abc"}
foo = self.sc.broadcast(bar)
http://git-wip-us.apache.org/repos/asf/spark/blob/ff48b1b3/python/pyspark/sql/udf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index 1231381..54b5a865 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -92,6 +92,7 @@ class UserDefinedFunction(object):
func.__name__ if hasattr(func, '__name__')
else func.__class__.__name__)
self.evalType = evalType
+ self._deterministic = True
@property
def returnType(self):
@@ -129,7 +130,7 @@ class UserDefinedFunction(object):
wrapped_func = _wrap_function(sc, self.func, self.returnType)
jdt = spark._jsparkSession.parseDataType(self.returnType.json())
judf =
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
- self._name, wrapped_func, jdt, self.evalType)
+ self._name, wrapped_func, jdt, self.evalType, self._deterministic)
return judf
def __call__(self, *cols):
@@ -161,5 +162,15 @@ class UserDefinedFunction(object):
wrapper.func = self.func
wrapper.returnType = self.returnType
wrapper.evalType = self.evalType
+ wrapper.asNondeterministic = self.asNondeterministic
return wrapper
+
+ def asNondeterministic(self):
+ """
+ Updates UserDefinedFunction to nondeterministic.
+
+ .. versionadded:: 2.3
+ """
+ self._deterministic = False
+ return self
http://git-wip-us.apache.org/repos/asf/spark/blob/ff48b1b3/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index 3ff4761..dc2468a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
import org.apache.spark.annotation.InterfaceStability
+import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.internal.Logging
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
@@ -41,8 +42,6 @@ import org.apache.spark.util.Utils
* spark.udf
* }}}
*
- * @note The user-defined functions must be deterministic.
- *
* @since 1.3.0
*/
@InterfaceStability.Stable
@@ -58,6 +57,8 @@ class UDFRegistration private[sql] (functionRegistry:
FunctionRegistry) extends
| pythonIncludes: ${udf.func.pythonIncludes}
| pythonExec: ${udf.func.pythonExec}
| dataType: ${udf.dataType}
+ | pythonEvalType: ${PythonEvalType.toString(udf.pythonEvalType)}
+ | udfDeterministic: ${udf.udfDeterministic}
""".stripMargin)
functionRegistry.createOrReplaceTempFunction(name, udf.builder)
http://git-wip-us.apache.org/repos/asf/spark/blob/ff48b1b3/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index ef27fbc..d3f743d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -29,9 +29,12 @@ case class PythonUDF(
func: PythonFunction,
dataType: DataType,
children: Seq[Expression],
- evalType: Int)
+ evalType: Int,
+ udfDeterministic: Boolean)
extends Expression with Unevaluable with NonSQLExpression with
UserDefinedExpression {
+ override lazy val deterministic: Boolean = udfDeterministic &&
children.forall(_.deterministic)
+
override def toString: String = s"$name(${children.mkString(", ")})"
override def nullable: Boolean = true
http://git-wip-us.apache.org/repos/asf/spark/blob/ff48b1b3/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
index 348e49e..50dca32 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala
@@ -29,10 +29,11 @@ case class UserDefinedPythonFunction(
name: String,
func: PythonFunction,
dataType: DataType,
- pythonEvalType: Int) {
+ pythonEvalType: Int,
+ udfDeterministic: Boolean) {
def builder(e: Seq[Expression]): PythonUDF = {
- PythonUDF(name, func, dataType, e, pythonEvalType)
+ PythonUDF(name, func, dataType, e, pythonEvalType, udfDeterministic)
}
/** Returns a [[Column]] that will evaluate to calling this UDF with the
given input. */
http://git-wip-us.apache.org/repos/asf/spark/blob/ff48b1b3/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
index 53d3f34..9e4a2e8 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -109,4 +109,5 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction(
name = "dummyUDF",
func = new DummyUDF,
dataType = BooleanType,
- pythonEvalType = PythonEvalType.SQL_BATCHED_UDF)
+ pythonEvalType = PythonEvalType.SQL_BATCHED_UDF,
+ udfDeterministic = true)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]