Repository: spark Updated Branches: refs/heads/master 1528ff4c9 -> 600c0b69c
http://git-wip-us.apache.org/repos/asf/spark/blob/600c0b69/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala new file mode 100644 index 0000000..32311a5 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Test basic expression parsing. If a type of expression is supported it should be tested here. + * + * Please note that some of the expressions test don't have to be sound expressions, only their + * structure needs to be valid. Unsound expressions should be caught by the Analyzer or + * CheckAnalysis classes. + */ +class ExpressionParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, e: Expression): Unit = { + compareExpressions(parseExpression(sqlCommand), e) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parseExpression(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("star expressions") { + // Global Star + assertEqual("*", UnresolvedStar(None)) + + // Targeted Star + assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b")))) + } + + // NamedExpression (Alias/Multialias) + test("named expressions") { + // No Alias + val r0 = 'a + assertEqual("a", r0) + + // Single Alias. + val r1 = 'a as "b" + assertEqual("a as b", r1) + assertEqual("a b", r1) + + // Multi-Alias + assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c"))) + assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c"))) + + // Numeric literals without a space between the literal qualifier and the alias, should not be + // interpreted as such. An unresolved reference should be returned instead. + // TODO add the JIRA-ticket number. + assertEqual("1SL", Symbol("1SL")) + + // Aliased star is allowed. + assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b) + } + + test("binary logical expressions") { + // And + assertEqual("a and b", 'a && 'b) + + // Or + assertEqual("a or b", 'a || 'b) + + // Combination And/Or check precedence + assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd)) + assertEqual("a or b or c and d", 'a || 'b || ('c && 'd)) + + // Multiple AND/OR get converted into a balanced tree + assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f)) + assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f)) + } + + test("long binary logical expressions") { + def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { + val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) + val e = parseExpression(sql) + assert(e.collect { case _: EqualTo => true }.size === 1000) + assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) + } + testVeryBinaryExpression(" AND ", classOf[And]) + testVeryBinaryExpression(" OR ", classOf[Or]) + } + + test("not expressions") { + assertEqual("not a", !'a) + assertEqual("!a", !'a) + assertEqual("not true > true", Not(GreaterThan(true, true))) + } + + test("exists expression") { + intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported") + } + + test("comparison expressions") { + assertEqual("a = b", 'a === 'b) + assertEqual("a == b", 'a === 'b) + assertEqual("a <=> b", 'a <=> 'b) + assertEqual("a <> b", 'a =!= 'b) + assertEqual("a != b", 'a =!= 'b) + assertEqual("a < b", 'a < 'b) + assertEqual("a <= b", 'a <= 'b) + assertEqual("a > b", 'a > 'b) + assertEqual("a >= b", 'a >= 'b) + } + + test("between expressions") { + assertEqual("a between b and c", 'a >= 'b && 'a <= 'c) + assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c)) + } + + test("in expressions") { + assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd)) + assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd))) + } + + test("in sub-query") { + intercept("a in (select b from c)", "IN with a Sub-query is currently not supported") + } + + test("like expressions") { + assertEqual("a like 'pattern%'", 'a like "pattern%") + assertEqual("a not like 'pattern%'", !('a like "pattern%")) + assertEqual("a rlike 'pattern%'", 'a rlike "pattern%") + assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%")) + assertEqual("a regexp 'pattern%'", 'a rlike "pattern%") + assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) + } + + test("is null expressions") { + assertEqual("a is null", 'a.isNull) + assertEqual("a is not null", 'a.isNotNull) + assertEqual("a = b is null", ('a === 'b).isNull) + assertEqual("a = b is not null", ('a === 'b).isNotNull) + } + + test("binary arithmetic expressions") { + // Simple operations + assertEqual("a * b", 'a * 'b) + assertEqual("a / b", 'a / 'b) + assertEqual("a DIV b", ('a / 'b).cast(LongType)) + assertEqual("a % b", 'a % 'b) + assertEqual("a + b", 'a + 'b) + assertEqual("a - b", 'a - 'b) + assertEqual("a & b", 'a & 'b) + assertEqual("a ^ b", 'a ^ 'b) + assertEqual("a | b", 'a | 'b) + + // Check precedences + assertEqual( + "a * t | b ^ c & d - e + f % g DIV h / i * k", + 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) + } + + test("unary arithmetic expressions") { + assertEqual("+a", 'a) + assertEqual("-a", -'a) + assertEqual("~a", ~'a) + assertEqual("-+~~a", -(~(~'a))) + } + + test("cast expressions") { + // Note that DataType parsing is tested elsewhere. + assertEqual("cast(a as int)", 'a.cast(IntegerType)) + assertEqual("cast(a as timestamp)", 'a.cast(TimestampType)) + assertEqual("cast(a as array<int>)", 'a.cast(ArrayType(IntegerType))) + assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType)) + } + + test("function expressions") { + assertEqual("foo()", 'foo.function()) + assertEqual("foo.bar()", Symbol("foo.bar").function()) + assertEqual("foo(*)", 'foo.function(star())) + assertEqual("count(*)", 'count.function(1)) + assertEqual("foo(a, b)", 'foo.function('a, 'b)) + assertEqual("foo(all a, b)", 'foo.function('a, 'b)) + assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) + assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) + assertEqual("`select`(all a, b)", 'select.function('a, 'b)) + } + + test("window function expressions") { + val func = 'foo.function(star()) + def windowed( + partitioning: Seq[Expression] = Seq.empty, + ordering: Seq[SortOrder] = Seq.empty, + frame: WindowFrame = UnspecifiedFrame): Expression = { + WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) + } + + // Basic window testing. + assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1"))) + assertEqual("foo(*) over ()", windowed()) + assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) + assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) + + // Test use of expressions in window functions. + assertEqual( + "sum(product + 1) over (partition by ((product) + (1)) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + assertEqual( + "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + + // Range/Row + val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) + val boundaries = Seq( + ("10 preceding", ValuePreceding(10), CurrentRow), + ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis + ("unbounded preceding", UnboundedPreceding, CurrentRow), + ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis + ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), + ("between unbounded preceding and unbounded following", + UnboundedPreceding, UnboundedFollowing), + ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), + ("between current row and 5 following", CurrentRow, ValueFollowing(5)), + ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) + ) + frameTypes.foreach { + case (frameTypeSql, frameType) => + boundaries.foreach { + case (boundarySql, begin, end) => + val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)" + val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end)) + assertEqual(query, expr) + } + } + + // We cannot use non integer constants. + intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", + "Frame bound value must be a constant integer.") + + // We cannot use an arbitrary expression. + intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", + "Frame bound value must be a constant integer.") + } + + test("row constructor") { + // Note that '(a)' will be interpreted as a nested expression. + assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) + assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) + } + + test("scalar sub-query") { + assertEqual( + "(select max(val) from tbl) > current", + ScalarSubquery(table("tbl").select('max.function('val))) > 'current) + assertEqual( + "a = (select b from s)", + 'a === ScalarSubquery(table("s").select('b))) + } + + test("case when") { + assertEqual("case a when 1 then b when 2 then c else d end", + CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) + assertEqual("case when a = 1 then b when a = 2 then c else d end", + CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) + } + + test("dereference") { + assertEqual("a.b", UnresolvedAttribute("a.b")) + assertEqual("`select`.b", UnresolvedAttribute("select.b")) + assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. + assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) + } + + test("reference") { + // Regular + assertEqual("a", 'a) + + // Starting with a digit. + assertEqual("1a", Symbol("1a")) + + // Quoted using a keyword. + assertEqual("`select`", 'select) + + // Unquoted using an unreserved keyword. + assertEqual("columns", 'columns) + } + + test("subscript") { + assertEqual("a[b]", 'a.getItem('b)) + assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1)) + assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b)) + } + + test("parenthesis") { + assertEqual("(a)", 'a) + assertEqual("r * (a + b)", 'r * ('a + 'b)) + } + + test("type constructors") { + // Dates. + assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11"))) + intercept[IllegalArgumentException] { + parseExpression("DAtE 'mar 11 2016'") + } + + // Timestamps. + assertEqual("tImEstAmp '2016-03-11 20:54:00.000'", + Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) + intercept[IllegalArgumentException] { + parseExpression("timestamP '2016-33-11 20:54:00.000'") + } + + // Unsupported datatype. + intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.") + } + + test("literals") { + // NULL + assertEqual("null", Literal(null)) + + // Boolean + assertEqual("trUe", Literal(true)) + assertEqual("False", Literal(false)) + + // Integral should have the narrowest possible type + assertEqual("787324", Literal(787324)) + assertEqual("7873247234798249234", Literal(7873247234798249234L)) + assertEqual("78732472347982492793712334", + Literal(BigDecimal("78732472347982492793712334").underlying())) + + // Decimal + assertEqual("7873247234798249279371.2334", + Literal(BigDecimal("7873247234798249279371.2334").underlying())) + + // Scientific Decimal + assertEqual("9.0e1", 90d) + assertEqual(".9e+2", 90d) + assertEqual("0.9e+2", 90d) + assertEqual("900e-1", 90d) + assertEqual("900.0E-1", 90d) + assertEqual("9.e+1", 90d) + intercept(".e3") + + // Tiny Int Literal + assertEqual("10Y", Literal(10.toByte)) + intercept("-1000Y") + + // Small Int Literal + assertEqual("10S", Literal(10.toShort)) + intercept("40000S") + + // Long Int Literal + assertEqual("10L", Literal(10L)) + intercept("78732472347982492793712334L") + + // Double Literal + assertEqual("10.0D", Literal(10.0D)) + // TODO we need to figure out if we should throw an exception here! + assertEqual("1E309", Literal(Double.PositiveInfinity)) + } + + test("strings") { + // Single Strings. + assertEqual("\"hello\"", "hello") + assertEqual("'hello'", "hello") + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld") + assertEqual("'hello' \" \" 'world'", "hello world") + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%") + assertEqual("'no-pattern\\%'", "no-pattern\\%") + assertEqual("'pattern\\\\%'", "pattern\\%") + assertEqual("'pattern\\\\\\%'", "pattern\\\\%") + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') + assertEqual("'\\''", "\'") // Single quote + assertEqual("'\\\"'", "\"") // Double quote + assertEqual("'\\b'", "\b") // Backspace + assertEqual("'\\n'", "\n") // Newline + assertEqual("'\\r'", "\r") // Carriage return + assertEqual("'\\t'", "\t") // Tab character + assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") + + // Unicode + assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") + } + + test("intervals") { + def intervalLiteral(u: String, s: String): Literal = { + Literal(CalendarInterval.fromSingleUnitString(u, s)) + } + + // Empty interval statement + intercept("interval", "at least one time unit should be given for interval literal") + + // Single Intervals. + val units = Seq( + "year", + "month", + "week", + "day", + "hour", + "minute", + "second", + "millisecond", + "microsecond") + val forms = Seq("", "s") + val values = Seq("0", "10", "-7", "21") + units.foreach { unit => + forms.foreach { form => + values.foreach { value => + val expected = intervalLiteral(unit, value) + assertEqual(s"interval $value $unit$form", expected) + assertEqual(s"interval '$value' $unit$form", expected) + } + } + } + + // Hive nanosecond notation. + assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789")) + assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789")) + + // Non Existing unit + intercept("interval 10 nanoseconds", "No interval can be constructed") + + // Year-Month intervals. + val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0") + yearMonthValues.foreach { value => + val result = Literal(CalendarInterval.fromYearMonthString(value)) + assertEqual(s"interval '$value' year to month", result) + } + + // Day-Time intervals. + val datTimeValues = Seq( + "99 11:22:33.123456789", + "-99 11:22:33.123456789", + "10 9:8:7.123456789", + "1 0:0:0", + "-1 0:0:0", + "1 0:0:1") + datTimeValues.foreach { value => + val result = Literal(CalendarInterval.fromDayTimeString(value)) + assertEqual(s"interval '$value' day to second", result) + } + + // Unknown FROM TO intervals + intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.") + + // Composed intervals. + assertEqual( + "interval 3 months 22 seconds 1 millisecond", + Literal(new CalendarInterval(3, 22001000L))) + assertEqual( + "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second", + Literal(new CalendarInterval(14, + 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND))) + } + + test("composed expressions") { + assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q")) + assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) + intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/600c0b69/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala new file mode 100644 index 0000000..4206d22 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class PlanParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { + comparePlans(parsePlan(sqlCommand), plan) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parsePlan(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("case insensitive") { + val plan = table("a").select(star()) + assertEqual("sELEct * FroM a", plan) + assertEqual("select * fRoM a", plan) + assertEqual("SELECT * FROM a", plan) + } + + test("show functions") { + assertEqual("show functions", ShowFunctions(None, None)) + assertEqual("show functions foo", ShowFunctions(None, Some("foo"))) + assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar"))) + assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*"))) + intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name") + } + + test("describe function") { + assertEqual("describe function bar", DescribeFunction("bar", isExtended = false)) + assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true)) + assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false)) + assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true)) + } + + test("set operations") { + val a = table("a").select(star()) + val b = table("b").select(star()) + + assertEqual("select * from a union select * from b", Distinct(a.union(b))) + assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) + assertEqual("select * from a union all select * from b", a.union(b)) + assertEqual("select * from a except select * from b", a.except(b)) + intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") + assertEqual("select * from a except distinct select * from b", a.except(b)) + assertEqual("select * from a intersect select * from b", a.intersect(b)) + intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") + assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) + } + + test("common table expressions") { + def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { + val ctes = namedPlans.map { + case (name, cte) => + name -> SubqueryAlias(name, cte) + }.toMap + With(plan, ctes) + } + assertEqual( + "with cte1 as (select * from a) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> table("a").select(star()))) + assertEqual( + "with cte1 (select 1) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1))) + assertEqual( + "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2", + cte(table("cte2").select(star()), + "cte1" -> OneRowRelation.select(1), + "cte2" -> table("cte1").select(star()))) + intercept( + "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1", + "Name 'cte1' is used for multiple common table expressions") + } + + test("simple select query") { + assertEqual("select 1", OneRowRelation.select(1)) + assertEqual("select a, b", OneRowRelation.select('a, 'b)) + assertEqual("select a, b from db.c", table("db", "c").select('a, 'b)) + assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) + assertEqual( + "select a, b from db.c having x < 1", + table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType))) + assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) + assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) + } + + test("reverse select query") { + assertEqual("from a", table("a")) + assertEqual("from a select b, c", table("a").select('b, 'c)) + assertEqual( + "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c)) + assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c))) + assertEqual( + "from (from a union all from b) c select *", + table("a").union(table("b")).as("c").select(star())) + } + + test("transform query spec") { + val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) + assertEqual("select transform(a, b) using 'func' from e where f < 10", + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + assertEqual("map a, b using 'func' as c, d from e", + p.copy(output = Seq('c.string, 'd.string))) + assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + + test("multi select query") { + assertEqual( + "from a select * select * where s < 10", + table("a").select(star()).union(table("a").where('s < 10).select(star()))) + intercept( + "from a select * select * from x where a.s < 10", + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") + assertEqual( + "from a insert into tbl1 select * insert into tbl2 select * where s < 10", + table("a").select(star()).insertInto("tbl1").union( + table("a").where('s < 10).select(star()).insertInto("tbl2"))) + } + + test("query organization") { + // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows + val baseSql = "select * from t" + val basePlan = table("t").select(star()) + + val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame)) + val limitWindowClauses = Seq( + ("", (p: LogicalPlan) => p), + (" limit 10", (p: LogicalPlan) => p.limit(10)), + (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)), + (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10)) + ) + + val orderSortDistrClusterClauses = Seq( + ("", basePlan), + (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), + (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), + (" distribute by a, b", basePlan.distribute('a, 'b)), + (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)), + (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc)) + ) + + orderSortDistrClusterClauses.foreach { + case (s1, p1) => + limitWindowClauses.foreach { + case (s2, pf2) => + assertEqual(baseSql + s1 + s2, pf2(p1)) + } + } + + val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported" + intercept(s"$baseSql order by a sort by a", msg) + intercept(s"$baseSql cluster by a distribute by a", msg) + intercept(s"$baseSql order by a cluster by a", msg) + intercept(s"$baseSql order by a distribute by a", msg) + } + + test("insert into") { + val sql = "select * from t" + val plan = table("t").select(star()) + def insert( + partition: Map[String, Option[String]], + overwrite: Boolean = false, + ifNotExists: Boolean = false): LogicalPlan = + InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + + // Single inserts + assertEqual(s"insert overwrite table s $sql", + insert(Map.empty, overwrite = true)) + assertEqual(s"insert overwrite table s if not exists $sql", + insert(Map.empty, overwrite = true, ifNotExists = true)) + assertEqual(s"insert into s $sql", + insert(Map.empty)) + assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", + insert(Map("c" -> Option("d"), "e" -> Option("1")))) + assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql", + insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true)) + + // Multi insert + val plan2 = table("t").where('x > 5).select(star()) + assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", + InsertIntoTable( + table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + InsertIntoTable( + table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + } + + test("aggregation") { + val sql = "select a, b, sum(c) as c from d group by a, b" + + // Normal + assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c"))) + + // Cube + assertEqual(s"$sql with cube", + table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Rollup + assertEqual(s"$sql with rollup", + table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Grouping Sets + assertEqual(s"$sql grouping sets((a, b), (a), ())", + GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) + intercept(s"$sql grouping sets((a, b), (c), ())", + "c doesn't show up in the GROUP BY list") + } + + test("limit") { + val sql = "select * from t" + val plan = table("t").select(star()) + assertEqual(s"$sql limit 10", plan.limit(10)) + assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType))) + } + + test("window spec") { + // Note that WindowSpecs are testing in the ExpressionParserSuite + val sql = "select * from t" + val plan = table("t").select(star()) + val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), + SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) + + // Test window resolution. + val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) + assertEqual( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w1""".stripMargin, + WithWindowDefinition(ws1, plan)) + + // Fail with no reference. + intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'") + + // Fail when resolved reference is not a window spec. + intercept( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w2""".stripMargin, + "Window reference 'w2' is not a window specification" + ) + } + + test("lateral view") { + // Single lateral view + assertEqual( + "select * from t lateral view explode(x) expl as x", + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + .select(star())) + + // Multiple lateral views + assertEqual( + """select * + |from t + |lateral view explode(x) expl + |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty) + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z")) + .select(star())) + + // Multi-Insert lateral views. + val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + assertEqual( + """from t1 + |lateral view explode(x) expl as x + |insert into t2 + |select * + |lateral view json_tuple(x, y) jtup q, z + |insert into t3 + |select * + |where s < 10 + """.stripMargin, + Union(from + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z")) + .select(star()) + .insertInto("t2"), + from.where('s < 10).select(star()).insertInto("t3"))) + + // Unsupported generator. + intercept( + "select * from t lateral view posexplode(x) posexpl as x, y", + "Generator function 'posexplode' is not supported") + } + + test("joins") { + // Test single joins. + val testUnconditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t as tt $sql u", + table("t").as("tt").join(table("u"), jt, None).select(star())) + } + val testConditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u as uu on a = b", + table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star())) + } + val testNaturalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t tt natural $sql u as uu", + table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star())) + } + val testUsingJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u using(a, b)", + table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) + } + val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) + + def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { + tests.foreach(_(sql, jt)) + } + test("cross join", Inner, Seq(testUnconditionalJoin)) + test(",", Inner, Seq(testUnconditionalJoin)) + test("join", Inner, testAll) + test("inner join", Inner, testAll) + test("left join", LeftOuter, testAll) + test("left outer join", LeftOuter, testAll) + test("right join", RightOuter, testAll) + test("right outer join", RightOuter, testAll) + test("full join", FullOuter, testAll) + test("full outer join", FullOuter, testAll) + + // Test multiple consecutive joins + assertEqual( + "select * from a join b join c right join d", + table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) + } + + test("sampled relations") { + val sql = "select * from t" + assertEqual(s"$sql tablesample(100 rows)", + table("t").limit(100).select(star())) + assertEqual(s"$sql tablesample(43 percent) as x", + Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", + Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", + "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported") + intercept(s"$sql tablesample(bucket 11 out of 10) as x", + s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]") + } + + test("sub-query") { + val plan = table("t0").select('id) + assertEqual("select id from (t0)", plan) + assertEqual("select id from ((((((t0))))))", plan) + assertEqual( + "(select * from t1) union distinct (select * from t2)", + Distinct(table("t1").select(star()).union(table("t2").select(star())))) + assertEqual( + "select * from ((select * from t1) union (select * from t2)) t", + Distinct( + table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star())) + assertEqual( + """select id + |from (((select id from t0) + | union all + | (select id from t0)) + | union all + | (select id from t0)) as u_1 + """.stripMargin, + plan.union(plan).union(plan).as("u_1").select('id)) + } + + test("scalar sub-query") { + assertEqual( + "select (select max(b) from s) ss from t", + table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss"))) + assertEqual( + "select * from t where a = (select b from s)", + table("t").where('a === ScalarSubquery(table("s").select('b))).select(star())) + assertEqual( + "select g from t group by g having a > (select b from s)", + table("t") + .groupBy('g)('g) + .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType))) + } + + test("table reference") { + assertEqual("table t", table("t")) + assertEqual("table d.t", table("d", "t")) + } + + test("inline table") { + assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( + Seq('col1.int), + Seq(1, 2, 3, 4).map(x => Row(x)))) + assertEqual( + "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", + LocalRelation.fromExternalRows( + Seq('a.int, 'b.string), + Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) + intercept("values (a, 'a'), (b, 'b')", + "All expressions in an inline table must be constants.") + intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", + "Number of aliases must match the number of fields in an inline table.") + intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/600c0b69/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala new file mode 100644 index 0000000..0874322 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier + +class TableIdentifierParserSuite extends SparkFunSuite { + import CatalystSqlParser._ + + test("table identifier") { + // Regular names. + assert(TableIdentifier("q") === parseTableIdentifier("q")) + assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q")) + + // Illegal names. + intercept[ParseException](parseTableIdentifier("")) + intercept[ParseException](parseTableIdentifier("d.q.g")) + + // SQL Keywords. + val keywords = Seq("select", "from", "where", "left", "right") + keywords.foreach { keyword => + intercept[ParseException](parseTableIdentifier(keyword)) + assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) + assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/600c0b69/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 0541844..aa5d433 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} import org.apache.spark.sql.catalyst.util._ /** @@ -32,6 +32,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { */ protected def normalizeExprIds(plan: LogicalPlan) = { plan transformAllExpressions { + case s: ScalarSubquery => + ScalarSubquery(s.query, ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => @@ -40,21 +42,25 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { } /** - * Normalizes the filter conditions that appear in the plan. For instance, - * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) - * etc., will all now be equivalent. + * Normalizes plans: + * - Filter the filter conditions that appear in a plan. For instance, + * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) + * etc., will all now be equivalent. + * - Sample the seed will replaced by 0L. */ - private def normalizeFilters(plan: LogicalPlan) = { + private def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case filter @ Filter(condition: Expression, child: LogicalPlan) => Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child) + case sample: Sample => + sample.copy(seed = 0L)(true) } } /** Fails the test if the two plans do not match */ protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizeFilters(normalizeExprIds(plan1)) - val normalized2 = normalizeFilters(normalizeExprIds(plan2)) + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) if (normalized1 != normalized2) { fail( s""" http://git-wip-us.apache.org/repos/asf/spark/blob/600c0b69/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala new file mode 100644 index 0000000..c098fa9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.ng.{AbstractSqlParser, AstBuilder} +import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.execution.command.{DescribeCommand => _, _} +import org.apache.spark.sql.execution.datasources._ + +/** + * Concrete parser for Spark SQL statements. + */ +object SparkSqlParser extends AbstractSqlParser{ + val astBuilder = new SparkSqlAstBuilder +} + +/** + * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. + */ +class SparkSqlAstBuilder extends AstBuilder { + import org.apache.spark.sql.catalyst.parser.ng.ParserUtils._ + + /** + * Create a [[SetCommand]] logical plan. + * + * Note that we assume that everything after the SET keyword is assumed to be a part of the + * key-value pair. The split between key and value is made by searching for the first `=` + * character in the raw string. + */ + override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) { + // Construct the command. + val raw = remainder(ctx.SET.getSymbol) + val keyValueSeparatorIndex = raw.indexOf('=') + if (keyValueSeparatorIndex >= 0) { + val key = raw.substring(0, keyValueSeparatorIndex).trim + val value = raw.substring(keyValueSeparatorIndex + 1).trim + SetCommand(Some(key -> Option(value))) + } else if (raw.nonEmpty) { + SetCommand(Some(raw.trim -> None)) + } else { + SetCommand(None) + } + } + + /** + * Create a [[SetDatabaseCommand]] logical plan. + */ + override def visitUse(ctx: UseContext): LogicalPlan = withOrigin(ctx) { + SetDatabaseCommand(ctx.db.getText) + } + + /** + * Create a [[ShowTablesCommand]] logical plan. + */ + override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) { + if (ctx.LIKE != null) { + logWarning("SHOW TABLES LIKE option is ignored.") + } + ShowTablesCommand(Option(ctx.db).map(_.getText)) + } + + /** + * Create a [[RefreshTable]] logical plan. + */ + override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) { + RefreshTable(visitTableIdentifier(ctx.tableIdentifier)) + } + + /** + * Create a [[CacheTableCommand]] logical plan. + */ + override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) { + val query = Option(ctx.query).map(plan) + CacheTableCommand(ctx.identifier.getText, query, ctx.LAZY != null) + } + + /** + * Create an [[UncacheTableCommand]] logical plan. + */ + override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) { + UncacheTableCommand(ctx.identifier.getText) + } + + /** + * Create a [[ClearCacheCommand]] logical plan. + */ + override def visitClearCache(ctx: ClearCacheContext): LogicalPlan = withOrigin(ctx) { + ClearCacheCommand + } + + /** + * Create an [[ExplainCommand]] logical plan. + */ + override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) { + val options = ctx.explainOption.asScala + if (options.exists(_.FORMATTED != null)) { + logWarning("EXPLAIN FORMATTED option is ignored.") + } + if (options.exists(_.LOGICAL != null)) { + logWarning("EXPLAIN LOGICAL option is ignored.") + } + + // Create the explain comment. + val statement = plan(ctx.statement) + if (isExplainableStatement(statement)) { + ExplainCommand(statement, extended = options.exists(_.EXTENDED != null)) + } else { + ExplainCommand(OneRowRelation) + } + } + + /** + * Determine if a plan should be explained at all. + */ + protected def isExplainableStatement(plan: LogicalPlan): Boolean = plan match { + case _: datasources.DescribeCommand => false + case _ => true + } + + /** + * Create a [[DescribeCommand]] logical plan. + */ + override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { + // FORMATTED and columns are not supported. Return null and let the parser decide what to do + // with this (create an exception or pass it on to a different system). + if (ctx.describeColName != null || ctx.FORMATTED != null || ctx.partitionSpec != null) { + null + } else { + datasources.DescribeCommand( + visitTableIdentifier(ctx.tableIdentifier), + ctx.EXTENDED != null) + } + } + + /** Type to keep track of a table header. */ + type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) + + /** + * Validate a create table statement and return the [[TableIdentifier]]. + */ + override def visitCreateTableHeader( + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + val temporary = ctx.TEMPORARY != null + val ifNotExists = ctx.EXISTS != null + assert(!temporary || !ifNotExists, + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.", + ctx) + (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null) + } + + /** + * Create a [[CreateTableUsing]] or a [[CreateTableUsingAsSelect]] logical plan. + * + * TODO add bucketing and partitioning. + */ + override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + if (external) { + logWarning("EXTERNAL option is not supported.") + } + val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) + val provider = ctx.tableProvider.qualifiedName.getText + + if (ctx.query != null) { + // Get the backing query. + val query = plan(ctx.query) + + // Determine the storage mode. + val mode = if (ifNotExists) { + SaveMode.Ignore + } else if (temp) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + CreateTableUsingAsSelect(table, provider, temp, Array.empty, None, mode, options, query) + } else { + val struct = Option(ctx.colTypeList).map(createStructType) + CreateTableUsing(table, struct, provider, temp, options, ifNotExists, managedIfNoPath = false) + } + } + + /** + * Convert a table property list into a key-value map. + */ + override def visitTablePropertyList( + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + ctx.tableProperty.asScala.map { property => + // A key can either be a String or a collection of dot separated elements. We need to treat + // these differently. + val key = if (property.key.STRING != null) { + string(property.key.STRING) + } else { + property.key.getText + } + val value = Option(property.value).map(string).orNull + key -> value + }.toMap + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/600c0b69/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8abb9d7..7ce15e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.parser.CatalystQl import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1172,8 +1172,7 @@ object functions { * @group normal_funcs */ def expr(expr: String): Column = { - val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl()) - Column(parser.parseExpression(expr)) + Column(SparkSqlParser.parseExpression(expr)) } ////////////////////////////////////////////////////////////////////////////////////////////// http://git-wip-us.apache.org/repos/asf/spark/blob/600c0b69/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index e5f02ca..9bc6407 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -81,7 +81,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. */ - lazy val sqlParser: ParserInterface = new SparkQl(conf) + lazy val sqlParser: ParserInterface = SparkSqlParser /** * Planner that converts optimized logical plans to physical plans. http://git-wip-us.apache.org/repos/asf/spark/blob/600c0b69/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5af1a4f..a5a4ff1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -329,8 +329,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("full outer join") { - upperCaseData.where('N <= 4).registerTempTable("left") - upperCaseData.where('N >= 3).registerTempTable("right") + upperCaseData.where('N <= 4).registerTempTable("`left`") + upperCaseData.where('N >= 3).registerTempTable("`right`") val left = UnresolvedRelation(TableIdentifier("left"), None) val right = UnresolvedRelation(TableIdentifier("right"), None) http://git-wip-us.apache.org/repos/asf/spark/blob/600c0b69/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c958eac..b727e88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1656,7 +1656,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val e2 = intercept[AnalysisException] { sql("select interval 23 nanosecond") } - assert(e2.message.contains("cannot recognize input near")) + assert(e2.message.contains("No interval can be constructed")) } test("SPARK-8945: add and subtract expressions for interval type") { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
