chenhao-db commented on code in PR #45708:
URL: https://github.com/apache/spark/pull/45708#discussion_r1541767305
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala:
##########
@@ -63,3 +70,300 @@ case class ParseJson(child: Expression) extends
UnaryExpression
override protected def withNewChildInternal(newChild: Expression): ParseJson
=
copy(child = newChild)
}
+
+// A path segment in the `VariantGet` expression. It represents either an
object key access (when
+// `key` is not null) or an array index access (when `key` is null).
+case class PathSegment(key: String, index: Int)
+
+object VariantPathParser extends RegexParsers {
+ private def root: Parser[Char] = '$'
+
+ // Parse index segment like `[123]`.
+ private def index: Parser[PathSegment] =
+ for {
+ index <- '[' ~> "\\d+".r <~ ']'
+ } yield {
+ PathSegment(null, index.toInt)
+ }
+
+ // Parse key segment like `.name`, `['name']`, or `["name"]`.
+ private def key: Parser[PathSegment] =
+ for {
+ key <- '.' ~> "[^\\.\\[]+".r | "['" ~> "[^\\'\\?]+".r <~ "']" |
+ "[\"" ~> "[^\\\"\\?]+".r <~ "\"]"
+ } yield {
+ PathSegment(key, 0)
+ }
+
+ private val parser: Parser[List[PathSegment]] = phrase(root ~> rep(key |
index))
+
+ def parse(str: String): Option[Array[PathSegment]] = {
+ this.parseAll(parser, str) match {
+ case Success(result, _) => Some(result.toArray)
+ case _ => None
+ }
+ }
+}
+
+/**
+ * The implementation for `variant_get` and `try_variant_get` expressions.
Extracts a sub-variant
+ * value according to a path and cast it into a concrete data type.
+ * @param child The source variant value to extract from.
+ * @param path A literal path expression. It has the same format as the JSON
path.
+ * @param schema The target data type to cast into.
+ * @param failOnError Controls whether the expression should throw an
exception or return null if
+ * the cast fails.
+ * @param timeZoneId A string identifier of a time zone. It is required by
timestamp-related casts.
+ */
+case class VariantGet(
+ child: Expression,
+ path: Expression,
+ schema: DataType,
+ failOnError: Boolean,
+ timeZoneId: Option[String] = None)
+ extends BinaryExpression
+ with TimeZoneAwareExpression
+ with NullIntolerant
+ with ExpectsInputTypes
+ with CodegenFallback
+ with QueryErrorsBase {
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val check = super.checkInputDataTypes()
+ if (check.isFailure) {
+ check
+ } else if (!path.foldable) {
+ DataTypeMismatch(
+ errorSubClass = "NON_FOLDABLE_INPUT",
+ messageParameters = Map(
+ "inputName" -> toSQLId("path"),
+ "inputType" -> toSQLType(path.dataType),
+ "inputExpr" -> toSQLExpr(path)
+ )
+ )
+ } else if (!VariantGet.checkDataType(schema)) {
+ DataTypeMismatch(
+ errorSubClass = "CAST_WITHOUT_SUGGESTION",
+ messageParameters = Map(
+ "srcType" -> toSQLType(VariantType),
+ "targetType" -> toSQLType(schema)
+ )
+ )
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ override lazy val dataType: DataType = schema.asNullable
+
+ @transient private lazy val parsedPath = {
+ val pathValue = path.eval().toString
+ VariantPathParser.parse(pathValue).getOrElse {
+ throw QueryExecutionErrors.invalidVariantGetPath(pathValue, prettyName)
+ }
+ }
+
+ final override def nodePatternsInternal(): Seq[TreePattern] =
Seq(VARIANT_GET)
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(VariantType, StringType)
+
+ override def prettyName: String = if (failOnError) "variant_get" else
"try_variant_get"
+
+ override def nullable: Boolean = true
+
+ protected override def nullSafeEval(input: Any, path: Any): Any = {
Review Comment:
I have done some experiments with the `StaticInvoke` approach. Suppose I
have encapsulated the `VariantGet` implementation into the following function:
```
case object VariantGetCodegen {
def variantGet(input: VariantVal, parsedPath: Array[PathSegment],
dataType: DataType, failOnError: Boolean, zoneId:
Option[String]): Any = {...}
}
```
and make `VariantGet` a `RuntimeReplaceable` expression with a replacement
of `StaticInvoke` that invokes `VariantGetCodegen.variantGet`. It still won't
directly work because the codegen logic of `StaticInvoke` assumes the return
type of the method directly matches the return type, but the return type of
`VariantGetCodegen.variantGet` is `Any`.
In order to make it work, I have to create a wrapper for each return type,
like:
```
case object VariantGetCodegen {
def variantGetByte(input: VariantVal, parsedPath: Array[PathSegment],
dataType: DataType, failOnError: Boolean, zoneId:
Option[String]): Byte =
variantGet(input, parsedPath. dataType, failOnError,
zoneId).asInstanceOf[Byte]
def variantGetShort(input: VariantVal, parsedPath: Array[PathSegment],
dataType: DataType, failOnError: Boolean, zoneId:
Option[String]): Short =
variantGet(input, parsedPath. dataType, failOnError,
zoneId).asInstanceOf[Short]
def variantGetStruct(input: VariantVal, parsedPath: Array[PathSegment],
dataType: DataType, failOnError: Boolean, zoneId:
Option[String]): InternalRow =
variantGet(input, parsedPath. dataType, failOnError,
zoneId).asInstanceOf[InternalRow]
...
}
```
and pick one method according to the return type. It is very cumbersome and
doesn't really avoid any boxing/unboxing costs.
On the other hand, if we have this `VariantGetCodegen.variantGet` method, it
is reasonably easy to write the codegen by hand. I just need to cast the return
value of this method into the desired type. The whole point of using
`StaticInvoke` is to simplify the implementation, but I think it actually makes
the implementation much more complex.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]