This is an automated email from the ASF dual-hosted git repository.
marong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new d4d724143 [VL] Support create temporary function for native hive udf
(#6829)
d4d724143 is described below
commit d4d724143286eec63605e10585a6133ce3c52b9d
Author: Rong Ma <[email protected]>
AuthorDate: Sun Aug 25 13:11:33 2024 +0800
[VL] Support create temporary function for native hive udf (#6829)
---
backends-velox/pom.xml | 7 ++
.../gluten/backendsapi/velox/VeloxRuleApi.scala | 2 +-
.../backendsapi/velox/VeloxSparkPlanExecApi.scala | 7 ++
.../apache/spark/sql/expression/UDFResolver.scala | 30 ++++---
.../spark/sql/hive/VeloxHiveUDFTransformer.scala | 40 ++++-----
.../apache/gluten/expression/VeloxUdfSuite.scala | 99 ++++++++++++++++++++++
cpp/velox/udf/examples/MyUDF.cc | 39 +++++++++
.../gluten/backendsapi/SparkPlanExecApi.scala | 8 +-
.../gluten/expression/ExpressionConverter.scala | 4 +-
.../apache/spark/sql/hive/HiveUDFTransformer.scala | 6 ++
10 files changed, 201 insertions(+), 41 deletions(-)
diff --git a/backends-velox/pom.xml b/backends-velox/pom.xml
index 0fe8f5f6f..417f64999 100755
--- a/backends-velox/pom.xml
+++ b/backends-velox/pom.xml
@@ -140,6 +140,13 @@
<artifactId>spark-core_${scala.binary.version}</artifactId>
<type>test-jar</type>
</dependency>
+ <dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-hive_${scala.binary.version}</artifactId>
+ <version>${spark.version}</version>
+ <type>test-jar</type>
+ <scope>test</scope>
+ </dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
index abb39c5bb..438895b25 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala
@@ -47,7 +47,7 @@ private object VeloxRuleApi {
// Regular Spark rules.
injector.injectOptimizerRule(CollectRewriteRule.apply)
injector.injectOptimizerRule(HLLRewriteRule.apply)
- UDFResolver.getFunctionSignatures.foreach(injector.injectFunction)
+ UDFResolver.getFunctionSignatures().foreach(injector.injectFunction)
injector.injectPostHocResolutionRule(ArrowConvertorRule.apply)
}
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
index bd390004f..554b3791d 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala
@@ -50,6 +50,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.utils.ExecUtil
import org.apache.spark.sql.expression.{UDFExpression,
UserDefinedAggregateFunction}
+import org.apache.spark.sql.hive.VeloxHiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -819,4 +820,10 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
case other => other
}
}
+
+ override def genHiveUDFTransformer(
+ expr: Expression,
+ attributeSeq: Seq[Attribute]): ExpressionTransformer = {
+ VeloxHiveUDFTransformer.replaceWithExpressionTransformer(expr,
attributeSeq)
+ }
}
diff --git
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
index ab83c55ee..39032e46f 100644
---
a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.expression
import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
-import org.apache.gluten.exception.GlutenException
+import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}
import org.apache.gluten.expression.{ConverterUtils, ExpressionTransformer,
ExpressionType, GenericExpressionTransformer, Transformable}
import org.apache.gluten.udf.UdfJniWrapper
import org.apache.gluten.vectorized.JniWorkspace
@@ -95,11 +95,14 @@ case class UDAFSignature(
case class UDFExpression(
name: String,
+ alias: String,
dataType: DataType,
nullable: Boolean,
children: Seq[Expression])
extends Unevaluable
with Transformable {
+ override def nodeName: String = alias
+
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
this.copy(children = newChildren)
@@ -118,11 +121,11 @@ case class UDFExpression(
}
object UDFResolver extends Logging {
- private val UDFNames = mutable.HashSet[String]()
+ val UDFNames = mutable.HashSet[String]()
// (udf_name, arg1, arg2, ...) => return type
private val UDFMap = mutable.HashMap[String,
mutable.ListBuffer[UDFSignature]]()
- private val UDAFNames = mutable.HashSet[String]()
+ val UDAFNames = mutable.HashSet[String]()
// (udaf_name, arg1, arg2, ...) => return type, intermediate attributes
private val UDAFMap =
mutable.HashMap[String, mutable.ListBuffer[UDAFSignature]]()
@@ -331,7 +334,7 @@ object UDFResolver extends Logging {
.mkString(",")
}
- def getFunctionSignatures: Seq[(FunctionIdentifier, ExpressionInfo,
FunctionBuilder)] = {
+ def getFunctionSignatures(): Seq[(FunctionIdentifier, ExpressionInfo,
FunctionBuilder)] = {
val sparkContext = SparkContext.getActive.get
val sparkConf = sparkContext.conf
val udfLibPaths =
sparkConf.getOption(VeloxBackendSettings.GLUTEN_VELOX_UDF_LIB_PATHS)
@@ -341,13 +344,12 @@ object UDFResolver extends Logging {
Seq.empty
case Some(_) =>
UdfJniWrapper.getFunctionSignatures()
-
UDFNames.map {
name =>
(
new FunctionIdentifier(name),
new ExpressionInfo(classOf[UDFExpression].getName, name),
- (e: Seq[Expression]) => getUdfExpression(name)(e))
+ (e: Seq[Expression]) => getUdfExpression(name, name)(e))
}.toSeq ++ UDAFNames.map {
name =>
(
@@ -364,27 +366,29 @@ object UDFResolver extends Logging {
.toBoolean
}
- private def getUdfExpression(name: String)(children: Seq[Expression]) = {
+ def getUdfExpression(name: String, alias: String)(children:
Seq[Expression]): UDFExpression = {
def errorMessage: String =
s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")}
is not registered."
val allowTypeConversion = checkAllowTypeConversion
val signatures =
- UDFMap.getOrElse(name, throw new
UnsupportedOperationException(errorMessage));
+ UDFMap.getOrElse(name, throw new
GlutenNotSupportException(errorMessage));
signatures.find(sig => tryBind(sig, children.map(_.dataType),
allowTypeConversion)) match {
case Some(sig) =>
UDFExpression(
name,
+ alias,
sig.expressionType.dataType,
sig.expressionType.nullable,
if (!allowTypeConversion && !sig.allowTypeConversion) children
- else applyCast(children, sig))
+ else applyCast(children, sig)
+ )
case None =>
- throw new UnsupportedOperationException(errorMessage)
+ throw new GlutenNotSupportException(errorMessage)
}
}
- private def getUdafExpression(name: String)(children: Seq[Expression]) = {
+ def getUdafExpression(name: String)(children: Seq[Expression]):
UserDefinedAggregateFunction = {
def errorMessage: String =
s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")}
is not registered."
@@ -392,7 +396,7 @@ object UDFResolver extends Logging {
val signatures =
UDAFMap.getOrElse(
name,
- throw new UnsupportedOperationException(errorMessage)
+ throw new GlutenNotSupportException(errorMessage)
)
signatures.find(sig => tryBind(sig, children.map(_.dataType),
allowTypeConversion)) match {
case Some(sig) =>
@@ -405,7 +409,7 @@ object UDFResolver extends Logging {
sig.intermediateAttrs
)
case None =>
- throw new UnsupportedOperationException(errorMessage)
+ throw new GlutenNotSupportException(errorMessage)
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala
b/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala
similarity index 63%
copy from
gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala
copy to
backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala
index 5cd64cc21..d895faa31 100644
---
a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala
+++
b/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala
@@ -17,43 +17,33 @@
package org.apache.spark.sql.hive
import org.apache.gluten.exception.GlutenNotSupportException
-import org.apache.gluten.expression.{ExpressionConverter,
ExpressionTransformer, GenericExpressionTransformer, UDFMappings}
+import org.apache.gluten.expression.{ExpressionConverter,
ExpressionTransformer}
-import org.apache.spark.sql.catalyst.expressions._
-
-import java.util.Locale
-
-object HiveUDFTransformer {
- def isHiveUDF(expr: Expression): Boolean = {
- expr match {
- case _: HiveSimpleUDF | _: HiveGenericUDF => true
- case _ => false
- }
- }
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.expression.UDFResolver
+object VeloxHiveUDFTransformer {
def replaceWithExpressionTransformer(
expr: Expression,
attributeSeq: Seq[Attribute]): ExpressionTransformer = {
- val udfName = expr match {
+ val (udfName, udfClassName) = expr match {
case s: HiveSimpleUDF =>
- s.name.stripPrefix("default.")
+ (s.name.stripPrefix("default."), s.funcWrapper.functionClassName)
case g: HiveGenericUDF =>
- g.name.stripPrefix("default.")
+ (g.name.stripPrefix("default."), g.funcWrapper.functionClassName)
case _ =>
throw new GlutenNotSupportException(
s"Expression $expr is not a HiveSimpleUDF or HiveGenericUDF")
}
- UDFMappings.hiveUDFMap.get(udfName.toLowerCase(Locale.ROOT)) match {
- case Some(name) =>
- GenericExpressionTransformer(
- name,
- ExpressionConverter.replaceWithExpressionTransformer(expr.children,
attributeSeq),
- expr)
- case _ =>
- throw new GlutenNotSupportException(
- s"Not supported hive udf:$expr"
- + s" name:$udfName hiveUDFMap:${UDFMappings.hiveUDFMap}")
+ if (UDFResolver.UDFNames.contains(udfClassName)) {
+ UDFResolver
+ .getUdfExpression(udfClassName, udfName)(expr.children)
+ .getTransformer(
+ ExpressionConverter.replaceWithExpressionTransformer(expr.children,
attributeSeq)
+ )
+ } else {
+ HiveUDFTransformer.genTransformerFromUDFMappings(udfName, expr,
attributeSeq)
}
}
}
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
index 008337b94..596757df3 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala
@@ -22,6 +22,7 @@ import org.apache.gluten.tags.{SkipTestTags, UDFTest}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{GlutenQueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.plans.SQLHelper
+import org.apache.spark.sql.expression.UDFResolver
import java.nio.file.Paths
import java.sql.Date
@@ -56,12 +57,31 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with
SQLHelper {
.builder()
.master(master)
.config(sparkConf)
+ .enableHiveSupport()
.getOrCreate()
}
_spark.sparkContext.setLogLevel("info")
}
+ override def afterAll(): Unit = {
+ try {
+ super.afterAll()
+ if (_spark != null) {
+ try {
+ _spark.sessionState.catalog.reset()
+ } finally {
+ _spark.stop()
+ _spark = null
+ }
+ }
+ } finally {
+ SparkSession.clearActiveSession()
+ SparkSession.clearDefaultSession()
+ doThreadPostAudit()
+ }
+ }
+
override protected def spark = _spark
protected def sparkConf: SparkConf = {
@@ -128,6 +148,85 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with
SQLHelper {
.sameElements(Array(Row(1.0, 1.0, 1L))))
}
}
+
+ test("test hive udf replacement") {
+ val tbl = "test_hive_udf_replacement"
+ withTempPath {
+ dir =>
+ try {
+ spark.sql(s"""
+ |CREATE EXTERNAL TABLE $tbl
+ |LOCATION 'file://$dir'
+ |AS select * from values (1, '1'), (2, '2'), (3, '3')
+ |""".stripMargin)
+
+ // Check native hive udf has been registered.
+ assert(
+
UDFResolver.UDFNames.contains("org.apache.spark.sql.hive.execution.UDFStringString"))
+
+ spark.sql("""
+ |CREATE TEMPORARY FUNCTION hive_string_string
+ |AS 'org.apache.spark.sql.hive.execution.UDFStringString'
+ |""".stripMargin)
+
+ val nativeResult =
+ spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM
$tbl""").collect()
+ // Unregister native hive udf to fallback.
+
UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString")
+ val fallbackResult =
+ spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM
$tbl""").collect()
+ assert(nativeResult.sameElements(fallbackResult))
+
+ // Add an unimplemented udf to the map to test fallback of
registered native hive udf.
+
UDFResolver.UDFNames.add("org.apache.spark.sql.hive.execution.UDFIntegerToString")
+ spark.sql("""
+ |CREATE TEMPORARY FUNCTION hive_int_to_string
+ |AS
'org.apache.spark.sql.hive.execution.UDFIntegerToString'
+ |""".stripMargin)
+ val df = spark.sql(s"""select hive_int_to_string(col1) from $tbl""")
+ checkAnswer(df, Seq(Row("1"), Row("2"), Row("3")))
+ } finally {
+ spark.sql(s"DROP TABLE IF EXISTS $tbl")
+ spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string")
+ spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_int_to_string")
+ }
+ }
+ }
+
+ test("test udf fallback in partition filter") {
+ withTempPath {
+ dir =>
+ try {
+ spark.sql("""
+ |CREATE TEMPORARY FUNCTION hive_int_to_string
+ |AS
'org.apache.spark.sql.hive.execution.UDFIntegerToString'
+ |""".stripMargin)
+
+ spark.sql(s"""
+ |CREATE EXTERNAL TABLE t(i INT, p INT)
+ |LOCATION 'file://$dir'
+ |PARTITIONED BY (p)""".stripMargin)
+
+ spark
+ .range(0, 10, 1)
+ .selectExpr("id as col")
+ .createOrReplaceTempView("temp")
+
+ for (part <- Seq(1, 2, 3, 4)) {
+ spark.sql(s"""
+ |INSERT OVERWRITE TABLE t PARTITION (p=$part)
+ |SELECT col FROM temp""".stripMargin)
+ }
+
+ val df = spark.sql("SELECT i FROM t WHERE hive_int_to_string(p) =
'4'")
+ checkAnswer(df, (0 until 10).map(Row(_)))
+ } finally {
+ spark.sql("DROP TABLE IF EXISTS t")
+ spark.sql("DROP VIEW IF EXISTS temp")
+ spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string")
+ }
+ }
+ }
}
@UDFTest
diff --git a/cpp/velox/udf/examples/MyUDF.cc b/cpp/velox/udf/examples/MyUDF.cc
index db1c5d770..75e68413a 100644
--- a/cpp/velox/udf/examples/MyUDF.cc
+++ b/cpp/velox/udf/examples/MyUDF.cc
@@ -30,6 +30,7 @@ namespace {
static const char* kInteger = "int";
static const char* kBigInt = "bigint";
static const char* kDate = "date";
+static const char* kVarChar = "varchar";
namespace myudf {
@@ -248,6 +249,43 @@ class MyDate2Registerer final : public
gluten::UdfRegisterer {
};
} // namespace mydate
+namespace hivestringstring {
+template <typename T>
+struct HiveStringStringFunction {
+ VELOX_DEFINE_FUNCTION_TYPES(T);
+
+ FOLLY_ALWAYS_INLINE void call(out_type<Varchar>& result, const
arg_type<Varchar>& a, const arg_type<Varchar>& b) {
+ result.append(a.data());
+ result.append(" ");
+ result.append(b.data());
+ }
+};
+
+// name: org.apache.spark.sql.hive.execution.UDFStringString
+// signatures:
+// varchar, varchar -> varchar
+// type: SimpleFunction
+class HiveStringStringRegisterer final : public gluten::UdfRegisterer {
+ public:
+ int getNumUdf() override {
+ return 1;
+ }
+
+ void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
+ // Set `allowTypeConversion` for hive udf.
+ udfEntries[index++] = {name_.c_str(), kVarChar, 2, arg_, false, true};
+ }
+
+ void registerSignatures() override {
+ facebook::velox::registerFunction<HiveStringStringFunction, Varchar,
Varchar, Varchar>({name_});
+ }
+
+ private:
+ const std::string name_ =
"org.apache.spark.sql.hive.execution.UDFStringString";
+ const char* arg_[2] = {kVarChar, kVarChar};
+};
+} // namespace hivestringstring
+
std::vector<std::shared_ptr<gluten::UdfRegisterer>>& globalRegisters() {
static std::vector<std::shared_ptr<gluten::UdfRegisterer>> registerers;
return registerers;
@@ -264,6 +302,7 @@ void setupRegisterers() {
registerers.push_back(std::make_shared<myudf::MyUdf3Registerer>());
registerers.push_back(std::make_shared<mydate::MyDateRegisterer>());
registerers.push_back(std::make_shared<mydate::MyDate2Registerer>());
+
registerers.push_back(std::make_shared<hivestringstring::HiveStringStringRegisterer>());
inited = true;
}
} // namespace
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
index 0227ed5da..fb87a9ac9 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala
@@ -41,7 +41,7 @@ import
org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.BuildSideRelation
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
-import org.apache.spark.sql.hive.HiveTableScanExecTransformer
+import org.apache.spark.sql.hive.{HiveTableScanExecTransformer,
HiveUDFTransformer}
import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType}
import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -670,4 +670,10 @@ trait SparkPlanExecApi {
DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale)
}
}
+
+ def genHiveUDFTransformer(
+ expr: Expression,
+ attributeSeq: Seq[Attribute]): ExpressionTransformer = {
+ HiveUDFTransformer.replaceWithExpressionTransformer(expr, attributeSeq)
+ }
}
diff --git
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
index 8bca5dbf8..d5ca31bb5 100644
---
a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
+++
b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala
@@ -128,7 +128,9 @@ object ExpressionConverter extends SQLConfHelper with
Logging {
case s: ScalaUDF =>
return replaceScalaUDFWithExpressionTransformer(s, attributeSeq,
expressionsMap)
case _ if HiveUDFTransformer.isHiveUDF(expr) =>
- return HiveUDFTransformer.replaceWithExpressionTransformer(expr,
attributeSeq)
+ return
BackendsApiManager.getSparkPlanExecApiInstance.genHiveUDFTransformer(
+ expr,
+ attributeSeq)
case i: StaticInvoke =>
val objectName = i.staticObject.getName.stripSuffix("$")
if (objectName.endsWith("UrlCodec")) {
diff --git
a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala
b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala
index 5cd64cc21..52739aaca 100644
---
a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala
+++
b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala
@@ -43,7 +43,13 @@ object HiveUDFTransformer {
throw new GlutenNotSupportException(
s"Expression $expr is not a HiveSimpleUDF or HiveGenericUDF")
}
+ genTransformerFromUDFMappings(udfName, expr, attributeSeq)
+ }
+ def genTransformerFromUDFMappings(
+ udfName: String,
+ expr: Expression,
+ attributeSeq: Seq[Attribute]): GenericExpressionTransformer = {
UDFMappings.hiveUDFMap.get(udfName.toLowerCase(Locale.ROOT)) match {
case Some(name) =>
GenericExpressionTransformer(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]