This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3a61eeff340a [SPARK-50619][SQL] Refactor VariantGet.cast to pack the
cast arguments
3a61eeff340a is described below
commit 3a61eeff340a0f979dfb2929aeff128c60f18a2c
Author: Chenhao Li <[email protected]>
AuthorDate: Thu Dec 19 17:04:34 2024 +0800
[SPARK-50619][SQL] Refactor VariantGet.cast to pack the cast arguments
### What changes were proposed in this pull request?
As the title. It refactors the code for simplification.
### Why are the changes needed?
The refactor will make it simpler for the shredded user to use
`VariantGet.cast`.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49239 from chenhao-db/VariantCastArgs.
Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/expressions/Cast.scala | 13 +++--
.../expressions/variant/variantExpressions.scala | 61 +++++++++-------------
2 files changed, 34 insertions(+), 40 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index d4ebdf10ef11..abd635e22f26 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -572,6 +572,11 @@ case class Cast(
}
}
+ private lazy val castArgs = variant.VariantCastArgs(
+ evalMode != EvalMode.TRY,
+ timeZoneId,
+ zoneId)
+
def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType)
// [[func]] assumes the input is no longer null because eval already does
the null check.
@@ -1127,7 +1132,7 @@ case class Cast(
_ => throw QueryExecutionErrors.cannotCastFromNullTypeError(to)
} else if (from.isInstanceOf[VariantType]) {
buildCast[VariantVal](_, v => {
- variant.VariantGet.cast(v, to, evalMode != EvalMode.TRY, timeZoneId,
zoneId)
+ variant.VariantGet.cast(v, to, castArgs)
})
} else {
to match {
@@ -1225,12 +1230,10 @@ case class Cast(
case _ if from.isInstanceOf[VariantType] => (c, evPrim, evNull) =>
val tmp = ctx.freshVariable("tmp", classOf[Object])
val dataTypeArg = ctx.addReferenceObj("dataType", to)
- val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
- val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId,
classOf[ZoneId].getName)
- val failOnError = evalMode != EvalMode.TRY
+ val castArgsArg = ctx.addReferenceObj("castArgs", castArgs)
val cls = classOf[variant.VariantGet].getName
code"""
- Object $tmp = $cls.cast($c, $dataTypeArg, $failOnError, $zoneStrArg,
$zoneIdArg);
+ Object $tmp = $cls.cast($c, $dataTypeArg, $castArgsArg);
if ($tmp == null) {
$evNull = true;
} else {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
index 2fa0ce0f570c..c19df82e6576 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/variantExpressions.scala
@@ -278,14 +278,13 @@ case class VariantGet(
override def nullable: Boolean = true
override def nullIntolerant: Boolean = true
+ private lazy val castArgs = VariantCastArgs(
+ failOnError,
+ timeZoneId,
+ zoneId)
+
protected override def nullSafeEval(input: Any, path: Any): Any = {
- VariantGet.variantGet(
- input.asInstanceOf[VariantVal],
- parsedPath,
- dataType,
- failOnError,
- timeZoneId,
- zoneId)
+ VariantGet.variantGet(input.asInstanceOf[VariantVal], parsedPath,
dataType, castArgs)
}
protected override def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
@@ -293,15 +292,14 @@ case class VariantGet(
val tmp = ctx.freshVariable("tmp", classOf[Object])
val parsedPathArg = ctx.addReferenceObj("parsedPath", parsedPath)
val dataTypeArg = ctx.addReferenceObj("dataType", dataType)
- val zoneStrArg = ctx.addReferenceObj("zoneStr", timeZoneId)
- val zoneIdArg = ctx.addReferenceObj("zoneId", zoneId,
classOf[ZoneId].getName)
+ val castArgsArg = ctx.addReferenceObj("castArgs", castArgs)
val code = code"""
${childCode.code}
boolean ${ev.isNull} = ${childCode.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} =
${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) {
Object $tmp =
org.apache.spark.sql.catalyst.expressions.variant.VariantGet.variantGet(
- ${childCode.value}, $parsedPathArg, $dataTypeArg, $failOnError,
$zoneStrArg, $zoneIdArg);
+ ${childCode.value}, $parsedPathArg, $dataTypeArg, $castArgsArg);
if ($tmp == null) {
${ev.isNull} = true;
} else {
@@ -323,6 +321,12 @@ case class VariantGet(
override def withTimeZone(timeZoneId: String): VariantGet = copy(timeZoneId
= Option(timeZoneId))
}
+// Several parameters used by `VariantGet.cast`. Packed together to simplify
parameter passing.
+case class VariantCastArgs(
+ failOnError: Boolean,
+ zoneStr: Option[String],
+ zoneId: ZoneId)
+
case object VariantGet {
/**
* Returns whether a data type can be cast into/from variant. For scalar
types, we allow a subset
@@ -347,9 +351,7 @@ case object VariantGet {
input: VariantVal,
parsedPath: Array[VariantPathParser.PathSegment],
dataType: DataType,
- failOnError: Boolean,
- zoneStr: Option[String],
- zoneId: ZoneId): Any = {
+ castArgs: VariantCastArgs): Any = {
var v = new Variant(input.getValue, input.getMetadata)
for (path <- parsedPath) {
v = path match {
@@ -359,21 +361,16 @@ case object VariantGet {
}
if (v == null) return null
}
- VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId)
+ VariantGet.cast(v, dataType, castArgs)
}
/**
* A simple wrapper of the `cast` function that takes `Variant` rather than
`VariantVal`. The
* `Cast` expression uses it and makes the implementation simpler.
*/
- def cast(
- input: VariantVal,
- dataType: DataType,
- failOnError: Boolean,
- zoneStr: Option[String],
- zoneId: ZoneId): Any = {
+ def cast(input: VariantVal, dataType: DataType, castArgs: VariantCastArgs):
Any = {
val v = new Variant(input.getValue, input.getMetadata)
- VariantGet.cast(v, dataType, failOnError, zoneStr, zoneId)
+ VariantGet.cast(v, dataType, castArgs)
}
/**
@@ -383,15 +380,10 @@ case object VariantGet {
* "hello" to int). If the cast fails, throw an exception when `failOnError`
is true, or return a
* SQL NULL when it is false.
*/
- def cast(
- v: Variant,
- dataType: DataType,
- failOnError: Boolean,
- zoneStr: Option[String],
- zoneId: ZoneId): Any = {
+ def cast(v: Variant, dataType: DataType, castArgs: VariantCastArgs): Any = {
def invalidCast(): Any = {
- if (failOnError) {
- throw QueryExecutionErrors.invalidVariantCast(v.toJson(zoneId),
dataType)
+ if (castArgs.failOnError) {
+ throw
QueryExecutionErrors.invalidVariantCast(v.toJson(castArgs.zoneId), dataType)
} else {
null
}
@@ -411,7 +403,7 @@ case object VariantGet {
val input = variantType match {
case Type.OBJECT | Type.ARRAY =>
return if (dataType.isInstanceOf[StringType]) {
- UTF8String.fromString(v.toJson(zoneId))
+ UTF8String.fromString(v.toJson(castArgs.zoneId))
} else {
invalidCast()
}
@@ -457,7 +449,7 @@ case object VariantGet {
}
case _ =>
if (Cast.canAnsiCast(input.dataType, dataType)) {
- val result = Cast(input, dataType, zoneStr, EvalMode.TRY).eval()
+ val result = Cast(input, dataType, castArgs.zoneStr,
EvalMode.TRY).eval()
if (result == null) invalidCast() else result
} else {
invalidCast()
@@ -468,7 +460,7 @@ case object VariantGet {
val size = v.arraySize()
val array = new Array[Any](size)
for (i <- 0 until size) {
- array(i) = cast(v.getElementAtIndex(i), elementType, failOnError,
zoneStr, zoneId)
+ array(i) = cast(v.getElementAtIndex(i), elementType, castArgs)
}
new GenericArrayData(array)
} else {
@@ -482,7 +474,7 @@ case object VariantGet {
for (i <- 0 until size) {
val field = v.getFieldAtIndex(i)
keyArray(i) = UTF8String.fromString(field.key)
- valueArray(i) = cast(field.value, valueType, failOnError, zoneStr,
zoneId)
+ valueArray(i) = cast(field.value, valueType, castArgs)
}
ArrayBasedMapData(keyArray, valueArray)
} else {
@@ -495,8 +487,7 @@ case object VariantGet {
val field = v.getFieldAtIndex(i)
st.getFieldIndex(field.key) match {
case Some(idx) =>
- row.update(idx,
- cast(field.value, fields(idx).dataType, failOnError,
zoneStr, zoneId))
+ row.update(idx, cast(field.value, fields(idx).dataType,
castArgs))
case _ =>
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]