This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 44d762a [SPARK-35389][SQL] V2 ScalarFunction should support magic
method with null arguments
44d762a is described below
commit 44d762abc6395570f1f493a145fd5d1cbdf0b49e
Author: Chao Sun <[email protected]>
AuthorDate: Tue May 18 08:45:55 2021 +0000
[SPARK-35389][SQL] V2 ScalarFunction should support magic method with null
arguments
### What changes were proposed in this pull request?
When creating `Invoke` and `StaticInvoke` for `ScalarFunction`'s magic
method, set `propagateNull` to false.
### Why are the changes needed?
When `propgagateNull` is true (which is the default value), `Invoke` and
`StaticInvoke` will return null if any of the argument is null. For scalar
function this is incorrect, as we should leave the logic to function
implementation instead.
### Does this PR introduce _any_ user-facing change?
Yes. Now null arguments shall be properly handled with magic method.
### How was this patch tested?
Added new tests.
Closes #32553 from sunchao/SPARK-35389.
Authored-by: Chao Sun <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../catalog/functions/ScalarFunction.java | 19 +++++++++++
.../spark/sql/catalyst/analysis/Analyzer.scala | 5 +--
.../sql/catalyst/expressions/objects/objects.scala | 26 +++++++++++----
.../connector/catalog/functions/JavaStrLen.java | 19 +++++++++++
.../sql/connector/DataSourceV2FunctionSuite.scala | 37 +++++++++++++++++++++-
5 files changed, 96 insertions(+), 10 deletions(-)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
index 858ab92..d261a24 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/functions/ScalarFunction.java
@@ -31,6 +31,7 @@ import org.apache.spark.sql.types.DataType;
* InternalRow API for the {@link DataType SQL data type} returned by {@link
#resultType()}.
* The mapping between {@link DataType} and the corresponding JVM type is
defined below.
* <p>
+ * <h2> Magic method </h2>
* <b>IMPORTANT</b>: the default implementation of {@link #produceResult}
throws
* {@link UnsupportedOperationException}. Users must choose to either override
this method, or
* implement a magic method with name {@link #MAGIC_METHOD_NAME}, which takes
individual parameters
@@ -82,6 +83,24 @@ import org.apache.spark.sql.types.DataType;
* following the mapping defined below, and then checking if there is a
matching method from all the
* declared methods in the UDF class, using method name and the Java types.
* <p>
+ * <h2> Handling of nullable primitive arguments </h2>
+ * The handling of null primitive arguments is different between the magic
method approach and
+ * the {@link #produceResult} approach. With the former, whenever any of the
method arguments meet
+ * the following conditions:
+ * <ol>
+ * <li>the argument is of primitive type</li>
+ * <li>the argument is nullable</li>
+ * <li>the value of the argument is null</li>
+ * </ol>
+ * Spark will return null directly instead of calling the magic method. On the
other hand, Spark
+ * will pass null primitive arguments to {@link #produceResult} and it is
user's responsibility to
+ * handle them in the function implementation.
+ * <p>
+ * Because of the difference, if Spark users want to implement special
handling of nulls for
+ * nullable primitive arguments, they should override the {@link
#produceResult} method instead
+ * of using the magic method approach.
+ * <p>
+ * <h2> Spark data type to Java type mapping </h2>
* The following are the mapping from {@link DataType SQL data type} to Java
type which is used
* by Spark to infer parameter types for the magic methods as well as return
value type for
* {@link #produceResult}:
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 9954ca0..3f2e93a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2204,11 +2204,12 @@ class Analyzer(override val catalogManager:
CatalogManager)
findMethod(scalarFunc, MAGIC_METHOD_NAME, argClasses) match {
case Some(m) if Modifier.isStatic(m.getModifiers) =>
StaticInvoke(scalarFunc.getClass, scalarFunc.resultType(),
- MAGIC_METHOD_NAME, arguments, returnNullable =
scalarFunc.isResultNullable)
+ MAGIC_METHOD_NAME, arguments, propagateNull = false,
+ returnNullable = scalarFunc.isResultNullable)
case Some(_) =>
val caller = Literal.create(scalarFunc,
ObjectType(scalarFunc.getClass))
Invoke(caller, MAGIC_METHOD_NAME, scalarFunc.resultType(),
- arguments, returnNullable = scalarFunc.isResultNullable)
+ arguments, propagateNull = false, returnNullable =
scalarFunc.isResultNullable)
case _ =>
// TODO: handle functions defined in Scala too - in Scala, even if
a
// subclass do not override the default method in parent interface
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index e871c30..c88f785 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -50,7 +50,10 @@ trait InvokeLike extends Expression with NonSQLExpression {
def propagateNull: Boolean
- protected lazy val needNullCheck: Boolean = propagateNull &&
arguments.exists(_.nullable)
+ protected lazy val needNullCheck: Boolean =
needNullCheckForIndex.contains(true)
+ protected lazy val needNullCheckForIndex: Array[Boolean] =
+ arguments.map(a => a.nullable && (propagateNull ||
+ ScalaReflection.dataTypeJavaClass(a.dataType).isPrimitive)).toArray
protected lazy val evaluatedArgs: Array[Object] = new
Array[Object](arguments.length)
private lazy val boxingFn: Any => Any =
ScalaReflection.typeBoxedJavaMapping
@@ -89,7 +92,7 @@ trait InvokeLike extends Expression with NonSQLExpression {
val reset = s"$resultIsNull = false;"
val argCodes = arguments.zipWithIndex.map { case (e, i) =>
val expr = e.genCode(ctx)
- val updateResultIsNull = if (e.nullable) {
+ val updateResultIsNull = if (needNullCheckForIndex(i)) {
s"$resultIsNull = ${expr.isNull};"
} else {
""
@@ -131,11 +134,14 @@ trait InvokeLike extends Expression with NonSQLExpression
{
def invoke(obj: Any, method: Method, input: InternalRow): Any = {
var i = 0
val len = arguments.length
+ var resultNull = false
while (i < len) {
- evaluatedArgs(i) = arguments(i).eval(input).asInstanceOf[Object]
+ val result = arguments(i).eval(input).asInstanceOf[Object]
+ evaluatedArgs(i) = result
+ resultNull = resultNull || (result == null && needNullCheckForIndex(i))
i += 1
}
- if (needNullCheck && evaluatedArgs.contains(null)) {
+ if (needNullCheck && resultNull) {
// return null if one of arguments is null
null
} else {
@@ -226,7 +232,9 @@ object SerializerSupport {
* @param functionName The name of the method to call.
* @param arguments An optional list of expressions to pass as arguments to
the function.
* @param propagateNull When true, and any of the arguments is null, null will
be returned instead
- * of calling the function.
+ * of calling the function. Also note: when this is false
but any of the
+ * arguments is of primitive type and is null, null also
will be returned
+ * without invoking the function.
* @param returnNullable When false, indicating the invoked method will always
return
* non-null value.
*/
@@ -318,7 +326,9 @@ case class StaticInvoke(
* @param arguments An optional list of expressions, whose evaluation will be
passed to the
* function.
* @param propagateNull When true, and any of the arguments is null, null will
be returned instead
- * of calling the function.
+ * of calling the function. Also note: when this is false
but any of the
+ * arguments is of primitive type and is null, null also
will be returned
+ * without invoking the function.
* @param returnNullable When false, indicating the invoked method will always
return
* non-null value.
*/
@@ -452,7 +462,9 @@ object NewInstance {
* @param cls The class to construct.
* @param arguments A list of expression to use as arguments to the
constructor.
* @param propagateNull When true, if any of the arguments is null, then null
will be returned
- * instead of trying to construct the object.
+ * instead of trying to construct the object. Also note:
when this is false
+ * but any of the arguments is of primitive type and is
null, null also will
+ * be returned without constructing the object.
* @param dataType The type of object being constructed, as a Spark SQL
datatype. This allows you
* to manually specify the type when the object in question is
a valid internal
* representation (i.e. ArrayData) instead of an object.
diff --git
a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
index 7cd010b..1b16896 100644
---
a/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
+++
b/sql/core/src/test/java/test/org/apache/spark/sql/connector/catalog/functions/JavaStrLen.java
@@ -120,5 +120,24 @@ public class JavaStrLen implements UnboundFunction {
public static class JavaStrLenNoImpl extends JavaStrLenBase {
}
+
+ // a null-safe version which returns 0 for null arguments
+ public static class JavaStrLenMagicNullSafe extends JavaStrLenBase {
+ public int invoke(UTF8String str) {
+ if (str == null) {
+ return 0;
+ }
+ return str.toString().length();
+ }
+ }
+
+ public static class JavaStrLenStaticMagicNullSafe extends JavaStrLenBase {
+ public static int invoke(UTF8String str) {
+ if (str == null) {
+ return 0;
+ }
+ return str.toString().length();
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
index bd4dfe4..801aee5 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2FunctionSuite.scala
@@ -20,12 +20,14 @@ package org.apache.spark.sql.connector
import java.util
import java.util.Collections
-import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage,
JavaStrLen}
+import test.org.apache.spark.sql.connector.catalog.functions.{JavaAverage,
JavaLongAdd, JavaStrLen}
+import
test.org.apache.spark.sql.connector.catalog.functions.JavaLongAdd.JavaLongAddMagic
import test.org.apache.spark.sql.connector.catalog.functions.JavaStrLen._
import org.apache.spark.SparkException
import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.InternalRow
+import
org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode.{FALLBACK,
NO_CODEGEN}
import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog,
Identifier, InMemoryCatalog, SupportsNamespaces}
import org.apache.spark.sql.connector.catalog.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -213,6 +215,39 @@ class DataSourceV2FunctionSuite extends
DatasourceV2SQLBase {
.getMessage.contains("neither implement magic method nor override
'produceResult'"))
}
+ test("SPARK-35389: magic function should handle null arguments") {
+
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"),
emptyProps)
+ addFunction(Identifier.of(Array("ns"), "strlen"), new JavaStrLen(new
JavaStrLenMagicNullSafe))
+ addFunction(Identifier.of(Array("ns"), "strlen2"),
+ new JavaStrLen(new JavaStrLenStaticMagicNullSafe))
+ Seq("strlen", "strlen2").foreach { name =>
+ checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as STRING))"),
Row(0) :: Nil)
+ }
+ }
+
+ test("SPARK-35389: magic function should handle null primitive arguments") {
+
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"),
emptyProps)
+ addFunction(Identifier.of(Array("ns"), "add"), new JavaLongAdd(new
JavaLongAddMagic(false)))
+ addFunction(Identifier.of(Array("ns"), "static_add"),
+ new JavaLongAdd(new JavaLongAddMagic(false)))
+
+ Seq("add", "static_add").foreach { name =>
+ Seq(true, false).foreach { codegenEnabled =>
+ val codeGenFactoryMode = if (codegenEnabled) FALLBACK else NO_CODEGEN
+
+ withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key ->
codegenEnabled.toString,
+ SQLConf.CODEGEN_FACTORY_MODE.key -> codeGenFactoryMode.toString) {
+
+ checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as BIGINT),
42L)"), Row(null) :: Nil)
+ checkAnswer(sql(s"SELECT testcat.ns.$name(42L, CAST(NULL as
BIGINT))"), Row(null) :: Nil)
+ checkAnswer(sql(s"SELECT testcat.ns.$name(42L, 58L)"), Row(100) ::
Nil)
+ checkAnswer(sql(s"SELECT testcat.ns.$name(CAST(NULL as BIGINT),
CAST(NULL as BIGINT))"),
+ Row(null) :: Nil)
+ }
+ }
+ }
+ }
+
test("bad bound function (neither scalar nor aggregate)") {
catalog("testcat").asInstanceOf[SupportsNamespaces].createNamespace(Array("ns"),
emptyProps)
addFunction(Identifier.of(Array("ns"), "strlen"), StrLen(BadBoundFunction))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]