kazuyukitanimura commented on code in PR #4476:
URL: https://github.com/apache/datafusion-comet/pull/4476#discussion_r3327100269
##########
spark/src/main/scala/org/apache/comet/serde/hash.scala:
##########
@@ -68,72 +73,76 @@ object CometMurmur3Hash extends
CometExpressionSerde[Murmur3Hash] {
}
object CometSha2 extends CometExpressionSerde[Sha2] {
- override def convert(
- expr: Sha2,
- inputs: Seq[Attribute],
- binding: Boolean): Option[ExprOuterClass.Expr] = {
- if (!HashUtils.isSupportedType(expr)) {
- return None
- }
- // It's possible for spark to dynamically compute the number of bits from
input
- // expression, however DataFusion does not support that yet.
+ private val nonFoldableNumBitsReason =
+ "The `numBits` argument must be a foldable literal value"
+
+ override def getUnsupportedReasons(): Seq[String] =
+ HashUtils.unsupportedReasons :+ nonFoldableNumBitsReason
+
+ override def getSupportLevel(expr: Sha2): SupportLevel = {
if (!expr.right.foldable) {
- withInfo(expr, "For Sha2, non literal numBits is not supported")
- return None
+ Unsupported(Some(nonFoldableNumBitsReason))
+ } else {
+ HashUtils.supportLevelForChildren(expr)
}
+ }
+ override def convert(
+ expr: Sha2,
+ inputs: Seq[Attribute],
+ binding: Boolean): Option[ExprOuterClass.Expr] = {
val leftExpr = exprToProtoInternal(expr.left, inputs, binding)
val numBitsExpr = exprToProtoInternal(expr.right, inputs, binding)
scalarFunctionExprToProtoWithReturnType("sha2", StringType, false,
leftExpr, numBitsExpr)
}
}
object CometSha1 extends CometExpressionSerde[Sha1] {
+
+ override def getUnsupportedReasons(): Seq[String] =
HashUtils.unsupportedReasons
+
+ override def getSupportLevel(expr: Sha1): SupportLevel =
+ HashUtils.supportLevelForChildren(expr)
+
override def convert(
expr: Sha1,
inputs: Seq[Attribute],
binding: Boolean): Option[ExprOuterClass.Expr] = {
- if (!HashUtils.isSupportedType(expr)) {
- withInfo(expr, s"HashUtils doesn't support dataType:
${expr.child.dataType}")
- return None
- }
val childExpr = exprToProtoInternal(expr.child, inputs, binding)
scalarFunctionExprToProtoWithReturnType("sha1", StringType, false,
childExpr)
}
}
private object HashUtils {
- def isSupportedType(expr: Expression): Boolean = {
- for (child <- expr.children) {
- if (!isSupportedDataType(expr, child.dataType)) {
- return false
- }
+
+ private val unsupportedDecimalReason =
+ "`DecimalType` with precision > 18 is not supported (Spark hashes via Java
`BigDecimal`)"
+ private val unsupportedTimeTypeReason = "`TimeType` is not supported"
+
+ val unsupportedReasons: Seq[String] =
+ Seq(unsupportedDecimalReason, unsupportedTimeTypeReason, "Unsupported
child data type")
+
+ def supportLevelForChildren(expr: Expression): SupportLevel = {
+ expr.children.iterator
+ .flatMap(c => unsupportedReasonFor(c.dataType).iterator)
+ .toSeq
+ .headOption match {
+ case Some(reason) => Unsupported(Some(reason))
+ case None => Compatible()
}
- true
}
- private def isSupportedDataType(expr: Expression, dt: DataType): Boolean = {
- dt match {
- case d: DecimalType if d.precision > 18 =>
- // Spark converts decimals with precision > 18 into
- // Java BigDecimal before hashing
- withInfo(expr, s"Unsupported datatype: $dt (precision > 18)")
- false
- case s: StructType =>
- s.fields.forall(f => isSupportedDataType(expr, f.dataType))
- case a: ArrayType =>
- isSupportedDataType(expr, a.elementType)
- case m: MapType =>
- isSupportedDataType(expr, m.keyType) && isSupportedDataType(expr,
m.valueType)
- case dt if isTimeType(dt) =>
- withInfo(expr, s"Unsupported datatype $dt")
- false
- case _ if !supportedDataType(dt, allowComplex = true) =>
- withInfo(expr, s"Unsupported datatype $dt")
- false
- case _ =>
- true
- }
+ private def unsupportedReasonFor(dt: DataType): Option[String] = dt match {
+ case d: DecimalType if d.precision > 18 => Some(unsupportedDecimalReason)
+ case s: StructType =>
+ s.fields.iterator.flatMap(f =>
unsupportedReasonFor(f.dataType).iterator).toSeq.headOption
Review Comment:
If `unsupportedReasonFor(f.dataType).iterator` returns `None` for the first
element of `fields`, will `s.fields.iterator.flatMap(f =>
unsupportedReasonFor(f.dataType).iterator).toSeq.headOption` return `None`?
Do we need to make sure all `fields` return `None`?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]