This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1a47233  [SPARK-26493][SQL] Allow multiple spark.sql.extensions
1a47233 is described below

commit 1a47233f998cd26bac06fa5529a1755a3758d198
Author: Jamison Bennett <jamison.benn...@gmail.com>
AuthorDate: Thu Jan 10 10:23:03 2019 +0800

    [SPARK-26493][SQL] Allow multiple spark.sql.extensions
    
    ## What changes were proposed in this pull request?
    
    Allow multiple spark.sql.extensions to be specified in the
    configuration.
    
    ## How was this patch tested?
    
    New tests are added.
    
    Closes #23398 from jamisonbennett/SPARK-26493.
    
    Authored-by: Jamison Bennett <jamison.benn...@gmail.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../sql/catalyst/analysis/FunctionRegistry.scala   |  11 +-
 .../apache/spark/sql/internal/StaticSQLConf.scala  |  10 +-
 .../scala/org/apache/spark/sql/SparkSession.scala  |  11 +-
 .../apache/spark/sql/SparkSessionExtensions.scala  |  24 ++-
 .../spark/sql/SparkSessionExtensionSuite.scala     | 167 +++++++++++++++++++--
 5 files changed, 198 insertions(+), 25 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index c79f990..befc02f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -25,6 +25,7 @@ import scala.language.existentials
 import scala.reflect.ClassTag
 import scala.util.{Failure, Success, Try}
 
+import org.apache.spark.internal.Logging
 import org.apache.spark.sql.AnalysisException
 import org.apache.spark.sql.catalyst.FunctionIdentifier
 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
@@ -87,7 +88,7 @@ trait FunctionRegistry {
   override def clone(): FunctionRegistry = throw new 
CloneNotSupportedException()
 }
 
-class SimpleFunctionRegistry extends FunctionRegistry {
+class SimpleFunctionRegistry extends FunctionRegistry with Logging {
 
   @GuardedBy("this")
   private val functionBuilders =
@@ -103,7 +104,13 @@ class SimpleFunctionRegistry extends FunctionRegistry {
       name: FunctionIdentifier,
       info: ExpressionInfo,
       builder: FunctionBuilder): Unit = synchronized {
-    functionBuilders.put(normalizeFuncName(name), (info, builder))
+    val normalizedName = normalizeFuncName(name)
+    val newFunction = (info, builder)
+    functionBuilders.put(normalizedName, newFunction) match {
+      case Some(previousFunction) if previousFunction != newFunction =>
+        logWarning(s"The function $normalizedName replaced a previously 
registered function.")
+      case _ =>
+    }
   }
 
   override def lookupFunction(name: FunctionIdentifier, children: 
Seq[Expression]): Expression = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
index d9c354b..0a8dc28 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala
@@ -99,9 +99,15 @@ object StaticSQLConf {
       .createWithDefault(false)
 
   val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions")
-    .doc("Name of the class used to configure Spark Session extensions. The 
class should " +
-      "implement Function1[SparkSessionExtension, Unit], and must have a 
no-args constructor.")
+    .doc("A comma-separated list of classes that implement " +
+      "Function1[SparkSessionExtension, Unit] used to configure Spark Session 
extensions. The " +
+      "classes must have a no-args constructor. If multiple extensions are 
specified, they are " +
+      "applied in the specified order. For the case of rules and planner 
strategies, they are " +
+      "applied in the specified order. For the case of parsers, the last 
parser is used and each " +
+      "parser can delegate to its predecessor. For the case of function name 
conflicts, the last " +
+      "registered function name is used.")
     .stringConf
+    .toSequence
     .createOptional
 
   val QUERY_EXECUTION_LISTENERS = 
buildStaticConf("spark.sql.queryExecutionListeners")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 26272c3..1c13a68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -93,7 +93,7 @@ class SparkSession private(
   private[sql] def this(sc: SparkContext) {
     this(sc, None, None,
       SparkSession.applyExtensions(
-        sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
+        
sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty),
         new SparkSessionExtensions))
   }
 
@@ -950,7 +950,7 @@ object SparkSession extends Logging {
         }
 
         applyExtensions(
-          sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS),
+          
sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty),
           extensions)
 
         session = new SparkSession(sparkContext, None, None, extensions)
@@ -1138,14 +1138,13 @@ object SparkSession extends Logging {
   }
 
   /**
-   * Initialize extensions for given extension classname. This class will be 
applied to the
+   * Initialize extensions for given extension classnames. The classes will be 
applied to the
    * extensions passed into this function.
    */
   private def applyExtensions(
-      extensionOption: Option[String],
+      extensionConfClassNames: Seq[String],
       extensions: SparkSessionExtensions): SparkSessionExtensions = {
-    if (extensionOption.isDefined) {
-      val extensionConfClassName = extensionOption.get
+    extensionConfClassNames.foreach { extensionConfClassName =>
       try {
         val extensionConfClass = Utils.classForName(extensionConfClassName)
         val extensionConf = extensionConfClass.getConstructor().newInstance()
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
index 5ed7678..66becf3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala
@@ -44,12 +44,12 @@ import org.apache.spark.sql.catalyst.rules.Rule
  * <li>(External) Catalog listeners.</li>
  * </ul>
  *
- * The extensions can be used by calling withExtension on the 
[[SparkSession.Builder]], for
+ * The extensions can be used by calling `withExtensions` on the 
[[SparkSession.Builder]], for
  * example:
  * {{{
  *   SparkSession.builder()
  *     .master("...")
- *     .conf("...", true)
+ *     .config("...", true)
  *     .withExtensions { extensions =>
  *       extensions.injectResolutionRule { session =>
  *         ...
@@ -61,6 +61,26 @@ import org.apache.spark.sql.catalyst.rules.Rule
  *     .getOrCreate()
  * }}}
  *
+ * The extensions can also be used by setting the Spark SQL configuration 
property
+ * `spark.sql.extensions`. Multiple extensions can be set using a 
comma-separated list. For example:
+ * {{{
+ *   SparkSession.builder()
+ *     .master("...")
+ *     .config("spark.sql.extensions", "org.example.MyExtensions")
+ *     .getOrCreate()
+ *
+ *   class MyExtensions extends Function1[SparkSessionExtensions, Unit] {
+ *     override def apply(extensions: SparkSessionExtensions): Unit = {
+ *       extensions.injectResolutionRule { session =>
+ *         ...
+ *       }
+ *       extensions.injectParser { (session, parser) =>
+ *         ...
+ *       }
+ *     }
+ *   }
+ * }}}
+ *
  * Note that none of the injected builders should assume that the 
[[SparkSession]] is fully
  * initialized and should not touch the session's internals (e.g. the 
SessionState).
  */
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
index 234711e..9f33feb 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{DataType, IntegerType, 
StructType}
  */
 class SparkSessionExtensionSuite extends SparkFunSuite {
   type ExtensionsBuilder = SparkSessionExtensions => Unit
-  private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder
+  private def create(builder: ExtensionsBuilder): Seq[ExtensionsBuilder] = 
Seq(builder)
 
   private def stop(spark: SparkSession): Unit = {
     spark.stop()
@@ -38,55 +38,71 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
     SparkSession.clearDefaultSession()
   }
 
-  private def withSession(builder: ExtensionsBuilder)(f: SparkSession => 
Unit): Unit = {
-    val spark = 
SparkSession.builder().master("local[1]").withExtensions(builder).getOrCreate()
+  private def withSession(builders: Seq[ExtensionsBuilder])(f: SparkSession => 
Unit): Unit = {
+    val builder = SparkSession.builder().master("local[1]")
+    builders.foreach(builder.withExtensions)
+    val spark = builder.getOrCreate()
     try f(spark) finally {
       stop(spark)
     }
   }
 
   test("inject analyzer rule") {
-    withSession(_.injectResolutionRule(MyRule)) { session =>
+    withSession(Seq(_.injectResolutionRule(MyRule))) { session =>
       
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
     }
   }
 
+  test("inject post hoc resolution analyzer rule") {
+    withSession(Seq(_.injectPostHocResolutionRule(MyRule))) { session =>
+      
assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session)))
+    }
+  }
+
   test("inject check analysis rule") {
-    withSession(_.injectCheckRule(MyCheckRule)) { session =>
+    withSession(Seq(_.injectCheckRule(MyCheckRule))) { session =>
       
assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session)))
     }
   }
 
   test("inject optimizer rule") {
-    withSession(_.injectOptimizerRule(MyRule)) { session =>
+    withSession(Seq(_.injectOptimizerRule(MyRule))) { session =>
       
assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session)))
     }
   }
 
   test("inject spark planner strategy") {
-    withSession(_.injectPlannerStrategy(MySparkStrategy)) { session =>
+    withSession(Seq(_.injectPlannerStrategy(MySparkStrategy))) { session =>
       
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
     }
   }
 
   test("inject parser") {
     val extension = create { extensions =>
-      extensions.injectParser((_, _) => CatalystSqlParser)
+      extensions.injectParser((_: SparkSession, _: ParserInterface) => 
CatalystSqlParser)
     }
     withSession(extension) { session =>
-      assert(session.sessionState.sqlParser == CatalystSqlParser)
+      assert(session.sessionState.sqlParser === CatalystSqlParser)
+    }
+  }
+
+  test("inject multiple rules") {
+    withSession(Seq(_.injectOptimizerRule(MyRule),
+        _.injectPlannerStrategy(MySparkStrategy))) { session =>
+      
assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session)))
+      
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
     }
   }
 
   test("inject stacked parsers") {
     val extension = create { extensions =>
-      extensions.injectParser((_, _) => CatalystSqlParser)
+      extensions.injectParser((_: SparkSession, _: ParserInterface) => 
CatalystSqlParser)
       extensions.injectParser(MyParser)
       extensions.injectParser(MyParser)
     }
     withSession(extension) { session =>
       val parser = MyParser(session, MyParser(session, CatalystSqlParser))
-      assert(session.sessionState.sqlParser == parser)
+      assert(session.sessionState.sqlParser === parser)
     }
   }
 
@@ -108,12 +124,89 @@ class SparkSessionExtensionSuite extends SparkFunSuite {
     try {
       
assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session)))
       
assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session)))
+      
assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session)))
+      
assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session)))
+      
assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session)))
+      assert(session.sessionState.sqlParser.isInstanceOf[MyParser])
+      assert(session.sessionState.functionRegistry
+        .lookupFunction(MyExtensions.myFunction._1).isDefined)
+    } finally {
+      stop(session)
+    }
+  }
+
+  test("use multiple custom class for extensions in the specified order") {
+    val session = SparkSession.builder()
+      .master("local[1]")
+      .config("spark.sql.extensions", Seq(
+        classOf[MyExtensions2].getCanonicalName,
+        classOf[MyExtensions].getCanonicalName).mkString(","))
+      .getOrCreate()
+    try {
+      assert(session.sessionState.planner.strategies.containsSlice(
+        Seq(MySparkStrategy2(session), MySparkStrategy(session))))
+      val orderedRules = Seq(MyRule2(session), MyRule(session))
+      val orderedCheckRules = Seq(MyCheckRule2(session), MyCheckRule(session))
+      val parser = MyParser(session, CatalystSqlParser)
+      
assert(session.sessionState.analyzer.extendedResolutionRules.containsSlice(orderedRules))
+      
assert(session.sessionState.analyzer.postHocResolutionRules.containsSlice(orderedRules))
+      
assert(session.sessionState.analyzer.extendedCheckRules.containsSlice(orderedCheckRules))
+      
assert(session.sessionState.optimizer.batches.flatMap(_.rules).filter(orderedRules.contains)
+        .containsSlice(orderedRules ++ orderedRules)) // The optimizer rules 
are duplicated
+      assert(session.sessionState.sqlParser === parser)
+      assert(session.sessionState.functionRegistry
+        .lookupFunction(MyExtensions.myFunction._1).isDefined)
+      assert(session.sessionState.functionRegistry
+        .lookupFunction(MyExtensions2.myFunction._1).isDefined)
+    } finally {
+      stop(session)
+    }
+  }
+
+  test("allow an extension to be duplicated") {
+    val session = SparkSession.builder()
+      .master("local[1]")
+      .config("spark.sql.extensions", Seq(
+        classOf[MyExtensions].getCanonicalName,
+        classOf[MyExtensions].getCanonicalName).mkString(","))
+      .getOrCreate()
+    try {
+      assert(session.sessionState.planner.strategies.count(_ === 
MySparkStrategy(session)) === 2)
+      assert(session.sessionState.analyzer.extendedResolutionRules.count(_ === 
MyRule(session)) ===
+        2)
+      assert(session.sessionState.analyzer.postHocResolutionRules.count(_ === 
MyRule(session)) ===
+        2)
+      assert(session.sessionState.analyzer.extendedCheckRules.count(_ === 
MyCheckRule(session)) ===
+        2)
+      assert(session.sessionState.optimizer.batches.flatMap(_.rules)
+        .count(_ === MyRule(session)) === 4) // The optimizer rules are 
duplicated
+      val outerParser = session.sessionState.sqlParser
+      assert(outerParser.isInstanceOf[MyParser])
+      
assert(outerParser.asInstanceOf[MyParser].delegate.isInstanceOf[MyParser])
       assert(session.sessionState.functionRegistry
         .lookupFunction(MyExtensions.myFunction._1).isDefined)
     } finally {
       stop(session)
     }
   }
+
+  test("use the last registered function name when there are duplicates") {
+    val session = SparkSession.builder()
+      .master("local[1]")
+      .config("spark.sql.extensions", Seq(
+        classOf[MyExtensions2].getCanonicalName,
+        classOf[MyExtensions2Duplicate].getCanonicalName).mkString(","))
+      .getOrCreate()
+    try {
+      val lastRegistered = session.sessionState.functionRegistry
+        .lookupFunction(FunctionIdentifier("myFunction2"))
+      assert(lastRegistered.isDefined)
+      assert(lastRegistered.get !== MyExtensions2.myFunction._2)
+      assert(lastRegistered.get === MyExtensions2Duplicate.myFunction._2)
+    } finally {
+      stop(session)
+    }
+  }
 }
 
 case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] {
@@ -151,14 +244,62 @@ case class MyParser(spark: SparkSession, delegate: 
ParserInterface) extends Pars
 object MyExtensions {
 
   val myFunction = (FunctionIdentifier("myFunction"),
-    new ExpressionInfo("noClass", "myDb", "myFunction", "usage", "extended 
usage" ),
-    (myArgs: Seq[Expression]) => Literal(5, IntegerType))
+    new ExpressionInfo("noClass", "myDb", "myFunction", "usage", "extended 
usage"),
+    (_: Seq[Expression]) => Literal(5, IntegerType))
 }
 
 class MyExtensions extends (SparkSessionExtensions => Unit) {
   def apply(e: SparkSessionExtensions): Unit = {
     e.injectPlannerStrategy(MySparkStrategy)
     e.injectResolutionRule(MyRule)
+    e.injectPostHocResolutionRule(MyRule)
+    e.injectCheckRule(MyCheckRule)
+    e.injectOptimizerRule(MyRule)
+    e.injectParser(MyParser)
     e.injectFunction(MyExtensions.myFunction)
   }
 }
+
+case class MyRule2(spark: SparkSession) extends Rule[LogicalPlan] {
+  override def apply(plan: LogicalPlan): LogicalPlan = plan
+}
+
+case class MyCheckRule2(spark: SparkSession) extends (LogicalPlan => Unit) {
+  override def apply(plan: LogicalPlan): Unit = { }
+}
+
+case class MySparkStrategy2(spark: SparkSession) extends SparkStrategy {
+  override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty
+}
+
+object MyExtensions2 {
+
+  val myFunction = (FunctionIdentifier("myFunction2"),
+    new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "extended 
usage"),
+    (_: Seq[Expression]) => Literal(5, IntegerType))
+}
+
+class MyExtensions2 extends (SparkSessionExtensions => Unit) {
+  def apply(e: SparkSessionExtensions): Unit = {
+    e.injectPlannerStrategy(MySparkStrategy2)
+    e.injectResolutionRule(MyRule2)
+    e.injectPostHocResolutionRule(MyRule2)
+    e.injectCheckRule(MyCheckRule2)
+    e.injectOptimizerRule(MyRule2)
+    e.injectParser((_: SparkSession, _: ParserInterface) => CatalystSqlParser)
+    e.injectFunction(MyExtensions2.myFunction)
+  }
+}
+
+object MyExtensions2Duplicate {
+
+  val myFunction = (FunctionIdentifier("myFunction2"),
+    new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "extended 
usage"),
+    (_: Seq[Expression]) => Literal(5, IntegerType))
+}
+
+class MyExtensions2Duplicate extends (SparkSessionExtensions => Unit) {
+  def apply(e: SparkSessionExtensions): Unit = {
+    e.injectFunction(MyExtensions2Duplicate.myFunction)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to