Repository: spark
Updated Branches:
  refs/heads/master b4c99f436 -> 8c67aa7f0


[SPARK-20311][SQL] Support aliases for table value functions

## What changes were proposed in this pull request?
This pr added parsing rules to support aliases in table value functions.
The previous pr (#17666) has been reverted because of the regression. This new 
pr fixed the regression and add tests in `SQLQueryTestSuite`.

## How was this patch tested?
Added tests in `PlanParserSuite` and `SQLQueryTestSuite`.

Author: Takeshi Yamamuro <yamam...@apache.org>

Closes #17928 from maropu/SPARK-20311-3.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/8c67aa7f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/8c67aa7f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/8c67aa7f

Branch: refs/heads/master
Commit: 8c67aa7f00e0186abe05a1628faf2232b364a61f
Parents: b4c99f4
Author: Takeshi Yamamuro <yamam...@apache.org>
Authored: Thu May 11 18:09:31 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Thu May 11 18:09:31 2017 +0800

----------------------------------------------------------------------
 .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 20 ++++++++----
 .../analysis/ResolveTableValuedFunctions.scala  | 22 ++++++++++++--
 .../sql/catalyst/analysis/unresolved.scala      | 10 ++++--
 .../spark/sql/catalyst/parser/AstBuilder.scala  | 17 ++++++++---
 .../sql/catalyst/analysis/AnalysisSuite.scala   | 14 ++++++++-
 .../sql/catalyst/parser/PlanParserSuite.scala   | 13 +++++++-
 .../resources/sql-tests/inputs/inline-table.sql |  3 ++
 .../sql-tests/inputs/table-valued-functions.sql |  3 ++
 .../sql-tests/results/inline-table.sql.out      | 32 +++++++++++++++++++-
 .../results/table-valued-functions.sql.out      | 32 +++++++++++++++++++-
 10 files changed, 147 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/8c67aa7f/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
index 14c511f..ed5450b 100644
--- 
a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
+++ 
b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4
@@ -472,15 +472,23 @@ identifierComment
     ;
 
 relationPrimary
-    : tableIdentifier sample? (AS? strictIdentifier)?               #tableName
-    | '(' queryNoWith ')' sample? (AS? strictIdentifier)?           
#aliasedQuery
-    | '(' relation ')' sample? (AS? strictIdentifier)?              
#aliasedRelation
-    | inlineTable                                                   
#inlineTableDefault2
-    | identifier '(' (expression (',' expression)*)? ')'            
#tableValuedFunction
+    : tableIdentifier sample? (AS? strictIdentifier)?      #tableName
+    | '(' queryNoWith ')' sample? (AS? strictIdentifier)?  #aliasedQuery
+    | '(' relation ')' sample? (AS? strictIdentifier)?     #aliasedRelation
+    | inlineTable                                          #inlineTableDefault2
+    | functionTable                                        #tableValuedFunction
     ;
 
 inlineTable
-    : VALUES expression (',' expression)*  (AS? identifier identifierList?)?
+    : VALUES expression (',' expression)* tableAlias
+    ;
+
+functionTable
+    : identifier '(' (expression (',' expression)*)? ')' tableAlias
+    ;
+
+tableAlias
+    : (AS? strictIdentifier identifierList?)?
     ;
 
 rowFormat

http://git-wip-us.apache.org/repos/asf/spark/blob/8c67aa7f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
index de6de24..dad1340 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis
 
 import java.util.Locale
 
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range}
+import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, 
Range}
 import org.apache.spark.sql.catalyst.rules._
 import org.apache.spark.sql.types.{DataType, IntegerType, LongType}
 
@@ -105,7 +105,7 @@ object ResolveTableValuedFunctions extends 
Rule[LogicalPlan] {
 
   override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
     case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) 
=>
-      builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
+      val resolvedFunc = 
builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match {
         case Some(tvf) =>
           val resolved = tvf.flatMap { case (argList, resolver) =>
             argList.implicitCast(u.functionArgs) match {
@@ -125,5 +125,21 @@ object ResolveTableValuedFunctions extends 
Rule[LogicalPlan] {
         case _ =>
           u.failAnalysis(s"could not resolve `${u.functionName}` to a 
table-valued function")
       }
+
+      // If alias names assigned, add `Project` with the aliases
+      if (u.outputNames.nonEmpty) {
+        val outputAttrs = resolvedFunc.output
+        // Checks if the number of the aliases is equal to expected one
+        if (u.outputNames.size != outputAttrs.size) {
+          u.failAnalysis(s"expected ${outputAttrs.size} columns but " +
+            s"found ${u.outputNames.size} columns")
+        }
+        val aliases = outputAttrs.zip(u.outputNames).map {
+          case (attr, name) => Alias(attr, name)()
+        }
+        Project(aliases, resolvedFunc)
+      } else {
+        resolvedFunc
+      }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8c67aa7f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index 262b894..51bef6e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -66,10 +66,16 @@ case class UnresolvedInlineTable(
 /**
  * A table-valued function, e.g.
  * {{{
- *   select * from range(10);
+ *   select id from range(10);
+ *
+ *   // Assign alias names
+ *   select t.a from range(10) t(a);
  * }}}
  */
-case class UnresolvedTableValuedFunction(functionName: String, functionArgs: 
Seq[Expression])
+case class UnresolvedTableValuedFunction(
+    functionName: String,
+    functionArgs: Seq[Expression],
+    outputNames: Seq[String])
   extends LeafNode {
 
   override def output: Seq[Attribute] = Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/8c67aa7f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index d2a9b4a..046ea65 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -687,7 +687,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with 
Logging {
    */
   override def visitTableValuedFunction(ctx: TableValuedFunctionContext)
       : LogicalPlan = withOrigin(ctx) {
-    UnresolvedTableValuedFunction(ctx.identifier.getText, 
ctx.expression.asScala.map(expression))
+    val func = ctx.functionTable
+    val aliases = if (func.tableAlias.identifierList != null) {
+      visitIdentifierList(func.tableAlias.identifierList)
+    } else {
+      Seq.empty
+    }
+
+    val tvf = UnresolvedTableValuedFunction(
+      func.identifier.getText, func.expression.asScala.map(expression), 
aliases)
+    tvf.optionalMap(func.tableAlias.strictIdentifier)(aliasPlan)
   }
 
   /**
@@ -705,14 +714,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with 
Logging {
       }
     }
 
-    val aliases = if (ctx.identifierList != null) {
-      visitIdentifierList(ctx.identifierList)
+    val aliases = if (ctx.tableAlias.identifierList != null) {
+      visitIdentifierList(ctx.tableAlias.identifierList)
     } else {
       Seq.tabulate(rows.head.size)(i => s"col${i + 1}")
     }
 
     val table = UnresolvedInlineTable(aliases, rows)
-    table.optionalMap(ctx.identifier)(aliasPlan)
+    table.optionalMap(ctx.tableAlias.strictIdentifier)(aliasPlan)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/8c67aa7f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 893bb1b..31047f6 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -25,7 +25,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
 import org.apache.spark.sql.catalyst.plans.Cross
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.types._
@@ -441,4 +440,17 @@ class AnalysisSuite extends AnalysisTest with 
ShouldMatchers {
 
     checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation)
   }
+
+  test("SPARK-20311 range(N) as alias") {
+    def rangeWithAliases(args: Seq[Int], outputNames: Seq[String]): 
LogicalPlan = {
+      SubqueryAlias("t", UnresolvedTableValuedFunction("range", 
args.map(Literal(_)), outputNames))
+        .select(star())
+    }
+    assertAnalysisSuccess(rangeWithAliases(3 :: Nil, "a" :: Nil))
+    assertAnalysisSuccess(rangeWithAliases(1 :: 4 :: Nil, "b" :: Nil))
+    assertAnalysisSuccess(rangeWithAliases(2 :: 6 :: 2 :: Nil, "c" :: Nil))
+    assertAnalysisError(
+      rangeWithAliases(3 :: Nil, "a" :: "b" :: Nil),
+      Seq("expected 1 columns but found 2 columns"))
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/8c67aa7f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
index 411777d..cf137cf 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala
@@ -468,7 +468,18 @@ class PlanParserSuite extends PlanTest {
   test("table valued function") {
     assertEqual(
       "select * from range(2)",
-      UnresolvedTableValuedFunction("range", Literal(2) :: Nil).select(star()))
+      UnresolvedTableValuedFunction("range", Literal(2) :: Nil, 
Seq.empty).select(star()))
+  }
+
+  test("SPARK-20311 range(N) as alias") {
+    assertEqual(
+      "SELECT * FROM range(10) AS t",
+      SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(10) :: 
Nil, Seq.empty))
+        .select(star()))
+    assertEqual(
+      "SELECT * FROM range(7) AS t(a)",
+      SubqueryAlias("t", UnresolvedTableValuedFunction("range", Literal(7) :: 
Nil, "a" :: Nil))
+        .select(star()))
   }
 
   test("inline table") {

http://git-wip-us.apache.org/repos/asf/spark/blob/8c67aa7f/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql 
b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql
index b3ec956..41d3164 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql
@@ -49,3 +49,6 @@ select * from values ("one", count(1)), ("two", 2) as data(a, 
b);
 
 -- string to timestamp
 select * from values (timestamp('1991-12-06 00:00:00.0'), 
array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) 
as data(a, b);
+
+-- cross-join inline tables
+EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN 
VALUES ('one', 1), ('three', null);

http://git-wip-us.apache.org/repos/asf/spark/blob/8c67aa7f/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
index d0d2df7..72cd8ca 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/table-valued-functions.sql
@@ -24,3 +24,6 @@ select * from RaNgE(2);
 
 -- Explain
 EXPLAIN select * from RaNgE(2);
+
+-- cross-join table valued functions
+EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3);

http://git-wip-us.apache.org/repos/asf/spark/blob/8c67aa7f/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out 
b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out
index 4e80f0b..c065ce5 100644
--- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 17
+-- Number of queries: 18
 
 
 -- !query 0
@@ -151,3 +151,33 @@ select * from values (timestamp('1991-12-06 00:00:00.0'), 
array(timestamp('1991-
 struct<a:timestamp,b:array<timestamp>>
 -- !query 16 output
 1991-12-06 00:00:00    [1991-12-06 01:00:00.0,1991-12-06 12:00:00.0]
+
+
+-- !query 17
+EXPLAIN EXTENDED SELECT * FROM VALUES ('one', 1), ('three', null) CROSS JOIN 
VALUES ('one', 1), ('three', null)
+-- !query 17 schema
+struct<plan:string>
+-- !query 17 output
+== Parsed Logical Plan ==
+'Project [*]
++- 'Join Cross
+   :- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)]
+   +- 'UnresolvedInlineTable [col1, col2], [List(one, 1), List(three, null)]
+
+== Analyzed Logical Plan ==
+col1: string, col2: int, col1: string, col2: int
+Project [col1#x, col2#x, col1#x, col2#x]
++- Join Cross
+   :- LocalRelation [col1#x, col2#x]
+   +- LocalRelation [col1#x, col2#x]
+
+== Optimized Logical Plan ==
+Join Cross
+:- LocalRelation [col1#x, col2#x]
++- LocalRelation [col1#x, col2#x]
+
+== Physical Plan ==
+BroadcastNestedLoopJoin BuildRight, Cross
+:- LocalTableScan [col1#x, col2#x]
++- BroadcastExchange IdentityBroadcastMode
+   +- LocalTableScan [col1#x, col2#x]

http://git-wip-us.apache.org/repos/asf/spark/blob/8c67aa7f/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out 
b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
index e2ee970..a8bc6fa 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 9
+-- Number of queries: 10
 
 
 -- !query 0
@@ -103,3 +103,33 @@ struct<plan:string>
 -- !query 8 output
 == Physical Plan ==
 *Range (0, 2, step=1, splits=2)
+
+
+-- !query 9
+EXPLAIN EXTENDED SELECT * FROM range(3) CROSS JOIN range(3)
+-- !query 9 schema
+struct<plan:string>
+-- !query 9 output
+== Parsed Logical Plan ==
+'Project [*]
++- 'Join Cross
+   :- 'UnresolvedTableValuedFunction range, [3]
+   +- 'UnresolvedTableValuedFunction range, [3]
+
+== Analyzed Logical Plan ==
+id: bigint, id: bigint
+Project [id#xL, id#xL]
++- Join Cross
+   :- Range (0, 3, step=1, splits=None)
+   +- Range (0, 3, step=1, splits=None)
+
+== Optimized Logical Plan ==
+Join Cross
+:- Range (0, 3, step=1, splits=None)
++- Range (0, 3, step=1, splits=None)
+
+== Physical Plan ==
+BroadcastNestedLoopJoin BuildRight, Cross
+:- *Range (0, 3, step=1, splits=2)
++- BroadcastExchange IdentityBroadcastMode
+   +- *Range (0, 3, step=1, splits=2)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to