This is an automated email from the ASF dual-hosted git repository. agrove pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push: new a5b36b45d chore: Refactor QueryPlanSerde to allow logic to be moved to individual classes per expression (#1331) a5b36b45d is described below commit a5b36b45de934bf5f3cf768067944a69053da540 Author: Andy Grove <agr...@apache.org> AuthorDate: Fri Jan 24 14:04:53 2025 -0700 chore: Refactor QueryPlanSerde to allow logic to be moved to individual classes per expression (#1331) * Move castToProto to top-level method * move more methods to top-level * more refactoring * revert accidental rename * move CometExpressionSerde to QueryPlanSerde * add scaladoc * address input --- .../org/apache/comet/serde/QueryPlanSerde.scala | 3102 ++++++++++---------- .../main/scala/org/apache/comet/serde/arrays.scala | 25 +- 2 files changed, 1622 insertions(+), 1505 deletions(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index 22bb7dc82..c3d7ac749 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -826,723 +826,790 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim } /** - * Convert a Spark expression to protobuf. - * - * @param expr - * The input expression - * @param inputs - * The input attributes - * @param binding - * Whether to bind the expression to the input attributes - * @return - * The protobuf representation of the expression, or None if the expression is not supported + * Wrap an expression in a cast. */ - def exprToProto( + def castToProto( expr: Expression, - input: Seq[Attribute], - binding: Boolean = true): Option[Expr] = { - def castToProto( - timeZoneId: Option[String], - dt: DataType, - childExpr: Option[Expr], - evalMode: CometEvalMode.Value): Option[Expr] = { - val dataType = serializeDataType(dt) - - if (childExpr.isDefined && dataType.isDefined) { + timeZoneId: Option[String], + dt: DataType, + childExpr: Expr, + evalMode: CometEvalMode.Value): Option[Expr] = { + serializeDataType(dt) match { + case Some(dataType) => val castBuilder = ExprOuterClass.Cast.newBuilder() - castBuilder.setChild(childExpr.get) - castBuilder.setDatatype(dataType.get) + castBuilder.setChild(childExpr) + castBuilder.setDatatype(dataType) castBuilder.setEvalMode(evalModeToProto(evalMode)) castBuilder.setAllowIncompat(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) - val timeZone = timeZoneId.getOrElse("UTC") - castBuilder.setTimezone(timeZone) - + castBuilder.setTimezone(timeZoneId.getOrElse("UTC")) Some( ExprOuterClass.Expr .newBuilder() .setCast(castBuilder) .build()) - } else { - if (!dataType.isDefined) { - withInfo(expr, s"Unsupported datatype ${dt}") - } else { - withInfo(expr, s"Unsupported expression $childExpr") - } + case _ => + withInfo(expr, s"Unsupported datatype in castToProto: $dt") None - } } + } - def exprToProtoInternal(expr: Expression, inputs: Seq[Attribute]): Option[Expr] = { - SQLConf.get - - def handleCast( - child: Expression, - inputs: Seq[Attribute], - dt: DataType, - timeZoneId: Option[String], - evalMode: CometEvalMode.Value): Option[Expr] = { - - val childExpr = exprToProtoInternal(child, inputs) - if (childExpr.isDefined) { - val castSupport = - CometCast.isSupported(child.dataType, dt, timeZoneId, evalMode) - - def getIncompatMessage(reason: Option[String]): String = - "Comet does not guarantee correct results for cast " + - s"from ${child.dataType} to $dt " + - s"with timezone $timeZoneId and evalMode $evalMode" + - reason.map(str => s" ($str)").getOrElse("") - - castSupport match { - case Compatible(_) => - castToProto(timeZoneId, dt, childExpr, evalMode) - case Incompatible(reason) => - if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) { - logWarning(getIncompatMessage(reason)) - castToProto(timeZoneId, dt, childExpr, evalMode) - } else { - withInfo( - expr, - s"${getIncompatMessage(reason)}. To enable all incompatible casts, set " + - s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true") - None - } - case Unsupported => - withInfo( - expr, - s"Unsupported cast from ${child.dataType} to $dt " + - s"with timezone $timeZoneId and evalMode $evalMode") - None + def handleCast( + expr: Expression, + child: Expression, + inputs: Seq[Attribute], + binding: Boolean, + dt: DataType, + timeZoneId: Option[String], + evalMode: CometEvalMode.Value): Option[Expr] = { + + val childExpr = exprToProtoInternal(child, inputs, binding) + if (childExpr.isDefined) { + val castSupport = + CometCast.isSupported(child.dataType, dt, timeZoneId, evalMode) + + def getIncompatMessage(reason: Option[String]): String = + "Comet does not guarantee correct results for cast " + + s"from ${child.dataType} to $dt " + + s"with timezone $timeZoneId and evalMode $evalMode" + + reason.map(str => s" ($str)").getOrElse("") + + castSupport match { + case Compatible(_) => + castToProto(expr, timeZoneId, dt, childExpr.get, evalMode) + case Incompatible(reason) => + if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) { + logWarning(getIncompatMessage(reason)) + castToProto(expr, timeZoneId, dt, childExpr.get, evalMode) + } else { + withInfo( + expr, + s"${getIncompatMessage(reason)}. To enable all incompatible casts, set " + + s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true") + None } - } else { - withInfo(expr, child) + case Unsupported => + withInfo( + expr, + s"Unsupported cast from ${child.dataType} to $dt " + + s"with timezone $timeZoneId and evalMode $evalMode") None - } } + } else { + withInfo(expr, child) + None + } + } - expr match { - case a @ Alias(_, _) => - val r = exprToProtoInternal(a.child, inputs) - if (r.isEmpty) { - withInfo(expr, a.child) - } - r - - case cast @ Cast(_: Literal, dataType, _, _) => - // This can happen after promoting decimal precisions - val value = cast.eval() - exprToProtoInternal(Literal(value, dataType), inputs) - - case UnaryExpression(child) if expr.prettyName == "trycast" => - val timeZoneId = SQLConf.get.sessionLocalTimeZone - handleCast(child, inputs, expr.dataType, Some(timeZoneId), CometEvalMode.TRY) + /** + * Convert a Spark expression to protobuf. + * + * @param expr + * The input expression + * @param inputs + * The input attributes + * @param binding + * Whether to bind the expression to the input attributes + * @return + * The protobuf representation of the expression, or None if the expression is not supported + */ + def exprToProto( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean = true): Option[Expr] = { - case c @ Cast(child, dt, timeZoneId, _) => - handleCast(child, inputs, dt, timeZoneId, evalMode(c)) + val conf = SQLConf.get + val newExpr = + DecimalPrecision.promote(conf.decimalOperationsAllowPrecisionLoss, expr, !conf.ansiEnabled) + exprToProtoInternal(newExpr, inputs, binding) + } - case add @ Add(left, right, _) if supportedDataType(left.dataType) => - createMathExpression( - left, - right, - inputs, - add.dataType, - getFailOnError(add), - (builder, mathExpr) => builder.setAdd(mathExpr)) + def exprToProtoInternal( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[Expr] = { + SQLConf.get + + expr match { + case a @ Alias(_, _) => + val r = exprToProtoInternal(a.child, inputs, binding) + if (r.isEmpty) { + withInfo(expr, a.child) + } + r + + case cast @ Cast(_: Literal, dataType, _, _) => + // This can happen after promoting decimal precisions + val value = cast.eval() + exprToProtoInternal(Literal(value, dataType), inputs, binding) + + case UnaryExpression(child) if expr.prettyName == "trycast" => + val timeZoneId = SQLConf.get.sessionLocalTimeZone + handleCast( + expr, + child, + inputs, + binding, + expr.dataType, + Some(timeZoneId), + CometEvalMode.TRY) + + case c @ Cast(child, dt, timeZoneId, _) => + handleCast(expr, child, inputs, binding, dt, timeZoneId, evalMode(c)) + + case add @ Add(left, right, _) if supportedDataType(left.dataType) => + createMathExpression( + expr, + left, + right, + inputs, + binding, + add.dataType, + getFailOnError(add), + (builder, mathExpr) => builder.setAdd(mathExpr)) + + case add @ Add(left, _, _) if !supportedDataType(left.dataType) => + withInfo(add, s"Unsupported datatype ${left.dataType}") + None - case add @ Add(left, _, _) if !supportedDataType(left.dataType) => - withInfo(add, s"Unsupported datatype ${left.dataType}") - None + case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) => + createMathExpression( + expr, + left, + right, + inputs, + binding, + sub.dataType, + getFailOnError(sub), + (builder, mathExpr) => builder.setSubtract(mathExpr)) + + case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) => + withInfo(sub, s"Unsupported datatype ${left.dataType}") + None - case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) => - createMathExpression( - left, - right, - inputs, - sub.dataType, - getFailOnError(sub), - (builder, mathExpr) => builder.setSubtract(mathExpr)) + case mul @ Multiply(left, right, _) + if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + createMathExpression( + expr, + left, + right, + inputs, + binding, + mul.dataType, + getFailOnError(mul), + (builder, mathExpr) => builder.setMultiply(mathExpr)) + + case mul @ Multiply(left, _, _) => + if (!supportedDataType(left.dataType)) { + withInfo(mul, s"Unsupported datatype ${left.dataType}") + } + if (decimalBeforeSpark34(left.dataType)) { + withInfo(mul, "Decimal support requires Spark 3.4 or later") + } + None - case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) => - withInfo(sub, s"Unsupported datatype ${left.dataType}") - None + case div @ Divide(left, right, _) + if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + // Datafusion now throws an exception for dividing by zero + // See https://github.com/apache/arrow-datafusion/pull/6792 + // For now, use NullIf to swap zeros with nulls. + val rightExpr = nullIfWhenPrimitive(right) + + createMathExpression( + expr, + left, + rightExpr, + inputs, + binding, + div.dataType, + getFailOnError(div), + (builder, mathExpr) => builder.setDivide(mathExpr)) + + case div @ Divide(left, _, _) => + if (!supportedDataType(left.dataType)) { + withInfo(div, s"Unsupported datatype ${left.dataType}") + } + if (decimalBeforeSpark34(left.dataType)) { + withInfo(div, "Decimal support requires Spark 3.4 or later") + } + None - case mul @ Multiply(left, right, _) - if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - createMathExpression( - left, - right, - inputs, - mul.dataType, - getFailOnError(mul), - (builder, mathExpr) => builder.setMultiply(mathExpr)) + case rem @ Remainder(left, right, _) + if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => + val rightExpr = nullIfWhenPrimitive(right) - case mul @ Multiply(left, _, _) => - if (!supportedDataType(left.dataType)) { - withInfo(mul, s"Unsupported datatype ${left.dataType}") - } - if (decimalBeforeSpark34(left.dataType)) { - withInfo(mul, "Decimal support requires Spark 3.4 or later") - } - None + createMathExpression( + expr, + left, + rightExpr, + inputs, + binding, + rem.dataType, + getFailOnError(rem), + (builder, mathExpr) => builder.setRemainder(mathExpr)) - case div @ Divide(left, right, _) - if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - // Datafusion now throws an exception for dividing by zero - // See https://github.com/apache/arrow-datafusion/pull/6792 - // For now, use NullIf to swap zeros with nulls. - val rightExpr = nullIfWhenPrimitive(right) + case rem @ Remainder(left, _, _) => + if (!supportedDataType(left.dataType)) { + withInfo(rem, s"Unsupported datatype ${left.dataType}") + } + if (decimalBeforeSpark34(left.dataType)) { + withInfo(rem, "Decimal support requires Spark 3.4 or later") + } + None - createMathExpression( - left, - rightExpr, - inputs, - div.dataType, - getFailOnError(div), - (builder, mathExpr) => builder.setDivide(mathExpr)) + case EqualTo(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setEq(binaryExpr)) + + case Not(EqualTo(left, right)) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setNeq(binaryExpr)) + + case EqualNullSafe(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setEqNullSafe(binaryExpr)) + + case Not(EqualNullSafe(left, right)) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr)) + + case GreaterThan(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setGt(binaryExpr)) + + case GreaterThanOrEqual(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setGtEq(binaryExpr)) + + case LessThan(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setLt(binaryExpr)) + + case LessThanOrEqual(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setLtEq(binaryExpr)) + + case Literal(value, dataType) if supportedDataType(dataType, allowStruct = value == null) => + val exprBuilder = ExprOuterClass.Literal.newBuilder() + + if (value == null) { + exprBuilder.setIsNull(true) + } else { + exprBuilder.setIsNull(false) + dataType match { + case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean]) + case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte]) + case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short]) + case _: IntegerType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case _: LongType => exprBuilder.setLongVal(value.asInstanceOf[Long]) + case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float]) + case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double]) + case _: StringType => + exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString) + case _: TimestampType => exprBuilder.setLongVal(value.asInstanceOf[Long]) + case _: DecimalType => + // Pass decimal literal as bytes. + val unscaled = value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue + exprBuilder.setDecimalVal( + com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray)) + case _: BinaryType => + val byteStr = + com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]]) + exprBuilder.setBytesVal(byteStr) + case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) + case dt if isTimestampNTZType(dt) => + exprBuilder.setLongVal(value.asInstanceOf[Long]) + case dt => + logWarning(s"Unexpected date type '$dt' for literal value '$value'") + } + } + + val dt = serializeDataType(dataType) + + if (dt.isDefined) { + exprBuilder.setDatatype(dt.get) - case div @ Divide(left, _, _) => - if (!supportedDataType(left.dataType)) { - withInfo(div, s"Unsupported datatype ${left.dataType}") - } - if (decimalBeforeSpark34(left.dataType)) { - withInfo(div, "Decimal support requires Spark 3.4 or later") - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setLiteral(exprBuilder) + .build()) + } else { + withInfo(expr, s"Unsupported datatype $dataType") None + } + case Literal(_, dataType) if !supportedDataType(dataType) => + withInfo(expr, s"Unsupported datatype $dataType") + None - case rem @ Remainder(left, right, _) - if supportedDataType(left.dataType) && !decimalBeforeSpark34(left.dataType) => - val rightExpr = nullIfWhenPrimitive(right) + case Substring(str, Literal(pos, _), Literal(len, _)) => + val strExpr = exprToProtoInternal(str, inputs, binding) - createMathExpression( - left, - rightExpr, - inputs, - rem.dataType, - getFailOnError(rem), - (builder, mathExpr) => builder.setRemainder(mathExpr)) + if (strExpr.isDefined) { + val builder = ExprOuterClass.Substring.newBuilder() + builder.setChild(strExpr.get) + builder.setStart(pos.asInstanceOf[Int]) + builder.setLen(len.asInstanceOf[Int]) - case rem @ Remainder(left, _, _) => - if (!supportedDataType(left.dataType)) { - withInfo(rem, s"Unsupported datatype ${left.dataType}") - } - if (decimalBeforeSpark34(left.dataType)) { - withInfo(rem, "Decimal support requires Spark 3.4 or later") - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setSubstring(builder) + .build()) + } else { + withInfo(expr, str) None + } - case EqualTo(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setEq(binaryExpr)) - - case Not(EqualTo(left, right)) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setNeq(binaryExpr)) - - case EqualNullSafe(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setEqNullSafe(binaryExpr)) - - case Not(EqualNullSafe(left, right)) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr)) - - case GreaterThan(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setGt(binaryExpr)) - - case GreaterThanOrEqual(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setGtEq(binaryExpr)) - - case LessThan(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setLt(binaryExpr)) - - case LessThanOrEqual(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setLtEq(binaryExpr)) - - case Literal(value, dataType) - if supportedDataType(dataType, allowStruct = value == null) => - val exprBuilder = ExprOuterClass.Literal.newBuilder() - - if (value == null) { - exprBuilder.setIsNull(true) - } else { - exprBuilder.setIsNull(false) - dataType match { - case _: BooleanType => exprBuilder.setBoolVal(value.asInstanceOf[Boolean]) - case _: ByteType => exprBuilder.setByteVal(value.asInstanceOf[Byte]) - case _: ShortType => exprBuilder.setShortVal(value.asInstanceOf[Short]) - case _: IntegerType => exprBuilder.setIntVal(value.asInstanceOf[Int]) - case _: LongType => exprBuilder.setLongVal(value.asInstanceOf[Long]) - case _: FloatType => exprBuilder.setFloatVal(value.asInstanceOf[Float]) - case _: DoubleType => exprBuilder.setDoubleVal(value.asInstanceOf[Double]) - case _: StringType => - exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString) - case _: TimestampType => exprBuilder.setLongVal(value.asInstanceOf[Long]) - case _: DecimalType => - // Pass decimal literal as bytes. - val unscaled = value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue - exprBuilder.setDecimalVal( - com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray)) - case _: BinaryType => - val byteStr = - com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]]) - exprBuilder.setBytesVal(byteStr) - case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int]) - case dt if isTimestampNTZType(dt) => - exprBuilder.setLongVal(value.asInstanceOf[Long]) - case dt => - logWarning(s"Unexpected date type '$dt' for literal value '$value'") - } - } - - val dt = serializeDataType(dataType) - - if (dt.isDefined) { - exprBuilder.setDatatype(dt.get) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setLiteral(exprBuilder) - .build()) - } else { - withInfo(expr, s"Unsupported datatype $dataType") - None - } - case Literal(_, dataType) if !supportedDataType(dataType) => - withInfo(expr, s"Unsupported datatype $dataType") + case StructsToJson(options, child, timezoneId) => + if (options.nonEmpty) { + withInfo(expr, "StructsToJson with options is not supported") None + } else { - case Substring(str, Literal(pos, _), Literal(len, _)) => - val strExpr = exprToProtoInternal(str, inputs) - - if (strExpr.isDefined) { - val builder = ExprOuterClass.Substring.newBuilder() - builder.setChild(strExpr.get) - builder.setStart(pos.asInstanceOf[Int]) - builder.setLen(len.asInstanceOf[Int]) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setSubstring(builder) - .build()) - } else { - withInfo(expr, str) - None - } - - case StructsToJson(options, child, timezoneId) => - if (options.nonEmpty) { - withInfo(expr, "StructsToJson with options is not supported") - None - } else { - - def isSupportedType(dt: DataType): Boolean = { - dt match { - case StructType(fields) => - fields.forall(f => isSupportedType(f.dataType)) - case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType | - DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | - DataTypes.DoubleType | DataTypes.StringType => - true - case DataTypes.DateType | DataTypes.TimestampType => - // TODO implement these types with tests for formatting options and timezone - false - case _: MapType | _: ArrayType => - // Spark supports map and array in StructsToJson but this is not yet - // implemented in Comet - false - case _ => false - } - } - - val isSupported = child.dataType match { - case s: StructType => - s.fields.forall(f => isSupportedType(f.dataType)) + def isSupportedType(dt: DataType): Boolean = { + dt match { + case StructType(fields) => + fields.forall(f => isSupportedType(f.dataType)) + case DataTypes.BooleanType | DataTypes.ByteType | DataTypes.ShortType | + DataTypes.IntegerType | DataTypes.LongType | DataTypes.FloatType | + DataTypes.DoubleType | DataTypes.StringType => + true + case DataTypes.DateType | DataTypes.TimestampType => + // TODO implement these types with tests for formatting options and timezone + false case _: MapType | _: ArrayType => // Spark supports map and array in StructsToJson but this is not yet // implemented in Comet false - case _ => - false + case _ => false } + } - if (isSupported) { - exprToProto(child, input, binding) match { - case Some(p) => - val toJson = ExprOuterClass.ToJson - .newBuilder() - .setChild(p) - .setTimezone(timezoneId.getOrElse("UTC")) - .setIgnoreNullFields(true) - .build() - Some( - ExprOuterClass.Expr - .newBuilder() - .setToJson(toJson) - .build()) - case _ => - withInfo(expr, child) - None - } - } else { - withInfo(expr, "Unsupported data type", child) - None - } + val isSupported = child.dataType match { + case s: StructType => + s.fields.forall(f => isSupportedType(f.dataType)) + case _: MapType | _: ArrayType => + // Spark supports map and array in StructsToJson but this is not yet + // implemented in Comet + false + case _ => + false } - case Like(left, right, escapeChar) => - if (escapeChar == '\\') { - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setLike(binaryExpr)) + if (isSupported) { + exprToProto(child, inputs, binding) match { + case Some(p) => + val toJson = ExprOuterClass.ToJson + .newBuilder() + .setChild(p) + .setTimezone(timezoneId.getOrElse("UTC")) + .setIgnoreNullFields(true) + .build() + Some( + ExprOuterClass.Expr + .newBuilder() + .setToJson(toJson) + .build()) + case _ => + withInfo(expr, child) + None + } } else { - // TODO custom escape char - withInfo(expr, s"custom escape character $escapeChar not supported in LIKE") + withInfo(expr, "Unsupported data type", child) None } + } - case RLike(left, right) => - // we currently only support scalar regex patterns - right match { - case Literal(pattern, DataTypes.StringType) => - if (!RegExp.isSupportedPattern(pattern.toString) && - !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) { - withInfo( - expr, - s"Regexp pattern $pattern is not compatible with Spark. " + - s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " + - "to allow it anyway.") - return None - } - case _ => - withInfo(expr, "Only scalar regexp patterns are supported") - return None - } - + case Like(left, right, escapeChar) => + if (escapeChar == '\\') { createBinaryExpr( + expr, left, right, inputs, - (builder, binaryExpr) => builder.setRlike(binaryExpr)) + binding, + (builder, binaryExpr) => builder.setLike(binaryExpr)) + } else { + // TODO custom escape char + withInfo(expr, s"custom escape character $escapeChar not supported in LIKE") + None + } - case StartsWith(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setStartsWith(binaryExpr)) + case RLike(left, right) => + // we currently only support scalar regex patterns + right match { + case Literal(pattern, DataTypes.StringType) => + if (!RegExp.isSupportedPattern(pattern.toString) && + !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) { + withInfo( + expr, + s"Regexp pattern $pattern is not compatible with Spark. " + + s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " + + "to allow it anyway.") + return None + } + case _ => + withInfo(expr, "Only scalar regexp patterns are supported") + return None + } - case EndsWith(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setEndsWith(binaryExpr)) + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setRlike(binaryExpr)) + + case StartsWith(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setStartsWith(binaryExpr)) + + case EndsWith(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setEndsWith(binaryExpr)) + + case Contains(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setContains(binaryExpr)) + + case StringSpace(child) => + createUnaryExpr( + expr, + child, + inputs, + binding, + (builder, unaryExpr) => builder.setStringSpace(unaryExpr)) + + case Hour(child, timeZoneId) => + val childExpr = exprToProtoInternal(child, inputs, binding) - case Contains(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setContains(binaryExpr)) + if (childExpr.isDefined) { + val builder = ExprOuterClass.Hour.newBuilder() + builder.setChild(childExpr.get) - case StringSpace(child) => - createUnaryExpr( - child, - inputs, - (builder, unaryExpr) => builder.setStringSpace(unaryExpr)) + val timeZone = timeZoneId.getOrElse("UTC") + builder.setTimezone(timeZone) - case Hour(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) + Some( + ExprOuterClass.Expr + .newBuilder() + .setHour(builder) + .build()) + } else { + withInfo(expr, child) + None + } - if (childExpr.isDefined) { - val builder = ExprOuterClass.Hour.newBuilder() - builder.setChild(childExpr.get) + case Minute(child, timeZoneId) => + val childExpr = exprToProtoInternal(child, inputs, binding) - val timeZone = timeZoneId.getOrElse("UTC") - builder.setTimezone(timeZone) + if (childExpr.isDefined) { + val builder = ExprOuterClass.Minute.newBuilder() + builder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setHour(builder) - .build()) - } else { - withInfo(expr, child) - None - } + val timeZone = timeZoneId.getOrElse("UTC") + builder.setTimezone(timeZone) - case Minute(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) + Some( + ExprOuterClass.Expr + .newBuilder() + .setMinute(builder) + .build()) + } else { + withInfo(expr, child) + None + } - if (childExpr.isDefined) { - val builder = ExprOuterClass.Minute.newBuilder() - builder.setChild(childExpr.get) + case DateAdd(left, right) => + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) + val optExpr = scalarExprToProtoWithReturnType("date_add", DateType, leftExpr, rightExpr) + optExprWithInfo(optExpr, expr, left, right) - val timeZone = timeZoneId.getOrElse("UTC") - builder.setTimezone(timeZone) + case DateSub(left, right) => + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) + val optExpr = scalarExprToProtoWithReturnType("date_sub", DateType, leftExpr, rightExpr) + optExprWithInfo(optExpr, expr, left, right) - Some( - ExprOuterClass.Expr - .newBuilder() - .setMinute(builder) - .build()) - } else { - withInfo(expr, child) - None - } + case TruncDate(child, format) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val formatExpr = exprToProtoInternal(format, inputs, binding) - case DateAdd(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - val optExpr = scalarExprToProtoWithReturnType("date_add", DateType, leftExpr, rightExpr) - optExprWithInfo(optExpr, expr, left, right) + if (childExpr.isDefined && formatExpr.isDefined) { + val builder = ExprOuterClass.TruncDate.newBuilder() + builder.setChild(childExpr.get) + builder.setFormat(formatExpr.get) - case DateSub(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - val optExpr = scalarExprToProtoWithReturnType("date_sub", DateType, leftExpr, rightExpr) - optExprWithInfo(optExpr, expr, left, right) + Some( + ExprOuterClass.Expr + .newBuilder() + .setTruncDate(builder) + .build()) + } else { + withInfo(expr, child, format) + None + } - case TruncDate(child, format) => - val childExpr = exprToProtoInternal(child, inputs) - val formatExpr = exprToProtoInternal(format, inputs) + case TruncTimestamp(format, child, timeZoneId) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val formatExpr = exprToProtoInternal(format, inputs, binding) - if (childExpr.isDefined && formatExpr.isDefined) { - val builder = ExprOuterClass.TruncDate.newBuilder() - builder.setChild(childExpr.get) - builder.setFormat(formatExpr.get) + if (childExpr.isDefined && formatExpr.isDefined) { + val builder = ExprOuterClass.TruncTimestamp.newBuilder() + builder.setChild(childExpr.get) + builder.setFormat(formatExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setTruncDate(builder) - .build()) - } else { - withInfo(expr, child, format) - None - } + val timeZone = timeZoneId.getOrElse("UTC") + builder.setTimezone(timeZone) - case TruncTimestamp(format, child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) - val formatExpr = exprToProtoInternal(format, inputs) + Some( + ExprOuterClass.Expr + .newBuilder() + .setTruncTimestamp(builder) + .build()) + } else { + withInfo(expr, child, format) + None + } - if (childExpr.isDefined && formatExpr.isDefined) { - val builder = ExprOuterClass.TruncTimestamp.newBuilder() - builder.setChild(childExpr.get) - builder.setFormat(formatExpr.get) + case Second(child, timeZoneId) => + val childExpr = exprToProtoInternal(child, inputs, binding) - val timeZone = timeZoneId.getOrElse("UTC") - builder.setTimezone(timeZone) + if (childExpr.isDefined) { + val builder = ExprOuterClass.Second.newBuilder() + builder.setChild(childExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setTruncTimestamp(builder) - .build()) - } else { - withInfo(expr, child, format) - None - } + val timeZone = timeZoneId.getOrElse("UTC") + builder.setTimezone(timeZone) + + Some( + ExprOuterClass.Expr + .newBuilder() + .setSecond(builder) + .build()) + } else { + withInfo(expr, child) + None + } - case Second(child, timeZoneId) => - val childExpr = exprToProtoInternal(child, inputs) + case Year(child) => + val periodType = exprToProtoInternal(Literal("year"), inputs, binding) + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = scalarExprToProto("datepart", Seq(periodType, childExpr): _*) + .map(e => { + Expr + .newBuilder() + .setCast( + ExprOuterClass.Cast + .newBuilder() + .setChild(e) + .setDatatype(serializeDataType(IntegerType).get) + .setEvalMode(ExprOuterClass.EvalMode.LEGACY) + .setAllowIncompat(false) + .build()) + .build() + }) + optExprWithInfo(optExpr, expr, child) + + case IsNull(child) => + createUnaryExpr( + expr, + child, + inputs, + binding, + (builder, unaryExpr) => builder.setIsNull(unaryExpr)) + + case IsNotNull(child) => + createUnaryExpr( + expr, + child, + inputs, + binding, + (builder, unaryExpr) => builder.setIsNotNull(unaryExpr)) + + case IsNaN(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = + scalarExprToProtoWithReturnType("isnan", BooleanType, childExpr) + + optExprWithInfo(optExpr, expr, child) + + case SortOrder(child, direction, nullOrdering, _) => + val childExpr = exprToProtoInternal(child, inputs, binding) - if (childExpr.isDefined) { - val builder = ExprOuterClass.Second.newBuilder() - builder.setChild(childExpr.get) + if (childExpr.isDefined) { + val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder() + sortOrderBuilder.setChild(childExpr.get) - val timeZone = timeZoneId.getOrElse("UTC") - builder.setTimezone(timeZone) + direction match { + case Ascending => sortOrderBuilder.setDirectionValue(0) + case Descending => sortOrderBuilder.setDirectionValue(1) + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setSecond(builder) - .build()) - } else { - withInfo(expr, child) - None + nullOrdering match { + case NullsFirst => sortOrderBuilder.setNullOrderingValue(0) + case NullsLast => sortOrderBuilder.setNullOrderingValue(1) } - case Year(child) => - val periodType = exprToProtoInternal(Literal("year"), inputs) - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = scalarExprToProto("datepart", Seq(periodType, childExpr): _*) - .map(e => { - Expr - .newBuilder() - .setCast( - ExprOuterClass.Cast - .newBuilder() - .setChild(e) - .setDatatype(serializeDataType(IntegerType).get) - .setEvalMode(ExprOuterClass.EvalMode.LEGACY) - .setAllowIncompat(false) - .build()) - .build() - }) - optExprWithInfo(optExpr, expr, child) + Some( + ExprOuterClass.Expr + .newBuilder() + .setSortOrder(sortOrderBuilder) + .build()) + } else { + withInfo(expr, child) + None + } - case IsNull(child) => - createUnaryExpr(child, inputs, (builder, unaryExpr) => builder.setIsNull(unaryExpr)) + case And(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setAnd(binaryExpr)) + + case Or(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setOr(binaryExpr)) + + case UnaryExpression(child) if expr.prettyName == "promote_precision" => + // `UnaryExpression` includes `PromotePrecision` for Spark 3.3 + // `PromotePrecision` is just a wrapper, don't need to serialize it. + exprToProtoInternal(child, inputs, binding) + + case CheckOverflow(child, dt, nullOnOverflow) => + val childExpr = exprToProtoInternal(child, inputs, binding) - case IsNotNull(child) => - createUnaryExpr(child, inputs, (builder, unaryExpr) => builder.setIsNotNull(unaryExpr)) + if (childExpr.isDefined) { + val builder = ExprOuterClass.CheckOverflow.newBuilder() + builder.setChild(childExpr.get) + builder.setFailOnError(!nullOnOverflow) - case IsNaN(child) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = - scalarExprToProtoWithReturnType("isnan", BooleanType, childExpr) + // `dataType` must be decimal type + val dataType = serializeDataType(dt) + builder.setDatatype(dataType.get) - optExprWithInfo(optExpr, expr, child) + Some( + ExprOuterClass.Expr + .newBuilder() + .setCheckOverflow(builder) + .build()) + } else { + withInfo(expr, child) + None + } - case SortOrder(child, direction, nullOrdering, _) => - val childExpr = exprToProtoInternal(child, inputs) + case attr: AttributeReference => + val dataType = serializeDataType(attr.dataType) - if (childExpr.isDefined) { - val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder() - sortOrderBuilder.setChild(childExpr.get) + if (dataType.isDefined) { + if (binding) { + // Spark may produce unresolvable attributes in some cases, + // for example https://github.com/apache/datafusion-comet/issues/925. + // So, we allow the binding to fail. + val boundRef: Any = BindReferences + .bindReference(attr, inputs, allowFailures = true) - direction match { - case Ascending => sortOrderBuilder.setDirectionValue(0) - case Descending => sortOrderBuilder.setDirectionValue(1) + if (boundRef.isInstanceOf[AttributeReference]) { + withInfo(attr, s"cannot resolve $attr among ${inputs.mkString(", ")}") + return None } - nullOrdering match { - case NullsFirst => sortOrderBuilder.setNullOrderingValue(0) - case NullsLast => sortOrderBuilder.setNullOrderingValue(1) - } + val boundExpr = ExprOuterClass.BoundReference + .newBuilder() + .setIndex(boundRef.asInstanceOf[BoundReference].ordinal) + .setDatatype(dataType.get) + .build() Some( ExprOuterClass.Expr .newBuilder() - .setSortOrder(sortOrderBuilder) + .setBound(boundExpr) .build()) } else { - withInfo(expr, child) - None - } - - case And(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setAnd(binaryExpr)) - - case Or(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setOr(binaryExpr)) - - case UnaryExpression(child) if expr.prettyName == "promote_precision" => - // `UnaryExpression` includes `PromotePrecision` for Spark 3.3 - // `PromotePrecision` is just a wrapper, don't need to serialize it. - exprToProtoInternal(child, inputs) - - case CheckOverflow(child, dt, nullOnOverflow) => - val childExpr = exprToProtoInternal(child, inputs) - - if (childExpr.isDefined) { - val builder = ExprOuterClass.CheckOverflow.newBuilder() - builder.setChild(childExpr.get) - builder.setFailOnError(!nullOnOverflow) - - // `dataType` must be decimal type - val dataType = serializeDataType(dt) - builder.setDatatype(dataType.get) + val unboundRef = ExprOuterClass.UnboundReference + .newBuilder() + .setName(attr.name) + .setDatatype(dataType.get) + .build() Some( ExprOuterClass.Expr .newBuilder() - .setCheckOverflow(builder) + .setUnbound(unboundRef) .build()) - } else { - withInfo(expr, child) - None - } - - case attr: AttributeReference => - val dataType = serializeDataType(attr.dataType) - - if (dataType.isDefined) { - if (binding) { - // Spark may produce unresolvable attributes in some cases, - // for example https://github.com/apache/datafusion-comet/issues/925. - // So, we allow the binding to fail. - val boundRef: Any = BindReferences - .bindReference(attr, inputs, allowFailures = true) - - if (boundRef.isInstanceOf[AttributeReference]) { - withInfo(attr, s"cannot resolve $attr among ${inputs.mkString(", ")}") - return None - } - - val boundExpr = ExprOuterClass.BoundReference - .newBuilder() - .setIndex(boundRef.asInstanceOf[BoundReference].ordinal) - .setDatatype(dataType.get) - .build() - - Some( - ExprOuterClass.Expr - .newBuilder() - .setBound(boundExpr) - .build()) - } else { - val unboundRef = ExprOuterClass.UnboundReference - .newBuilder() - .setName(attr.name) - .setDatatype(dataType.get) - .build() - - Some( - ExprOuterClass.Expr - .newBuilder() - .setUnbound(unboundRef) - .build()) - } - } else { - withInfo(attr, s"unsupported datatype: ${attr.dataType}") - None } + } else { + withInfo(attr, s"unsupported datatype: ${attr.dataType}") + None + } - // abs implementation is not correct - // https://github.com/apache/datafusion-comet/issues/666 + // abs implementation is not correct + // https://github.com/apache/datafusion-comet/issues/666 // case Abs(child, failOnErr) => // val childExpr = exprToProtoInternal(child, inputs) // if (childExpr.isDefined) { @@ -1557,950 +1624,967 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim // None // } - case Acos(child) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = scalarExprToProto("acos", childExpr) - optExprWithInfo(optExpr, expr, child) - - case Asin(child) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = scalarExprToProto("asin", childExpr) - optExprWithInfo(optExpr, expr, child) - - case Atan(child) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = scalarExprToProto("atan", childExpr) - optExprWithInfo(optExpr, expr, child) - - case Atan2(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr) - optExprWithInfo(optExpr, expr, left, right) - - case Hex(child) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = - scalarExprToProtoWithReturnType("hex", StringType, childExpr) - - optExprWithInfo(optExpr, expr, child) - - case e: Unhex => - val unHex = unhexSerde(e) - - val childExpr = exprToProtoInternal(unHex._1, inputs) - val failOnErrorExpr = exprToProtoInternal(unHex._2, inputs) - - val optExpr = - scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr) - optExprWithInfo(optExpr, expr, unHex._1) - - case e @ Ceil(child) => - val childExpr = exprToProtoInternal(child, inputs) - child.dataType match { - case t: DecimalType if t.scale == 0 => // zero scale is no-op - childExpr - case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 - withInfo(e, s"Decimal type $t has negative scale") - None - case _ => - val optExpr = scalarExprToProtoWithReturnType("ceil", e.dataType, childExpr) - optExprWithInfo(optExpr, expr, child) - } - - case Cos(child) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = scalarExprToProto("cos", childExpr) - optExprWithInfo(optExpr, expr, child) - - case Exp(child) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = scalarExprToProto("exp", childExpr) - optExprWithInfo(optExpr, expr, child) - - case e @ Floor(child) => - val childExpr = exprToProtoInternal(child, inputs) - child.dataType match { - case t: DecimalType if t.scale == 0 => // zero scale is no-op - childExpr - case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 - withInfo(e, s"Decimal type $t has negative scale") - None - case _ => - val optExpr = scalarExprToProtoWithReturnType("floor", e.dataType, childExpr) - optExprWithInfo(optExpr, expr, child) - } - - // The expression for `log` functions is defined as null on numbers less than or equal - // to 0. This matches Spark and Hive behavior, where non positive values eval to null - // instead of NaN or -Infinity. - case Log(child) => - val childExpr = exprToProtoInternal(nullIfNegative(child), inputs) - val optExpr = scalarExprToProto("ln", childExpr) - optExprWithInfo(optExpr, expr, child) - - case Log10(child) => - val childExpr = exprToProtoInternal(nullIfNegative(child), inputs) - val optExpr = scalarExprToProto("log10", childExpr) - optExprWithInfo(optExpr, expr, child) - - case Log2(child) => - val childExpr = exprToProtoInternal(nullIfNegative(child), inputs) - val optExpr = scalarExprToProto("log2", childExpr) - optExprWithInfo(optExpr, expr, child) - - case Pow(left, right) => - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - val optExpr = scalarExprToProto("pow", leftExpr, rightExpr) - optExprWithInfo(optExpr, expr, left, right) - - case r: Round => - // _scale s a constant, copied from Spark's RoundBase because it is a protected val - val scaleV: Any = r.scale.eval(EmptyRow) - val _scale: Int = scaleV.asInstanceOf[Int] - - lazy val childExpr = exprToProtoInternal(r.child, inputs) - r.child.dataType match { - case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 - withInfo(r, "Decimal type has negative scale") - None - case _ if scaleV == null => - exprToProtoInternal(Literal(null), inputs) - case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => - childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark - case _: FloatType | DoubleType => - // We cannot properly match with the Spark behavior for floating-point numbers. - // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a - // double to string internally in order to create its own internal representation. - // The problem is BigDecimal uses java.lang.Double.toString() and it has complicated - // rounding algorithm. E.g. -5.81855622136895E8 is actually - // -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of - // 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a - // difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be - // -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that - // toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can - // be rounded up to 6.13171162472835E18 that still represents the same double number. - // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not. - // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead - // of 6.1317116247283999E18. - withInfo(r, "Comet does not support Spark's BigDecimal rounding") - None - case _ => - // `scale` must be Int64 type in DataFusion - val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs) - val optExpr = - scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) - optExprWithInfo(optExpr, expr, r.child) - } + case Acos(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = scalarExprToProto("acos", childExpr) + optExprWithInfo(optExpr, expr, child) + + case Asin(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = scalarExprToProto("asin", childExpr) + optExprWithInfo(optExpr, expr, child) + + case Atan(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = scalarExprToProto("atan", childExpr) + optExprWithInfo(optExpr, expr, child) + + case Atan2(left, right) => + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) + val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr) + optExprWithInfo(optExpr, expr, left, right) + + case Hex(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = + scalarExprToProtoWithReturnType("hex", StringType, childExpr) + + optExprWithInfo(optExpr, expr, child) + + case e: Unhex => + val unHex = unhexSerde(e) + + val childExpr = exprToProtoInternal(unHex._1, inputs, binding) + val failOnErrorExpr = exprToProtoInternal(unHex._2, inputs, binding) + + val optExpr = + scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr) + optExprWithInfo(optExpr, expr, unHex._1) + + case e @ Ceil(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + child.dataType match { + case t: DecimalType if t.scale == 0 => // zero scale is no-op + childExpr + case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + withInfo(e, s"Decimal type $t has negative scale") + None + case _ => + val optExpr = scalarExprToProtoWithReturnType("ceil", e.dataType, childExpr) + optExprWithInfo(optExpr, expr, child) + } + + case Cos(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = scalarExprToProto("cos", childExpr) + optExprWithInfo(optExpr, expr, child) + + case Exp(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = scalarExprToProto("exp", childExpr) + optExprWithInfo(optExpr, expr, child) + + case e @ Floor(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + child.dataType match { + case t: DecimalType if t.scale == 0 => // zero scale is no-op + childExpr + case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + withInfo(e, s"Decimal type $t has negative scale") + None + case _ => + val optExpr = scalarExprToProtoWithReturnType("floor", e.dataType, childExpr) + optExprWithInfo(optExpr, expr, child) + } + + // The expression for `log` functions is defined as null on numbers less than or equal + // to 0. This matches Spark and Hive behavior, where non positive values eval to null + // instead of NaN or -Infinity. + case Log(child) => + val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, binding) + val optExpr = scalarExprToProto("ln", childExpr) + optExprWithInfo(optExpr, expr, child) + + case Log10(child) => + val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, binding) + val optExpr = scalarExprToProto("log10", childExpr) + optExprWithInfo(optExpr, expr, child) + + case Log2(child) => + val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, binding) + val optExpr = scalarExprToProto("log2", childExpr) + optExprWithInfo(optExpr, expr, child) + + case Pow(left, right) => + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) + val optExpr = scalarExprToProto("pow", leftExpr, rightExpr) + optExprWithInfo(optExpr, expr, left, right) + + case r: Round => + // _scale s a constant, copied from Spark's RoundBase because it is a protected val + val scaleV: Any = r.scale.eval(EmptyRow) + val _scale: Int = scaleV.asInstanceOf[Int] + + lazy val childExpr = exprToProtoInternal(r.child, inputs, binding) + r.child.dataType match { + case t: DecimalType if t.scale < 0 => // Spark disallows negative scale SPARK-30252 + withInfo(r, "Decimal type has negative scale") + None + case _ if scaleV == null => + exprToProtoInternal(Literal(null), inputs, binding) + case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 => + childExpr // _scale(I.e. decimal place) >= 0 is a no-op for integer types in Spark + case _: FloatType | DoubleType => + // We cannot properly match with the Spark behavior for floating-point numbers. + // Spark uses BigDecimal for rounding float/double, and BigDecimal fist converts a + // double to string internally in order to create its own internal representation. + // The problem is BigDecimal uses java.lang.Double.toString() and it has complicated + // rounding algorithm. E.g. -5.81855622136895E8 is actually + // -581855622.13689494132995605468750. Note the 5th fractional digit is 4 instead of + // 5. Java(Scala)'s toString() rounds it up to -581855622.136895. This makes a + // difference when rounding at 5th digit, I.e. round(-5.81855622136895E8, 5) should be + // -5.818556221369E8, instead of -5.8185562213689E8. There is also an example that + // toString() does NOT round up. 6.1317116247283497E18 is 6131711624728349696. It can + // be rounded up to 6.13171162472835E18 that still represents the same double number. + // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, toString() does not. + // That results in round(6.1317116247283497E18, -5) == 6.1317116247282995E18 instead + // of 6.1317116247283999E18. + withInfo(r, "Comet does not support Spark's BigDecimal rounding") + None + case _ => + // `scale` must be Int64 type in DataFusion + val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, LongType), inputs, binding) + val optExpr = + scalarExprToProtoWithReturnType("round", r.dataType, childExpr, scaleExpr) + optExprWithInfo(optExpr, expr, r.child) + } - // TODO enable once https://github.com/apache/datafusion/issues/11557 is fixed or - // when we have a Spark-compatible version implemented in Comet + // TODO enable once https://github.com/apache/datafusion/issues/11557 is fixed or + // when we have a Spark-compatible version implemented in Comet // case Signum(child) => // val childExpr = exprToProtoInternal(child, inputs) // val optExpr = scalarExprToProto("signum", childExpr) // optExprWithInfo(optExpr, expr, child) - case Sin(child) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = scalarExprToProto("sin", childExpr) - optExprWithInfo(optExpr, expr, child) - - case Sqrt(child) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = scalarExprToProto("sqrt", childExpr) - optExprWithInfo(optExpr, expr, child) - - case Tan(child) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = scalarExprToProto("tan", childExpr) - optExprWithInfo(optExpr, expr, child) - - case Ascii(child) => - val castExpr = Cast(child, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs) - val optExpr = scalarExprToProto("ascii", childExpr) - optExprWithInfo(optExpr, expr, castExpr) - - case BitLength(child) => - val castExpr = Cast(child, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs) - val optExpr = scalarExprToProto("bit_length", childExpr) - optExprWithInfo(optExpr, expr, castExpr) - - case If(predicate, trueValue, falseValue) => - val predicateExpr = exprToProtoInternal(predicate, inputs) - val trueExpr = exprToProtoInternal(trueValue, inputs) - val falseExpr = exprToProtoInternal(falseValue, inputs) - if (predicateExpr.isDefined && trueExpr.isDefined && falseExpr.isDefined) { - val builder = ExprOuterClass.IfExpr.newBuilder() - builder.setIfExpr(predicateExpr.get) - builder.setTrueExpr(trueExpr.get) - builder.setFalseExpr(falseExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setIf(builder) - .build()) - } else { - withInfo(expr, predicate, trueValue, falseValue) - None - } + case Sin(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = scalarExprToProto("sin", childExpr) + optExprWithInfo(optExpr, expr, child) + + case Sqrt(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = scalarExprToProto("sqrt", childExpr) + optExprWithInfo(optExpr, expr, child) + + case Tan(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = scalarExprToProto("tan", childExpr) + optExprWithInfo(optExpr, expr, child) + + case Ascii(child) => + val castExpr = Cast(child, StringType) + val childExpr = exprToProtoInternal(castExpr, inputs, binding) + val optExpr = scalarExprToProto("ascii", childExpr) + optExprWithInfo(optExpr, expr, castExpr) + + case BitLength(child) => + val castExpr = Cast(child, StringType) + val childExpr = exprToProtoInternal(castExpr, inputs, binding) + val optExpr = scalarExprToProto("bit_length", childExpr) + optExprWithInfo(optExpr, expr, castExpr) + + case If(predicate, trueValue, falseValue) => + val predicateExpr = exprToProtoInternal(predicate, inputs, binding) + val trueExpr = exprToProtoInternal(trueValue, inputs, binding) + val falseExpr = exprToProtoInternal(falseValue, inputs, binding) + if (predicateExpr.isDefined && trueExpr.isDefined && falseExpr.isDefined) { + val builder = ExprOuterClass.IfExpr.newBuilder() + builder.setIfExpr(predicateExpr.get) + builder.setTrueExpr(trueExpr.get) + builder.setFalseExpr(falseExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setIf(builder) + .build()) + } else { + withInfo(expr, predicate, trueValue, falseValue) + None + } - case CaseWhen(branches, elseValue) => - var allBranches: Seq[Expression] = Seq() - val whenSeq = branches.map(elements => { - allBranches = allBranches :+ elements._1 - exprToProtoInternal(elements._1, inputs) - }) - val thenSeq = branches.map(elements => { - allBranches = allBranches :+ elements._2 - exprToProtoInternal(elements._2, inputs) - }) - assert(whenSeq.length == thenSeq.length) - if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) { - val builder = ExprOuterClass.CaseWhen.newBuilder() - builder.addAllWhen(whenSeq.map(_.get).asJava) - builder.addAllThen(thenSeq.map(_.get).asJava) - if (elseValue.isDefined) { - val elseValueExpr = - exprToProtoInternal(elseValue.get, inputs) - if (elseValueExpr.isDefined) { - builder.setElseExpr(elseValueExpr.get) - } else { - withInfo(expr, elseValue.get) - return None - } + case CaseWhen(branches, elseValue) => + var allBranches: Seq[Expression] = Seq() + val whenSeq = branches.map(elements => { + allBranches = allBranches :+ elements._1 + exprToProtoInternal(elements._1, inputs, binding) + }) + val thenSeq = branches.map(elements => { + allBranches = allBranches :+ elements._2 + exprToProtoInternal(elements._2, inputs, binding) + }) + assert(whenSeq.length == thenSeq.length) + if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) { + val builder = ExprOuterClass.CaseWhen.newBuilder() + builder.addAllWhen(whenSeq.map(_.get).asJava) + builder.addAllThen(thenSeq.map(_.get).asJava) + if (elseValue.isDefined) { + val elseValueExpr = + exprToProtoInternal(elseValue.get, inputs, binding) + if (elseValueExpr.isDefined) { + builder.setElseExpr(elseValueExpr.get) + } else { + withInfo(expr, elseValue.get) + return None } - Some( - ExprOuterClass.Expr - .newBuilder() - .setCaseWhen(builder) - .build()) - } else { - withInfo(expr, allBranches: _*) - None - } - case ConcatWs(children) => - var childExprs: Seq[Expression] = Seq() - val exprs = children.map(e => { - val castExpr = Cast(e, StringType) - childExprs = childExprs :+ castExpr - exprToProtoInternal(castExpr, inputs) - }) - val optExpr = scalarExprToProto("concat_ws", exprs: _*) - optExprWithInfo(optExpr, expr, childExprs: _*) - - case Chr(child) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = scalarExprToProto("chr", childExpr) - optExprWithInfo(optExpr, expr, child) - - case InitCap(child) => - if (CometConf.COMET_EXEC_INITCAP_ENABLED.get()) { - val castExpr = Cast(child, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs) - val optExpr = scalarExprToProto("initcap", childExpr) - optExprWithInfo(optExpr, expr, castExpr) - } else { - withInfo( - expr, - "Comet initCap is not compatible with Spark yet. " + - "See https://github.com/apache/datafusion-comet/issues/1052 ." + - s"Set ${CometConf.COMET_EXEC_INITCAP_ENABLED.key}=true to enable it anyway.") - None } - - case Length(child) => + Some( + ExprOuterClass.Expr + .newBuilder() + .setCaseWhen(builder) + .build()) + } else { + withInfo(expr, allBranches: _*) + None + } + case ConcatWs(children) => + var childExprs: Seq[Expression] = Seq() + val exprs = children.map(e => { + val castExpr = Cast(e, StringType) + childExprs = childExprs :+ castExpr + exprToProtoInternal(castExpr, inputs, binding) + }) + val optExpr = scalarExprToProto("concat_ws", exprs: _*) + optExprWithInfo(optExpr, expr, childExprs: _*) + + case Chr(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = scalarExprToProto("chr", childExpr) + optExprWithInfo(optExpr, expr, child) + + case InitCap(child) => + if (CometConf.COMET_EXEC_INITCAP_ENABLED.get()) { val castExpr = Cast(child, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs) - val optExpr = scalarExprToProto("length", childExpr) + val childExpr = exprToProtoInternal(castExpr, inputs, binding) + val optExpr = scalarExprToProto("initcap", childExpr) optExprWithInfo(optExpr, expr, castExpr) + } else { + withInfo( + expr, + "Comet initCap is not compatible with Spark yet. " + + "See https://github.com/apache/datafusion-comet/issues/1052 ." + + s"Set ${CometConf.COMET_EXEC_INITCAP_ENABLED.key}=true to enable it anyway.") + None + } - case Md5(child) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = scalarExprToProto("md5", childExpr) - optExprWithInfo(optExpr, expr, child) - - case OctetLength(child) => + case Length(child) => + val castExpr = Cast(child, StringType) + val childExpr = exprToProtoInternal(castExpr, inputs, binding) + val optExpr = scalarExprToProto("length", childExpr) + optExprWithInfo(optExpr, expr, castExpr) + + case Md5(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = scalarExprToProto("md5", childExpr) + optExprWithInfo(optExpr, expr, child) + + case OctetLength(child) => + val castExpr = Cast(child, StringType) + val childExpr = exprToProtoInternal(castExpr, inputs, binding) + val optExpr = scalarExprToProto("octet_length", childExpr) + optExprWithInfo(optExpr, expr, castExpr) + + case Reverse(child) => + val castExpr = Cast(child, StringType) + val childExpr = exprToProtoInternal(castExpr, inputs, binding) + val optExpr = scalarExprToProto("reverse", childExpr) + optExprWithInfo(optExpr, expr, castExpr) + + case StringInstr(str, substr) => + val leftCast = Cast(str, StringType) + val rightCast = Cast(substr, StringType) + val leftExpr = exprToProtoInternal(leftCast, inputs, binding) + val rightExpr = exprToProtoInternal(rightCast, inputs, binding) + val optExpr = scalarExprToProto("strpos", leftExpr, rightExpr) + optExprWithInfo(optExpr, expr, leftCast, rightCast) + + case StringRepeat(str, times) => + val leftCast = Cast(str, StringType) + val rightCast = Cast(times, LongType) + val leftExpr = exprToProtoInternal(leftCast, inputs, binding) + val rightExpr = exprToProtoInternal(rightCast, inputs, binding) + val optExpr = scalarExprToProto("repeat", leftExpr, rightExpr) + optExprWithInfo(optExpr, expr, leftCast, rightCast) + + case StringReplace(src, search, replace) => + val srcCast = Cast(src, StringType) + val searchCast = Cast(search, StringType) + val replaceCast = Cast(replace, StringType) + val srcExpr = exprToProtoInternal(srcCast, inputs, binding) + val searchExpr = exprToProtoInternal(searchCast, inputs, binding) + val replaceExpr = exprToProtoInternal(replaceCast, inputs, binding) + val optExpr = scalarExprToProto("replace", srcExpr, searchExpr, replaceExpr) + optExprWithInfo(optExpr, expr, srcCast, searchCast, replaceCast) + + case StringTranslate(src, matching, replace) => + val srcCast = Cast(src, StringType) + val matchingCast = Cast(matching, StringType) + val replaceCast = Cast(replace, StringType) + val srcExpr = exprToProtoInternal(srcCast, inputs, binding) + val matchingExpr = exprToProtoInternal(matchingCast, inputs, binding) + val replaceExpr = exprToProtoInternal(replaceCast, inputs, binding) + val optExpr = scalarExprToProto("translate", srcExpr, matchingExpr, replaceExpr) + optExprWithInfo(optExpr, expr, srcCast, matchingCast, replaceCast) + + case StringTrim(srcStr, trimStr) => + trim(expr, srcStr, trimStr, inputs, binding, "trim") + + case StringTrimLeft(srcStr, trimStr) => + trim(expr, srcStr, trimStr, inputs, binding, "ltrim") + + case StringTrimRight(srcStr, trimStr) => + trim(expr, srcStr, trimStr, inputs, binding, "rtrim") + + case StringTrimBoth(srcStr, trimStr, _) => + trim(expr, srcStr, trimStr, inputs, binding, "btrim") + + case Upper(child) => + if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) { val castExpr = Cast(child, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs) - val optExpr = scalarExprToProto("octet_length", childExpr) + val childExpr = exprToProtoInternal(castExpr, inputs, binding) + val optExpr = scalarExprToProto("upper", childExpr) optExprWithInfo(optExpr, expr, castExpr) + } else { + withInfo( + expr, + "Comet is not compatible with Spark for case conversion in " + + s"locale-specific cases. Set ${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " + + "to enable it anyway.") + None + } - case Reverse(child) => + case Lower(child) => + if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) { val castExpr = Cast(child, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs) - val optExpr = scalarExprToProto("reverse", childExpr) + val childExpr = exprToProtoInternal(castExpr, inputs, binding) + val optExpr = scalarExprToProto("lower", childExpr) optExprWithInfo(optExpr, expr, castExpr) + } else { + withInfo( + expr, + "Comet is not compatible with Spark for case conversion in " + + s"locale-specific cases. Set ${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " + + "to enable it anyway.") + None + } - case StringInstr(str, substr) => - val leftCast = Cast(str, StringType) - val rightCast = Cast(substr, StringType) - val leftExpr = exprToProtoInternal(leftCast, inputs) - val rightExpr = exprToProtoInternal(rightCast, inputs) - val optExpr = scalarExprToProto("strpos", leftExpr, rightExpr) - optExprWithInfo(optExpr, expr, leftCast, rightCast) - - case StringRepeat(str, times) => - val leftCast = Cast(str, StringType) - val rightCast = Cast(times, LongType) - val leftExpr = exprToProtoInternal(leftCast, inputs) - val rightExpr = exprToProtoInternal(rightCast, inputs) - val optExpr = scalarExprToProto("repeat", leftExpr, rightExpr) - optExprWithInfo(optExpr, expr, leftCast, rightCast) - - case StringReplace(src, search, replace) => - val srcCast = Cast(src, StringType) - val searchCast = Cast(search, StringType) - val replaceCast = Cast(replace, StringType) - val srcExpr = exprToProtoInternal(srcCast, inputs) - val searchExpr = exprToProtoInternal(searchCast, inputs) - val replaceExpr = exprToProtoInternal(replaceCast, inputs) - val optExpr = scalarExprToProto("replace", srcExpr, searchExpr, replaceExpr) - optExprWithInfo(optExpr, expr, srcCast, searchCast, replaceCast) - - case StringTranslate(src, matching, replace) => - val srcCast = Cast(src, StringType) - val matchingCast = Cast(matching, StringType) - val replaceCast = Cast(replace, StringType) - val srcExpr = exprToProtoInternal(srcCast, inputs) - val matchingExpr = exprToProtoInternal(matchingCast, inputs) - val replaceExpr = exprToProtoInternal(replaceCast, inputs) - val optExpr = scalarExprToProto("translate", srcExpr, matchingExpr, replaceExpr) - optExprWithInfo(optExpr, expr, srcCast, matchingCast, replaceCast) - - case StringTrim(srcStr, trimStr) => - trim(expr, srcStr, trimStr, inputs, "trim") - - case StringTrimLeft(srcStr, trimStr) => - trim(expr, srcStr, trimStr, inputs, "ltrim") - - case StringTrimRight(srcStr, trimStr) => - trim(expr, srcStr, trimStr, inputs, "rtrim") - - case StringTrimBoth(srcStr, trimStr, _) => - trim(expr, srcStr, trimStr, inputs, "btrim") - - case Upper(child) => - if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) { - val castExpr = Cast(child, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs) - val optExpr = scalarExprToProto("upper", childExpr) - optExprWithInfo(optExpr, expr, castExpr) - } else { - withInfo( - expr, - "Comet is not compatible with Spark for case conversion in " + - s"locale-specific cases. Set ${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " + - "to enable it anyway.") - None - } - - case Lower(child) => - if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) { - val castExpr = Cast(child, StringType) - val childExpr = exprToProtoInternal(castExpr, inputs) - val optExpr = scalarExprToProto("lower", childExpr) - optExprWithInfo(optExpr, expr, castExpr) - } else { - withInfo( - expr, - "Comet is not compatible with Spark for case conversion in " + - s"locale-specific cases. Set ${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " + - "to enable it anyway.") - None - } - - case BitwiseAnd(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr)) - - case BitwiseNot(child) => - createUnaryExpr(child, inputs, (builder, unaryExpr) => builder.setBitwiseNot(unaryExpr)) - - case BitwiseOr(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setBitwiseOr(binaryExpr)) + case BitwiseAnd(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr)) + + case BitwiseNot(child) => + createUnaryExpr( + expr, + child, + inputs, + binding, + (builder, unaryExpr) => builder.setBitwiseNot(unaryExpr)) + + case BitwiseOr(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setBitwiseOr(binaryExpr)) + + case BitwiseXor(left, right) => + createBinaryExpr( + expr, + left, + right, + inputs, + binding, + (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr)) + + case ShiftRight(left, right) => + // DataFusion bitwise shift right expression requires + // same data type between left and right side + val rightExpression = if (left.dataType == LongType) { + Cast(right, LongType) + } else { + right + } + + createBinaryExpr( + expr, + left, + rightExpression, + inputs, + binding, + (builder, binaryExpr) => builder.setBitwiseShiftRight(binaryExpr)) + + case ShiftLeft(left, right) => + // DataFusion bitwise shift right expression requires + // same data type between left and right side + val rightExpression = if (left.dataType == LongType) { + Cast(right, LongType) + } else { + right + } + + createBinaryExpr( + expr, + left, + rightExpression, + inputs, + binding, + (builder, binaryExpr) => builder.setBitwiseShiftLeft(binaryExpr)) + case In(value, list) => + in(expr, value, list, inputs, binding, negate = false) + + case InSet(value, hset) => + val valueDataType = value.dataType + val list = hset.map { setVal => + Literal(setVal, valueDataType) + }.toSeq + // Change `InSet` to `In` expression + // We do Spark `InSet` optimization in native (DataFusion) side. + in(expr, value, list, inputs, binding, negate = false) + + case Not(In(value, list)) => + in(expr, value, list, inputs, binding, negate = true) + + case Not(child) => + createUnaryExpr( + expr, + child, + inputs, + binding, + (builder, unaryExpr) => builder.setNot(unaryExpr)) + + case UnaryMinus(child, failOnError) => + val childExpr = exprToProtoInternal(child, inputs, binding) + if (childExpr.isDefined) { + val builder = ExprOuterClass.UnaryMinus.newBuilder() + builder.setChild(childExpr.get) + builder.setFailOnError(failOnError) + Some( + ExprOuterClass.Expr + .newBuilder() + .setUnaryMinus(builder) + .build()) + } else { + withInfo(expr, child) + None + } - case BitwiseXor(left, right) => - createBinaryExpr( - left, - right, - inputs, - (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr)) + case a @ Coalesce(_) => + val exprChildren = a.children.map(exprToProtoInternal(_, inputs, binding)) + scalarExprToProto("coalesce", exprChildren: _*) + + // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for + // char types. + // See https://github.com/apache/spark/pull/38151 + case s: StaticInvoke + if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] && + s.dataType.isInstanceOf[StringType] && + s.functionName == "readSidePadding" && + s.arguments.size == 2 && + s.propagateNull && + !s.returnNullable && + s.isDeterministic => + val argsExpr = Seq( + exprToProtoInternal(Cast(s.arguments(0), StringType), inputs, binding), + exprToProtoInternal(s.arguments(1), inputs, binding)) + + if (argsExpr.forall(_.isDefined)) { + val builder = ExprOuterClass.ScalarFunc.newBuilder() + builder.setFunc("read_side_padding") + argsExpr.foreach(arg => builder.addArgs(arg.get)) + + Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) + } else { + withInfo(expr, s.arguments: _*) + None + } - case ShiftRight(left, right) => - // DataFusion bitwise shift right expression requires - // same data type between left and right side - val rightExpression = if (left.dataType == LongType) { - Cast(right, LongType) - } else { - right - } + case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) => + val dataType = serializeDataType(expr.dataType) + if (dataType.isEmpty) { + withInfo(expr, s"Unsupported datatype ${expr.dataType}") + return None + } + val ex = exprToProtoInternal(expr, inputs, binding) + ex.map { child => + val builder = ExprOuterClass.NormalizeNaNAndZero + .newBuilder() + .setChild(child) + .setDatatype(dataType.get) + ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build() + } - createBinaryExpr( - left, - rightExpression, - inputs, - (builder, binaryExpr) => builder.setBitwiseShiftRight(binaryExpr)) + case s @ execution.ScalarSubquery(_, _) if supportedDataType(s.dataType) => + val dataType = serializeDataType(s.dataType) + if (dataType.isEmpty) { + withInfo(s, s"Scalar subquery returns unsupported datatype ${s.dataType}") + return None + } - case ShiftLeft(left, right) => - // DataFusion bitwise shift right expression requires - // same data type between left and right side - val rightExpression = if (left.dataType == LongType) { - Cast(right, LongType) - } else { - right - } + val builder = ExprOuterClass.Subquery + .newBuilder() + .setId(s.exprId.id) + .setDatatype(dataType.get) + Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build()) + + case UnscaledValue(child) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = scalarExprToProtoWithReturnType("unscaled_value", LongType, childExpr) + optExprWithInfo(optExpr, expr, child) + + case MakeDecimal(child, precision, scale, true) => + val childExpr = exprToProtoInternal(child, inputs, binding) + val optExpr = scalarExprToProtoWithReturnType( + "make_decimal", + DecimalType(precision, scale), + childExpr) + optExprWithInfo(optExpr, expr, child) + + case b @ BloomFilterMightContain(_, _) => + val bloomFilter = b.left + val value = b.right + val bloomFilterExpr = exprToProtoInternal(bloomFilter, inputs, binding) + val valueExpr = exprToProtoInternal(value, inputs, binding) + if (bloomFilterExpr.isDefined && valueExpr.isDefined) { + val builder = ExprOuterClass.BloomFilterMightContain.newBuilder() + builder.setBloomFilter(bloomFilterExpr.get) + builder.setValue(valueExpr.get) + Some( + ExprOuterClass.Expr + .newBuilder() + .setBloomFilterMightContain(builder) + .build()) + } else { + withInfo(expr, bloomFilter, value) + None + } - createBinaryExpr( - left, - rightExpression, - inputs, - (builder, binaryExpr) => builder.setBitwiseShiftLeft(binaryExpr)) - case In(value, list) => - in(expr, value, list, inputs, false) - - case InSet(value, hset) => - val valueDataType = value.dataType - val list = hset.map { setVal => - Literal(setVal, valueDataType) - }.toSeq - // Change `InSet` to `In` expression - // We do Spark `InSet` optimization in native (DataFusion) side. - in(expr, value, list, inputs, false) - - case Not(In(value, list)) => - in(expr, value, list, inputs, true) - - case Not(child) => - createUnaryExpr(child, inputs, (builder, unaryExpr) => builder.setNot(unaryExpr)) - - case UnaryMinus(child, failOnError) => - val childExpr = exprToProtoInternal(child, inputs) - if (childExpr.isDefined) { - val builder = ExprOuterClass.UnaryMinus.newBuilder() - builder.setChild(childExpr.get) - builder.setFailOnError(failOnError) - Some( - ExprOuterClass.Expr - .newBuilder() - .setUnaryMinus(builder) - .build()) - } else { - withInfo(expr, child) - None - } + case Murmur3Hash(children, seed) => + val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType)) + if (firstUnSupportedInput.isDefined) { + withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}") + return None + } + val exprs = children.map(exprToProtoInternal(_, inputs, binding)) + val seedBuilder = ExprOuterClass.Literal + .newBuilder() + .setDatatype(serializeDataType(IntegerType).get) + .setIntVal(seed) + val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build()) + // the seed is put at the end of the arguments + scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*) + + case XxHash64(children, seed) => + val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType)) + if (firstUnSupportedInput.isDefined) { + withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}") + return None + } + val exprs = children.map(exprToProtoInternal(_, inputs, binding)) + val seedBuilder = ExprOuterClass.Literal + .newBuilder() + .setDatatype(serializeDataType(LongType).get) + .setLongVal(seed) + val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build()) + // the seed is put at the end of the arguments + scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ seedExpr: _*) + + case Sha2(left, numBits) => + if (!numBits.foldable) { + withInfo(expr, "non literal numBits is not supported") + return None + } + // it's possible for spark to dynamically compute the number of bits from input + // expression, however DataFusion does not support that yet. + val childExpr = exprToProtoInternal(left, inputs, binding) + val bits = numBits.eval().asInstanceOf[Int] + val algorithm = bits match { + case 224 => "sha224" + case 256 | 0 => "sha256" + case 384 => "sha384" + case 512 => "sha512" + case _ => + null + } + if (algorithm == null) { + exprToProtoInternal(Literal(null, StringType), inputs, binding) + } else { + scalarExprToProtoWithReturnType(algorithm, StringType, childExpr) + } - case a @ Coalesce(_) => - val exprChildren = a.children.map(exprToProtoInternal(_, inputs)) - scalarExprToProto("coalesce", exprChildren: _*) - - // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called to pad spaces for - // char types. - // See https://github.com/apache/spark/pull/38151 - case s: StaticInvoke - if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] && - s.dataType.isInstanceOf[StringType] && - s.functionName == "readSidePadding" && - s.arguments.size == 2 && - s.propagateNull && - !s.returnNullable && - s.isDeterministic => - val argsExpr = Seq( - exprToProtoInternal(Cast(s.arguments(0), StringType), inputs), - exprToProtoInternal(s.arguments(1), inputs)) - - if (argsExpr.forall(_.isDefined)) { - val builder = ExprOuterClass.ScalarFunc.newBuilder() - builder.setFunc("read_side_padding") - argsExpr.foreach(arg => builder.addArgs(arg.get)) - - Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build()) - } else { - withInfo(expr, s.arguments: _*) - None - } + case struct @ CreateNamedStruct(_) => + if (struct.names.length != struct.names.distinct.length) { + withInfo(expr, "CreateNamedStruct with duplicate field names are not supported") + return None + } - case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) => - val dataType = serializeDataType(expr.dataType) - if (dataType.isEmpty) { - withInfo(expr, s"Unsupported datatype ${expr.dataType}") - return None - } - val ex = exprToProtoInternal(expr, inputs) - ex.map { child => - val builder = ExprOuterClass.NormalizeNaNAndZero - .newBuilder() - .setChild(child) - .setDatatype(dataType.get) - ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build() - } + val valExprs = struct.valExprs.map(exprToProto(_, inputs, binding)) - case s @ execution.ScalarSubquery(_, _) if supportedDataType(s.dataType) => - val dataType = serializeDataType(s.dataType) - if (dataType.isEmpty) { - withInfo(s, s"Scalar subquery returns unsupported datatype ${s.dataType}") - return None - } + if (valExprs.forall(_.isDefined)) { + val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder() + structBuilder.addAllValues(valExprs.map(_.get).asJava) + structBuilder.addAllNames(struct.names.map(_.toString).asJava) - val builder = ExprOuterClass.Subquery - .newBuilder() - .setId(s.exprId.id) - .setDatatype(dataType.get) - Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build()) - - case UnscaledValue(child) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = scalarExprToProtoWithReturnType("unscaled_value", LongType, childExpr) - optExprWithInfo(optExpr, expr, child) - - case MakeDecimal(child, precision, scale, true) => - val childExpr = exprToProtoInternal(child, inputs) - val optExpr = scalarExprToProtoWithReturnType( - "make_decimal", - DecimalType(precision, scale), - childExpr) - optExprWithInfo(optExpr, expr, child) - - case b @ BloomFilterMightContain(_, _) => - val bloomFilter = b.left - val value = b.right - val bloomFilterExpr = exprToProtoInternal(bloomFilter, inputs) - val valueExpr = exprToProtoInternal(value, inputs) - if (bloomFilterExpr.isDefined && valueExpr.isDefined) { - val builder = ExprOuterClass.BloomFilterMightContain.newBuilder() - builder.setBloomFilter(bloomFilterExpr.get) - builder.setValue(valueExpr.get) - Some( - ExprOuterClass.Expr - .newBuilder() - .setBloomFilterMightContain(builder) - .build()) - } else { - withInfo(expr, bloomFilter, value) - None - } + Some( + ExprOuterClass.Expr + .newBuilder() + .setCreateNamedStruct(structBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for CreateNamedStruct", struct.valExprs: _*) + None + } - case Murmur3Hash(children, seed) => - val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType)) - if (firstUnSupportedInput.isDefined) { - withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}") - return None - } - val exprs = children.map(exprToProtoInternal(_, inputs)) - val seedBuilder = ExprOuterClass.Literal - .newBuilder() - .setDatatype(serializeDataType(IntegerType).get) - .setIntVal(seed) - val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build()) - // the seed is put at the end of the arguments - scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*) - - case XxHash64(children, seed) => - val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType)) - if (firstUnSupportedInput.isDefined) { - withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}") - return None - } - val exprs = children.map(exprToProtoInternal(_, inputs)) - val seedBuilder = ExprOuterClass.Literal + case GetStructField(child, ordinal, _) => + exprToProto(child, inputs, binding).map { childExpr => + val getStructFieldBuilder = ExprOuterClass.GetStructField .newBuilder() - .setDatatype(serializeDataType(LongType).get) - .setLongVal(seed) - val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build()) - // the seed is put at the end of the arguments - scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ seedExpr: _*) - - case Sha2(left, numBits) => - if (!numBits.foldable) { - withInfo(expr, "non literal numBits is not supported") - return None - } - // it's possible for spark to dynamically compute the number of bits from input - // expression, however DataFusion does not support that yet. - val childExpr = exprToProtoInternal(left, inputs) - val bits = numBits.eval().asInstanceOf[Int] - val algorithm = bits match { - case 224 => "sha224" - case 256 | 0 => "sha256" - case 384 => "sha384" - case 512 => "sha512" - case _ => - null - } - if (algorithm == null) { - exprToProtoInternal(Literal(null, StringType), inputs) - } else { - scalarExprToProtoWithReturnType(algorithm, StringType, childExpr) - } + .setChild(childExpr) + .setOrdinal(ordinal) - case struct @ CreateNamedStruct(_) => - if (struct.names.length != struct.names.distinct.length) { - withInfo(expr, "CreateNamedStruct with duplicate field names are not supported") - return None - } + ExprOuterClass.Expr + .newBuilder() + .setGetStructField(getStructFieldBuilder) + .build() + } - val valExprs = struct.valExprs.map(exprToProto(_, inputs, binding)) + case CreateArray(children, _) => + val childExprs = children.map(exprToProto(_, inputs, binding)) - if (valExprs.forall(_.isDefined)) { - val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder() - structBuilder.addAllValues(valExprs.map(_.get).asJava) - structBuilder.addAllNames(struct.names.map(_.toString).asJava) + if (childExprs.forall(_.isDefined)) { + scalarExprToProto("make_array", childExprs: _*) + } else { + withInfo(expr, "unsupported arguments for CreateArray", children: _*) + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setCreateNamedStruct(structBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for CreateNamedStruct", struct.valExprs: _*) - None - } + case GetArrayItem(child, ordinal, failOnError) => + val childExpr = exprToProto(child, inputs, binding) + val ordinalExpr = exprToProto(ordinal, inputs, binding) - case GetStructField(child, ordinal, _) => - exprToProto(child, inputs, binding).map { childExpr => - val getStructFieldBuilder = ExprOuterClass.GetStructField - .newBuilder() - .setChild(childExpr) - .setOrdinal(ordinal) + if (childExpr.isDefined && ordinalExpr.isDefined) { + val listExtractBuilder = ExprOuterClass.ListExtract + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinalExpr.get) + .setOneBased(false) + .setFailOnError(failOnError) + Some( ExprOuterClass.Expr .newBuilder() - .setGetStructField(getStructFieldBuilder) - .build() - } - - case CreateArray(children, _) => - val childExprs = children.map(exprToProto(_, inputs, binding)) - - if (childExprs.forall(_.isDefined)) { - scalarExprToProto("make_array", childExprs: _*) - } else { - withInfo(expr, "unsupported arguments for CreateArray", children: _*) - None - } + .setListExtract(listExtractBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) + None + } - case GetArrayItem(child, ordinal, failOnError) => - val childExpr = exprToProto(child, inputs, binding) - val ordinalExpr = exprToProto(ordinal, inputs, binding) + case expr if expr.prettyName == "array_insert" => + val srcExprProto = exprToProto(expr.children(0), inputs, binding) + val posExprProto = exprToProto(expr.children(1), inputs, binding) + val itemExprProto = exprToProto(expr.children(2), inputs, binding) + val legacyNegativeIndex = + SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean + if (srcExprProto.isDefined && posExprProto.isDefined && itemExprProto.isDefined) { + val arrayInsertBuilder = ExprOuterClass.ArrayInsert + .newBuilder() + .setSrcArrayExpr(srcExprProto.get) + .setPosExpr(posExprProto.get) + .setItemExpr(itemExprProto.get) + .setLegacyNegativeIndex(legacyNegativeIndex) - if (childExpr.isDefined && ordinalExpr.isDefined) { - val listExtractBuilder = ExprOuterClass.ListExtract + Some( + ExprOuterClass.Expr .newBuilder() - .setChild(childExpr.get) - .setOrdinal(ordinalExpr.get) - .setOneBased(false) - .setFailOnError(failOnError) - - Some( - ExprOuterClass.Expr - .newBuilder() - .setListExtract(listExtractBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for GetArrayItem", child, ordinal) - None - } + .setArrayInsert(arrayInsertBuilder) + .build()) + } else { + withInfo( + expr, + "unsupported arguments for ArrayInsert", + expr.children(0), + expr.children(1), + expr.children(2)) + None + } - case expr if expr.prettyName == "array_insert" => - val srcExprProto = exprToProto(expr.children(0), inputs, binding) - val posExprProto = exprToProto(expr.children(1), inputs, binding) - val itemExprProto = exprToProto(expr.children(2), inputs, binding) - val legacyNegativeIndex = - SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean - if (srcExprProto.isDefined && posExprProto.isDefined && itemExprProto.isDefined) { - val arrayInsertBuilder = ExprOuterClass.ArrayInsert - .newBuilder() - .setSrcArrayExpr(srcExprProto.get) - .setPosExpr(posExprProto.get) - .setItemExpr(itemExprProto.get) - .setLegacyNegativeIndex(legacyNegativeIndex) + case ElementAt(child, ordinal, defaultValue, failOnError) + if child.dataType.isInstanceOf[ArrayType] => + val childExpr = exprToProto(child, inputs, binding) + val ordinalExpr = exprToProto(ordinal, inputs, binding) + val defaultExpr = defaultValue.flatMap(exprToProto(_, inputs, binding)) - Some( - ExprOuterClass.Expr - .newBuilder() - .setArrayInsert(arrayInsertBuilder) - .build()) - } else { - withInfo( - expr, - "unsupported arguments for ArrayInsert", - expr.children(0), - expr.children(1), - expr.children(2)) - None - } + if (childExpr.isDefined && ordinalExpr.isDefined && + defaultExpr.isDefined == defaultValue.isDefined) { + val arrayExtractBuilder = ExprOuterClass.ListExtract + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinalExpr.get) + .setOneBased(true) + .setFailOnError(failOnError) - case ElementAt(child, ordinal, defaultValue, failOnError) - if child.dataType.isInstanceOf[ArrayType] => - val childExpr = exprToProto(child, inputs, binding) - val ordinalExpr = exprToProto(ordinal, inputs, binding) - val defaultExpr = defaultValue.flatMap(exprToProto(_, inputs, binding)) + defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_)) - if (childExpr.isDefined && ordinalExpr.isDefined && - defaultExpr.isDefined == defaultValue.isDefined) { - val arrayExtractBuilder = ExprOuterClass.ListExtract + Some( + ExprOuterClass.Expr .newBuilder() - .setChild(childExpr.get) - .setOrdinal(ordinalExpr.get) - .setOneBased(true) - .setFailOnError(failOnError) - - defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_)) + .setListExtract(arrayExtractBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for ElementAt", child, ordinal) + None + } - Some( - ExprOuterClass.Expr - .newBuilder() - .setListExtract(arrayExtractBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for ElementAt", child, ordinal) - None - } + case GetArrayStructFields(child, _, ordinal, _, _) => + val childExpr = exprToProto(child, inputs, binding) - case GetArrayStructFields(child, _, ordinal, _, _) => - val childExpr = exprToProto(child, inputs, binding) + if (childExpr.isDefined) { + val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields + .newBuilder() + .setChild(childExpr.get) + .setOrdinal(ordinal) - if (childExpr.isDefined) { - val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields + Some( + ExprOuterClass.Expr .newBuilder() - .setChild(childExpr.get) - .setOrdinal(ordinal) - - Some( - ExprOuterClass.Expr + .setGetArrayStructFields(arrayStructFieldsBuilder) + .build()) + } else { + withInfo(expr, "unsupported arguments for GetArrayStructFields", child) + None + } + case expr: ArrayRemove => CometArrayRemove.convert(expr, inputs, binding) + case expr if expr.prettyName == "array_contains" => + createBinaryExpr( + expr, + expr.children(0), + expr.children(1), + inputs, + binding, + (builder, binaryExpr) => builder.setArrayContains(binaryExpr)) + case _ if expr.prettyName == "array_append" => + createBinaryExpr( + expr, + expr.children(0), + expr.children(1), + inputs, + binding, + (builder, binaryExpr) => builder.setArrayAppend(binaryExpr)) + case _ if expr.prettyName == "array_intersect" => + createBinaryExpr( + expr, + expr.children(0), + expr.children(1), + inputs, + binding, + (builder, binaryExpr) => builder.setArrayIntersect(binaryExpr)) + case ArrayJoin(arrayExpr, delimiterExpr, nullReplacementExpr) => + val arrayExprProto = exprToProto(arrayExpr, inputs, binding) + val delimiterExprProto = exprToProto(delimiterExpr, inputs, binding) + + if (arrayExprProto.isDefined && delimiterExprProto.isDefined) { + val arrayJoinBuilder = nullReplacementExpr match { + case Some(nrExpr) => + val nullReplacementExprProto = exprToProto(nrExpr, inputs, binding) + ExprOuterClass.ArrayJoin .newBuilder() - .setGetArrayStructFields(arrayStructFieldsBuilder) - .build()) - } else { - withInfo(expr, "unsupported arguments for GetArrayStructFields", child) - None - } - case expr: ArrayRemove => - if (CometArrayRemove.checkSupport(expr)) { - createBinaryExpr( - expr.children(0), - expr.children(1), - inputs, - (builder, binaryExpr) => builder.setArrayRemove(binaryExpr)) - } else { - None - } - case expr if expr.prettyName == "array_contains" => - createBinaryExpr( - expr.children(0), - expr.children(1), - inputs, - (builder, binaryExpr) => builder.setArrayContains(binaryExpr)) - case _ if expr.prettyName == "array_append" => - createBinaryExpr( - expr.children(0), - expr.children(1), - inputs, - (builder, binaryExpr) => builder.setArrayAppend(binaryExpr)) - case _ if expr.prettyName == "array_intersect" => - createBinaryExpr( - expr.children(0), - expr.children(1), - inputs, - (builder, binaryExpr) => builder.setArrayIntersect(binaryExpr)) - case ArrayJoin(arrayExpr, delimiterExpr, nullReplacementExpr) => - val arrayExprProto = exprToProto(arrayExpr, inputs, binding) - val delimiterExprProto = exprToProto(delimiterExpr, inputs, binding) - - if (arrayExprProto.isDefined && delimiterExprProto.isDefined) { - val arrayJoinBuilder = nullReplacementExpr match { - case Some(nrExpr) => - val nullReplacementExprProto = exprToProto(nrExpr, inputs, binding) - ExprOuterClass.ArrayJoin - .newBuilder() - .setArrayExpr(arrayExprProto.get) - .setDelimiterExpr(delimiterExprProto.get) - .setNullReplacementExpr(nullReplacementExprProto.get) - case None => - ExprOuterClass.ArrayJoin - .newBuilder() - .setArrayExpr(arrayExprProto.get) - .setDelimiterExpr(delimiterExprProto.get) - } - Some( - ExprOuterClass.Expr + .setArrayExpr(arrayExprProto.get) + .setDelimiterExpr(delimiterExprProto.get) + .setNullReplacementExpr(nullReplacementExprProto.get) + case None => + ExprOuterClass.ArrayJoin .newBuilder() - .setArrayJoin(arrayJoinBuilder) - .build()) - } else { - val exprs: List[Expression] = nullReplacementExpr match { - case Some(nrExpr) => List(arrayExpr, delimiterExpr, nrExpr) - case None => List(arrayExpr, delimiterExpr) - } - withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*) - None + .setArrayExpr(arrayExprProto.get) + .setDelimiterExpr(delimiterExprProto.get) } - case _ => - withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) - None - } - } - - /** - * Creates a UnaryExpr by calling exprToProtoInternal for the provided child expression and - * then invokes the supplied function to wrap this UnaryExpr in a top-level Expr. - * - * @param child - * Spark expression - * @param inputs - * Inputs to the expression - * @param f - * Function that accepts an Expr.Builder and a UnaryExpr and builds the specific top-level - * Expr - * @return - * Some(Expr) or None if not supported - */ - def createUnaryExpr( - child: Expression, - inputs: Seq[Attribute], - f: (ExprOuterClass.Expr.Builder, ExprOuterClass.UnaryExpr) => ExprOuterClass.Expr.Builder) - : Option[ExprOuterClass.Expr] = { - val childExpr = exprToProtoInternal(child, inputs) - if (childExpr.isDefined) { - // create the generic UnaryExpr message - val inner = ExprOuterClass.UnaryExpr - .newBuilder() - .setChild(childExpr.get) - .build() - // call the user-supplied function to wrap UnaryExpr in a top-level Expr - // such as Expr.IsNull or Expr.IsNotNull - Some( - f( + Some( ExprOuterClass.Expr - .newBuilder(), - inner).build()) - } else { - withInfo(expr, child) + .newBuilder() + .setArrayJoin(arrayJoinBuilder) + .build()) + } else { + val exprs: List[Expression] = nullReplacementExpr match { + case Some(nrExpr) => List(arrayExpr, delimiterExpr, nrExpr) + case None => List(arrayExpr, delimiterExpr) + } + withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*) + None + } + case _ => + withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*) None - } } + } - def createBinaryExpr( - left: Expression, - right: Expression, - inputs: Seq[Attribute], - f: ( - ExprOuterClass.Expr.Builder, - ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder) - : Option[ExprOuterClass.Expr] = { - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - if (leftExpr.isDefined && rightExpr.isDefined) { - // create the generic BinaryExpr message - val inner = ExprOuterClass.BinaryExpr - .newBuilder() - .setLeft(leftExpr.get) - .setRight(rightExpr.get) - .build() - // call the user-supplied function to wrap BinaryExpr in a top-level Expr - // such as Expr.And or Expr.Or - Some( - f( - ExprOuterClass.Expr - .newBuilder(), - inner).build()) - } else { - withInfo(expr, left, right) - None - } + /** + * Creates a UnaryExpr by calling exprToProtoInternal for the provided child expression and then + * invokes the supplied function to wrap this UnaryExpr in a top-level Expr. + * + * @param child + * Spark expression + * @param inputs + * Inputs to the expression + * @param f + * Function that accepts an Expr.Builder and a UnaryExpr and builds the specific top-level + * Expr + * @return + * Some(Expr) or None if not supported + */ + def createUnaryExpr( + expr: Expression, + child: Expression, + inputs: Seq[Attribute], + binding: Boolean, + f: (ExprOuterClass.Expr.Builder, ExprOuterClass.UnaryExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { + val childExpr = exprToProtoInternal(child, inputs, binding) // TODO review + if (childExpr.isDefined) { + // create the generic UnaryExpr message + val inner = ExprOuterClass.UnaryExpr + .newBuilder() + .setChild(childExpr.get) + .build() + // call the user-supplied function to wrap UnaryExpr in a top-level Expr + // such as Expr.IsNull or Expr.IsNotNull + Some( + f( + ExprOuterClass.Expr + .newBuilder(), + inner).build()) + } else { + withInfo(expr, child) + None } + } - def createMathExpression( - left: Expression, - right: Expression, - inputs: Seq[Attribute], - dataType: DataType, - failOnError: Boolean, - f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => ExprOuterClass.Expr.Builder) - : Option[ExprOuterClass.Expr] = { - val leftExpr = exprToProtoInternal(left, inputs) - val rightExpr = exprToProtoInternal(right, inputs) - - if (leftExpr.isDefined && rightExpr.isDefined) { - // create the generic MathExpr message - val builder = ExprOuterClass.MathExpr.newBuilder() - builder.setLeft(leftExpr.get) - builder.setRight(rightExpr.get) - builder.setFailOnError(failOnError) - serializeDataType(dataType).foreach { t => - builder.setReturnType(t) - } - val inner = builder.build() - // call the user-supplied function to wrap MathExpr in a top-level Expr - // such as Expr.Add or Expr.Divide - Some( - f( - ExprOuterClass.Expr - .newBuilder(), - inner).build()) - } else { - withInfo(expr, left, right) - None - } + def createBinaryExpr( + expr: Expression, + left: Expression, + right: Expression, + inputs: Seq[Attribute], + binding: Boolean, + f: (ExprOuterClass.Expr.Builder, ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) + if (leftExpr.isDefined && rightExpr.isDefined) { + // create the generic BinaryExpr message + val inner = ExprOuterClass.BinaryExpr + .newBuilder() + .setLeft(leftExpr.get) + .setRight(rightExpr.get) + .build() + // call the user-supplied function to wrap BinaryExpr in a top-level Expr + // such as Expr.And or Expr.Or + Some( + f( + ExprOuterClass.Expr + .newBuilder(), + inner).build()) + } else { + withInfo(expr, left, right) + None } + } - def trim( - expr: Expression, // parent expression - srcStr: Expression, - trimStr: Option[Expression], - inputs: Seq[Attribute], - trimType: String): Option[Expr] = { - val srcCast = Cast(srcStr, StringType) - val srcExpr = exprToProtoInternal(srcCast, inputs) - if (trimStr.isDefined) { - val trimCast = Cast(trimStr.get, StringType) - val trimExpr = exprToProtoInternal(trimCast, inputs) - val optExpr = scalarExprToProto(trimType, srcExpr, trimExpr) - optExprWithInfo(optExpr, expr, srcCast, trimCast) - } else { - val optExpr = scalarExprToProto(trimType, srcExpr) - optExprWithInfo(optExpr, expr, srcCast) + def createMathExpression( + expr: Expression, + left: Expression, + right: Expression, + inputs: Seq[Attribute], + binding: Boolean, + dataType: DataType, + failOnError: Boolean, + f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => ExprOuterClass.Expr.Builder) + : Option[ExprOuterClass.Expr] = { + val leftExpr = exprToProtoInternal(left, inputs, binding) + val rightExpr = exprToProtoInternal(right, inputs, binding) + + if (leftExpr.isDefined && rightExpr.isDefined) { + // create the generic MathExpr message + val builder = ExprOuterClass.MathExpr.newBuilder() + builder.setLeft(leftExpr.get) + builder.setRight(rightExpr.get) + builder.setFailOnError(failOnError) + serializeDataType(dataType).foreach { t => + builder.setReturnType(t) } + val inner = builder.build() + // call the user-supplied function to wrap MathExpr in a top-level Expr + // such as Expr.Add or Expr.Divide + Some( + f( + ExprOuterClass.Expr + .newBuilder(), + inner).build()) + } else { + withInfo(expr, left, right) + None } + } - def in( - expr: Expression, - value: Expression, - list: Seq[Expression], - inputs: Seq[Attribute], - negate: Boolean): Option[Expr] = { - val valueExpr = exprToProtoInternal(value, inputs) - val listExprs = list.map(exprToProtoInternal(_, inputs)) - if (valueExpr.isDefined && listExprs.forall(_.isDefined)) { - val builder = ExprOuterClass.In.newBuilder() - builder.setInValue(valueExpr.get) - builder.addAllLists(listExprs.map(_.get).asJava) - builder.setNegated(negate) - Some( - ExprOuterClass.Expr - .newBuilder() - .setIn(builder) - .build()) - } else { - val allExprs = list ++ Seq(value) - withInfo(expr, allExprs: _*) - None - } + def trim( + expr: Expression, // parent expression + srcStr: Expression, + trimStr: Option[Expression], + inputs: Seq[Attribute], + binding: Boolean, + trimType: String): Option[Expr] = { + val srcCast = Cast(srcStr, StringType) + val srcExpr = exprToProtoInternal(srcCast, inputs, binding) + if (trimStr.isDefined) { + val trimCast = Cast(trimStr.get, StringType) + val trimExpr = exprToProtoInternal(trimCast, inputs, binding) + val optExpr = scalarExprToProto(trimType, srcExpr, trimExpr) + optExprWithInfo(optExpr, expr, srcCast, trimCast) + } else { + val optExpr = scalarExprToProto(trimType, srcExpr) + optExprWithInfo(optExpr, expr, srcCast) } + } - val conf = SQLConf.get - val newExpr = - DecimalPrecision.promote(conf.decimalOperationsAllowPrecisionLoss, expr, !conf.ansiEnabled) - exprToProtoInternal(newExpr, input) + def in( + expr: Expression, + value: Expression, + list: Seq[Expression], + inputs: Seq[Attribute], + binding: Boolean, + negate: Boolean): Option[Expr] = { + val valueExpr = exprToProtoInternal(value, inputs, binding) + val listExprs = list.map(exprToProtoInternal(_, inputs, binding)) + if (valueExpr.isDefined && listExprs.forall(_.isDefined)) { + val builder = ExprOuterClass.In.newBuilder() + builder.setInValue(valueExpr.get) + builder.addAllLists(listExprs.map(_.get).asJava) + builder.setNegated(negate) + Some( + ExprOuterClass.Expr + .newBuilder() + .setIn(builder) + .build()) + } else { + val allExprs = list ++ Seq(value) + withInfo(expr, allExprs: _*) + None + } } def scalarExprToProtoWithReturnType( @@ -3364,3 +3448,29 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim nativeScanBuilder.addFilePartitions(partitionBuilder.build()) } } + +/** + * Trait for providing serialization logic for expressions. + */ +trait CometExpressionSerde { + + /** + * Convert a Spark expression into a protocol buffer representation that can be passed into + * native code. + * + * @param expr + * The Spark expression. + * @param inputs + * The input attributes. + * @param binding + * Whether the attributes are bound (this is only relevant in aggregate expressions). + * @return + * Protocol buffer representation, or None if the expression could not be converted. In this + * case it is expected that the input expression will have been tagged with reasons why it + * could not be converted. + */ + def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] +} diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala b/spark/src/main/scala/org/apache/comet/serde/arrays.scala index ecb4f5a13..9058a641e 100644 --- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala +++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala @@ -19,18 +19,16 @@ package org.apache.comet.serde -import org.apache.spark.sql.catalyst.expressions.{ArrayRemove, Expression} +import org.apache.spark.sql.catalyst.expressions.{ArrayRemove, Attribute, Expression} import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, DecimalType, StructType} import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.QueryPlanSerde.createBinaryExpr import org.apache.comet.shims.CometExprShim -trait CometExpression { - def checkSupport(expr: Expression): Boolean -} - -object CometArrayRemove extends CometExpression with CometExprShim { +object CometArrayRemove extends CometExpressionSerde with CometExprShim { + /** Exposed for unit testing */ def isTypeSupported(dt: DataType): Boolean = { import DataTypes._ dt match { @@ -46,15 +44,24 @@ object CometArrayRemove extends CometExpression with CometExprShim { } } - override def checkSupport(expr: Expression): Boolean = { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { val ar = expr.asInstanceOf[ArrayRemove] val inputTypes: Set[DataType] = ar.children.map(_.dataType).toSet for (dt <- inputTypes) { if (!isTypeSupported(dt)) { withInfo(expr, s"data type not supported: $dt") - return false + return None } } - true + createBinaryExpr( + expr, + expr.children(0), + expr.children(1), + inputs, + binding, + (builder, binaryExpr) => builder.setArrayRemove(binaryExpr)) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org For additional commands, e-mail: commits-h...@datafusion.apache.org