This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 1a47233 [SPARK-26493][SQL] Allow multiple spark.sql.extensions 1a47233 is described below commit 1a47233f998cd26bac06fa5529a1755a3758d198 Author: Jamison Bennett <jamison.benn...@gmail.com> AuthorDate: Thu Jan 10 10:23:03 2019 +0800 [SPARK-26493][SQL] Allow multiple spark.sql.extensions ## What changes were proposed in this pull request? Allow multiple spark.sql.extensions to be specified in the configuration. ## How was this patch tested? New tests are added. Closes #23398 from jamisonbennett/SPARK-26493. Authored-by: Jamison Bennett <jamison.benn...@gmail.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../sql/catalyst/analysis/FunctionRegistry.scala | 11 +- .../apache/spark/sql/internal/StaticSQLConf.scala | 10 +- .../scala/org/apache/spark/sql/SparkSession.scala | 11 +- .../apache/spark/sql/SparkSessionExtensions.scala | 24 ++- .../spark/sql/SparkSessionExtensionSuite.scala | 167 +++++++++++++++++++-- 5 files changed, 198 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index c79f990..befc02f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -25,6 +25,7 @@ import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder @@ -87,7 +88,7 @@ trait FunctionRegistry { override def clone(): FunctionRegistry = throw new CloneNotSupportedException() } -class SimpleFunctionRegistry extends FunctionRegistry { +class SimpleFunctionRegistry extends FunctionRegistry with Logging { @GuardedBy("this") private val functionBuilders = @@ -103,7 +104,13 @@ class SimpleFunctionRegistry extends FunctionRegistry { name: FunctionIdentifier, info: ExpressionInfo, builder: FunctionBuilder): Unit = synchronized { - functionBuilders.put(normalizeFuncName(name), (info, builder)) + val normalizedName = normalizeFuncName(name) + val newFunction = (info, builder) + functionBuilders.put(normalizedName, newFunction) match { + case Some(previousFunction) if previousFunction != newFunction => + logWarning(s"The function $normalizedName replaced a previously registered function.") + case _ => + } } override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index d9c354b..0a8dc28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -99,9 +99,15 @@ object StaticSQLConf { .createWithDefault(false) val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions") - .doc("Name of the class used to configure Spark Session extensions. The class should " + - "implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.") + .doc("A comma-separated list of classes that implement " + + "Function1[SparkSessionExtension, Unit] used to configure Spark Session extensions. The " + + "classes must have a no-args constructor. If multiple extensions are specified, they are " + + "applied in the specified order. For the case of rules and planner strategies, they are " + + "applied in the specified order. For the case of parsers, the last parser is used and each " + + "parser can delegate to its predecessor. For the case of function name conflicts, the last " + + "registered function name is used.") .stringConf + .toSequence .createOptional val QUERY_EXECUTION_LISTENERS = buildStaticConf("spark.sql.queryExecutionListeners") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 26272c3..1c13a68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -93,7 +93,7 @@ class SparkSession private( private[sql] def this(sc: SparkContext) { this(sc, None, None, SparkSession.applyExtensions( - sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS), + sc.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty), new SparkSessionExtensions)) } @@ -950,7 +950,7 @@ object SparkSession extends Logging { } applyExtensions( - sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS), + sparkContext.getConf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS).getOrElse(Seq.empty), extensions) session = new SparkSession(sparkContext, None, None, extensions) @@ -1138,14 +1138,13 @@ object SparkSession extends Logging { } /** - * Initialize extensions for given extension classname. This class will be applied to the + * Initialize extensions for given extension classnames. The classes will be applied to the * extensions passed into this function. */ private def applyExtensions( - extensionOption: Option[String], + extensionConfClassNames: Seq[String], extensions: SparkSessionExtensions): SparkSessionExtensions = { - if (extensionOption.isDefined) { - val extensionConfClassName = extensionOption.get + extensionConfClassNames.foreach { extensionConfClassName => try { val extensionConfClass = Utils.classForName(extensionConfClassName) val extensionConf = extensionConfClass.getConstructor().newInstance() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala index 5ed7678..66becf3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -44,12 +44,12 @@ import org.apache.spark.sql.catalyst.rules.Rule * <li>(External) Catalog listeners.</li> * </ul> * - * The extensions can be used by calling withExtension on the [[SparkSession.Builder]], for + * The extensions can be used by calling `withExtensions` on the [[SparkSession.Builder]], for * example: * {{{ * SparkSession.builder() * .master("...") - * .conf("...", true) + * .config("...", true) * .withExtensions { extensions => * extensions.injectResolutionRule { session => * ... @@ -61,6 +61,26 @@ import org.apache.spark.sql.catalyst.rules.Rule * .getOrCreate() * }}} * + * The extensions can also be used by setting the Spark SQL configuration property + * `spark.sql.extensions`. Multiple extensions can be set using a comma-separated list. For example: + * {{{ + * SparkSession.builder() + * .master("...") + * .config("spark.sql.extensions", "org.example.MyExtensions") + * .getOrCreate() + * + * class MyExtensions extends Function1[SparkSessionExtensions, Unit] { + * override def apply(extensions: SparkSessionExtensions): Unit = { + * extensions.injectResolutionRule { session => + * ... + * } + * extensions.injectParser { (session, parser) => + * ... + * } + * } + * } + * }}} + * * Note that none of the injected builders should assume that the [[SparkSession]] is fully * initialized and should not touch the session's internals (e.g. the SessionState). */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 234711e..9f33feb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{DataType, IntegerType, StructType} */ class SparkSessionExtensionSuite extends SparkFunSuite { type ExtensionsBuilder = SparkSessionExtensions => Unit - private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder + private def create(builder: ExtensionsBuilder): Seq[ExtensionsBuilder] = Seq(builder) private def stop(spark: SparkSession): Unit = { spark.stop() @@ -38,55 +38,71 @@ class SparkSessionExtensionSuite extends SparkFunSuite { SparkSession.clearDefaultSession() } - private def withSession(builder: ExtensionsBuilder)(f: SparkSession => Unit): Unit = { - val spark = SparkSession.builder().master("local[1]").withExtensions(builder).getOrCreate() + private def withSession(builders: Seq[ExtensionsBuilder])(f: SparkSession => Unit): Unit = { + val builder = SparkSession.builder().master("local[1]") + builders.foreach(builder.withExtensions) + val spark = builder.getOrCreate() try f(spark) finally { stop(spark) } } test("inject analyzer rule") { - withSession(_.injectResolutionRule(MyRule)) { session => + withSession(Seq(_.injectResolutionRule(MyRule))) { session => assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) } } + test("inject post hoc resolution analyzer rule") { + withSession(Seq(_.injectPostHocResolutionRule(MyRule))) { session => + assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session))) + } + } + test("inject check analysis rule") { - withSession(_.injectCheckRule(MyCheckRule)) { session => + withSession(Seq(_.injectCheckRule(MyCheckRule))) { session => assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) } } test("inject optimizer rule") { - withSession(_.injectOptimizerRule(MyRule)) { session => + withSession(Seq(_.injectOptimizerRule(MyRule))) { session => assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) } } test("inject spark planner strategy") { - withSession(_.injectPlannerStrategy(MySparkStrategy)) { session => + withSession(Seq(_.injectPlannerStrategy(MySparkStrategy))) { session => assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) } } test("inject parser") { val extension = create { extensions => - extensions.injectParser((_, _) => CatalystSqlParser) + extensions.injectParser((_: SparkSession, _: ParserInterface) => CatalystSqlParser) } withSession(extension) { session => - assert(session.sessionState.sqlParser == CatalystSqlParser) + assert(session.sessionState.sqlParser === CatalystSqlParser) + } + } + + test("inject multiple rules") { + withSession(Seq(_.injectOptimizerRule(MyRule), + _.injectPlannerStrategy(MySparkStrategy))) { session => + assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) + assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) } } test("inject stacked parsers") { val extension = create { extensions => - extensions.injectParser((_, _) => CatalystSqlParser) + extensions.injectParser((_: SparkSession, _: ParserInterface) => CatalystSqlParser) extensions.injectParser(MyParser) extensions.injectParser(MyParser) } withSession(extension) { session => val parser = MyParser(session, MyParser(session, CatalystSqlParser)) - assert(session.sessionState.sqlParser == parser) + assert(session.sessionState.sqlParser === parser) } } @@ -108,12 +124,89 @@ class SparkSessionExtensionSuite extends SparkFunSuite { try { assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + assert(session.sessionState.analyzer.postHocResolutionRules.contains(MyRule(session))) + assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) + assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) + assert(session.sessionState.sqlParser.isInstanceOf[MyParser]) + assert(session.sessionState.functionRegistry + .lookupFunction(MyExtensions.myFunction._1).isDefined) + } finally { + stop(session) + } + } + + test("use multiple custom class for extensions in the specified order") { + val session = SparkSession.builder() + .master("local[1]") + .config("spark.sql.extensions", Seq( + classOf[MyExtensions2].getCanonicalName, + classOf[MyExtensions].getCanonicalName).mkString(",")) + .getOrCreate() + try { + assert(session.sessionState.planner.strategies.containsSlice( + Seq(MySparkStrategy2(session), MySparkStrategy(session)))) + val orderedRules = Seq(MyRule2(session), MyRule(session)) + val orderedCheckRules = Seq(MyCheckRule2(session), MyCheckRule(session)) + val parser = MyParser(session, CatalystSqlParser) + assert(session.sessionState.analyzer.extendedResolutionRules.containsSlice(orderedRules)) + assert(session.sessionState.analyzer.postHocResolutionRules.containsSlice(orderedRules)) + assert(session.sessionState.analyzer.extendedCheckRules.containsSlice(orderedCheckRules)) + assert(session.sessionState.optimizer.batches.flatMap(_.rules).filter(orderedRules.contains) + .containsSlice(orderedRules ++ orderedRules)) // The optimizer rules are duplicated + assert(session.sessionState.sqlParser === parser) + assert(session.sessionState.functionRegistry + .lookupFunction(MyExtensions.myFunction._1).isDefined) + assert(session.sessionState.functionRegistry + .lookupFunction(MyExtensions2.myFunction._1).isDefined) + } finally { + stop(session) + } + } + + test("allow an extension to be duplicated") { + val session = SparkSession.builder() + .master("local[1]") + .config("spark.sql.extensions", Seq( + classOf[MyExtensions].getCanonicalName, + classOf[MyExtensions].getCanonicalName).mkString(",")) + .getOrCreate() + try { + assert(session.sessionState.planner.strategies.count(_ === MySparkStrategy(session)) === 2) + assert(session.sessionState.analyzer.extendedResolutionRules.count(_ === MyRule(session)) === + 2) + assert(session.sessionState.analyzer.postHocResolutionRules.count(_ === MyRule(session)) === + 2) + assert(session.sessionState.analyzer.extendedCheckRules.count(_ === MyCheckRule(session)) === + 2) + assert(session.sessionState.optimizer.batches.flatMap(_.rules) + .count(_ === MyRule(session)) === 4) // The optimizer rules are duplicated + val outerParser = session.sessionState.sqlParser + assert(outerParser.isInstanceOf[MyParser]) + assert(outerParser.asInstanceOf[MyParser].delegate.isInstanceOf[MyParser]) assert(session.sessionState.functionRegistry .lookupFunction(MyExtensions.myFunction._1).isDefined) } finally { stop(session) } } + + test("use the last registered function name when there are duplicates") { + val session = SparkSession.builder() + .master("local[1]") + .config("spark.sql.extensions", Seq( + classOf[MyExtensions2].getCanonicalName, + classOf[MyExtensions2Duplicate].getCanonicalName).mkString(",")) + .getOrCreate() + try { + val lastRegistered = session.sessionState.functionRegistry + .lookupFunction(FunctionIdentifier("myFunction2")) + assert(lastRegistered.isDefined) + assert(lastRegistered.get !== MyExtensions2.myFunction._2) + assert(lastRegistered.get === MyExtensions2Duplicate.myFunction._2) + } finally { + stop(session) + } + } } case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { @@ -151,14 +244,62 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars object MyExtensions { val myFunction = (FunctionIdentifier("myFunction"), - new ExpressionInfo("noClass", "myDb", "myFunction", "usage", "extended usage" ), - (myArgs: Seq[Expression]) => Literal(5, IntegerType)) + new ExpressionInfo("noClass", "myDb", "myFunction", "usage", "extended usage"), + (_: Seq[Expression]) => Literal(5, IntegerType)) } class MyExtensions extends (SparkSessionExtensions => Unit) { def apply(e: SparkSessionExtensions): Unit = { e.injectPlannerStrategy(MySparkStrategy) e.injectResolutionRule(MyRule) + e.injectPostHocResolutionRule(MyRule) + e.injectCheckRule(MyCheckRule) + e.injectOptimizerRule(MyRule) + e.injectParser(MyParser) e.injectFunction(MyExtensions.myFunction) } } + +case class MyRule2(spark: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan +} + +case class MyCheckRule2(spark: SparkSession) extends (LogicalPlan => Unit) { + override def apply(plan: LogicalPlan): Unit = { } +} + +case class MySparkStrategy2(spark: SparkSession) extends SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty +} + +object MyExtensions2 { + + val myFunction = (FunctionIdentifier("myFunction2"), + new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "extended usage"), + (_: Seq[Expression]) => Literal(5, IntegerType)) +} + +class MyExtensions2 extends (SparkSessionExtensions => Unit) { + def apply(e: SparkSessionExtensions): Unit = { + e.injectPlannerStrategy(MySparkStrategy2) + e.injectResolutionRule(MyRule2) + e.injectPostHocResolutionRule(MyRule2) + e.injectCheckRule(MyCheckRule2) + e.injectOptimizerRule(MyRule2) + e.injectParser((_: SparkSession, _: ParserInterface) => CatalystSqlParser) + e.injectFunction(MyExtensions2.myFunction) + } +} + +object MyExtensions2Duplicate { + + val myFunction = (FunctionIdentifier("myFunction2"), + new ExpressionInfo("noClass", "myDb", "myFunction2", "usage", "extended usage"), + (_: Seq[Expression]) => Literal(5, IntegerType)) +} + +class MyExtensions2Duplicate extends (SparkSessionExtensions => Unit) { + def apply(e: SparkSessionExtensions): Unit = { + e.injectFunction(MyExtensions2Duplicate.myFunction) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org