Repository: spark
Updated Branches:
refs/heads/master c8f7691c6 -> 6e0fc8b0f
[SPARK-25560][SQL] Allow FunctionInjection in SparkExtensions
This allows an implementer of Spark Session Extensions to utilize a
method "injectFunction" which will add a new function to the default
Spark Session Catalogue.
## What changes were proposed in this pull request?
Adds a new function to SparkSessionExtensions
def injectFunction(functionDescription: FunctionDescription)
Where function description is a new type
type FunctionDescription = (FunctionIdentifier, FunctionBuilder)
The functions are loaded in BaseSessionBuilder when the function registry does
not have a parent
function registry to get loaded from.
## How was this patch tested?
New unit tests are added for the extension in SparkSessionExtensionSuite
Closes #22576 from RussellSpitzer/SPARK-25560.
Authored-by: Russell Spitzer <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6e0fc8b0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6e0fc8b0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6e0fc8b0
Branch: refs/heads/master
Commit: 6e0fc8b0fc2798b6372d1101f7996f57bae8fea4
Parents: c8f7691
Author: Russell Spitzer <[email protected]>
Authored: Fri Oct 19 10:40:56 2018 +0200
Committer: Herman van Hovell <[email protected]>
Committed: Fri Oct 19 10:40:56 2018 +0200
----------------------------------------------------------------------
.../spark/sql/SparkSessionExtensions.scala | 22 ++++++++++++++++++
.../sql/internal/BaseSessionStateBuilder.scala | 3 ++-
.../spark/sql/SparkSessionExtensionSuite.scala | 24 ++++++++++++++++++--
3 files changed, 46 insertions(+), 3 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/6e0fc8b0/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index 6b02ac2..a486434 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -20,6 +20,10 @@ package org.apache.spark.sql
import scala.collection.mutable
import org.apache.spark.annotation.{DeveloperApi, Experimental,
InterfaceStability}
+import org.apache.spark.sql.catalyst.FunctionIdentifier
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
@@ -68,6 +72,7 @@ class SparkSessionExtensions {
type CheckRuleBuilder = SparkSession => LogicalPlan => Unit
type StrategyBuilder = SparkSession => Strategy
type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface
+ type FunctionDescription = (FunctionIdentifier, ExpressionInfo,
FunctionBuilder)
private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder]
@@ -171,4 +176,21 @@ class SparkSessionExtensions {
def injectParser(builder: ParserBuilder): Unit = {
parserBuilders += builder
}
+
+ private[this] val injectedFunctions =
mutable.Buffer.empty[FunctionDescription]
+
+ private[sql] def registerFunctions(functionRegistry: FunctionRegistry) = {
+ for ((name, expressionInfo, function) <- injectedFunctions) {
+ functionRegistry.registerFunction(name, expressionInfo, function)
+ }
+ functionRegistry
+ }
+
+ /**
+ * Injects a custom function into the
[[org.apache.spark.sql.catalyst.analysis.FunctionRegistry]]
+ * at runtime for all sessions.
+ */
+ def injectFunction(functionDescription: FunctionDescription): Unit = {
+ injectedFunctions += functionDescription
+ }
}
http://git-wip-us.apache.org/repos/asf/spark/blob/6e0fc8b0/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 60bba5e..f67cc32 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -95,7 +95,8 @@ abstract class BaseSessionStateBuilder(
* This either gets cloned from a pre-existing version or cloned from the
built-in registry.
*/
protected lazy val functionRegistry: FunctionRegistry = {
-
parentState.map(_.functionRegistry).getOrElse(FunctionRegistry.builtin).clone()
+ parentState.map(_.functionRegistry.clone())
+
.getOrElse(extensions.registerFunctions(FunctionRegistry.builtin.clone()))
}
/**
http://git-wip-us.apache.org/repos/asf/spark/blob/6e0fc8b0/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
----------------------------------------------------------------------
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 43db796..234711e 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -18,12 +18,12 @@ package org.apache.spark.sql
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo,
Literal}
import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser,
ParserInterface}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy}
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{DataType, IntegerType, StructType}
/**
* Test cases for the [[SparkSessionExtensions]].
@@ -90,6 +90,16 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
}
}
+ test("inject function") {
+ val extensions = create { extensions =>
+ extensions.injectFunction(MyExtensions.myFunction)
+ }
+ withSession(extensions) { session =>
+ assert(session.sessionState.functionRegistry
+ .lookupFunction(MyExtensions.myFunction._1).isDefined)
+ }
+ }
+
test("use custom class for extensions") {
val session = SparkSession.builder()
.master("local[1]")
@@ -98,6 +108,8 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
try {
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
+ assert(session.sessionState.functionRegistry
+ .lookupFunction(MyExtensions.myFunction._1).isDefined)
} finally {
stop(session)
}
@@ -136,9 +148,17 @@ case class MyParser(spark: SparkSession, delegate:
ParserInterface) extends Pars
delegate.parseDataType(sqlText)
}
+object MyExtensions {
+
+ val myFunction = (FunctionIdentifier("myFunction"),
+ new ExpressionInfo("noClass", "myDb", "myFunction", "usage", "extended
usage" ),
+ (myArgs: Seq[Expression]) => Literal(5, IntegerType))
+}
+
class MyExtensions extends (SparkSessionExtensions => Unit) {
def apply(e: SparkSessionExtensions): Unit = {
e.injectPlannerStrategy(MySparkStrategy)
e.injectResolutionRule(MyRule)
+ e.injectFunction(MyExtensions.myFunction)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]