chenhao-db commented on code in PR #45806:
URL: https://github.com/apache/spark/pull/45806#discussion_r1555269229
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -403,3 +405,134 @@ object VariantGetExpressionBuilder extends
VariantGetExpressionBuilderBase(true)
)
// scalastyle:on line.size.limit
object TryVariantGetExpressionBuilder extends
VariantGetExpressionBuilderBase(false)
+
+@ExpressionDescription(
+ usage = "_FUNC_(v) - Returns schema in the SQL format of a variant.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(parse_json('null'));
+ VOID
+ > SELECT _FUNC_(parse_json('[{"b":true,"a":0}]'));
+ ARRAY<STRUCT<a: BIGINT, b: BOOLEAN>>
+ """,
+ since = "4.0.0",
+ group = "variant_funcs"
+)
+case class SchemaOfVariant(child: Expression)
+ extends UnaryExpression
+ with RuntimeReplaceable
+ with ExpectsInputTypes {
+ override lazy val replacement: Expression = StaticInvoke(
+ SchemaOfVariant.getClass,
+ StringType,
+ "schemaOfVariant",
+ Seq(child),
+ inputTypes,
+ returnNullable = false)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(VariantType)
+
+ override def dataType: DataType = StringType
+
+ override def prettyName: String = "schema_of_variant"
+
+ override protected def withNewChildInternal(newChild: Expression):
SchemaOfVariant =
+ copy(child = newChild)
+}
+
+object SchemaOfVariant {
+ /** The actual implementation of the `SchemaOfVariant` expression. */
+ def schemaOfVariant(input: VariantVal): UTF8String = {
+ val v = new Variant(input.getValue, input.getMetadata)
+ UTF8String.fromString(schemaOf(v).sql)
+ }
+
+ /**
+ * Return the schema of a variant. Struct fields are guaranteed to be sorted
alphabetically.
+ */
+ def schemaOf(v: Variant): DataType = v.getType match {
+ case Type.OBJECT =>
+ val size = v.objectSize()
+ val fields = new Array[StructField](size)
+ for (i <- 0 until size) {
+ val field = v.getFieldAtIndex(i)
+ fields(i) = StructField(field.key, schemaOf(field.value))
+ }
+ // According to the variant spec, object fields must be sorted
alphabetically. So we don't
+ // have to sort, but just need to validate they are sorted.
+ for (i <- 1 until size) {
+ if (fields(i - 1).name >= fields(i).name) {
+ throw new SparkRuntimeException("MALFORMED_VARIANT", Map.empty)
+ }
+ }
+ StructType(fields)
+ case Type.ARRAY =>
+ var elementType: DataType = NullType
+ for (i <- 0 until v.arraySize()) {
+ elementType = mergeSchema(elementType,
schemaOf(v.getElementAtIndex(i)))
+ }
+ ArrayType(elementType)
+ case Type.NULL => NullType
+ case Type.BOOLEAN => BooleanType
+ case Type.LONG => LongType
+ case Type.STRING => StringType
+ case Type.DOUBLE => DoubleType
+ case Type.DECIMAL =>
+ val d = v.getDecimal
+ DecimalType(d.precision(), d.scale())
+ }
+
+ /**
+ * Returns the tightest common type for two given data types. Input struct
fields are assumed to
+ * be sorted alphabetically.
+ */
+ def mergeSchema(t1: DataType, t2: DataType): DataType = (t1, t2) match {
+ case (t1, t2) if t1 == t2 => t1
+ case (t1, NullType) => t1
+ case (NullType, t2) => t2
+ case (DoubleType, _: NumericType) | (_: NumericType, DoubleType) =>
DoubleType
+ case (t1: IntegralType, t2: DecimalType) =>
mergeSchema(DecimalType.forType(t1), t2)
+ case (t1: DecimalType, t2: IntegralType) => mergeSchema(t1,
DecimalType.forType(t2))
+ case (t1: DecimalType, t2: DecimalType) =>
+ val scale = math.max(t1.scale, t2.scale)
+ val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
+ if (range + scale > DecimalType.MAX_PRECISION) {
Review Comment:
I think it will be more intuitive if the variant can be successfully cast
into the inferred schema (even at the cost of a precision loss). If we return a
truncated decimal here, the cast will deterministically fail.
Actually, this `mergedSchema` function, including this fallback-to-double
logic, is largely adapted from the existing JSON schema inference code:
https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JsonInferSchema.scala#L364.
I find it is not too difficult to reuse this code, so I changed the
implementation to depend on it instead.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]