Repository: spark Updated Branches: refs/heads/master 5585c5765 -> 776f299fc
[SPARK-24709][SQL] schema_of_json() - schema inference from an example ## What changes were proposed in this pull request? In the PR, I propose to add new function - *schema_of_json()* which infers schema of JSON string literal. The result of the function is a string containing a schema in DDL format. One of the use cases is using of *schema_of_json()* in the combination with *from_json()*. Currently, _from_json()_ requires a schema as a mandatory argument. The *schema_of_json()* function will allow to point out an JSON string as an example which has the same schema as the first argument of _from_json()_. For instance: ```sql select from_json(json_column, schema_of_json('{"c1": [0], "c2": [{"c3":0}]}')) from json_table; ``` ## How was this patch tested? Added new test to `JsonFunctionsSuite`, `JsonExpressionsSuite` and SQL tests to `json-functions.sql` Author: Maxim Gekk <maxim.g...@databricks.com> Closes #21686 from MaxGekk/infer_schema_json. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/776f299f Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/776f299f Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/776f299f Branch: refs/heads/master Commit: 776f299fc8146b400e97185b1577b0fc8f06e14b Parents: 5585c57 Author: Maxim Gekk <maxim.g...@databricks.com> Authored: Wed Jul 4 09:38:18 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Wed Jul 4 09:38:18 2018 +0800 ---------------------------------------------------------------------- python/pyspark/sql/functions.py | 27 ++ .../catalyst/analysis/FunctionRegistry.scala | 1 + .../catalyst/expressions/jsonExpressions.scala | 52 ++- .../sql/catalyst/json/JsonInferSchema.scala | 348 ++++++++++++++++++ .../expressions/JsonExpressionsSuite.scala | 7 + .../datasources/json/JsonDataSource.scala | 2 +- .../datasources/json/JsonInferSchema.scala | 349 ------------------- .../scala/org/apache/spark/sql/functions.scala | 42 +++ .../sql-tests/inputs/json-functions.sql | 4 + .../sql-tests/results/json-functions.sql.out | 20 +- .../apache/spark/sql/JsonFunctionsSuite.scala | 17 +- .../execution/datasources/json/JsonSuite.scala | 4 +- 12 files changed, 509 insertions(+), 364 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/python/pyspark/sql/functions.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 9652d3e..4d37197 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2189,11 +2189,16 @@ def from_json(col, schema, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=[Row(a=1)])] + >>> schema = schema_of_json(lit('''{"a": 0}''')) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=Row(a=1))] """ sc = SparkContext._active_spark_context if isinstance(schema, DataType): schema = schema.json() + elif isinstance(schema, Column): + schema = _to_java_column(schema) jc = sc._jvm.functions.from_json(_to_java_column(col), schema, options) return Column(jc) @@ -2235,6 +2240,28 @@ def to_json(col, options={}): return Column(jc) +@ignore_unicode_prefix +@since(2.4) +def schema_of_json(col): + """ + Parses a column containing a JSON string and infers its schema in DDL format. + + :param col: string column in json format + + >>> from pyspark.sql.types import * + >>> data = [(1, '{"a": 1}')] + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(schema_of_json(df.value).alias("json")).collect() + [Row(json=u'struct<a:bigint>')] + >>> df.select(schema_of_json(lit('{"a": 0}')).alias("json")).collect() + [Row(json=u'struct<a:bigint>')] + """ + + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.schema_of_json(_to_java_column(col)) + return Column(jc) + + @since(1.5) def size(col): """ http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a574d8a..80a0af6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -505,6 +505,7 @@ object FunctionRegistry { // json expression[StructsToJson]("to_json"), expression[JsonToStructs]("from_json"), + expression[SchemaOfJson]("schema_of_json"), // cast expression[Cast]("cast"), http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index f6d74f5..8cd8605 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter} +import java.io._ import scala.util.parsing.combinator.RegexParsers @@ -28,7 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, BadRecordException, FailFastMode, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.inferField +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -525,17 +526,19 @@ case class JsonToStructs( override def nullable: Boolean = true // Used in `FunctionRegistry` - def this(child: Expression, schema: Expression) = + def this(child: Expression, schema: Expression, options: Map[String, String]) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), - options = Map.empty[String, String], + schema = JsonExprUtils.evalSchemaExpr(schema), + options = options, child = child, timeZoneId = None, forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)) + def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String]) + def this(child: Expression, schema: Expression, options: Expression) = this( - schema = JsonExprUtils.validateSchemaLiteral(schema), + schema = JsonExprUtils.evalSchemaExpr(schema), options = JsonExprUtils.convertToMapData(options), child = child, timeZoneId = None, @@ -744,11 +747,44 @@ case class StructsToJson( override def inputTypes: Seq[AbstractDataType] = TypeCollection(ArrayType, StructType) :: Nil } +/** + * A function infers schema of JSON string. + */ +@ExpressionDescription( + usage = "_FUNC_(json[, options]) - Returns schema in the DDL format of JSON string.", + examples = """ + Examples: + > SELECT _FUNC_('[{"col":0}]'); + array<struct<col:int>> + """, + since = "2.4.0") +case class SchemaOfJson(child: Expression) + extends UnaryExpression with String2StringExpression with CodegenFallback { + + private val jsonOptions = new JSONOptions(Map.empty, "UTC") + private val jsonFactory = new JsonFactory() + jsonOptions.setJacksonOptions(jsonFactory) + + override def convert(v: UTF8String): UTF8String = { + val dt = Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, v)) { parser => + parser.nextToken() + inferField(parser, jsonOptions) + } + + UTF8String.fromString(dt.catalogString) + } +} + object JsonExprUtils { - def validateSchemaLiteral(exp: Expression): DataType = exp match { + def evalSchemaExpr(exp: Expression): DataType = exp match { case Literal(s, StringType) => DataType.fromDDL(s.toString) - case e => throw new AnalysisException(s"Expected a string literal instead of $e") + case e @ SchemaOfJson(_: Literal) => + val ddlSchema = e.eval().asInstanceOf[UTF8String] + DataType.fromDDL(ddlSchema.toString) + case e => throw new AnalysisException( + "Schema should be specified in DDL format as a string literal" + + s" or output of the schema_of_json function instead of ${e.sql}") } def convertToMapData(exp: Expression): Map[String, String] = exp match { http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala new file mode 100644 index 0000000..491ca00 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.json + +import java.util.Comparator + +import com.fasterxml.jackson.core._ + +import org.apache.spark.SparkException +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.TypeCoercion +import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil +import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +private[sql] object JsonInferSchema { + + /** + * Infer the type of a collection of json records in three stages: + * 1. Infer the type of each record + * 2. Merge types by choosing the lowest type necessary to cover equal keys + * 3. Replace any remaining null fields with string, the top type + */ + def infer[T]( + json: RDD[T], + configOptions: JSONOptions, + createParser: (JsonFactory, T) => JsonParser): StructType = { + val parseMode = configOptions.parseMode + val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord + + // In each RDD partition, perform schema inference on each row and merge afterwards. + val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode) + val mergedTypesFromPartitions = json.mapPartitions { iter => + val factory = new JsonFactory() + configOptions.setJacksonOptions(factory) + iter.flatMap { row => + try { + Utils.tryWithResource(createParser(factory, row)) { parser => + parser.nextToken() + Some(inferField(parser, configOptions)) + } + } catch { + case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { + case PermissiveMode => + Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType)))) + case DropMalformedMode => + None + case FailFastMode => + throw new SparkException("Malformed records are detected in schema inference. " + + s"Parse Mode: ${FailFastMode.name}.", e) + } + } + }.reduceOption(typeMerger).toIterator + } + + // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because + // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have + // active SparkSession and `SQLConf.get` may point to the wrong configs. + val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) + + canonicalizeType(rootType, configOptions) match { + case Some(st: StructType) => st + case _ => + // canonicalizeType erases all empty structs, including the only one we want to keep + StructType(Nil) + } + } + + private[this] val structFieldComparator = new Comparator[StructField] { + override def compare(o1: StructField, o2: StructField): Int = { + o1.name.compareTo(o2.name) + } + } + + private def isSorted(arr: Array[StructField]): Boolean = { + var i: Int = 0 + while (i < arr.length - 1) { + if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) { + return false + } + i += 1 + } + true + } + + /** + * Infer the type of a json document from the parser's token stream + */ + def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { + import com.fasterxml.jackson.core.JsonToken._ + parser.getCurrentToken match { + case null | VALUE_NULL => NullType + + case FIELD_NAME => + parser.nextToken() + inferField(parser, configOptions) + + case VALUE_STRING if parser.getTextLength < 1 => + // Zero length strings and nulls have special handling to deal + // with JSON generators that do not distinguish between the two. + // To accurately infer types for empty strings that are really + // meant to represent nulls we assume that the two are isomorphic + // but will defer treating null fields as strings until all the + // record fields' types have been combined. + NullType + + case VALUE_STRING => StringType + case START_OBJECT => + val builder = Array.newBuilder[StructField] + while (nextUntil(parser, END_OBJECT)) { + builder += StructField( + parser.getCurrentName, + inferField(parser, configOptions), + nullable = true) + } + val fields: Array[StructField] = builder.result() + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(fields, structFieldComparator) + StructType(fields) + + case START_ARRAY => + // If this JSON array is empty, we use NullType as a placeholder. + // If this array is not empty in other JSON objects, we can resolve + // the type as we pass through all JSON objects. + var elementType: DataType = NullType + while (nextUntil(parser, END_ARRAY)) { + elementType = compatibleType( + elementType, inferField(parser, configOptions)) + } + + ArrayType(elementType) + + case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType + + case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType + + case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => + import JsonParser.NumberType._ + parser.getNumberType match { + // For Integer values, use LongType by default. + case INT | LONG => LongType + // Since we do not have a data type backed by BigInteger, + // when we see a Java BigInteger, we use DecimalType. + case BIG_INTEGER | BIG_DECIMAL => + val v = parser.getDecimalValue + if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) { + DecimalType(Math.max(v.precision(), v.scale()), v.scale()) + } else { + DoubleType + } + case FLOAT | DOUBLE if configOptions.prefersDecimal => + val v = parser.getDecimalValue + if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) { + DecimalType(Math.max(v.precision(), v.scale()), v.scale()) + } else { + DoubleType + } + case FLOAT | DOUBLE => + DoubleType + } + + case VALUE_TRUE | VALUE_FALSE => BooleanType + } + } + + /** + * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields, + * drops NullTypes or converts them to StringType based on provided options. + */ + private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match { + case at: ArrayType => + canonicalizeType(at.elementType, options) + .map(t => at.copy(elementType = t)) + + case StructType(fields) => + val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f => + canonicalizeType(f.dataType, options) + .map(t => f.copy(dataType = t)) + } + // SPARK-8093: empty structs should be deleted + if (canonicalFields.isEmpty) { + None + } else { + Some(StructType(canonicalFields)) + } + + case NullType => + if (options.dropFieldIfAllNull) { + None + } else { + Some(StringType) + } + + case other => Some(other) + } + + private def withCorruptField( + struct: StructType, + other: DataType, + columnNameOfCorruptRecords: String, + parseMode: ParseMode) = parseMode match { + case PermissiveMode => + // If we see any other data type at the root level, we get records that cannot be + // parsed. So, we use the struct as the data type and add the corrupt field to the schema. + if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { + // If this given struct does not have a column used for corrupt records, + // add this field. + val newFields: Array[StructField] = + StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields + // Note: other code relies on this sorting for correctness, so don't remove it! + java.util.Arrays.sort(newFields, structFieldComparator) + StructType(newFields) + } else { + // Otherwise, just return this struct. + struct + } + + case DropMalformedMode => + // If corrupt record handling is disabled we retain the valid schema and discard the other. + struct + + case FailFastMode => + // If `other` is not struct type, consider it as malformed one and throws an exception. + throw new SparkException("Malformed records are detected in schema inference. " + + s"Parse Mode: ${FailFastMode.name}. Reasons: Failed to infer a common schema. " + + s"Struct types are expected, but `${other.catalogString}` was found.") + } + + /** + * Remove top-level ArrayType wrappers and merge the remaining schemas + */ + private def compatibleRootType( + columnNameOfCorruptRecords: String, + parseMode: ParseMode): (DataType, DataType) => DataType = { + // Since we support array of json objects at the top level, + // we need to check the element type and find the root level data type. + case (ArrayType(ty1, _), ty2) => + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + case (ty1, ArrayType(ty2, _)) => + compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) + // Discard null/empty documents + case (struct: StructType, NullType) => struct + case (NullType, struct: StructType) => struct + case (struct: StructType, o) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) + case (o, struct: StructType) if !o.isInstanceOf[StructType] => + withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) + // If we get anything else, we call compatibleType. + // Usually, when we reach here, ty1 and ty2 are two StructTypes. + case (ty1, ty2) => compatibleType(ty1, ty2) + } + + private[this] val emptyStructFieldArray = Array.empty[StructField] + + /** + * Returns the most general data type for two given data types. + */ + def compatibleType(t1: DataType, t2: DataType): DataType = { + TypeCoercion.findTightestCommonType(t1, t2).getOrElse { + // t1 or t2 is a StructType, ArrayType, or an unexpected type. + (t1, t2) match { + // Double support larger range than fixed decimal, DecimalType.Maximum should be enough + // in most case, also have better precision. + case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => + DoubleType + + case (t1: DecimalType, t2: DecimalType) => + val scale = math.max(t1.scale, t2.scale) + val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) + if (range + scale > 38) { + // DecimalType can't support precision > 38 + DoubleType + } else { + DecimalType(range + scale, scale) + } + + case (StructType(fields1), StructType(fields2)) => + // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. + // Therefore, we can take advantage of the fact that we're merging sorted lists and skip + // building a hash map or performing additional sorting. + assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}") + assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}") + + val newFields = new java.util.ArrayList[StructField]() + + var f1Idx = 0 + var f2Idx = 0 + + while (f1Idx < fields1.length && f2Idx < fields2.length) { + val f1Name = fields1(f1Idx).name + val f2Name = fields2(f2Idx).name + val comp = f1Name.compareTo(f2Name) + if (comp == 0) { + val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) + newFields.add(StructField(f1Name, dataType, nullable = true)) + f1Idx += 1 + f2Idx += 1 + } else if (comp < 0) { // f1Name < f2Name + newFields.add(fields1(f1Idx)) + f1Idx += 1 + } else { // f1Name > f2Name + newFields.add(fields2(f2Idx)) + f2Idx += 1 + } + } + while (f1Idx < fields1.length) { + newFields.add(fields1(f1Idx)) + f1Idx += 1 + } + while (f2Idx < fields2.length) { + newFields.add(fields2(f2Idx)) + f2Idx += 1 + } + StructType(newFields.toArray(emptyStructFieldArray)) + + case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => + ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) + + // The case that given `DecimalType` is capable of given `IntegralType` is handled in + // `findTightestCommonType`. Both cases below will be executed only when the given + // `DecimalType` is not capable of the given `IntegralType`. + case (t1: IntegralType, t2: DecimalType) => + compatibleType(DecimalType.forType(t1), t2) + case (t1: DecimalType, t2: IntegralType) => + compatibleType(t1, DecimalType.forType(t2)) + + // strings and every string is a Json object. + case (_, _) => StringType + } + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 00e9763..52203b9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -706,4 +706,11 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with assert(schemaToCompare == schema) } } + + test("SPARK-24709: infer schema of json strings") { + checkEvaluation(SchemaOfJson(Literal.create("""{"col":0}""")), "struct<col:bigint>") + checkEvaluation( + SchemaOfJson(Literal.create("""{"col0":["a"], "col1": {"col2": "b"}}""")), + "struct<col0:array<string>,col1:struct<col2:string>>") + } } http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala index 3b6df45..2fee212 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonDataSource.scala @@ -33,7 +33,7 @@ import org.apache.spark.input.{PortableDataStream, StreamInputFormat} import org.apache.spark.rdd.{BinaryFileRDD, RDD} import org.apache.spark.sql.{Dataset, Encoders, SparkSession} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.text.TextFileFormat http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala deleted file mode 100644 index 8e1b430..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonInferSchema.scala +++ /dev/null @@ -1,349 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.json - -import java.util.Comparator - -import com.fasterxml.jackson.core._ - -import org.apache.spark.SparkException -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.TypeCoercion -import org.apache.spark.sql.catalyst.json.JacksonUtils.nextUntil -import org.apache.spark.sql.catalyst.json.JSONOptions -import org.apache.spark.sql.catalyst.util.{DropMalformedMode, FailFastMode, ParseMode, PermissiveMode} -import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils - -private[sql] object JsonInferSchema { - - /** - * Infer the type of a collection of json records in three stages: - * 1. Infer the type of each record - * 2. Merge types by choosing the lowest type necessary to cover equal keys - * 3. Replace any remaining null fields with string, the top type - */ - def infer[T]( - json: RDD[T], - configOptions: JSONOptions, - createParser: (JsonFactory, T) => JsonParser): StructType = { - val parseMode = configOptions.parseMode - val columnNameOfCorruptRecord = configOptions.columnNameOfCorruptRecord - - // In each RDD partition, perform schema inference on each row and merge afterwards. - val typeMerger = compatibleRootType(columnNameOfCorruptRecord, parseMode) - val mergedTypesFromPartitions = json.mapPartitions { iter => - val factory = new JsonFactory() - configOptions.setJacksonOptions(factory) - iter.flatMap { row => - try { - Utils.tryWithResource(createParser(factory, row)) { parser => - parser.nextToken() - Some(inferField(parser, configOptions)) - } - } catch { - case e @ (_: RuntimeException | _: JsonProcessingException) => parseMode match { - case PermissiveMode => - Some(StructType(Seq(StructField(columnNameOfCorruptRecord, StringType)))) - case DropMalformedMode => - None - case FailFastMode => - throw new SparkException("Malformed records are detected in schema inference. " + - s"Parse Mode: ${FailFastMode.name}.", e) - } - } - }.reduceOption(typeMerger).toIterator - } - - // Here we get RDD local iterator then fold, instead of calling `RDD.fold` directly, because - // `RDD.fold` will run the fold function in DAGScheduler event loop thread, which may not have - // active SparkSession and `SQLConf.get` may point to the wrong configs. - val rootType = mergedTypesFromPartitions.toLocalIterator.fold(StructType(Nil))(typeMerger) - - canonicalizeType(rootType, configOptions) match { - case Some(st: StructType) => st - case _ => - // canonicalizeType erases all empty structs, including the only one we want to keep - StructType(Nil) - } - } - - private[this] val structFieldComparator = new Comparator[StructField] { - override def compare(o1: StructField, o2: StructField): Int = { - o1.name.compareTo(o2.name) - } - } - - private def isSorted(arr: Array[StructField]): Boolean = { - var i: Int = 0 - while (i < arr.length - 1) { - if (structFieldComparator.compare(arr(i), arr(i + 1)) > 0) { - return false - } - i += 1 - } - true - } - - /** - * Infer the type of a json document from the parser's token stream - */ - private def inferField(parser: JsonParser, configOptions: JSONOptions): DataType = { - import com.fasterxml.jackson.core.JsonToken._ - parser.getCurrentToken match { - case null | VALUE_NULL => NullType - - case FIELD_NAME => - parser.nextToken() - inferField(parser, configOptions) - - case VALUE_STRING if parser.getTextLength < 1 => - // Zero length strings and nulls have special handling to deal - // with JSON generators that do not distinguish between the two. - // To accurately infer types for empty strings that are really - // meant to represent nulls we assume that the two are isomorphic - // but will defer treating null fields as strings until all the - // record fields' types have been combined. - NullType - - case VALUE_STRING => StringType - case START_OBJECT => - val builder = Array.newBuilder[StructField] - while (nextUntil(parser, END_OBJECT)) { - builder += StructField( - parser.getCurrentName, - inferField(parser, configOptions), - nullable = true) - } - val fields: Array[StructField] = builder.result() - // Note: other code relies on this sorting for correctness, so don't remove it! - java.util.Arrays.sort(fields, structFieldComparator) - StructType(fields) - - case START_ARRAY => - // If this JSON array is empty, we use NullType as a placeholder. - // If this array is not empty in other JSON objects, we can resolve - // the type as we pass through all JSON objects. - var elementType: DataType = NullType - while (nextUntil(parser, END_ARRAY)) { - elementType = compatibleType( - elementType, inferField(parser, configOptions)) - } - - ArrayType(elementType) - - case (VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT) if configOptions.primitivesAsString => StringType - - case (VALUE_TRUE | VALUE_FALSE) if configOptions.primitivesAsString => StringType - - case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT => - import JsonParser.NumberType._ - parser.getNumberType match { - // For Integer values, use LongType by default. - case INT | LONG => LongType - // Since we do not have a data type backed by BigInteger, - // when we see a Java BigInteger, we use DecimalType. - case BIG_INTEGER | BIG_DECIMAL => - val v = parser.getDecimalValue - if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) { - DecimalType(Math.max(v.precision(), v.scale()), v.scale()) - } else { - DoubleType - } - case FLOAT | DOUBLE if configOptions.prefersDecimal => - val v = parser.getDecimalValue - if (Math.max(v.precision(), v.scale()) <= DecimalType.MAX_PRECISION) { - DecimalType(Math.max(v.precision(), v.scale()), v.scale()) - } else { - DoubleType - } - case FLOAT | DOUBLE => - DoubleType - } - - case VALUE_TRUE | VALUE_FALSE => BooleanType - } - } - - /** - * Recursively canonicalizes inferred types, e.g., removes StructTypes with no fields, - * drops NullTypes or converts them to StringType based on provided options. - */ - private def canonicalizeType(tpe: DataType, options: JSONOptions): Option[DataType] = tpe match { - case at: ArrayType => - canonicalizeType(at.elementType, options) - .map(t => at.copy(elementType = t)) - - case StructType(fields) => - val canonicalFields = fields.filter(_.name.nonEmpty).flatMap { f => - canonicalizeType(f.dataType, options) - .map(t => f.copy(dataType = t)) - } - // SPARK-8093: empty structs should be deleted - if (canonicalFields.isEmpty) { - None - } else { - Some(StructType(canonicalFields)) - } - - case NullType => - if (options.dropFieldIfAllNull) { - None - } else { - Some(StringType) - } - - case other => Some(other) - } - - private def withCorruptField( - struct: StructType, - other: DataType, - columnNameOfCorruptRecords: String, - parseMode: ParseMode) = parseMode match { - case PermissiveMode => - // If we see any other data type at the root level, we get records that cannot be - // parsed. So, we use the struct as the data type and add the corrupt field to the schema. - if (!struct.fieldNames.contains(columnNameOfCorruptRecords)) { - // If this given struct does not have a column used for corrupt records, - // add this field. - val newFields: Array[StructField] = - StructField(columnNameOfCorruptRecords, StringType, nullable = true) +: struct.fields - // Note: other code relies on this sorting for correctness, so don't remove it! - java.util.Arrays.sort(newFields, structFieldComparator) - StructType(newFields) - } else { - // Otherwise, just return this struct. - struct - } - - case DropMalformedMode => - // If corrupt record handling is disabled we retain the valid schema and discard the other. - struct - - case FailFastMode => - // If `other` is not struct type, consider it as malformed one and throws an exception. - throw new SparkException("Malformed records are detected in schema inference. " + - s"Parse Mode: ${FailFastMode.name}. Reasons: Failed to infer a common schema. " + - s"Struct types are expected, but `${other.catalogString}` was found.") - } - - /** - * Remove top-level ArrayType wrappers and merge the remaining schemas - */ - private def compatibleRootType( - columnNameOfCorruptRecords: String, - parseMode: ParseMode): (DataType, DataType) => DataType = { - // Since we support array of json objects at the top level, - // we need to check the element type and find the root level data type. - case (ArrayType(ty1, _), ty2) => - compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) - case (ty1, ArrayType(ty2, _)) => - compatibleRootType(columnNameOfCorruptRecords, parseMode)(ty1, ty2) - // Discard null/empty documents - case (struct: StructType, NullType) => struct - case (NullType, struct: StructType) => struct - case (struct: StructType, o) if !o.isInstanceOf[StructType] => - withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) - case (o, struct: StructType) if !o.isInstanceOf[StructType] => - withCorruptField(struct, o, columnNameOfCorruptRecords, parseMode) - // If we get anything else, we call compatibleType. - // Usually, when we reach here, ty1 and ty2 are two StructTypes. - case (ty1, ty2) => compatibleType(ty1, ty2) - } - - private[this] val emptyStructFieldArray = Array.empty[StructField] - - /** - * Returns the most general data type for two given data types. - */ - def compatibleType(t1: DataType, t2: DataType): DataType = { - TypeCoercion.findTightestCommonType(t1, t2).getOrElse { - // t1 or t2 is a StructType, ArrayType, or an unexpected type. - (t1, t2) match { - // Double support larger range than fixed decimal, DecimalType.Maximum should be enough - // in most case, also have better precision. - case (DoubleType, _: DecimalType) | (_: DecimalType, DoubleType) => - DoubleType - - case (t1: DecimalType, t2: DecimalType) => - val scale = math.max(t1.scale, t2.scale) - val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale) - if (range + scale > 38) { - // DecimalType can't support precision > 38 - DoubleType - } else { - DecimalType(range + scale, scale) - } - - case (StructType(fields1), StructType(fields2)) => - // Both fields1 and fields2 should be sorted by name, since inferField performs sorting. - // Therefore, we can take advantage of the fact that we're merging sorted lists and skip - // building a hash map or performing additional sorting. - assert(isSorted(fields1), s"StructType's fields were not sorted: ${fields1.toSeq}") - assert(isSorted(fields2), s"StructType's fields were not sorted: ${fields2.toSeq}") - - val newFields = new java.util.ArrayList[StructField]() - - var f1Idx = 0 - var f2Idx = 0 - - while (f1Idx < fields1.length && f2Idx < fields2.length) { - val f1Name = fields1(f1Idx).name - val f2Name = fields2(f2Idx).name - val comp = f1Name.compareTo(f2Name) - if (comp == 0) { - val dataType = compatibleType(fields1(f1Idx).dataType, fields2(f2Idx).dataType) - newFields.add(StructField(f1Name, dataType, nullable = true)) - f1Idx += 1 - f2Idx += 1 - } else if (comp < 0) { // f1Name < f2Name - newFields.add(fields1(f1Idx)) - f1Idx += 1 - } else { // f1Name > f2Name - newFields.add(fields2(f2Idx)) - f2Idx += 1 - } - } - while (f1Idx < fields1.length) { - newFields.add(fields1(f1Idx)) - f1Idx += 1 - } - while (f2Idx < fields2.length) { - newFields.add(fields2(f2Idx)) - f2Idx += 1 - } - StructType(newFields.toArray(emptyStructFieldArray)) - - case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => - ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) - - // The case that given `DecimalType` is capable of given `IntegralType` is handled in - // `findTightestCommonType`. Both cases below will be executed only when the given - // `DecimalType` is not capable of the given `IntegralType`. - case (t1: IntegralType, t2: DecimalType) => - compatibleType(DecimalType.forType(t1), t2) - case (t1: DecimalType, t2: IntegralType) => - compatibleType(t1, DecimalType.forType(t2)) - - // strings and every string is a Json object. - case (_, _) => StringType - } - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index acca957..614f65f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -3382,6 +3382,48 @@ object functions { } /** + * (Scala-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column): Column = { + from_json(e, schema, Map.empty[String, String].asJava) + } + + /** + * (Java-specific) Parses a column containing a JSON string into a `MapType` with `StringType` + * as keys type, `StructType` or `ArrayType` of `StructType`s with the specified schema. + * Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.4.0 + */ + def from_json(e: Column, schema: Column, options: java.util.Map[String, String]): Column = { + withExpr(new JsonToStructs(e.expr, schema.expr, options.asScala.toMap)) + } + + /** + * Parses a column containing a JSON string and infers its schema. + * + * @param e a string column containing JSON data. + * + * @group collection_funcs + * @since 2.4.0 + */ + def schema_of_json(e: Column): Column = withExpr(new SchemaOfJson(e.expr)) + + /** * (Scala-specific) Converts a column containing a `StructType`, `ArrayType` of `StructType`s, * a `MapType` or `ArrayType` of `MapType`s into a JSON string with the specified schema. * Throws an exception, in the case of an unsupported type. http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql index dc15d13..79fdd58 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -35,3 +35,7 @@ DROP VIEW IF EXISTS jsonTable; -- from_json - complex types select from_json('{"a":1, "b":2}', 'map<string, int>'); select from_json('{"a":1, "b":"2"}', 'struct<a:int,b:string>'); + +-- infer schema of json literal +select schema_of_json('{"c1":0, "c2":[1]}'); +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')); http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 2b3288d..3d49323 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 28 +-- Number of queries: 30 -- !query 0 @@ -183,7 +183,7 @@ select from_json('{"a":1}', 1) struct<> -- !query 17 output org.apache.spark.sql.AnalysisException -Expected a string literal instead of 1;; line 1 pos 7 +Schema should be specified in DDL format as a string literal or output of the schema_of_json function instead of 1;; line 1 pos 7 -- !query 18 @@ -274,3 +274,19 @@ select from_json('{"a":1, "b":"2"}', 'struct<a:int,b:string>') struct<jsontostructs({"a":1, "b":"2"}):struct<a:int,b:string>> -- !query 27 output {"a":1,"b":"2"} + + +-- !query 28 +select schema_of_json('{"c1":0, "c2":[1]}') +-- !query 28 schema +struct<schemaofjson({"c1":0, "c2":[1]}):string> +-- !query 28 output +struct<c1:bigint,c2:array<bigint>> + + +-- !query 29 +select from_json('{"c1":[1, 2, 3]}', schema_of_json('{"c1":[0]}')) +-- !query 29 schema +struct<jsontostructs({"c1":[1, 2, 3]}):struct<c1:array<bigint>>> +-- !query 29 output +{"c1":[1,2,3]} http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 7bf17cb..d3b2701 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.functions.{from_json, lit, map, struct, to_json} +import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -311,7 +311,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { val errMsg1 = intercept[AnalysisException] { df3.selectExpr("from_json(value, 1)") } - assert(errMsg1.getMessage.startsWith("Expected a string literal instead of")) + assert(errMsg1.getMessage.startsWith("Schema should be specified in DDL format as a string")) val errMsg2 = intercept[AnalysisException] { df3.selectExpr("""from_json(value, 'time InvalidType')""") } @@ -392,4 +392,17 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer(Seq("""{"{"f": 1}": "a"}""").toDS().select(from_json($"value", schema)), Row(null)) } + + test("SPARK-24709: infers schemas of json strings and pass them to from_json") { + val in = Seq("""{"a": [1, 2, 3]}""").toDS() + val out = in.select(from_json('value, schema_of_json(lit("""{"a": [1]}"""))) as "parsed") + val expected = StructType(StructField( + "parsed", + StructType(StructField( + "a", + ArrayType(LongType, true), true) :: Nil), + true) :: Nil) + + assert(out.schema == expected) + } } http://git-wip-us.apache.org/repos/asf/spark/blob/776f299f/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 897424d..eab15b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -31,11 +31,11 @@ import org.apache.hadoop.io.compress.GzipCodec import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{functions => F, _} -import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} +import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JsonInferSchema, JSONOptions} +import org.apache.spark.sql.catalyst.json.JsonInferSchema.compatibleType import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.ExternalRDD import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.datasources.json.JsonInferSchema.compatibleType import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org