This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d1af925f6e1 [SPARK-44282][CONNECT] Prepare DataType parsing for use in
Spark Connect Scala Client
d1af925f6e1 is described below
commit d1af925f6e12ce5ff62c13ffa9ed32e92d548863
Author: Herman van Hovell <[email protected]>
AuthorDate: Wed Jul 5 16:22:49 2023 -0400
[SPARK-44282][CONNECT] Prepare DataType parsing for use in Spark Connect
Scala Client
### What changes were proposed in this pull request?
This PR prepares moving DataType parsing to sql/api. Basically it puts all
DataType parsing functionality in a super class of regular parsing. We cannot
move the parser just yet because that need to happen at the same time as
DataType.
### Why are the changes needed?
We want the Spark Connect Scala Client to use a restricted class path.
DataType will be one of the shared classes, to move DataType we need to move
DataType parsing.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Modified existing tests.
Closes #41836 from hvanhovell/SPARK-44282.
Authored-by: Herman van Hovell <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
.../sql/catalyst/parser/AbstractSqlParser.scala | 79 ++++++++
.../spark/sql/catalyst/parser/AstBuilder.scala | 168 +----------------
.../sql/catalyst/parser/CatalystSqlParser.scala | 27 +++
.../sql/catalyst/parser/DataTypeAstBuilder.scala | 208 +++++++++++++++++++++
.../catalyst/parser/DataTypeParserInterface.scala | 37 ++++
.../sql/catalyst/parser/ParserInterface.scala | 16 +-
.../parser/{ParseDriver.scala => parsers.scala} | 71 +------
.../org/apache/spark/sql/types/DataType.scala | 6 +-
.../org/apache/spark/sql/types/StructType.scala | 12 +-
.../sql/catalyst/parser/DataTypeParserSuite.scala | 2 +-
10 files changed, 369 insertions(+), 257 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSqlParser.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSqlParser.scala
new file mode 100644
index 00000000000..2d6fabaaef6
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSqlParser.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
+import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.trees.Origin
+import org.apache.spark.sql.errors.QueryParsingErrors
+
+/**
+ * Base class for all ANTLR4 [[ParserInterface]] implementations.
+ */
+abstract class AbstractSqlParser extends AbstractParser with ParserInterface {
+ override def astBuilder: AstBuilder
+
+ /** Creates Expression for a given SQL string. */
+ override def parseExpression(sqlText: String): Expression = parse(sqlText) {
parser =>
+ val ctx = parser.singleExpression()
+ withOrigin(ctx, Some(sqlText)) {
+ astBuilder.visitSingleExpression(ctx)
+ }
+ }
+
+ /** Creates TableIdentifier for a given SQL string. */
+ override def parseTableIdentifier(sqlText: String): TableIdentifier =
parse(sqlText) { parser =>
+ astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier())
+ }
+
+ /** Creates FunctionIdentifier for a given SQL string. */
+ override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
+ parse(sqlText) { parser =>
+
astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
+ }
+ }
+
+ /** Creates a multi-part identifier for a given SQL string */
+ override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
+ parse(sqlText) { parser =>
+
astBuilder.visitSingleMultipartIdentifier(parser.singleMultipartIdentifier())
+ }
+ }
+
+ /** Creates LogicalPlan for a given SQL string of query. */
+ override def parseQuery(sqlText: String): LogicalPlan = parse(sqlText) {
parser =>
+ val ctx = parser.query()
+ withOrigin(ctx, Some(sqlText)) {
+ astBuilder.visitQuery(ctx)
+ }
+ }
+
+ /** Creates LogicalPlan for a given SQL string. */
+ override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) {
parser =>
+ val ctx = parser.singleStatement()
+ withOrigin(ctx, Some(sqlText)) {
+ astBuilder.visitSingleStatement(ctx) match {
+ case plan: LogicalPlan => plan
+ case _ =>
+ val position = Origin(None, None)
+ throw QueryParsingErrors.sqlStatementUnsupportedError(sqlText,
position)
+ }
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
index 488b4e46735..99fa0bf9809 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala
@@ -55,14 +55,10 @@ import org.apache.spark.util.random.RandomSampler
* The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression,
LogicalPlan or
* TableIdentifier.
*/
-class AstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper
with Logging {
+class AstBuilder extends DataTypeAstBuilder with SQLConfHelper with Logging {
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
import ParserUtils._
- protected def typedVisit[T](ctx: ParseTree): T = {
- ctx.accept(this).asInstanceOf[T]
- }
-
protected def withIdentClause(
ctx: IdentifierReferenceContext,
builder: Seq[String] => LogicalPlan): LogicalPlan = {
@@ -3025,96 +3021,6 @@ class AstBuilder extends
SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
* DataType parsing
*
********************************************************************************************
*/
- /**
- * Resolve/create a primitive type.
- */
- override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType
= withOrigin(ctx) {
- val typeName = ctx.`type`.start.getType
- (typeName, ctx.INTEGER_VALUE().asScala.toList) match {
- case (BOOLEAN, Nil) => BooleanType
- case (TINYINT | BYTE, Nil) => ByteType
- case (SMALLINT | SHORT, Nil) => ShortType
- case (INT | INTEGER, Nil) => IntegerType
- case (BIGINT | LONG, Nil) => LongType
- case (FLOAT | REAL, Nil) => FloatType
- case (DOUBLE, Nil) => DoubleType
- case (DATE, Nil) => DateType
- case (TIMESTAMP, Nil) => SQLConf.get.timestampType
- case (TIMESTAMP_NTZ, Nil) => TimestampNTZType
- case (TIMESTAMP_LTZ, Nil) => TimestampType
- case (STRING, Nil) => StringType
- case (CHARACTER | CHAR, length :: Nil) => CharType(length.getText.toInt)
- case (VARCHAR, length :: Nil) => VarcharType(length.getText.toInt)
- case (BINARY, Nil) => BinaryType
- case (DECIMAL | DEC | NUMERIC, Nil) => DecimalType.USER_DEFAULT
- case (DECIMAL | DEC | NUMERIC, precision :: Nil) =>
- DecimalType(precision.getText.toInt, 0)
- case (DECIMAL | DEC | NUMERIC, precision :: scale :: Nil) =>
- DecimalType(precision.getText.toInt, scale.getText.toInt)
- case (VOID, Nil) => NullType
- case (INTERVAL, Nil) => CalendarIntervalType
- case (CHARACTER | CHAR | VARCHAR, Nil) =>
- throw
QueryParsingErrors.charTypeMissingLengthError(ctx.`type`.getText, ctx)
- case (ARRAY | STRUCT | MAP, Nil) =>
- throw
QueryParsingErrors.nestedTypeMissingElementTypeError(ctx.`type`.getText, ctx)
- case (_, params) =>
- val badType = ctx.`type`.getText
- val dtStr = if (params.nonEmpty) s"$badType(${params.mkString(",")})"
else badType
- throw QueryParsingErrors.dataTypeUnsupportedError(dtStr, ctx)
- }
- }
-
- override def visitYearMonthIntervalDataType(ctx:
YearMonthIntervalDataTypeContext): DataType = {
- val startStr = ctx.from.getText.toLowerCase(Locale.ROOT)
- val start = YearMonthIntervalType.stringToField(startStr)
- if (ctx.to != null) {
- val endStr = ctx.to.getText.toLowerCase(Locale.ROOT)
- val end = YearMonthIntervalType.stringToField(endStr)
- if (end <= start) {
- throw QueryParsingErrors.fromToIntervalUnsupportedError(startStr,
endStr, ctx)
- }
- YearMonthIntervalType(start, end)
- } else {
- YearMonthIntervalType(start)
- }
- }
-
- override def visitDayTimeIntervalDataType(ctx:
DayTimeIntervalDataTypeContext): DataType = {
- val startStr = ctx.from.getText.toLowerCase(Locale.ROOT)
- val start = DayTimeIntervalType.stringToField(startStr)
- if (ctx.to != null ) {
- val endStr = ctx.to.getText.toLowerCase(Locale.ROOT)
- val end = DayTimeIntervalType.stringToField(endStr)
- if (end <= start) {
- throw QueryParsingErrors.fromToIntervalUnsupportedError(startStr,
endStr, ctx)
- }
- DayTimeIntervalType(start, end)
- } else {
- DayTimeIntervalType(start)
- }
- }
-
- /**
- * Create a complex DataType. Arrays, Maps and Structures are supported.
- */
- override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType =
withOrigin(ctx) {
- ctx.complex.getType match {
- case SqlBaseParser.ARRAY =>
- ArrayType(typedVisit(ctx.dataType(0)))
- case SqlBaseParser.MAP =>
- MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1)))
- case SqlBaseParser.STRUCT =>
-
StructType(Option(ctx.complexColTypeList).toArray.flatMap(visitComplexColTypeList))
- }
- }
-
- /**
- * Create top level table schema.
- */
- protected def createSchema(ctx: ColTypeListContext): StructType = {
- StructType(Option(ctx).toArray.flatMap(visitColTypeList))
- }
-
/**
* Create top level table schema.
*/
@@ -3122,32 +3028,6 @@ class AstBuilder extends
SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
StructType(Option(ctx).toArray.flatMap(visitCreateOrReplaceTableColTypeList))
}
- /**
- * Create a [[StructType]] from a number of column definitions.
- */
- override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] =
withOrigin(ctx) {
- ctx.colType().asScala.map(visitColType).toSeq
- }
-
- /**
- * Create a top level [[StructField]] from a column definition.
- */
- override def visitColType(ctx: ColTypeContext): StructField =
withOrigin(ctx) {
- import ctx._
-
- val builder = new MetadataBuilder
- // Add comment to metadata
- Option(commentSpec()).map(visitCommentSpec).foreach {
- builder.putString("comment", _)
- }
-
- StructField(
- name = colName.getText,
- dataType = typedVisit[DataType](ctx.dataType),
- nullable = NULL == null,
- metadata = builder.build())
- }
-
/**
* Create a [[StructType]] from a number of CREATE TABLE column definitions.
*/
@@ -3229,33 +3109,6 @@ class AstBuilder extends
SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
metadata = builder.build())
}
- /**
- * Create a [[StructType]] from a sequence of [[StructField]]s.
- */
- protected def createStructType(ctx: ComplexColTypeListContext): StructType =
{
- StructType(Option(ctx).toArray.flatMap(visitComplexColTypeList))
- }
-
- /**
- * Create a [[StructType]] from a number of column definitions.
- */
- override def visitComplexColTypeList(
- ctx: ComplexColTypeListContext): Seq[StructField] = withOrigin(ctx) {
- ctx.complexColType().asScala.map(visitComplexColType).toSeq
- }
-
- /**
- * Create a [[StructField]] from a column definition.
- */
- override def visitComplexColType(ctx: ComplexColTypeContext): StructField =
withOrigin(ctx) {
- import ctx._
- val structField = StructField(
- name = identifier.getText,
- dataType = typedVisit(dataType()),
- nullable = NULL == null)
-
Option(commentSpec).map(visitCommentSpec).map(structField.withComment).getOrElse(structField)
- }
-
/**
* Create a location string.
*/
@@ -3270,13 +3123,6 @@ class AstBuilder extends
SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
ctx.asScala.headOption.map(visitLocationSpec)
}
- /**
- * Create a comment string.
- */
- override def visitCommentSpec(ctx: CommentSpecContext): String =
withOrigin(ctx) {
- string(visitStringLit(ctx.stringLit))
- }
-
private def verifyAndGetExpression(exprCtx: ExpressionContext): String = {
// Make sure it can be converted to Catalyst expressions.
expression(exprCtx)
@@ -3418,18 +3264,6 @@ class AstBuilder extends
SqlBaseParserBaseVisitor[AnyRef] with SQLConfHelper wit
OptionList(options)
}
- override def visitStringLit(ctx: StringLitContext): Token = {
- if (ctx != null) {
- if (ctx.STRING_LITERAL != null) {
- ctx.STRING_LITERAL.getSymbol
- } else {
- ctx.DOUBLEQUOTED_STRING.getSymbol
- }
- } else {
- null
- }
- }
-
/**
* Type to keep track of a table header: (identifier, isTemporary,
ifNotExists, isExternal).
*/
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystSqlParser.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystSqlParser.scala
new file mode 100644
index 00000000000..8601111d9d4
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystSqlParser.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+/**
+ * Concrete SQL parser for Catalyst-only SQL statements.
+ */
+class CatalystSqlParser extends AbstractSqlParser {
+ override val astBuilder: AstBuilder = new AstBuilder
+}
+
+/** For test-only. */
+object CatalystSqlParser extends CatalystSqlParser
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
new file mode 100644
index 00000000000..84a8bc71b3f
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala
@@ -0,0 +1,208 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import java.util.Locale
+
+import scala.collection.JavaConverters._
+
+import org.antlr.v4.runtime.Token
+import org.antlr.v4.runtime.tree.ParseTree
+
+import org.apache.spark.sql.catalyst.SQLConfHelper
+import org.apache.spark.sql.catalyst.parser.ParserUtils.{string, withOrigin}
+import org.apache.spark.sql.catalyst.parser.SqlBaseParser._
+import org.apache.spark.sql.errors.QueryParsingErrors
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType,
ByteType, CalendarIntervalType, CharType, DataType, DateType,
DayTimeIntervalType, DecimalType, DoubleType, FloatType, IntegerType, LongType,
MapType, MetadataBuilder, NullType, ShortType, StringType, StructField,
StructType, TimestampNTZType, TimestampType, VarcharType, YearMonthIntervalType}
+
+class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] with
SQLConfHelper {
+ protected def typedVisit[T](ctx: ParseTree): T = {
+ ctx.accept(this).asInstanceOf[T]
+ }
+
+ override def visitSingleDataType(ctx: SingleDataTypeContext): DataType =
withOrigin(ctx) {
+ typedVisit[DataType](ctx.dataType)
+ }
+
+ override def visitSingleTableSchema(ctx: SingleTableSchemaContext):
StructType = {
+ withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList)))
+ }
+
+ override def visitStringLit(ctx: StringLitContext): Token = {
+ if (ctx != null) {
+ if (ctx.STRING_LITERAL != null) {
+ ctx.STRING_LITERAL.getSymbol
+ } else {
+ ctx.DOUBLEQUOTED_STRING.getSymbol
+ }
+ } else {
+ null
+ }
+ }
+
+ /**
+ * Resolve/create a primitive type.
+ */
+ override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType
= withOrigin(ctx) {
+ val typeName = ctx.`type`.start.getType
+ (typeName, ctx.INTEGER_VALUE().asScala.toList) match {
+ case (BOOLEAN, Nil) => BooleanType
+ case (TINYINT | BYTE, Nil) => ByteType
+ case (SMALLINT | SHORT, Nil) => ShortType
+ case (INT | INTEGER, Nil) => IntegerType
+ case (BIGINT | LONG, Nil) => LongType
+ case (FLOAT | REAL, Nil) => FloatType
+ case (DOUBLE, Nil) => DoubleType
+ case (DATE, Nil) => DateType
+ case (TIMESTAMP, Nil) => SQLConf.get.timestampType
+ case (TIMESTAMP_NTZ, Nil) => TimestampNTZType
+ case (TIMESTAMP_LTZ, Nil) => TimestampType
+ case (STRING, Nil) => StringType
+ case (CHARACTER | CHAR, length :: Nil) => CharType(length.getText.toInt)
+ case (VARCHAR, length :: Nil) => VarcharType(length.getText.toInt)
+ case (BINARY, Nil) => BinaryType
+ case (DECIMAL | DEC | NUMERIC, Nil) => DecimalType.USER_DEFAULT
+ case (DECIMAL | DEC | NUMERIC, precision :: Nil) =>
+ DecimalType(precision.getText.toInt, 0)
+ case (DECIMAL | DEC | NUMERIC, precision :: scale :: Nil) =>
+ DecimalType(precision.getText.toInt, scale.getText.toInt)
+ case (VOID, Nil) => NullType
+ case (INTERVAL, Nil) => CalendarIntervalType
+ case (CHARACTER | CHAR | VARCHAR, Nil) =>
+ throw
QueryParsingErrors.charTypeMissingLengthError(ctx.`type`.getText, ctx)
+ case (ARRAY | STRUCT | MAP, Nil) =>
+ throw
QueryParsingErrors.nestedTypeMissingElementTypeError(ctx.`type`.getText, ctx)
+ case (_, params) =>
+ val badType = ctx.`type`.getText
+ val dtStr = if (params.nonEmpty) s"$badType(${params.mkString(",")})"
else badType
+ throw QueryParsingErrors.dataTypeUnsupportedError(dtStr, ctx)
+ }
+ }
+
+ override def visitYearMonthIntervalDataType(ctx:
YearMonthIntervalDataTypeContext): DataType = {
+ val startStr = ctx.from.getText.toLowerCase(Locale.ROOT)
+ val start = YearMonthIntervalType.stringToField(startStr)
+ if (ctx.to != null) {
+ val endStr = ctx.to.getText.toLowerCase(Locale.ROOT)
+ val end = YearMonthIntervalType.stringToField(endStr)
+ if (end <= start) {
+ throw QueryParsingErrors.fromToIntervalUnsupportedError(startStr,
endStr, ctx)
+ }
+ YearMonthIntervalType(start, end)
+ } else {
+ YearMonthIntervalType(start)
+ }
+ }
+
+ override def visitDayTimeIntervalDataType(ctx:
DayTimeIntervalDataTypeContext): DataType = {
+ val startStr = ctx.from.getText.toLowerCase(Locale.ROOT)
+ val start = DayTimeIntervalType.stringToField(startStr)
+ if (ctx.to != null ) {
+ val endStr = ctx.to.getText.toLowerCase(Locale.ROOT)
+ val end = DayTimeIntervalType.stringToField(endStr)
+ if (end <= start) {
+ throw QueryParsingErrors.fromToIntervalUnsupportedError(startStr,
endStr, ctx)
+ }
+ DayTimeIntervalType(start, end)
+ } else {
+ DayTimeIntervalType(start)
+ }
+ }
+
+ /**
+ * Create a complex DataType. Arrays, Maps and Structures are supported.
+ */
+ override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType =
withOrigin(ctx) {
+ ctx.complex.getType match {
+ case SqlBaseParser.ARRAY =>
+ ArrayType(typedVisit(ctx.dataType(0)))
+ case SqlBaseParser.MAP =>
+ MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1)))
+ case SqlBaseParser.STRUCT =>
+
StructType(Option(ctx.complexColTypeList).toArray.flatMap(visitComplexColTypeList))
+ }
+ }
+
+ /**
+ * Create top level table schema.
+ */
+ protected def createSchema(ctx: ColTypeListContext): StructType = {
+ StructType(Option(ctx).toArray.flatMap(visitColTypeList))
+ }
+
+ /**
+ * Create a [[StructType]] from a number of column definitions.
+ */
+ override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] =
withOrigin(ctx) {
+ ctx.colType().asScala.map(visitColType).toSeq
+ }
+
+ /**
+ * Create a top level [[StructField]] from a column definition.
+ */
+ override def visitColType(ctx: ColTypeContext): StructField =
withOrigin(ctx) {
+ import ctx._
+
+ val builder = new MetadataBuilder
+ // Add comment to metadata
+ Option(commentSpec()).map(visitCommentSpec).foreach {
+ builder.putString("comment", _)
+ }
+
+ StructField(
+ name = colName.getText,
+ dataType = typedVisit[DataType](ctx.dataType),
+ nullable = NULL == null,
+ metadata = builder.build())
+ }
+
+ /**
+ * Create a [[StructType]] from a sequence of [[StructField]]s.
+ */
+ protected def createStructType(ctx: ComplexColTypeListContext): StructType =
{
+ StructType(Option(ctx).toArray.flatMap(visitComplexColTypeList))
+ }
+
+ /**
+ * Create a [[StructType]] from a number of column definitions.
+ */
+ override def visitComplexColTypeList(ctx: ComplexColTypeListContext):
Seq[StructField] = {
+ withOrigin(ctx) {
+ ctx.complexColType().asScala.map(visitComplexColType).toSeq
+ }
+ }
+
+ /**
+ * Create a [[StructField]] from a column definition.
+ */
+ override def visitComplexColType(ctx: ComplexColTypeContext): StructField =
withOrigin(ctx) {
+ import ctx._
+ val structField = StructField(
+ name = identifier.getText,
+ dataType = typedVisit(dataType()),
+ nullable = NULL == null)
+
Option(commentSpec).map(visitCommentSpec).map(structField.withComment).getOrElse(structField)
+ }
+
+ /**
+ * Create a comment string.
+ */
+ override def visitCommentSpec(ctx: CommentSpecContext): String =
withOrigin(ctx) {
+ string(visitStringLit(ctx.stringLit))
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala
new file mode 100644
index 00000000000..ab665f360b0
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserInterface.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.catalyst.parser
+
+import org.apache.spark.sql.types.{DataType, StructType}
+
+/**
+ * Interface for [[DataType]] parsing functionality.
+ */
+trait DataTypeParserInterface {
+ /**
+ * Parse a string to a [[StructType]]. The passed SQL string should be a
comma separated list
+ * of field definitions which will preserve the correct Hive metadata.
+ */
+ @throws[ParseException]("Text cannot be parsed to a schema")
+ def parseTableSchema(sqlText: String): StructType
+
+ /**
+ * Parse a string to a [[DataType]].
+ */
+ @throws[ParseException]("Text cannot be parsed to a DataType")
+ def parseDataType(sqlText: String): DataType
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
index 46dfbf24778..3aec1dd4311 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala
@@ -21,13 +21,12 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types.{DataType, StructType}
/**
* Interface for a parser.
*/
@DeveloperApi
-trait ParserInterface {
+trait ParserInterface extends DataTypeParserInterface {
/**
* Parse a string to a [[LogicalPlan]].
*/
@@ -58,19 +57,6 @@ trait ParserInterface {
@throws[ParseException]("Text cannot be parsed to a multi-part identifier")
def parseMultipartIdentifier(sqlText: String): Seq[String]
- /**
- * Parse a string to a [[StructType]]. The passed SQL string should be a
comma separated list
- * of field definitions which will preserve the correct Hive metadata.
- */
- @throws[ParseException]("Text cannot be parsed to a schema")
- def parseTableSchema(sqlText: String): StructType
-
- /**
- * Parse a string to a [[DataType]].
- */
- @throws[ParseException]("Text cannot be parsed to a DataType")
- def parseDataType(sqlText: String): DataType
-
/**
* Parse a query string to a [[LogicalPlan]].
*/
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
similarity index 85%
rename from
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
rename to
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
index 1fb23c4a71e..27670544e1a 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/parsers.scala
@@ -24,10 +24,7 @@ import org.antlr.v4.runtime.tree.TerminalNodeImpl
import org.apache.spark.{QueryContext, SparkThrowableHelper}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.{FunctionIdentifier, SQLConfHelper,
TableIdentifier}
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin}
import org.apache.spark.sql.errors.QueryParsingErrors
import org.apache.spark.sql.types.{DataType, StructType}
@@ -35,40 +32,12 @@ import org.apache.spark.sql.types.{DataType, StructType}
/**
* Base SQL parsing infrastructure.
*/
-abstract class AbstractSqlParser extends ParserInterface with SQLConfHelper
with Logging {
-
+abstract class AbstractParser extends DataTypeParserInterface with
SQLConfHelper with Logging {
/** Creates/Resolves DataType for a given SQL string. */
override def parseDataType(sqlText: String): DataType = parse(sqlText) {
parser =>
astBuilder.visitSingleDataType(parser.singleDataType())
}
- /** Creates Expression for a given SQL string. */
- override def parseExpression(sqlText: String): Expression = parse(sqlText) {
parser =>
- val ctx = parser.singleExpression()
- withOrigin(ctx, Some(sqlText)) {
- astBuilder.visitSingleExpression(ctx)
- }
- }
-
- /** Creates TableIdentifier for a given SQL string. */
- override def parseTableIdentifier(sqlText: String): TableIdentifier =
parse(sqlText) { parser =>
- astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier())
- }
-
- /** Creates FunctionIdentifier for a given SQL string. */
- override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = {
- parse(sqlText) { parser =>
-
astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier())
- }
- }
-
- /** Creates a multi-part identifier for a given SQL string */
- override def parseMultipartIdentifier(sqlText: String): Seq[String] = {
- parse(sqlText) { parser =>
-
astBuilder.visitSingleMultipartIdentifier(parser.singleMultipartIdentifier())
- }
- }
-
/**
* Creates StructType for a given SQL string, which is a comma separated
list of field
* definitions which will preserve the correct Hive metadata.
@@ -77,29 +46,8 @@ abstract class AbstractSqlParser extends ParserInterface
with SQLConfHelper with
astBuilder.visitSingleTableSchema(parser.singleTableSchema())
}
- /** Creates LogicalPlan for a given SQL string of query. */
- override def parseQuery(sqlText: String): LogicalPlan = parse(sqlText) {
parser =>
- val ctx = parser.query()
- withOrigin(ctx, Some(sqlText)) {
- astBuilder.visitQuery(ctx)
- }
- }
-
- /** Creates LogicalPlan for a given SQL string. */
- override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) {
parser =>
- val ctx = parser.singleStatement()
- withOrigin(ctx, Some(sqlText)) {
- astBuilder.visitSingleStatement(ctx) match {
- case plan: LogicalPlan => plan
- case _ =>
- val position = Origin(None, None)
- throw QueryParsingErrors.sqlStatementUnsupportedError(sqlText,
position)
- }
- }
- }
-
/** Get the builder (visitor) which converts a ParseTree into an AST. */
- protected def astBuilder: AstBuilder
+ protected def astBuilder: DataTypeAstBuilder
protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
logDebug(s"Parsing command: $command")
@@ -153,16 +101,6 @@ abstract class AbstractSqlParser extends ParserInterface
with SQLConfHelper with
}
}
-/**
- * Concrete SQL parser for Catalyst-only SQL statements.
- */
-class CatalystSqlParser extends AbstractSqlParser {
- val astBuilder = new AstBuilder
-}
-
-/** For test-only. */
-object CatalystSqlParser extends CatalystSqlParser
-
/**
* This string stream provides the lexer with upper case characters only. This
greatly simplifies
* lexing the stream, while we can maintain the original command.
@@ -438,5 +376,8 @@ case class UnclosedCommentProcessor(
stop = Origin(Option(failedToken.getStopIndex)))
}
}
+}
+object DataTypeParser extends AbstractParser {
+ override protected def astBuilder: DataTypeAstBuilder = new
DataTypeAstBuilder
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
index 4b701dc2438..29201b053cf 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala
@@ -30,7 +30,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkThrowable
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.analysis.Resolver
-import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
+import org.apache.spark.sql.catalyst.parser.DataTypeParser
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import
org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer,
DataTypeJsonSerializer}
import org.apache.spark.sql.catalyst.util.StringConcat
@@ -111,8 +111,8 @@ object DataType {
def fromDDL(ddl: String): DataType = {
parseTypeWithFallback(
ddl,
- CatalystSqlParser.parseDataType,
- fallbackParser = str => CatalystSqlParser.parseTableSchema(str))
+ DataTypeParser.parseDataType,
+ fallbackParser = str => DataTypeParser.parseTableSchema(str))
}
/**
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 5857aaa9530..013d416bc97 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -26,7 +26,7 @@ import org.json4s.JsonDSL._
import org.apache.spark.annotation.Stable
import org.apache.spark.sql.catalyst.analysis.Resolver
import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference}
-import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser,
LegacyTypeStringParser}
+import org.apache.spark.sql.catalyst.parser.{DataTypeParser,
LegacyTypeStringParser}
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{SparkStringUtils, StringConcat}
@@ -219,7 +219,7 @@ case class StructType(fields: Array[StructField]) extends
DataType with Seq[Stru
* }}}
*/
def add(name: String, dataType: String): StructType = {
- add(name, CatalystSqlParser.parseDataType(dataType), nullable = true,
Metadata.empty)
+ add(name, DataTypeParser.parseDataType(dataType), nullable = true,
Metadata.empty)
}
/**
@@ -234,7 +234,7 @@ case class StructType(fields: Array[StructField]) extends
DataType with Seq[Stru
* }}}
*/
def add(name: String, dataType: String, nullable: Boolean): StructType = {
- add(name, CatalystSqlParser.parseDataType(dataType), nullable,
Metadata.empty)
+ add(name, DataTypeParser.parseDataType(dataType), nullable, Metadata.empty)
}
/**
@@ -252,7 +252,7 @@ case class StructType(fields: Array[StructField]) extends
DataType with Seq[Stru
dataType: String,
nullable: Boolean,
metadata: Metadata): StructType = {
- add(name, CatalystSqlParser.parseDataType(dataType), nullable, metadata)
+ add(name, DataTypeParser.parseDataType(dataType), nullable, metadata)
}
/**
@@ -270,7 +270,7 @@ case class StructType(fields: Array[StructField]) extends
DataType with Seq[Stru
dataType: String,
nullable: Boolean,
comment: String): StructType = {
- add(name, CatalystSqlParser.parseDataType(dataType), nullable, comment)
+ add(name, DataTypeParser.parseDataType(dataType), nullable, comment)
}
/**
@@ -534,7 +534,7 @@ object StructType extends AbstractDataType {
*
* @since 2.2.0
*/
- def fromDDL(ddl: String): StructType =
CatalystSqlParser.parseTableSchema(ddl)
+ def fromDDL(ddl: String): StructType = DataTypeParser.parseTableSchema(ddl)
def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray)
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
index 562502ade43..f11e920e4c0 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.types._
class DataTypeParserSuite extends SparkFunSuite with SQLHelper {
- def parse(sql: String): DataType = CatalystSqlParser.parseDataType(sql)
+ def parse(sql: String): DataType = DataTypeParser.parseDataType(sql)
def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit
= {
test(s"parse ${dataTypeString.replace("\n", "")}") {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]