This is an automated email from the ASF dual-hosted git repository.

maxgekk pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 656ece1adb0f [SPARK-50081][SQL] Codegen Support for `XPath*`(by Invoke 
& RuntimeReplaceable)
656ece1adb0f is described below

commit 656ece1adb0fb6b4aea108ed92e5939e7b2dc7e9
Author: panbingkun <[email protected]>
AuthorDate: Sat Nov 23 11:59:38 2024 +0100

    [SPARK-50081][SQL] Codegen Support for `XPath*`(by Invoke & 
RuntimeReplaceable)
    
    ### What changes were proposed in this pull request?
    The pr aims to add `Codegen` Support for `xpath*`, include:
    - `xpath_boolean`
    - `xpath_short`
    - `xpath_int`
    - `xpath_long`
    - `xpath_float`
    - `xpath_double`
    - `xpath_string`
    - `xpath`
    
    ### Why are the changes needed?
    - improve codegen coverage.
    - simplified code.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Pass GA & Existed UT (eg: `XPathFunctionsSuite`, `XPathExpressionSuite`, 
`CollationSQLExpressionsSuite`#`*XPath*`, `CollationExpressionWalkerSuite`)
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #48610 from panbingkun/xpath_codegen.
    
    Lead-authored-by: panbingkun <[email protected]>
    Co-authored-by: panbingkun <[email protected]>
    Signed-off-by: Max Gekk <[email protected]>
---
 .../expressions/xml/XmlExpressionEvalUtils.scala   | 82 ++++++++++++++++++-
 .../spark/sql/catalyst/expressions/xml/xpath.scala | 92 +++++++++-------------
 .../explain-results/function_xpath.explain         |  2 +-
 .../explain-results/function_xpath_boolean.explain |  2 +-
 .../explain-results/function_xpath_double.explain  |  2 +-
 .../explain-results/function_xpath_float.explain   |  2 +-
 .../explain-results/function_xpath_int.explain     |  2 +-
 .../explain-results/function_xpath_long.explain    |  2 +-
 .../explain-results/function_xpath_number.explain  |  2 +-
 .../explain-results/function_xpath_short.explain   |  2 +-
 .../explain-results/function_xpath_string.explain  |  2 +-
 .../org/apache/spark/sql/XPathFunctionsSuite.scala | 36 +++++++++
 12 files changed, 162 insertions(+), 66 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala
index dff88475327a..44b98026d62d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XmlExpressionEvalUtils.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.sql.catalyst.expressions.xml
 
+import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.catalyst.xml.XmlInferSchema
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{ArrayType, DataType, StructType}
+import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
 
 object XmlExpressionEvalUtils {
@@ -40,3 +41,82 @@ object XmlExpressionEvalUtils {
     UTF8String.fromString(dataType.sql)
   }
 }
+
+trait XPathEvaluator {
+
+  protected val path: UTF8String
+
+  @transient protected lazy val xpathUtil: UDFXPathUtil = new UDFXPathUtil
+
+  final def evaluate(xml: UTF8String): Any = {
+    if (xml == null || xml.toString.isEmpty || path == null || 
path.toString.isEmpty) return null
+    doEvaluate(xml)
+  }
+
+  def doEvaluate(xml: UTF8String): Any
+}
+
+case class XPathBooleanEvaluator(path: UTF8String) extends XPathEvaluator {
+  override def doEvaluate(xml: UTF8String): Any = {
+    xpathUtil.evalBoolean(xml.toString, path.toString)
+  }
+}
+
+case class XPathShortEvaluator(path: UTF8String) extends XPathEvaluator {
+  override def doEvaluate(xml: UTF8String): Any = {
+    val ret = xpathUtil.evalNumber(xml.toString, path.toString)
+    if (ret eq null) null.asInstanceOf[Short] else ret.shortValue()
+  }
+}
+
+case class XPathIntEvaluator(path: UTF8String) extends XPathEvaluator {
+  override def doEvaluate(xml: UTF8String): Any = {
+    val ret = xpathUtil.evalNumber(xml.toString, path.toString)
+    if (ret eq null) null.asInstanceOf[Int] else ret.intValue()
+  }
+}
+
+case class XPathLongEvaluator(path: UTF8String) extends XPathEvaluator {
+  override def doEvaluate(xml: UTF8String): Any = {
+    val ret = xpathUtil.evalNumber(xml.toString, path.toString)
+    if (ret eq null) null.asInstanceOf[Long] else ret.longValue()
+  }
+}
+
+case class XPathFloatEvaluator(path: UTF8String) extends XPathEvaluator {
+  override def doEvaluate(xml: UTF8String): Any = {
+    val ret = xpathUtil.evalNumber(xml.toString, path.toString)
+    if (ret eq null) null.asInstanceOf[Float] else ret.floatValue()
+  }
+}
+
+case class XPathDoubleEvaluator(path: UTF8String) extends XPathEvaluator {
+  override def doEvaluate(xml: UTF8String): Any = {
+    val ret = xpathUtil.evalNumber(xml.toString, path.toString)
+    if (ret eq null) null.asInstanceOf[Double] else ret.doubleValue()
+  }
+}
+
+case class XPathStringEvaluator(path: UTF8String) extends XPathEvaluator {
+  override def doEvaluate(xml: UTF8String): Any = {
+    val ret = xpathUtil.evalString(xml.toString, path.toString)
+    UTF8String.fromString(ret)
+  }
+}
+
+case class XPathListEvaluator(path: UTF8String) extends XPathEvaluator {
+  override def doEvaluate(xml: UTF8String): Any = {
+    val nodeList = xpathUtil.evalNodeList(xml.toString, path.toString)
+    if (nodeList ne null) {
+      val ret = new Array[AnyRef](nodeList.getLength)
+      var i = 0
+      while (i < nodeList.getLength) {
+        ret(i) = UTF8String.fromString(nodeList.item(i).getNodeValue)
+        i += 1
+      }
+      new GenericArrayData(ret)
+    } else {
+      null
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
index 9848e062a08f..2c18ffa2abec 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/xpath.scala
@@ -21,8 +21,7 @@ import 
org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.Cast._
-import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.expressions.objects.Invoke
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.internal.types.StringTypeWithCollation
 import org.apache.spark.sql.types._
@@ -34,10 +33,9 @@ import org.apache.spark.unsafe.types.UTF8String
  * This is not the world's most efficient implementation due to type 
conversion, but works.
  */
 abstract class XPathExtract
-  extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
+  extends BinaryExpression with RuntimeReplaceable with ExpectsInputTypes {
   override def left: Expression = xml
   override def right: Expression = path
-  override def nullIntolerant: Boolean = true
 
   /** XPath expressions are always nullable, e.g. if the xml string is empty. 
*/
   override def nullable: Boolean = true
@@ -60,12 +58,20 @@ abstract class XPathExtract
     }
   }
 
-  @transient protected lazy val xpathUtil = new UDFXPathUtil
-  @transient protected lazy val pathString: String = 
path.eval().asInstanceOf[UTF8String].toString
-
   /** Concrete implementations need to override the following three methods. */
   def xml: Expression
   def path: Expression
+
+  @transient protected lazy val pathUTF8String: UTF8String = 
path.eval().asInstanceOf[UTF8String]
+
+  protected def evaluator: XPathEvaluator
+
+  override def replacement: Expression = Invoke(
+    Literal.create(evaluator, ObjectType(classOf[XPathEvaluator])),
+    "evaluate",
+    dataType,
+    Seq(xml),
+    Seq(xml.dataType))
 }
 
 // scalastyle:off line.size.limit
@@ -81,11 +87,9 @@ abstract class XPathExtract
 // scalastyle:on line.size.limit
 case class XPathBoolean(xml: Expression, path: Expression) extends 
XPathExtract with Predicate {
 
-  override def prettyName: String = "xpath_boolean"
+  @transient override lazy val evaluator: XPathEvaluator = 
XPathBooleanEvaluator(pathUTF8String)
 
-  override def nullSafeEval(xml: Any, path: Any): Any = {
-    xpathUtil.evalBoolean(xml.asInstanceOf[UTF8String].toString, pathString)
-  }
+  override def prettyName: String = "xpath_boolean"
 
   override protected def withNewChildrenInternal(
     newLeft: Expression, newRight: Expression): XPathBoolean = copy(xml = 
newLeft, path = newRight)
@@ -103,14 +107,12 @@ case class XPathBoolean(xml: Expression, path: 
Expression) extends XPathExtract
   group = "xml_funcs")
 // scalastyle:on line.size.limit
 case class XPathShort(xml: Expression, path: Expression) extends XPathExtract {
+
+  @transient override lazy val evaluator: XPathEvaluator = 
XPathShortEvaluator(pathUTF8String)
+
   override def prettyName: String = "xpath_short"
   override def dataType: DataType = ShortType
 
-  override def nullSafeEval(xml: Any, path: Any): Any = {
-    val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, 
pathString)
-    if (ret eq null) null else ret.shortValue()
-  }
-
   override protected def withNewChildrenInternal(
     newLeft: Expression, newRight: Expression): XPathShort = copy(xml = 
newLeft, path = newRight)
 }
@@ -127,14 +129,12 @@ case class XPathShort(xml: Expression, path: Expression) 
extends XPathExtract {
   group = "xml_funcs")
 // scalastyle:on line.size.limit
 case class XPathInt(xml: Expression, path: Expression) extends XPathExtract {
+
+  @transient override lazy val evaluator: XPathEvaluator = 
XPathIntEvaluator(pathUTF8String)
+
   override def prettyName: String = "xpath_int"
   override def dataType: DataType = IntegerType
 
-  override def nullSafeEval(xml: Any, path: Any): Any = {
-    val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, 
pathString)
-    if (ret eq null) null else ret.intValue()
-  }
-
   override protected def withNewChildrenInternal(
     newLeft: Expression, newRight: Expression): Expression = copy(xml = 
newLeft, path = newRight)
 }
@@ -151,14 +151,12 @@ case class XPathInt(xml: Expression, path: Expression) 
extends XPathExtract {
   group = "xml_funcs")
 // scalastyle:on line.size.limit
 case class XPathLong(xml: Expression, path: Expression) extends XPathExtract {
+
+  @transient override lazy val evaluator: XPathEvaluator = 
XPathLongEvaluator(pathUTF8String)
+
   override def prettyName: String = "xpath_long"
   override def dataType: DataType = LongType
 
-  override def nullSafeEval(xml: Any, path: Any): Any = {
-    val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, 
pathString)
-    if (ret eq null) null else ret.longValue()
-  }
-
   override protected def withNewChildrenInternal(
     newLeft: Expression, newRight: Expression): XPathLong = copy(xml = 
newLeft, path = newRight)
 }
@@ -175,14 +173,12 @@ case class XPathLong(xml: Expression, path: Expression) 
extends XPathExtract {
   group = "xml_funcs")
 // scalastyle:on line.size.limit
 case class XPathFloat(xml: Expression, path: Expression) extends XPathExtract {
+
+  @transient override lazy val evaluator: XPathEvaluator = 
XPathFloatEvaluator(pathUTF8String)
+
   override def prettyName: String = "xpath_float"
   override def dataType: DataType = FloatType
 
-  override def nullSafeEval(xml: Any, path: Any): Any = {
-    val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, 
pathString)
-    if (ret eq null) null else ret.floatValue()
-  }
-
   override protected def withNewChildrenInternal(
     newLeft: Expression, newRight: Expression): XPathFloat = copy(xml = 
newLeft, path = newRight)
 }
@@ -199,15 +195,13 @@ case class XPathFloat(xml: Expression, path: Expression) 
extends XPathExtract {
   group = "xml_funcs")
 // scalastyle:on line.size.limit
 case class XPathDouble(xml: Expression, path: Expression) extends XPathExtract 
{
+
+  @transient override lazy val evaluator: XPathEvaluator = 
XPathDoubleEvaluator(pathUTF8String)
+
   override def prettyName: String =
     getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("xpath_double")
   override def dataType: DataType = DoubleType
 
-  override def nullSafeEval(xml: Any, path: Any): Any = {
-    val ret = xpathUtil.evalNumber(xml.asInstanceOf[UTF8String].toString, 
pathString)
-    if (ret eq null) null else ret.doubleValue()
-  }
-
   override protected def withNewChildrenInternal(
     newLeft: Expression, newRight: Expression): XPathDouble = copy(xml = 
newLeft, path = newRight)
 }
@@ -224,14 +218,12 @@ case class XPathDouble(xml: Expression, path: Expression) 
extends XPathExtract {
   group = "xml_funcs")
 // scalastyle:on line.size.limit
 case class XPathString(xml: Expression, path: Expression) extends XPathExtract 
{
+
+  @transient override lazy val evaluator: XPathEvaluator = 
XPathStringEvaluator(pathUTF8String)
+
   override def prettyName: String = "xpath_string"
   override def dataType: DataType = SQLConf.get.defaultStringType
 
-  override def nullSafeEval(xml: Any, path: Any): Any = {
-    val ret = xpathUtil.evalString(xml.asInstanceOf[UTF8String].toString, 
pathString)
-    UTF8String.fromString(ret)
-  }
-
   override protected def withNewChildrenInternal(
     newLeft: Expression, newRight: Expression): Expression = copy(xml = 
newLeft, path = newRight)
 }
@@ -250,24 +242,12 @@ case class XPathString(xml: Expression, path: Expression) 
extends XPathExtract {
   group = "xml_funcs")
 // scalastyle:on line.size.limit
 case class XPathList(xml: Expression, path: Expression) extends XPathExtract {
+
+  @transient override lazy val evaluator: XPathEvaluator = 
XPathListEvaluator(pathUTF8String)
+
   override def prettyName: String = "xpath"
   override def dataType: DataType = ArrayType(SQLConf.get.defaultStringType)
 
-  override def nullSafeEval(xml: Any, path: Any): Any = {
-    val nodeList = 
xpathUtil.evalNodeList(xml.asInstanceOf[UTF8String].toString, pathString)
-    if (nodeList ne null) {
-      val ret = new Array[AnyRef](nodeList.getLength)
-      var i = 0
-      while (i < nodeList.getLength) {
-        ret(i) = UTF8String.fromString(nodeList.item(i).getNodeValue)
-        i += 1
-      }
-      new GenericArrayData(ret)
-    } else {
-      null
-    }
-  }
-
   override protected def withNewChildrenInternal(
     newLeft: Expression, newRight: Expression): XPathList = copy(xml = 
newLeft, path = newRight)
 }
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath.explain
index d9e2e55d9b12..4752e5218bb1 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath.explain
@@ -1,2 +1,2 @@
-Project [xpath(s#0, a/b/text()) AS xpath(s, a/b/text())#0]
+Project [invoke(XPathListEvaluator(a/b/text()).evaluate(s#0)) AS xpath(s, 
a/b/text())#0]
 +- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_boolean.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_boolean.explain
index 9b75f8180246..b537366736d2 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_boolean.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_boolean.explain
@@ -1,2 +1,2 @@
-Project [xpath_boolean(s#0, a/b) AS xpath_boolean(s, a/b)#0]
+Project [invoke(XPathBooleanEvaluator(a/b).evaluate(s#0)) AS xpath_boolean(s, 
a/b)#0]
 +- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_double.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_double.explain
index 9ce47136df24..76e0b0172184 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_double.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_double.explain
@@ -1,2 +1,2 @@
-Project [xpath_double(s#0, a/b) AS xpath_double(s, a/b)#0]
+Project [invoke(XPathDoubleEvaluator(a/b).evaluate(s#0)) AS xpath_double(s, 
a/b)#0]
 +- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_float.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_float.explain
index 02b29ec4afa9..21aebb357928 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_float.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_float.explain
@@ -1,2 +1,2 @@
-Project [xpath_float(s#0, a/b) AS xpath_float(s, a/b)#0]
+Project [invoke(XPathFloatEvaluator(a/b).evaluate(s#0)) AS xpath_float(s, 
a/b)#0]
 +- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_int.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_int.explain
index cdd56eaa7319..eee74472b1cf 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_int.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_int.explain
@@ -1,2 +1,2 @@
-Project [xpath_int(s#0, a/b) AS xpath_int(s, a/b)#0]
+Project [invoke(XPathIntEvaluator(a/b).evaluate(s#0)) AS xpath_int(s, a/b)#0]
 +- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_long.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_long.explain
index 3acefb13d0f8..8356c2c8e18c 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_long.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_long.explain
@@ -1,2 +1,2 @@
-Project [xpath_long(s#0, a/b) AS xpath_long(s, a/b)#0L]
+Project [invoke(XPathLongEvaluator(a/b).evaluate(s#0)) AS xpath_long(s, 
a/b)#0L]
 +- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_number.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_number.explain
index 0a30685f0c6d..bc32d4fefffb 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_number.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_number.explain
@@ -1,2 +1,2 @@
-Project [xpath_number(s#0, a/b) AS xpath_number(s, a/b)#0]
+Project [invoke(XPathDoubleEvaluator(a/b).evaluate(s#0)) AS xpath_number(s, 
a/b)#0]
 +- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_short.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_short.explain
index ed440972bf49..e0ba76b3acd0 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_short.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_short.explain
@@ -1,2 +1,2 @@
-Project [xpath_short(s#0, a/b) AS xpath_short(s, a/b)#0]
+Project [invoke(XPathShortEvaluator(a/b).evaluate(s#0)) AS xpath_short(s, 
a/b)#0]
 +- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
diff --git 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_string.explain
 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_string.explain
index f4103f68c3bc..80f2600e6cdd 100644
--- 
a/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_string.explain
+++ 
b/sql/connect/common/src/test/resources/query-tests/explain-results/function_xpath_string.explain
@@ -1,2 +1,2 @@
-Project [xpath_string(s#0, a/b) AS xpath_string(s, a/b)#0]
+Project [invoke(XPathStringEvaluator(a/b).evaluate(s#0)) AS xpath_string(s, 
a/b)#0]
 +- LocalRelation <empty>, [d#0, t#0, s#0, x#0L, wt#0]
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala
index f08466e8f8d9..f2a86cbf5415 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/XPathFunctionsSuite.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.sql.catalyst.expressions.IsNotNull
+import org.apache.spark.sql.execution.FilterExec
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSparkSession
 
@@ -76,4 +78,38 @@ class XPathFunctionsSuite extends QueryTest with 
SharedSparkSession {
     checkAnswer(df.select(xpath(col("xml"), lit("a/*/text()"))),
       Row(Seq("b1", "b2", "b3", "c1", "c2")))
   }
+
+  test("The replacement of `xpath*` functions should be NullIntolerant") {
+    def check(df: DataFrame, expected: Seq[Row]): Unit = {
+      val filter = df.queryExecution
+        .sparkPlan
+        .find(_.isInstanceOf[FilterExec])
+        .get.asInstanceOf[FilterExec]
+      assert(filter.condition.find(_.isInstanceOf[IsNotNull]).nonEmpty)
+      checkAnswer(df, expected)
+    }
+    withTable("t") {
+      sql("CREATE TABLE t AS SELECT * FROM VALUES ('<a><b>1</b></a>'), (NULL) 
T(xml)")
+      check(sql("SELECT * FROM t WHERE xpath_boolean(xml, 'a/b') = true"),
+        Seq(Row("<a><b>1</b></a>")))
+      check(sql("SELECT * FROM t WHERE xpath_short(xml, 'a/b') = 1"),
+        Seq(Row("<a><b>1</b></a>")))
+      check(sql("SELECT * FROM t WHERE xpath_int(xml, 'a/b') = 1"),
+        Seq(Row("<a><b>1</b></a>")))
+      check(sql("SELECT * FROM t WHERE xpath_long(xml, 'a/b') = 1"),
+        Seq(Row("<a><b>1</b></a>")))
+      check(sql("SELECT * FROM t WHERE xpath_float(xml, 'a/b') = 1"),
+        Seq(Row("<a><b>1</b></a>")))
+      check(sql("SELECT * FROM t WHERE xpath_double(xml, 'a/b') = 1"),
+        Seq(Row("<a><b>1</b></a>")))
+      check(sql("SELECT * FROM t WHERE xpath_string(xml, 'a/b') = '1'"),
+        Seq(Row("<a><b>1</b></a>")))
+    }
+    withTable("t") {
+      sql("CREATE TABLE t AS SELECT * FROM VALUES " +
+        "('<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>'), (NULL) 
T(xml)")
+      check(sql("SELECT * FROM t WHERE xpath(xml, 'a/b/text()') = array('b1', 
'b2', 'b3')"),
+        Seq(Row("<a><b>b1</b><b>b2</b><b>b3</b><c>c1</c><c>c2</c></a>")))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to