This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 492d1b14c0d1 [SPARK-48782][SQL] Add support for executing procedures
in catalogs
492d1b14c0d1 is described below
commit 492d1b14c0d19fa89b9ce9c0e48fc0e4c120b70c
Author: Anton Okolnychyi <[email protected]>
AuthorDate: Thu Sep 19 11:09:40 2024 +0200
[SPARK-48782][SQL] Add support for executing procedures in catalogs
### What changes were proposed in this pull request?
This PR adds support for executing procedures in catalogs.
### Why are the changes needed?
These changes are needed per [discussed and
voted](https://lists.apache.org/thread/w586jr53fxwk4pt9m94b413xyjr1v25m) SPIP
tracked in [SPARK-44167](https://issues.apache.org/jira/browse/SPARK-44167).
### Does this PR introduce _any_ user-facing change?
Yes. This PR adds CALL commands.
### How was this patch tested?
This PR comes with tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #47943 from aokolnychyi/spark-48782.
Authored-by: Anton Okolnychyi <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 6 +
docs/sql-ref-ansi-compliance.md | 1 +
.../spark/sql/catalyst/parser/SqlBaseLexer.g4 | 1 +
.../spark/sql/catalyst/parser/SqlBaseParser.g4 | 5 +
.../catalog/procedures/ProcedureParameter.java | 5 +
.../catalog/procedures/UnboundProcedure.java | 6 +
.../spark/sql/catalyst/analysis/Analyzer.scala | 65 +-
.../sql/catalyst/analysis/AnsiTypeCoercion.scala | 1 +
.../sql/catalyst/analysis/CheckAnalysis.scala | 8 +
.../spark/sql/catalyst/analysis/TypeCoercion.scala | 16 +
.../spark/sql/catalyst/analysis/package.scala | 6 +-
.../sql/catalyst/analysis/v2ResolutionPlans.scala | 17 +-
.../spark/sql/catalyst/parser/AstBuilder.scala | 22 +
.../plans/logical/ExecutableDuringAnalysis.scala | 28 +
.../plans/logical/FunctionBuilderBase.scala | 36 +-
.../sql/catalyst/plans/logical/MultiResult.scala | 30 +
.../sql/catalyst/plans/logical/v2Commands.scala | 67 ++-
.../sql/catalyst/rules/RuleIdCollection.scala | 1 +
.../spark/sql/catalyst/trees/TreePatterns.scala | 1 +
.../sql/connector/catalog/CatalogV2Implicits.scala | 7 +
.../spark/sql/errors/QueryCompilationErrors.scala | 7 +
.../sql/connector/catalog/InMemoryCatalog.scala | 19 +-
.../sql/catalyst/analysis/InvokeProcedures.scala | 71 +++
.../spark/sql/execution/MultiResultExec.scala | 36 ++
.../spark/sql/execution/SparkStrategies.scala | 2 +
.../spark/sql/execution/command/commands.scala | 11 +-
.../datasources/v2/DataSourceV2Strategy.scala | 6 +-
.../datasources/v2/ExplainOnlySparkPlan.scala | 38 ++
.../sql/internal/BaseSessionStateBuilder.scala | 3 +-
.../sql-tests/results/ansi/keywords.sql.out | 2 +
.../resources/sql-tests/results/keywords.sql.out | 1 +
.../spark/sql/connector/ProcedureSuite.scala | 654 +++++++++++++++++++++
.../ThriftServerWithSparkContextSuite.scala | 2 +-
.../spark/sql/hive/HiveSessionStateBuilder.scala | 3 +-
34 files changed, 1162 insertions(+), 22 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 6463cc2c12da..72985de6631f 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -1456,6 +1456,12 @@
],
"sqlState" : "2203G"
},
+ "FAILED_TO_LOAD_ROUTINE" : {
+ "message" : [
+ "Failed to load routine <routineName>."
+ ],
+ "sqlState" : "38000"
+ },
"FAILED_TO_PARSE_TOO_COMPLEX" : {
"message" : [
"The statement, including potential SQL functions and referenced views,
was too complex to parse.",
diff --git a/docs/sql-ref-ansi-compliance.md b/docs/sql-ref-ansi-compliance.md
index fff6906457f7..12dff1e325c4 100644
--- a/docs/sql-ref-ansi-compliance.md
+++ b/docs/sql-ref-ansi-compliance.md
@@ -426,6 +426,7 @@ Below is a list of all the keywords in Spark SQL.
|BY|non-reserved|non-reserved|reserved|
|BYTE|non-reserved|non-reserved|non-reserved|
|CACHE|non-reserved|non-reserved|non-reserved|
+|CALL|reserved|non-reserved|reserved|
|CALLED|non-reserved|non-reserved|non-reserved|
|CASCADE|non-reserved|non-reserved|non-reserved|
|CASE|reserved|non-reserved|reserved|
diff --git
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
index e704f9f58b96..de28041acd41 100644
---
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
+++
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseLexer.g4
@@ -146,6 +146,7 @@ BUCKETS: 'BUCKETS';
BY: 'BY';
BYTE: 'BYTE';
CACHE: 'CACHE';
+CALL: 'CALL';
CALLED: 'CALLED';
CASCADE: 'CASCADE';
CASE: 'CASE';
diff --git
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
index f13dde773496..e591a43b84d1 100644
---
a/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
+++
b/sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4
@@ -298,6 +298,10 @@ statement
LEFT_PAREN columns=multipartIdentifierPropertyList RIGHT_PAREN
(OPTIONS options=propertyList)?
#createIndex
| DROP INDEX (IF EXISTS)? identifier ON TABLE? identifierReference
#dropIndex
+ | CALL identifierReference
+ LEFT_PAREN
+ (functionArgument (COMMA functionArgument)*)?
+ RIGHT_PAREN #call
| unsupportedHiveNativeCommands .*?
#failNativeCommand
;
@@ -1851,6 +1855,7 @@ nonReserved
| BY
| BYTE
| CACHE
+ | CALL
| CALLED
| CASCADE
| CASE
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java
index 90d531ae2189..18c76833c587 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/ProcedureParameter.java
@@ -32,6 +32,11 @@ import static
org.apache.spark.sql.connector.catalog.procedures.ProcedureParamet
*/
@Evolving
public interface ProcedureParameter {
+ /**
+ * A field metadata key that indicates whether an argument is passed by name.
+ */
+ String BY_NAME_METADATA_KEY = "BY_NAME";
+
/**
* Creates a builder for an IN procedure parameter.
*
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java
index ee9a09055243..1a91fd21bf07 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/procedures/UnboundProcedure.java
@@ -35,6 +35,12 @@ public interface UnboundProcedure extends Procedure {
* validate if the input types are compatible while binding or delegate that
to Spark. Regardless,
* Spark will always perform the final validation of the arguments and
rearrange them as needed
* based on {@link BoundProcedure#parameters() reported parameters}.
+ * <p>
+ * The provided {@code inputType} is based on the procedure arguments. If an
argument is passed
+ * by name, its metadata will indicate this with {@link
ProcedureParameter#BY_NAME_METADATA_KEY}
+ * set to {@code true}. In such cases, the field name will match the name of
the target procedure
+ * parameter. If the argument is not named, {@link
ProcedureParameter#BY_NAME_METADATA_KEY} will
+ * not be set and the name will be assigned randomly.
*
* @param inputType the input types to bind to
* @return the bound procedure that is most suitable for the given input
types
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 0164af945ca2..9e5b1d1254c8 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Random, Success, Try}
-import org.apache.spark.{SparkException, SparkUnsupportedOperationException}
+import org.apache.spark.{SparkException, SparkThrowable,
SparkUnsupportedOperationException}
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
@@ -50,6 +50,7 @@ import org.apache.spark.sql.connector.catalog.{View => _, _}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.{After,
ColumnPosition}
import org.apache.spark.sql.connector.catalog.functions.{AggregateFunction =>
V2AggregateFunction, ScalarFunction, UnboundFunction}
+import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure,
ProcedureParameter, UnboundProcedure}
import org.apache.spark.sql.connector.expressions.{FieldReference,
IdentityTransform}
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
@@ -310,6 +311,8 @@ class Analyzer(override val catalogManager: CatalogManager)
extends RuleExecutor
ExtractGenerator ::
ResolveGenerate ::
ResolveFunctions ::
+ ResolveProcedures ::
+ BindProcedures ::
ResolveTableSpec ::
ResolveAliases ::
ResolveSubquery ::
@@ -2611,6 +2614,66 @@ class Analyzer(override val catalogManager:
CatalogManager) extends RuleExecutor
}
}
+ /**
+ * A rule that resolves procedures.
+ */
+ object ResolveProcedures extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsWithPruning(
+ _.containsPattern(UNRESOLVED_PROCEDURE), ruleId) {
+ case Call(UnresolvedProcedure(CatalogAndIdentifier(catalog, ident)),
args, execute) =>
+ val procedureCatalog = catalog.asProcedureCatalog
+ val procedure = load(procedureCatalog, ident)
+ Call(ResolvedProcedure(procedureCatalog, ident, procedure), args,
execute)
+ }
+
+ private def load(catalog: ProcedureCatalog, ident: Identifier):
UnboundProcedure = {
+ try {
+ catalog.loadProcedure(ident)
+ } catch {
+ case e: Exception if !e.isInstanceOf[SparkThrowable] =>
+ val nameParts = catalog.name +: ident.asMultipartIdentifier
+ throw QueryCompilationErrors.failedToLoadRoutineError(nameParts, e)
+ }
+ }
+ }
+
+ /**
+ * A rule that binds procedures to the input types and rearranges arguments
as needed.
+ */
+ object BindProcedures extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case Call(ResolvedProcedure(catalog, ident, unbound: UnboundProcedure),
args, execute)
+ if args.forall(_.resolved) =>
+ val inputType = extractInputType(args)
+ val bound = unbound.bind(inputType)
+ validateParameterModes(bound)
+ val rearrangedArgs = NamedParametersSupport.defaultRearrange(bound,
args)
+ Call(ResolvedProcedure(catalog, ident, bound), rearrangedArgs, execute)
+ }
+
+ private def extractInputType(args: Seq[Expression]): StructType = {
+ val fields = args.zipWithIndex.map {
+ case (NamedArgumentExpression(name, value), _) =>
+ StructField(name, value.dataType, value.nullable, byNameMetadata)
+ case (arg, index) =>
+ StructField(s"param$index", arg.dataType, arg.nullable)
+ }
+ StructType(fields)
+ }
+
+ private def byNameMetadata: Metadata = {
+ new MetadataBuilder()
+ .putBoolean(ProcedureParameter.BY_NAME_METADATA_KEY, value = true)
+ .build()
+ }
+
+ private def validateParameterModes(procedure: BoundProcedure): Unit = {
+ procedure.parameters.find(_.mode != ProcedureParameter.Mode.IN).foreach {
param =>
+ throw SparkException.internalError(s"Unsupported parameter mode:
${param.mode}")
+ }
+ }
+ }
+
/**
* This rule resolves and rewrites subqueries inside expressions.
*
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
index 17b1c4e249f5..3afe0ec8e9a7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AnsiTypeCoercion.scala
@@ -77,6 +77,7 @@ object AnsiTypeCoercion extends TypeCoercionBase {
override def typeCoercionRules: List[Rule[LogicalPlan]] =
UnpivotCoercion ::
WidenSetOperationTypes ::
+ ProcedureArgumentCoercion ::
new AnsiCombinedTypeCoercionRule(
CollationTypeCasts ::
InConversion ::
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 752ff49e1f90..5a9d5cd87ecc 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -676,6 +676,14 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
varName,
c.defaultExpr.originalSQL)
+ case c: Call if c.resolved && c.bound && c.checkArgTypes().isFailure
=>
+ c.checkArgTypes() match {
+ case mismatch: TypeCheckResult.DataTypeMismatch =>
+ c.dataTypeMismatch("CALL", mismatch)
+ case _ =>
+ throw SparkException.internalError("Invalid input for
procedure")
+ }
+
case _ => // Falls back to the following checks
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index 08c5b3531b4c..5983346ff1e2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.AlwaysProcess
import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.types.{AbstractArrayType,
AbstractMapType, AbstractStringType, StringTypeAnyCollation}
@@ -202,6 +203,20 @@ abstract class TypeCoercionBase {
}
}
+ /**
+ * A type coercion rule that implicitly casts procedure arguments to
expected types.
+ */
+ object ProcedureArgumentCoercion extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators
{
+ case c @ Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args,
_) if c.resolved =>
+ val expectedDataTypes = procedure.parameters.map(_.dataType)
+ val coercedArgs = args.zip(expectedDataTypes).map {
+ case (arg, expectedType) => implicitCast(arg,
expectedType).getOrElse(arg)
+ }
+ c.copy(args = coercedArgs)
+ }
+ }
+
/**
* Widens the data types of the [[Unpivot]] values.
*/
@@ -838,6 +853,7 @@ object TypeCoercion extends TypeCoercionBase {
override def typeCoercionRules: List[Rule[LogicalPlan]] =
UnpivotCoercion ::
WidenSetOperationTypes ::
+ ProcedureArgumentCoercion ::
new CombinedTypeCoercionRule(
CollationTypeCasts ::
InConversion ::
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
index c0689eb12167..daab9e4d78bf 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
@@ -67,9 +67,13 @@ package object analysis {
}
def dataTypeMismatch(expr: Expression, mismatch: DataTypeMismatch):
Nothing = {
+ dataTypeMismatch(toSQLExpr(expr), mismatch)
+ }
+
+ def dataTypeMismatch(sqlExpr: String, mismatch: DataTypeMismatch): Nothing
= {
throw new AnalysisException(
errorClass = s"DATATYPE_MISMATCH.${mismatch.errorSubClass}",
- messageParameters = mismatch.messageParameters + ("sqlExpr" ->
toSQLExpr(expr)),
+ messageParameters = mismatch.messageParameters + ("sqlExpr" ->
sqlExpr),
origin = t.origin)
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala
index ecdf40e87a89..dee78b8f03af 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala
@@ -23,13 +23,14 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.{Attribute, LeafExpression,
Unevaluable}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics}
-import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern,
UNRESOLVED_FUNC}
+import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern,
UNRESOLVED_FUNC, UNRESOLVED_PROCEDURE}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
-import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog,
Identifier, Table, TableCatalog}
+import org.apache.spark.sql.connector.catalog.{CatalogPlugin, FunctionCatalog,
Identifier, ProcedureCatalog, Table, TableCatalog}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
+import org.apache.spark.sql.connector.catalog.procedures.Procedure
import org.apache.spark.sql.types.{DataType, StructField}
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
@@ -135,6 +136,12 @@ case class UnresolvedFunctionName(
case class UnresolvedIdentifier(nameParts: Seq[String], allowTemp: Boolean =
false)
extends UnresolvedLeafNode
+/**
+ * A procedure identifier that should be resolved into [[ResolvedProcedure]].
+ */
+case class UnresolvedProcedure(nameParts: Seq[String]) extends
UnresolvedLeafNode {
+ final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_PROCEDURE)
+}
/**
* A resolved leaf node whose statistics has no meaning.
@@ -192,6 +199,12 @@ case class ResolvedFieldName(path: Seq[String], field:
StructField) extends Fiel
case class ResolvedFieldPosition(position: ColumnPosition) extends
FieldPosition
+case class ResolvedProcedure(
+ catalog: ProcedureCatalog,
+ ident: Identifier,
+ procedure: Procedure) extends LeafNodeWithoutStats {
+ override def output: Seq[Attribute] = Nil
+}
/**
* A plan containing resolved persistent views.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index cb0e0e35c370..52529bb4b789 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -5697,6 +5697,28 @@ class AstBuilder extends DataTypeAstBuilder
ctx.EXISTS != null)
}
+ /**
+ * Creates a plan for invoking a procedure.
+ *
+ * For example:
+ * {{{
+ * CALL multi_part_name(v1, v2, ...);
+ * CALL multi_part_name(v1, param2 => v2, ...);
+ * CALL multi_part_name(param1 => v1, param2 => v2, ...);
+ * }}}
+ */
+ override def visitCall(ctx: CallContext): LogicalPlan = withOrigin(ctx) {
+ val procedure = withIdentClause(ctx.identifierReference,
UnresolvedProcedure)
+ val args = ctx.functionArgument.asScala.map {
+ case expr if expr.namedArgumentExpression != null =>
+ val namedExpr = expr.namedArgumentExpression
+ NamedArgumentExpression(namedExpr.key.getText,
expression(namedExpr.value))
+ case expr =>
+ expression(expr)
+ }.toSeq
+ Call(procedure, args)
+ }
+
/**
* Create a TimestampAdd expression.
*/
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala
new file mode 100644
index 000000000000..dc8dbf701f6a
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ExecutableDuringAnalysis.scala
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+/**
+ * A logical plan node that requires execution during analysis.
+ */
+trait ExecutableDuringAnalysis extends LogicalPlan {
+ /**
+ * Returns the logical plan node that should be used for EXPLAIN.
+ */
+ def stageForExplain(): LogicalPlan
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
index 4701f4ea1e17..75b2fcd3a5f3 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/FunctionBuilderBase.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions.{Expression,
NamedArgumentExpression}
+import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns
+import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure,
ProcedureParameter}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.util.ArrayImplicits._
@@ -122,12 +124,32 @@ object NamedParametersSupport {
functionSignature: FunctionSignature,
args: Seq[Expression],
functionName: String): Seq[Expression] = {
- val parameters: Seq[InputParameter] = functionSignature.parameters
+ defaultRearrange(functionName, functionSignature.parameters, args)
+ }
+
+ final def defaultRearrange(procedure: BoundProcedure, args:
Seq[Expression]): Seq[Expression] = {
+ defaultRearrange(
+ procedure.name,
+ procedure.parameters.map(toInputParameter).toSeq,
+ args)
+ }
+
+ private def toInputParameter(param: ProcedureParameter): InputParameter = {
+ val defaultValue = Option(param.defaultValueExpression).map { expr =>
+ ResolveDefaultColumns.analyze(param.name, param.dataType, expr, "CALL")
+ }
+ InputParameter(param.name, defaultValue)
+ }
+
+ private def defaultRearrange(
+ routineName: String,
+ parameters: Seq[InputParameter],
+ args: Seq[Expression]): Seq[Expression] = {
if (parameters.dropWhile(_.default.isEmpty).exists(_.default.isEmpty)) {
- throw QueryCompilationErrors.unexpectedRequiredParameter(functionName,
parameters)
+ throw QueryCompilationErrors.unexpectedRequiredParameter(routineName,
parameters)
}
- val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args,
functionName)
+ val (positionalArgs, namedArgs) = splitAndCheckNamedArguments(args,
routineName)
val namedParameters: Seq[InputParameter] =
parameters.drop(positionalArgs.size)
// The following loop checks for the following:
@@ -140,12 +162,12 @@ object NamedParametersSupport {
namedArgs.foreach { namedArg =>
val parameterName = namedArg.key
if (!parameterNamesSet.contains(parameterName)) {
- throw QueryCompilationErrors.unrecognizedParameterName(functionName,
namedArg.key,
+ throw QueryCompilationErrors.unrecognizedParameterName(routineName,
namedArg.key,
parameterNamesSet.toSeq)
}
if (positionalParametersSet.contains(parameterName)) {
throw QueryCompilationErrors.positionalAndNamedArgumentDoubleReference(
- functionName, namedArg.key)
+ routineName, namedArg.key)
}
}
@@ -154,7 +176,7 @@ object NamedParametersSupport {
val validParameterSizes =
Array.range(parameters.count(_.default.isEmpty), parameters.size +
1).toImmutableArraySeq
throw QueryCompilationErrors.wrongNumArgsError(
- functionName, validParameterSizes, args.length)
+ routineName, validParameterSizes, args.length)
}
// This constructs a map from argument name to value for argument
rearrangement.
@@ -168,7 +190,7 @@ object NamedParametersSupport {
namedArgMap.getOrElse(
param.name,
if (param.default.isEmpty) {
- throw
QueryCompilationErrors.requiredParameterNotFound(functionName, param.name,
index)
+ throw
QueryCompilationErrors.requiredParameterNotFound(routineName, param.name, index)
} else {
param.default.get
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala
new file mode 100644
index 000000000000..f249e5c87eba
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/MultiResult.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.catalyst.expressions.Attribute
+
+case class MultiResult(children: Seq[LogicalPlan]) extends LogicalPlan {
+
+ override def output: Seq[Attribute] =
children.lastOption.map(_.output).getOrElse(Nil)
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[LogicalPlan]): MultiResult = {
+ copy(children = newChildren)
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
index fdd43404e1d9..b465e0e11612 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala
@@ -19,17 +19,22 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.{SparkIllegalArgumentException,
SparkUnsupportedOperationException}
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.analysis.{AnalysisContext,
AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation,
PartitionSpec, ResolvedIdentifier, UnresolvedException, ViewSchemaMode}
+import org.apache.spark.sql.catalyst.analysis.{AnalysisContext,
AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation,
PartitionSpec, ResolvedIdentifier, ResolvedProcedure, TypeCheckResult,
UnresolvedException, UnresolvedProcedure, ViewSchemaMode}
+import
org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch,
TypeCheckSuccess}
import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.catalog.FunctionResource
import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, AttributeSet, Expression, MetadataAttribute,
NamedExpression, UnaryExpression, Unevaluable, V2ExpressionUtils}
import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
import org.apache.spark.sql.catalyst.trees.BinaryLike
-import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, RowDeltaUtils,
WriteDeltaProjections}
+import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, truncatedString,
CharVarcharUtils, RowDeltaUtils, WriteDeltaProjections}
+import org.apache.spark.sql.catalyst.util.TypeUtils.{ordinalNumber, toSQLExpr}
import org.apache.spark.sql.connector.catalog._
+import
org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper,
MultipartIdentifierHelper}
+import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.write.{DeltaWrite, RowLevelOperation,
RowLevelOperationTable, SupportsDelta, Write}
+import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation
import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType,
MapType, MetadataBuilder, StringType, StructField, StructType}
import org.apache.spark.util.ArrayImplicits._
@@ -1571,3 +1576,61 @@ case class SetVariable(
override protected def withNewChildInternal(newChild: LogicalPlan):
SetVariable =
copy(sourceQuery = newChild)
}
+
+/**
+ * The logical plan of the CALL statement.
+ */
+case class Call(
+ procedure: LogicalPlan,
+ args: Seq[Expression],
+ execute: Boolean = true)
+ extends UnaryNode with ExecutableDuringAnalysis {
+
+ override def output: Seq[Attribute] = Nil
+
+ override def child: LogicalPlan = procedure
+
+ def bound: Boolean = procedure match {
+ case ResolvedProcedure(_, _, _: BoundProcedure) => true
+ case _ => false
+ }
+
+ def checkArgTypes(): TypeCheckResult = {
+ require(resolved && bound, "can check arg types only after resolution and
binding")
+
+ val params = procedure match {
+ case ResolvedProcedure(_, _, bound: BoundProcedure) => bound.parameters
+ }
+ require(args.length == params.length, "number of args and params must
match after binding")
+
+ args.zip(params).zipWithIndex.collectFirst {
+ case ((arg, param), idx)
+ if !DataType.equalsIgnoreCompatibleNullability(arg.dataType,
param.dataType) =>
+ DataTypeMismatch(
+ errorSubClass = "UNEXPECTED_INPUT_TYPE",
+ messageParameters = Map(
+ "paramIndex" -> ordinalNumber(idx),
+ "requiredType" -> toSQLType(param.dataType),
+ "inputSql" -> toSQLExpr(arg),
+ "inputType" -> toSQLType(arg.dataType)))
+ }.getOrElse(TypeCheckSuccess)
+ }
+
+ override def simpleString(maxFields: Int): String = {
+ val name = procedure match {
+ case ResolvedProcedure(catalog, ident, _) =>
+ s"${quoteIfNeeded(catalog.name)}.${ident.quoted}"
+ case UnresolvedProcedure(nameParts) =>
+ nameParts.quoted
+ }
+ val argsString = truncatedString(args, ", ", maxFields)
+ s"Call $name($argsString)"
+ }
+
+ override def stageForExplain(): Call = {
+ copy(execute = false)
+ }
+
+ override protected def withNewChildInternal(newChild: LogicalPlan): Call =
+ copy(procedure = newChild)
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index c70b43f0db17..b5556cbae7cd 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -54,6 +54,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveDeserializer" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions" ::
+ "org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveProcedures" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGenerate" ::
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveGroupingAnalytics" ::
"org.apache.spark.sql.catalyst.analysis.ResolveHigherOrderFunctions" ::
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 826ac52c2b81..0f1c98b53e0b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -157,6 +157,7 @@ object TreePattern extends Enumeration {
// Unresolved Plan patterns (Alphabetically ordered)
val UNRESOLVED_FUNC: Value = Value
+ val UNRESOLVED_PROCEDURE: Value = Value
val UNRESOLVED_SUBQUERY_COLUMN_ALIAS: Value = Value
val UNRESOLVED_TABLE_VALUED_FUNCTION: Value = Value
val UNRESOLVED_TRANSPOSE: Value = Value
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
index 65bdae85be12..282350dda67d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Implicits.scala
@@ -126,6 +126,13 @@ private[sql] object CatalogV2Implicits {
case _ =>
throw QueryCompilationErrors.missingCatalogAbilityError(plugin,
"functions")
}
+
+ def asProcedureCatalog: ProcedureCatalog = plugin match {
+ case procedureCatalog: ProcedureCatalog =>
+ procedureCatalog
+ case _ =>
+ throw QueryCompilationErrors.missingCatalogAbilityError(plugin,
"procedures")
+ }
}
implicit class NamespaceHelper(namespace: Array[String]) {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index ad0e1d07bf93..0b5255e95f07 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -853,6 +853,13 @@ private[sql] object QueryCompilationErrors extends
QueryErrorsBase with Compilat
origin = origin)
}
+ def failedToLoadRoutineError(nameParts: Seq[String], e: Exception):
Throwable = {
+ new AnalysisException(
+ errorClass = "FAILED_TO_LOAD_ROUTINE",
+ messageParameters = Map("routineName" -> toSQLId(nameParts)),
+ cause = Some(e))
+ }
+
def unresolvedRoutineError(name: FunctionIdentifier, searchPath:
Seq[String]): Throwable = {
new AnalysisException(
errorClass = "UNRESOLVED_ROUTINE",
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala
index 8d8d2317f098..411a88b8765f 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryCatalog.scala
@@ -24,10 +24,13 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.sql.catalyst.analysis.{NoSuchFunctionException,
NoSuchNamespaceException}
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction
+import org.apache.spark.sql.connector.catalog.procedures.UnboundProcedure
-class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog {
+class InMemoryCatalog extends InMemoryTableCatalog with FunctionCatalog with
ProcedureCatalog {
protected val functions: util.Map[Identifier, UnboundFunction] =
new ConcurrentHashMap[Identifier, UnboundFunction]()
+ protected val procedures: util.Map[Identifier, UnboundProcedure] =
+ new ConcurrentHashMap[Identifier, UnboundProcedure]()
override protected def allNamespaces: Seq[Seq[String]] = {
(tables.keySet.asScala.map(_.namespace.toSeq) ++
@@ -63,4 +66,18 @@ class InMemoryCatalog extends InMemoryTableCatalog with
FunctionCatalog {
def clearFunctions(): Unit = {
functions.clear()
}
+
+ override def loadProcedure(ident: Identifier): UnboundProcedure = {
+ val procedure = procedures.get(ident)
+ if (procedure == null) throw new RuntimeException("Procedure not found: "
+ ident)
+ procedure
+ }
+
+ def createProcedure(ident: Identifier, procedure: UnboundProcedure):
UnboundProcedure = {
+ procedures.put(ident, procedure)
+ }
+
+ def clearProcedures(): Unit = {
+ procedures.clear()
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala
new file mode 100644
index 000000000000..c7320d350a7f
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/InvokeProcedures.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import scala.jdk.CollectionConverters.IteratorHasAsScala
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Expression,
GenericInternalRow}
+import org.apache.spark.sql.catalyst.plans.logical.{Call, LocalRelation,
LogicalPlan, MultiResult}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.types.DataTypeUtils
+import org.apache.spark.sql.connector.catalog.procedures.BoundProcedure
+import org.apache.spark.sql.connector.read.{LocalScan, Scan}
+import org.apache.spark.util.ArrayImplicits._
+
+class InvokeProcedures(session: SparkSession) extends Rule[LogicalPlan] {
+
+ override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
+ case c: Call if c.resolved && c.bound && c.execute &&
c.checkArgTypes().isSuccess =>
+ session.sessionState.optimizer.execute(c) match {
+ case Call(ResolvedProcedure(_, _, procedure: BoundProcedure), args, _)
=>
+ invoke(procedure, args)
+ case _ =>
+ throw SparkException.internalError("Unexpected plan for optimized
CALL statement")
+ }
+ }
+
+ private def invoke(procedure: BoundProcedure, args: Seq[Expression]):
LogicalPlan = {
+ val input = toInternalRow(args)
+ val scanIterator = procedure.call(input)
+ val relations = scanIterator.asScala.map(toRelation).toSeq
+ relations match {
+ case Nil => LocalRelation(Nil)
+ case Seq(relation) => relation
+ case _ => MultiResult(relations)
+ }
+ }
+
+ private def toRelation(scan: Scan): LogicalPlan = scan match {
+ case s: LocalScan =>
+ val attrs = DataTypeUtils.toAttributes(s.readSchema)
+ val data = s.rows.toImmutableArraySeq
+ LocalRelation(attrs, data)
+ case _ =>
+ throw SparkException.internalError(
+ s"Only local scans are temporarily supported as procedure output:
${scan.getClass.getName}")
+ }
+
+ private def toInternalRow(args: Seq[Expression]): InternalRow = {
+ require(args.forall(_.foldable), "args must be foldable")
+ val values = args.map(_.eval()).toArray
+ new GenericInternalRow(values)
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala
new file mode 100644
index 000000000000..c2b12b053c92
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/MultiResultExec.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+
+case class MultiResultExec(children: Seq[SparkPlan]) extends SparkPlan {
+
+ override def output: Seq[Attribute] =
children.lastOption.map(_.output).getOrElse(Nil)
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ children.lastOption.map(_.execute()).getOrElse(sparkContext.emptyRDD)
+ }
+
+ override protected def withNewChildrenInternal(
+ newChildren: IndexedSeq[SparkPlan]): MultiResultExec = {
+ copy(children = newChildren)
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 6d940a30619f..aee735e48fc5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -1041,6 +1041,8 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
case WriteFiles(child, fileFormat, partitionColumns, bucket, options,
staticPartitions) =>
WriteFilesExec(planLater(child), fileFormat, partitionColumns, bucket,
options,
staticPartitions) :: Nil
+ case MultiResult(children) =>
+ MultiResultExec(children.map(planLater)) :: Nil
case _ => Nil
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
index ea2736b2c126..ea9d53190546 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference}
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan,
SupervisingCommand}
+import org.apache.spark.sql.catalyst.plans.logical.{Command,
ExecutableDuringAnalysis, LogicalPlan, SupervisingCommand}
import org.apache.spark.sql.catalyst.trees.{LeafLike, UnaryLike}
import org.apache.spark.sql.connector.ExternalCommandRunner
import org.apache.spark.sql.execution.{CommandExecutionMode, ExplainMode,
LeafExecNode, SparkPlan, UnaryExecNode}
@@ -165,14 +165,19 @@ case class ExplainCommand(
// Run through the optimizer to generate the physical plan.
override def run(sparkSession: SparkSession): Seq[Row] = try {
- val outputString = sparkSession.sessionState.executePlan(logicalPlan,
CommandExecutionMode.SKIP)
- .explainString(mode)
+ val stagedLogicalPlan = stageForAnalysis(logicalPlan)
+ val qe = sparkSession.sessionState.executePlan(stagedLogicalPlan,
CommandExecutionMode.SKIP)
+ val outputString = qe.explainString(mode)
Seq(Row(outputString))
} catch { case NonFatal(cause) =>
("Error occurred during query planning: \n" + cause.getMessage).split("\n")
.map(Row(_)).toImmutableArraySeq
}
+ private def stageForAnalysis(plan: LogicalPlan): LogicalPlan = plan
transform {
+ case p: ExecutableDuringAnalysis => p.stageForExplain()
+ }
+
def withTransformedSupervisedPlan(transformer: LogicalPlan => LogicalPlan):
LogicalPlan =
copy(logicalPlan = transformer(logicalPlan))
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
index d7f46c32f99a..76cd33b815ed 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala
@@ -32,8 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{And,
Attribute, DynamicPruning
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn,
- IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder}
+import org.apache.spark.sql.catalyst.util.{toPrettySQL, GeneratedColumn,
IdentityColumn, ResolveDefaultColumns, V2ExpressionBuilder}
import org.apache.spark.sql.connector.catalog.{Identifier,
StagingTableCatalog, SupportsDeleteV2, SupportsNamespaces,
SupportsPartitionManagement, SupportsWrite, Table, TableCapability,
TableCatalog, TruncatableTable}
import org.apache.spark.sql.connector.catalog.index.SupportsIndex
import org.apache.spark.sql.connector.expressions.{FieldReference,
LiteralValue}
@@ -554,6 +553,9 @@ class DataSourceV2Strategy(session: SparkSession) extends
Strategy with Predicat
systemScope,
pattern) :: Nil
+ case c: Call =>
+ ExplainOnlySparkPlan(c) :: Nil
+
case _ => Nil
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala
new file mode 100644
index 000000000000..bbf56eaa7118
--- /dev/null
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ExplainOnlySparkPlan.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.LeafLike
+import org.apache.spark.sql.execution.SparkPlan
+
+case class ExplainOnlySparkPlan(toExplain: LogicalPlan) extends SparkPlan with
LeafLike[SparkPlan] {
+
+ override def output: Seq[Attribute] = Nil
+
+ override def simpleString(maxFields: Int): String = {
+ toExplain.simpleString(maxFields)
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException()
+ }
+}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index a2539828733f..0d0258f11efb 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.internal
import org.apache.spark.annotation.Unstable
import org.apache.spark.sql.{ExperimentalMethods, SparkSession,
UDFRegistration, _}
import org.apache.spark.sql.artifact.ArtifactManager
-import org.apache.spark.sql.catalyst.analysis.{Analyzer,
EvalSubqueriesForTimeTravel, FunctionRegistry, ReplaceCharWithVarchar,
ResolveSessionCatalog, ResolveTranspose, TableFunctionRegistry}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer,
EvalSubqueriesForTimeTravel, FunctionRegistry, InvokeProcedures,
ReplaceCharWithVarchar, ResolveSessionCatalog, ResolveTranspose,
TableFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder,
SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.optimizer.Optimizer
@@ -206,6 +206,7 @@ abstract class BaseSessionStateBuilder(
ResolveWriteToStream +:
new EvalSubqueriesForTimeTravel +:
new ResolveTranspose(session) +:
+ new InvokeProcedures(session) +:
customResolutionRules
override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
diff --git
a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
index 6497a46c68cc..7c694503056a 100644
--- a/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/ansi/keywords.sql.out
@@ -32,6 +32,7 @@ BUCKETS false
BY false
BYTE false
CACHE false
+CALL true
CALLED false
CASCADE false
CASE true
@@ -378,6 +379,7 @@ ANY
AS
AUTHORIZATION
BOTH
+CALL
CASE
CAST
CHECK
diff --git a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
index 0dfd62599afa..2c16d961b131 100644
--- a/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/keywords.sql.out
@@ -32,6 +32,7 @@ BUCKETS false
BY false
BYTE false
CACHE false
+CALL false
CALLED false
CASCADE false
CASE false
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala
new file mode 100644
index 000000000000..e39a1b7ea340
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/connector/ProcedureSuite.scala
@@ -0,0 +1,654 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.connector
+
+import java.util.Collections
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.{SPARK_DOC_ROOT, SparkException,
SparkNumberFormatException}
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
+import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog,
Identifier, InMemoryCatalog}
+import org.apache.spark.sql.connector.catalog.procedures.{BoundProcedure,
ProcedureParameter, UnboundProcedure}
+import
org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode
+import
org.apache.spark.sql.connector.catalog.procedures.ProcedureParameter.Mode.{IN,
INOUT, OUT}
+import org.apache.spark.sql.connector.read.{LocalScan, Scan}
+import org.apache.spark.sql.errors.DataTypeErrors.{toSQLType, toSQLValue}
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types.{DataType, DataTypes, StructField,
StructType}
+import org.apache.spark.unsafe.types.UTF8String
+
+class ProcedureSuite extends QueryTest with SharedSparkSession with
BeforeAndAfter {
+
+ before {
+ spark.conf.set(s"spark.sql.catalog.cat", classOf[InMemoryCatalog].getName)
+ }
+
+ after {
+ spark.sessionState.catalogManager.reset()
+ spark.sessionState.conf.unsetConf(s"spark.sql.catalog.cat")
+ }
+
+ private def catalog: InMemoryCatalog = {
+ val catalog = spark.sessionState.catalogManager.catalog("cat")
+ catalog.asInstanceOf[InMemoryCatalog]
+ }
+
+ test("position arguments") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ checkAnswer(sql("CALL cat.ns.sum(5, 5)"), Row(10) :: Nil)
+ }
+
+ test("named arguments") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ checkAnswer(sql("CALL cat.ns.sum(in2 => 3, in1 => 5)"), Row(8) :: Nil)
+ }
+
+ test("position and named arguments") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ checkAnswer(sql("CALL cat.ns.sum(3, in2 => 1)"), Row(4) :: Nil)
+ }
+
+ test("foldable expressions") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ checkAnswer(sql("CALL cat.ns.sum(1 + 1, in2 => 2)"), Row(4) :: Nil)
+ checkAnswer(sql("CALL cat.ns.sum(in2 => 1, in1 => 2 + 1)"), Row(4) :: Nil)
+ checkAnswer(sql("CALL cat.ns.sum((1 + 1) * 2, in2 => (2 + 1) / 3)"),
Row(5) :: Nil)
+ }
+
+ test("type coercion") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundLongSum)
+ checkAnswer(sql("CALL cat.ns.sum(1, 2)"), Row(3) :: Nil)
+ checkAnswer(sql("CALL cat.ns.sum(1L, 2)"), Row(3) :: Nil)
+ checkAnswer(sql("CALL cat.ns.sum(1, 2L)"), Row(3) :: Nil)
+ }
+
+ test("multiple output rows") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "complex"),
UnboundComplexProcedure)
+ checkAnswer(
+ sql("CALL cat.ns.complex('X', 'Y', 3)"),
+ Row(1, "X1", "Y1") :: Row(2, "X2", "Y2") :: Row(3, "X3", "Y3") :: Nil)
+ }
+
+ test("parameters with default values") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "complex"),
UnboundComplexProcedure)
+ checkAnswer(sql("CALL cat.ns.complex()"), Row(1, "A1", "B1") :: Nil)
+ checkAnswer(sql("CALL cat.ns.complex('X', 'Y')"), Row(1, "X1", "Y1") ::
Nil)
+ }
+
+ test("parameters with invalid default values") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"),
UnboundInvalidDefaultProcedure)
+ checkError(
+ exception = intercept[AnalysisException](
+ sql("CALL cat.ns.sum()")
+ ),
+ condition = "INVALID_DEFAULT_VALUE.DATA_TYPE",
+ parameters = Map(
+ "statement" -> "CALL",
+ "colName" -> toSQLId("in2"),
+ "defaultValue" -> toSQLValue("B"),
+ "expectedType" -> toSQLType("INT"),
+ "actualType" -> toSQLType("STRING")))
+ }
+
+ test("IDENTIFIER") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ checkAnswer(
+ spark.sql("CALL IDENTIFIER(:p1)(1, 2)", Map("p1" -> "cat.ns.sum")),
+ Row(3) :: Nil)
+ }
+
+ test("parameterized statements") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ checkAnswer(
+ spark.sql("CALL cat.ns.sum(?, ?)", Array(2, 3)),
+ Row(5) :: Nil)
+ }
+
+ test("undefined procedure") {
+ checkError(
+ exception = intercept[AnalysisException](
+ sql("CALL cat.non_exist(1, 2)")
+ ),
+ sqlState = Some("38000"),
+ condition = "FAILED_TO_LOAD_ROUTINE",
+ parameters = Map("routineName" -> "`cat`.`non_exist`")
+ )
+ }
+
+ test("non-procedure catalog") {
+ withSQLConf("spark.sql.catalog.testcat" ->
classOf[BasicInMemoryTableCatalog].getName) {
+ checkError(
+ exception = intercept[AnalysisException](
+ sql("CALL testcat.procedure(1, 2)")
+ ),
+ condition = "_LEGACY_ERROR_TEMP_1184",
+ parameters = Map("plugin" -> "testcat", "ability" -> "procedures")
+ )
+ }
+ }
+
+ test("too many arguments") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ checkError(
+ exception = intercept[AnalysisException](
+ sql("CALL cat.ns.sum(1, 2, 3)")
+ ),
+ condition = "WRONG_NUM_ARGS.WITHOUT_SUGGESTION",
+ parameters = Map(
+ "functionName" -> toSQLId("sum"),
+ "expectedNum" -> "2",
+ "actualNum" -> "3",
+ "docroot" -> SPARK_DOC_ROOT))
+ }
+
+ test("custom default catalog") {
+ withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "cat") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ val df = sql("CALL ns.sum(1, 2)")
+ checkAnswer(df, Row(3) :: Nil)
+ }
+ }
+
+ test("custom default catalog and namespace") {
+ withSQLConf(SQLConf.DEFAULT_CATALOG.key -> "cat") {
+ catalog.createNamespace(Array("ns"), Collections.emptyMap)
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ sql("USE ns")
+ val df = sql("CALL sum(1, 2)")
+ checkAnswer(df, Row(3) :: Nil)
+ }
+ }
+
+ test("required parameter not found") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("CALL cat.ns.sum()")
+ },
+ condition = "REQUIRED_PARAMETER_NOT_FOUND",
+ parameters = Map(
+ "routineName" -> toSQLId("sum"),
+ "parameterName" -> toSQLId("in1"),
+ "index" -> "0"))
+ }
+
+ test("conflicting position and named parameter assignments") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("CALL cat.ns.sum(1, in1 => 2)")
+ },
+ condition =
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.BOTH_POSITIONAL_AND_NAMED",
+ parameters = Map(
+ "routineName" -> toSQLId("sum"),
+ "parameterName" -> toSQLId("in1")))
+ }
+
+ test("duplicate named parameter assignments") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("CALL cat.ns.sum(in1 => 1, in1 => 2)")
+ },
+ condition =
"DUPLICATE_ROUTINE_PARAMETER_ASSIGNMENT.DOUBLE_NAMED_ARGUMENT_REFERENCE",
+ parameters = Map(
+ "routineName" -> toSQLId("sum"),
+ "parameterName" -> toSQLId("in1")))
+ }
+
+ test("unknown parameter name") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("CALL cat.ns.sum(in1 => 1, in5 => 2)")
+ },
+ condition = "UNRECOGNIZED_PARAMETER_NAME",
+ parameters = Map(
+ "routineName" -> toSQLId("sum"),
+ "argumentName" -> toSQLId("in5"),
+ "proposal" -> (toSQLId("in1") + " " + toSQLId("in2"))))
+ }
+
+ test("position parameter after named parameter") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql("CALL cat.ns.sum(in1 => 1, 2)")
+ },
+ condition = "UNEXPECTED_POSITIONAL_ARGUMENT",
+ parameters = Map(
+ "routineName" -> toSQLId("sum"),
+ "parameterName" -> toSQLId("in1")))
+ }
+
+ test("invalid argument type") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ val call = "CALL cat.ns.sum(1, TIMESTAMP '2016-11-15 20:54:00.000')"
+ checkError(
+ exception = intercept[AnalysisException] {
+ sql(call)
+ },
+ condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ parameters = Map(
+ "sqlExpr" -> "CALL",
+ "paramIndex" -> "second",
+ "inputSql" -> "\"TIMESTAMP '2016-11-15 20:54:00'\"",
+ "inputType" -> toSQLType("TIMESTAMP"),
+ "requiredType" -> toSQLType("INT")),
+ context = ExpectedContext(fragment = call, start = 0, stop = call.length
- 1))
+ }
+
+ test("malformed input to implicit cast") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ val call = "CALL cat.ns.sum('A', 2)"
+ checkError(
+ exception = intercept[SparkNumberFormatException](
+ sql(call)
+ ),
+ condition = "CAST_INVALID_INPUT",
+ parameters = Map(
+ "expression" -> toSQLValue("A"),
+ "sourceType" -> toSQLType("STRING"),
+ "targetType" -> toSQLType("INT")),
+ context = ExpectedContext(fragment = call, start = 0, stop = call.length
- 1))
+ }
+
+ test("required parameters after optional") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"),
UnboundInvalidSum)
+ val e = intercept[SparkException] {
+ sql("CALL cat.ns.sum(in2 => 1)")
+ }
+ assert(e.getMessage.contains("required arguments should come before
optional arguments"))
+ }
+
+ test("INOUT parameters are not supported") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "procedure"),
UnboundInoutProcedure)
+ val e = intercept[SparkException] {
+ sql("CALL cat.ns.procedure(1)")
+ }
+ assert(e.getMessage.contains(" Unsupported parameter mode: INOUT"))
+ }
+
+ test("OUT parameters are not supported") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "procedure"),
UnboundOutProcedure)
+ val e = intercept[SparkException] {
+ sql("CALL cat.ns.procedure(1)")
+ }
+ assert(e.getMessage.contains("Unsupported parameter mode: OUT"))
+ }
+
+ test("EXPLAIN") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"),
UnboundNonExecutableSum)
+ val explain1 = sql("EXPLAIN CALL cat.ns.sum(5, 5)").head().get(0)
+ assert(explain1.toString.contains("cat.ns.sum(5, 5)"))
+ val explain2 = sql("EXPLAIN EXTENDED CALL cat.ns.sum(10,
10)").head().get(0)
+ assert(explain2.toString.contains("cat.ns.sum(10, 10)"))
+ }
+
+ test("void procedure") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "proc"),
UnboundVoidProcedure)
+ checkAnswer(sql("CALL cat.ns.proc('A', 'B')"), Nil)
+ }
+
+ test("multi-result procedure") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "proc"),
UnboundMultiResultProcedure)
+ checkAnswer(sql("CALL cat.ns.proc()"), Row("last") :: Nil)
+ }
+
+ test("invalid input to struct procedure") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "proc"),
UnboundStructProcedure)
+ val actualType =
+ StructType(Seq(
+ StructField("X", DataTypes.DateType, nullable = false),
+ StructField("Y", DataTypes.IntegerType, nullable = false)))
+ val expectedType = StructProcedure.parameters.head.dataType
+ val call = "CALL cat.ns.proc(named_struct('X', DATE '2011-11-11', 'Y', 2),
'VALUE')"
+ checkError(
+ exception = intercept[AnalysisException](sql(call)),
+ condition = "DATATYPE_MISMATCH.UNEXPECTED_INPUT_TYPE",
+ parameters = Map(
+ "sqlExpr" -> "CALL",
+ "paramIndex" -> "first",
+ "inputSql" -> "\"named_struct(X, DATE '2011-11-11', Y, 2)\"",
+ "inputType" -> toSQLType(actualType),
+ "requiredType" -> toSQLType(expectedType)),
+ context = ExpectedContext(fragment = call, start = 0, stop = call.length
- 1))
+ }
+
+ test("save execution summary") {
+ withTable("summary") {
+ catalog.createProcedure(Identifier.of(Array("ns"), "sum"), UnboundSum)
+ val result = sql("CALL cat.ns.sum(1, 2)")
+ result.write.saveAsTable("summary")
+ checkAnswer(spark.table("summary"), Row(3) :: Nil)
+ }
+ }
+
+ object UnboundVoidProcedure extends UnboundProcedure {
+ override def name: String = "void"
+ override def description: String = "void procedure"
+ override def bind(inputType: StructType): BoundProcedure = VoidProcedure
+ }
+
+ object VoidProcedure extends BoundProcedure {
+ override def name: String = "void"
+
+ override def description: String = "void procedure"
+
+ override def isDeterministic: Boolean = true
+
+ override def parameters: Array[ProcedureParameter] = Array(
+ ProcedureParameter.in("in1", DataTypes.StringType).build(),
+ ProcedureParameter.in("in2", DataTypes.StringType).build()
+ )
+
+ override def call(input: InternalRow): java.util.Iterator[Scan] = {
+ Collections.emptyIterator
+ }
+ }
+
+ object UnboundMultiResultProcedure extends UnboundProcedure {
+ override def name: String = "multi"
+ override def description: String = "multi-result procedure"
+ override def bind(inputType: StructType): BoundProcedure =
MultiResultProcedure
+ }
+
+ object MultiResultProcedure extends BoundProcedure {
+ override def name: String = "multi"
+
+ override def description: String = "multi-result procedure"
+
+ override def isDeterministic: Boolean = true
+
+ override def parameters: Array[ProcedureParameter] = Array()
+
+ override def call(input: InternalRow): java.util.Iterator[Scan] = {
+ val scans = java.util.Arrays.asList[Scan](
+ Result(
+ new StructType().add("out", DataTypes.IntegerType),
+ Array(InternalRow(1))),
+ Result(
+ new StructType().add("out", DataTypes.StringType),
+ Array(InternalRow(UTF8String.fromString("last"))))
+ )
+ scans.iterator()
+ }
+ }
+
+ object UnboundNonExecutableSum extends UnboundProcedure {
+ override def name: String = "sum"
+ override def description: String = "sum integers"
+ override def bind(inputType: StructType): BoundProcedure = Sum
+ }
+
+ object NonExecutableSum extends BoundProcedure {
+ override def name: String = "sum"
+
+ override def description: String = "sum integers"
+
+ override def isDeterministic: Boolean = true
+
+ override def parameters: Array[ProcedureParameter] = Array(
+ ProcedureParameter.in("in1", DataTypes.IntegerType).build(),
+ ProcedureParameter.in("in2", DataTypes.IntegerType).build()
+ )
+
+ override def call(input: InternalRow): java.util.Iterator[Scan] = {
+ throw new UnsupportedOperationException()
+ }
+ }
+
+ object UnboundSum extends UnboundProcedure {
+ override def name: String = "sum"
+ override def description: String = "sum integers"
+ override def bind(inputType: StructType): BoundProcedure = Sum
+ }
+
+ object Sum extends BoundProcedure {
+ override def name: String = "sum"
+
+ override def description: String = "sum integers"
+
+ override def isDeterministic: Boolean = true
+
+ override def parameters: Array[ProcedureParameter] = Array(
+ ProcedureParameter.in("in1", DataTypes.IntegerType).build(),
+ ProcedureParameter.in("in2", DataTypes.IntegerType).build()
+ )
+
+ def outputType: StructType = new StructType().add("out",
DataTypes.IntegerType)
+
+ override def call(input: InternalRow): java.util.Iterator[Scan] = {
+ val in1 = input.getInt(0)
+ val in2 = input.getInt(1)
+ val result = Result(outputType, Array(InternalRow(in1 + in2)))
+ Collections.singleton[Scan](result).iterator()
+ }
+ }
+
+ object UnboundLongSum extends UnboundProcedure {
+ override def name: String = "long_sum"
+ override def description: String = "sum longs"
+ override def bind(inputType: StructType): BoundProcedure = LongSum
+ }
+
+ object LongSum extends BoundProcedure {
+ override def name: String = "long_sum"
+
+ override def description: String = "sum longs"
+
+ override def isDeterministic: Boolean = true
+
+ override def parameters: Array[ProcedureParameter] = Array(
+ ProcedureParameter.in("in1", DataTypes.LongType).build(),
+ ProcedureParameter.in("in2", DataTypes.LongType).build()
+ )
+
+ def outputType: StructType = new StructType().add("out",
DataTypes.LongType)
+
+ override def call(input: InternalRow): java.util.Iterator[Scan] = {
+ val in1 = input.getLong(0)
+ val in2 = input.getLong(1)
+ val result = Result(outputType, Array(InternalRow(in1 + in2)))
+ Collections.singleton[Scan](result).iterator()
+ }
+ }
+
+ object UnboundInvalidSum extends UnboundProcedure {
+ override def name: String = "invalid"
+ override def description: String = "sum integers"
+ override def bind(inputType: StructType): BoundProcedure = InvalidSum
+ }
+
+ object InvalidSum extends BoundProcedure {
+ override def name: String = "invalid"
+
+ override def description: String = "sum integers"
+
+ override def isDeterministic: Boolean = false
+
+ override def parameters: Array[ProcedureParameter] = Array(
+ ProcedureParameter.in("in1",
DataTypes.IntegerType).defaultValue("1").build(),
+ ProcedureParameter.in("in2", DataTypes.IntegerType).build()
+ )
+
+ def outputType: StructType = new StructType().add("out",
DataTypes.IntegerType)
+
+ override def call(input: InternalRow): java.util.Iterator[Scan] = {
+ throw new UnsupportedOperationException()
+ }
+ }
+
+ object UnboundInvalidDefaultProcedure extends UnboundProcedure {
+ override def name: String = "sum"
+ override def description: String = "invalid default value procedure"
+ override def bind(inputType: StructType): BoundProcedure =
InvalidDefaultProcedure
+ }
+
+ object InvalidDefaultProcedure extends BoundProcedure {
+ override def name: String = "sum"
+
+ override def description: String = "invalid default value procedure"
+
+ override def isDeterministic: Boolean = true
+
+ override def parameters: Array[ProcedureParameter] = Array(
+ ProcedureParameter.in("in1",
DataTypes.IntegerType).defaultValue("10").build(),
+ ProcedureParameter.in("in2",
DataTypes.IntegerType).defaultValue("'B'").build()
+ )
+
+ def outputType: StructType = new StructType().add("out",
DataTypes.IntegerType)
+
+ override def call(input: InternalRow): java.util.Iterator[Scan] = {
+ throw new UnsupportedOperationException()
+ }
+ }
+
+ object UnboundComplexProcedure extends UnboundProcedure {
+ override def name: String = "complex"
+ override def description: String = "complex procedure"
+ override def bind(inputType: StructType): BoundProcedure = ComplexProcedure
+ }
+
+ object ComplexProcedure extends BoundProcedure {
+ override def name: String = "complex"
+
+ override def description: String = "complex procedure"
+
+ override def isDeterministic: Boolean = true
+
+ override def parameters: Array[ProcedureParameter] = Array(
+ ProcedureParameter.in("in1",
DataTypes.StringType).defaultValue("'A'").build(),
+ ProcedureParameter.in("in2",
DataTypes.StringType).defaultValue("'B'").build(),
+ ProcedureParameter.in("in3", DataTypes.IntegerType).defaultValue("1 + 1
- 1").build()
+ )
+
+ def outputType: StructType = new StructType()
+ .add("out1", DataTypes.IntegerType)
+ .add("out2", DataTypes.StringType)
+ .add("out3", DataTypes.StringType)
+
+
+ override def call(input: InternalRow): java.util.Iterator[Scan] = {
+ val in1 = input.getString(0)
+ val in2 = input.getString(1)
+ val in3 = input.getInt(2)
+
+ val rows = (1 to in3).map { index =>
+ val v1 = UTF8String.fromString(s"$in1$index")
+ val v2 = UTF8String.fromString(s"$in2$index")
+ InternalRow(index, v1, v2)
+ }.toArray
+
+ val result = Result(outputType, rows)
+ Collections.singleton[Scan](result).iterator()
+ }
+ }
+
+ object UnboundStructProcedure extends UnboundProcedure {
+ override def name: String = "struct_input"
+ override def description: String = "struct procedure"
+ override def bind(inputType: StructType): BoundProcedure = StructProcedure
+ }
+
+ object StructProcedure extends BoundProcedure {
+ override def name: String = "struct_input"
+
+ override def description: String = "struct procedure"
+
+ override def isDeterministic: Boolean = true
+
+ override def parameters: Array[ProcedureParameter] = Array(
+ ProcedureParameter
+ .in(
+ "in1",
+ StructType(Seq(
+ StructField("nested1", DataTypes.IntegerType),
+ StructField("nested2", DataTypes.StringType))))
+ .build(),
+ ProcedureParameter.in("in2", DataTypes.StringType).build()
+ )
+
+ override def call(input: InternalRow): java.util.Iterator[Scan] = {
+ Collections.emptyIterator
+ }
+ }
+
+ object UnboundInoutProcedure extends UnboundProcedure {
+ override def name: String = "procedure"
+ override def description: String = "inout procedure"
+ override def bind(inputType: StructType): BoundProcedure = InoutProcedure
+ }
+
+ object InoutProcedure extends BoundProcedure {
+ override def name: String = "procedure"
+
+ override def description: String = "inout procedure"
+
+ override def isDeterministic: Boolean = true
+
+ override def parameters: Array[ProcedureParameter] = Array(
+ CustomParameterImpl(INOUT, "in1", DataTypes.IntegerType)
+ )
+
+ def outputType: StructType = new StructType().add("out",
DataTypes.IntegerType)
+
+ override def call(input: InternalRow): java.util.Iterator[Scan] = {
+ throw new UnsupportedOperationException()
+ }
+ }
+
+ object UnboundOutProcedure extends UnboundProcedure {
+ override def name: String = "procedure"
+ override def description: String = "out procedure"
+ override def bind(inputType: StructType): BoundProcedure = OutProcedure
+ }
+
+ object OutProcedure extends BoundProcedure {
+ override def name: String = "procedure"
+
+ override def description: String = "out procedure"
+
+ override def isDeterministic: Boolean = true
+
+ override def parameters: Array[ProcedureParameter] = Array(
+ CustomParameterImpl(IN, "in1", DataTypes.IntegerType),
+ CustomParameterImpl(OUT, "out1", DataTypes.IntegerType)
+ )
+
+ def outputType: StructType = new StructType().add("out",
DataTypes.IntegerType)
+
+ override def call(input: InternalRow): java.util.Iterator[Scan] = {
+ throw new UnsupportedOperationException()
+ }
+ }
+
+ case class Result(readSchema: StructType, rows: Array[InternalRow]) extends
LocalScan
+
+ case class CustomParameterImpl(
+ mode: Mode,
+ name: String,
+ dataType: DataType) extends ProcedureParameter {
+ override def defaultValueExpression: String = null
+ override def comment: String = null
+ }
+}
diff --git
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
index 4bc4116a23da..dcf3bd8c7173 100644
---
a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
+++
b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/ThriftServerWithSparkContextSuite.scala
@@ -214,7 +214,7 @@ trait ThriftServerWithSparkContextSuite extends
SharedThriftServer {
val sessionHandle = client.openSession(user, "")
val infoValue = client.getInfo(sessionHandle,
GetInfoType.CLI_ODBC_KEYWORDS)
// scalastyle:off line.size.limit
- assert(infoValue.getStringValue ==
"ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURRENT_DAT
[...]
+ assert(infoValue.getStringValue ==
"ADD,AFTER,ALL,ALTER,ALWAYS,ANALYZE,AND,ANTI,ANY,ANY_VALUE,ARCHIVE,ARRAY,AS,ASC,AT,AUTHORIZATION,BEGIN,BETWEEN,BIGINT,BINARY,BINDING,BOOLEAN,BOTH,BUCKET,BUCKETS,BY,BYTE,CACHE,CALL,CALLED,CASCADE,CASE,CAST,CATALOG,CATALOGS,CHANGE,CHAR,CHARACTER,CHECK,CLEAR,CLUSTER,CLUSTERED,CODEGEN,COLLATE,COLLATION,COLLECTION,COLUMN,COLUMNS,COMMENT,COMMIT,COMPACT,COMPACTIONS,COMPENSATION,COMPUTE,CONCATENATE,CONSTRAINT,CONTAINS,COST,CREATE,CROSS,CUBE,CURRENT,CURREN
[...]
// scalastyle:on line.size.limit
}
}
diff --git
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
index 44c1ecd6902c..dbeb8607facc 100644
---
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
+++
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
@@ -25,7 +25,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF}
import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver,
GenericUDF, GenericUDTF}
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.{Analyzer,
EvalSubqueriesForTimeTravel, ReplaceCharWithVarchar, ResolveSessionCatalog,
ResolveTranspose}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer,
EvalSubqueriesForTimeTravel, InvokeProcedures, ReplaceCharWithVarchar,
ResolveSessionCatalog, ResolveTranspose}
import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener,
InvalidUDFClassException}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -95,6 +95,7 @@ class HiveSessionStateBuilder(
new EvalSubqueriesForTimeTravel +:
new DetermineTableStats(session) +:
new ResolveTranspose(session) +:
+ new InvokeProcedures(session) +:
customResolutionRules
override val postHocResolutionRules: Seq[Rule[LogicalPlan]] =
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]