This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new a5b36b45d chore: Refactor QueryPlanSerde to allow logic to be moved to 
individual classes per expression (#1331)
a5b36b45d is described below

commit a5b36b45de934bf5f3cf768067944a69053da540
Author: Andy Grove <agr...@apache.org>
AuthorDate: Fri Jan 24 14:04:53 2025 -0700

    chore: Refactor QueryPlanSerde to allow logic to be moved to individual 
classes per expression (#1331)
    
    * Move castToProto to top-level method
    
    * move more methods to top-level
    
    * more refactoring
    
    * revert accidental rename
    
    * move CometExpressionSerde to QueryPlanSerde
    
    * add scaladoc
    
    * address input
---
 .../org/apache/comet/serde/QueryPlanSerde.scala    | 3102 ++++++++++----------
 .../main/scala/org/apache/comet/serde/arrays.scala |   25 +-
 2 files changed, 1622 insertions(+), 1505 deletions(-)

diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index 22bb7dc82..c3d7ac749 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -826,723 +826,790 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
   }
 
   /**
-   * Convert a Spark expression to protobuf.
-   *
-   * @param expr
-   *   The input expression
-   * @param inputs
-   *   The input attributes
-   * @param binding
-   *   Whether to bind the expression to the input attributes
-   * @return
-   *   The protobuf representation of the expression, or None if the 
expression is not supported
+   * Wrap an expression in a cast.
    */
-  def exprToProto(
+  def castToProto(
       expr: Expression,
-      input: Seq[Attribute],
-      binding: Boolean = true): Option[Expr] = {
-    def castToProto(
-        timeZoneId: Option[String],
-        dt: DataType,
-        childExpr: Option[Expr],
-        evalMode: CometEvalMode.Value): Option[Expr] = {
-      val dataType = serializeDataType(dt)
-
-      if (childExpr.isDefined && dataType.isDefined) {
+      timeZoneId: Option[String],
+      dt: DataType,
+      childExpr: Expr,
+      evalMode: CometEvalMode.Value): Option[Expr] = {
+    serializeDataType(dt) match {
+      case Some(dataType) =>
         val castBuilder = ExprOuterClass.Cast.newBuilder()
-        castBuilder.setChild(childExpr.get)
-        castBuilder.setDatatype(dataType.get)
+        castBuilder.setChild(childExpr)
+        castBuilder.setDatatype(dataType)
         castBuilder.setEvalMode(evalModeToProto(evalMode))
         
castBuilder.setAllowIncompat(CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get())
-        val timeZone = timeZoneId.getOrElse("UTC")
-        castBuilder.setTimezone(timeZone)
-
+        castBuilder.setTimezone(timeZoneId.getOrElse("UTC"))
         Some(
           ExprOuterClass.Expr
             .newBuilder()
             .setCast(castBuilder)
             .build())
-      } else {
-        if (!dataType.isDefined) {
-          withInfo(expr, s"Unsupported datatype ${dt}")
-        } else {
-          withInfo(expr, s"Unsupported expression $childExpr")
-        }
+      case _ =>
+        withInfo(expr, s"Unsupported datatype in castToProto: $dt")
         None
-      }
     }
+  }
 
-    def exprToProtoInternal(expr: Expression, inputs: Seq[Attribute]): 
Option[Expr] = {
-      SQLConf.get
-
-      def handleCast(
-          child: Expression,
-          inputs: Seq[Attribute],
-          dt: DataType,
-          timeZoneId: Option[String],
-          evalMode: CometEvalMode.Value): Option[Expr] = {
-
-        val childExpr = exprToProtoInternal(child, inputs)
-        if (childExpr.isDefined) {
-          val castSupport =
-            CometCast.isSupported(child.dataType, dt, timeZoneId, evalMode)
-
-          def getIncompatMessage(reason: Option[String]): String =
-            "Comet does not guarantee correct results for cast " +
-              s"from ${child.dataType} to $dt " +
-              s"with timezone $timeZoneId and evalMode $evalMode" +
-              reason.map(str => s" ($str)").getOrElse("")
-
-          castSupport match {
-            case Compatible(_) =>
-              castToProto(timeZoneId, dt, childExpr, evalMode)
-            case Incompatible(reason) =>
-              if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
-                logWarning(getIncompatMessage(reason))
-                castToProto(timeZoneId, dt, childExpr, evalMode)
-              } else {
-                withInfo(
-                  expr,
-                  s"${getIncompatMessage(reason)}. To enable all incompatible 
casts, set " +
-                    s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
-                None
-              }
-            case Unsupported =>
-              withInfo(
-                expr,
-                s"Unsupported cast from ${child.dataType} to $dt " +
-                  s"with timezone $timeZoneId and evalMode $evalMode")
-              None
+  def handleCast(
+      expr: Expression,
+      child: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      dt: DataType,
+      timeZoneId: Option[String],
+      evalMode: CometEvalMode.Value): Option[Expr] = {
+
+    val childExpr = exprToProtoInternal(child, inputs, binding)
+    if (childExpr.isDefined) {
+      val castSupport =
+        CometCast.isSupported(child.dataType, dt, timeZoneId, evalMode)
+
+      def getIncompatMessage(reason: Option[String]): String =
+        "Comet does not guarantee correct results for cast " +
+          s"from ${child.dataType} to $dt " +
+          s"with timezone $timeZoneId and evalMode $evalMode" +
+          reason.map(str => s" ($str)").getOrElse("")
+
+      castSupport match {
+        case Compatible(_) =>
+          castToProto(expr, timeZoneId, dt, childExpr.get, evalMode)
+        case Incompatible(reason) =>
+          if (CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.get()) {
+            logWarning(getIncompatMessage(reason))
+            castToProto(expr, timeZoneId, dt, childExpr.get, evalMode)
+          } else {
+            withInfo(
+              expr,
+              s"${getIncompatMessage(reason)}. To enable all incompatible 
casts, set " +
+                s"${CometConf.COMET_CAST_ALLOW_INCOMPATIBLE.key}=true")
+            None
           }
-        } else {
-          withInfo(expr, child)
+        case Unsupported =>
+          withInfo(
+            expr,
+            s"Unsupported cast from ${child.dataType} to $dt " +
+              s"with timezone $timeZoneId and evalMode $evalMode")
           None
-        }
       }
+    } else {
+      withInfo(expr, child)
+      None
+    }
+  }
 
-      expr match {
-        case a @ Alias(_, _) =>
-          val r = exprToProtoInternal(a.child, inputs)
-          if (r.isEmpty) {
-            withInfo(expr, a.child)
-          }
-          r
-
-        case cast @ Cast(_: Literal, dataType, _, _) =>
-          // This can happen after promoting decimal precisions
-          val value = cast.eval()
-          exprToProtoInternal(Literal(value, dataType), inputs)
-
-        case UnaryExpression(child) if expr.prettyName == "trycast" =>
-          val timeZoneId = SQLConf.get.sessionLocalTimeZone
-          handleCast(child, inputs, expr.dataType, Some(timeZoneId), 
CometEvalMode.TRY)
+  /**
+   * Convert a Spark expression to protobuf.
+   *
+   * @param expr
+   *   The input expression
+   * @param inputs
+   *   The input attributes
+   * @param binding
+   *   Whether to bind the expression to the input attributes
+   * @return
+   *   The protobuf representation of the expression, or None if the 
expression is not supported
+   */
+  def exprToProto(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean = true): Option[Expr] = {
 
-        case c @ Cast(child, dt, timeZoneId, _) =>
-          handleCast(child, inputs, dt, timeZoneId, evalMode(c))
+    val conf = SQLConf.get
+    val newExpr =
+      DecimalPrecision.promote(conf.decimalOperationsAllowPrecisionLoss, expr, 
!conf.ansiEnabled)
+    exprToProtoInternal(newExpr, inputs, binding)
+  }
 
-        case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
-          createMathExpression(
-            left,
-            right,
-            inputs,
-            add.dataType,
-            getFailOnError(add),
-            (builder, mathExpr) => builder.setAdd(mathExpr))
+  def exprToProtoInternal(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[Expr] = {
+    SQLConf.get
+
+    expr match {
+      case a @ Alias(_, _) =>
+        val r = exprToProtoInternal(a.child, inputs, binding)
+        if (r.isEmpty) {
+          withInfo(expr, a.child)
+        }
+        r
+
+      case cast @ Cast(_: Literal, dataType, _, _) =>
+        // This can happen after promoting decimal precisions
+        val value = cast.eval()
+        exprToProtoInternal(Literal(value, dataType), inputs, binding)
+
+      case UnaryExpression(child) if expr.prettyName == "trycast" =>
+        val timeZoneId = SQLConf.get.sessionLocalTimeZone
+        handleCast(
+          expr,
+          child,
+          inputs,
+          binding,
+          expr.dataType,
+          Some(timeZoneId),
+          CometEvalMode.TRY)
+
+      case c @ Cast(child, dt, timeZoneId, _) =>
+        handleCast(expr, child, inputs, binding, dt, timeZoneId, evalMode(c))
+
+      case add @ Add(left, right, _) if supportedDataType(left.dataType) =>
+        createMathExpression(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          add.dataType,
+          getFailOnError(add),
+          (builder, mathExpr) => builder.setAdd(mathExpr))
+
+      case add @ Add(left, _, _) if !supportedDataType(left.dataType) =>
+        withInfo(add, s"Unsupported datatype ${left.dataType}")
+        None
 
-        case add @ Add(left, _, _) if !supportedDataType(left.dataType) =>
-          withInfo(add, s"Unsupported datatype ${left.dataType}")
-          None
+      case sub @ Subtract(left, right, _) if supportedDataType(left.dataType) 
=>
+        createMathExpression(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          sub.dataType,
+          getFailOnError(sub),
+          (builder, mathExpr) => builder.setSubtract(mathExpr))
+
+      case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) =>
+        withInfo(sub, s"Unsupported datatype ${left.dataType}")
+        None
 
-        case sub @ Subtract(left, right, _) if 
supportedDataType(left.dataType) =>
-          createMathExpression(
-            left,
-            right,
-            inputs,
-            sub.dataType,
-            getFailOnError(sub),
-            (builder, mathExpr) => builder.setSubtract(mathExpr))
+      case mul @ Multiply(left, right, _)
+          if supportedDataType(left.dataType) && 
!decimalBeforeSpark34(left.dataType) =>
+        createMathExpression(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          mul.dataType,
+          getFailOnError(mul),
+          (builder, mathExpr) => builder.setMultiply(mathExpr))
+
+      case mul @ Multiply(left, _, _) =>
+        if (!supportedDataType(left.dataType)) {
+          withInfo(mul, s"Unsupported datatype ${left.dataType}")
+        }
+        if (decimalBeforeSpark34(left.dataType)) {
+          withInfo(mul, "Decimal support requires Spark 3.4 or later")
+        }
+        None
 
-        case sub @ Subtract(left, _, _) if !supportedDataType(left.dataType) =>
-          withInfo(sub, s"Unsupported datatype ${left.dataType}")
-          None
+      case div @ Divide(left, right, _)
+          if supportedDataType(left.dataType) && 
!decimalBeforeSpark34(left.dataType) =>
+        // Datafusion now throws an exception for dividing by zero
+        // See https://github.com/apache/arrow-datafusion/pull/6792
+        // For now, use NullIf to swap zeros with nulls.
+        val rightExpr = nullIfWhenPrimitive(right)
+
+        createMathExpression(
+          expr,
+          left,
+          rightExpr,
+          inputs,
+          binding,
+          div.dataType,
+          getFailOnError(div),
+          (builder, mathExpr) => builder.setDivide(mathExpr))
+
+      case div @ Divide(left, _, _) =>
+        if (!supportedDataType(left.dataType)) {
+          withInfo(div, s"Unsupported datatype ${left.dataType}")
+        }
+        if (decimalBeforeSpark34(left.dataType)) {
+          withInfo(div, "Decimal support requires Spark 3.4 or later")
+        }
+        None
 
-        case mul @ Multiply(left, right, _)
-            if supportedDataType(left.dataType) && 
!decimalBeforeSpark34(left.dataType) =>
-          createMathExpression(
-            left,
-            right,
-            inputs,
-            mul.dataType,
-            getFailOnError(mul),
-            (builder, mathExpr) => builder.setMultiply(mathExpr))
+      case rem @ Remainder(left, right, _)
+          if supportedDataType(left.dataType) && 
!decimalBeforeSpark34(left.dataType) =>
+        val rightExpr = nullIfWhenPrimitive(right)
 
-        case mul @ Multiply(left, _, _) =>
-          if (!supportedDataType(left.dataType)) {
-            withInfo(mul, s"Unsupported datatype ${left.dataType}")
-          }
-          if (decimalBeforeSpark34(left.dataType)) {
-            withInfo(mul, "Decimal support requires Spark 3.4 or later")
-          }
-          None
+        createMathExpression(
+          expr,
+          left,
+          rightExpr,
+          inputs,
+          binding,
+          rem.dataType,
+          getFailOnError(rem),
+          (builder, mathExpr) => builder.setRemainder(mathExpr))
 
-        case div @ Divide(left, right, _)
-            if supportedDataType(left.dataType) && 
!decimalBeforeSpark34(left.dataType) =>
-          // Datafusion now throws an exception for dividing by zero
-          // See https://github.com/apache/arrow-datafusion/pull/6792
-          // For now, use NullIf to swap zeros with nulls.
-          val rightExpr = nullIfWhenPrimitive(right)
+      case rem @ Remainder(left, _, _) =>
+        if (!supportedDataType(left.dataType)) {
+          withInfo(rem, s"Unsupported datatype ${left.dataType}")
+        }
+        if (decimalBeforeSpark34(left.dataType)) {
+          withInfo(rem, "Decimal support requires Spark 3.4 or later")
+        }
+        None
 
-          createMathExpression(
-            left,
-            rightExpr,
-            inputs,
-            div.dataType,
-            getFailOnError(div),
-            (builder, mathExpr) => builder.setDivide(mathExpr))
+      case EqualTo(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setEq(binaryExpr))
+
+      case Not(EqualTo(left, right)) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setNeq(binaryExpr))
+
+      case EqualNullSafe(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setEqNullSafe(binaryExpr))
+
+      case Not(EqualNullSafe(left, right)) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr))
+
+      case GreaterThan(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setGt(binaryExpr))
+
+      case GreaterThanOrEqual(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setGtEq(binaryExpr))
+
+      case LessThan(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setLt(binaryExpr))
+
+      case LessThanOrEqual(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setLtEq(binaryExpr))
+
+      case Literal(value, dataType) if supportedDataType(dataType, allowStruct 
= value == null) =>
+        val exprBuilder = ExprOuterClass.Literal.newBuilder()
+
+        if (value == null) {
+          exprBuilder.setIsNull(true)
+        } else {
+          exprBuilder.setIsNull(false)
+          dataType match {
+            case _: BooleanType => 
exprBuilder.setBoolVal(value.asInstanceOf[Boolean])
+            case _: ByteType => 
exprBuilder.setByteVal(value.asInstanceOf[Byte])
+            case _: ShortType => 
exprBuilder.setShortVal(value.asInstanceOf[Short])
+            case _: IntegerType => 
exprBuilder.setIntVal(value.asInstanceOf[Int])
+            case _: LongType => 
exprBuilder.setLongVal(value.asInstanceOf[Long])
+            case _: FloatType => 
exprBuilder.setFloatVal(value.asInstanceOf[Float])
+            case _: DoubleType => 
exprBuilder.setDoubleVal(value.asInstanceOf[Double])
+            case _: StringType =>
+              exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString)
+            case _: TimestampType => 
exprBuilder.setLongVal(value.asInstanceOf[Long])
+            case _: DecimalType =>
+              // Pass decimal literal as bytes.
+              val unscaled = 
value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue
+              exprBuilder.setDecimalVal(
+                com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray))
+            case _: BinaryType =>
+              val byteStr =
+                
com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]])
+              exprBuilder.setBytesVal(byteStr)
+            case _: DateType => exprBuilder.setIntVal(value.asInstanceOf[Int])
+            case dt if isTimestampNTZType(dt) =>
+              exprBuilder.setLongVal(value.asInstanceOf[Long])
+            case dt =>
+              logWarning(s"Unexpected date type '$dt' for literal value 
'$value'")
+          }
+        }
+
+        val dt = serializeDataType(dataType)
+
+        if (dt.isDefined) {
+          exprBuilder.setDatatype(dt.get)
 
-        case div @ Divide(left, _, _) =>
-          if (!supportedDataType(left.dataType)) {
-            withInfo(div, s"Unsupported datatype ${left.dataType}")
-          }
-          if (decimalBeforeSpark34(left.dataType)) {
-            withInfo(div, "Decimal support requires Spark 3.4 or later")
-          }
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setLiteral(exprBuilder)
+              .build())
+        } else {
+          withInfo(expr, s"Unsupported datatype $dataType")
           None
+        }
+      case Literal(_, dataType) if !supportedDataType(dataType) =>
+        withInfo(expr, s"Unsupported datatype $dataType")
+        None
 
-        case rem @ Remainder(left, right, _)
-            if supportedDataType(left.dataType) && 
!decimalBeforeSpark34(left.dataType) =>
-          val rightExpr = nullIfWhenPrimitive(right)
+      case Substring(str, Literal(pos, _), Literal(len, _)) =>
+        val strExpr = exprToProtoInternal(str, inputs, binding)
 
-          createMathExpression(
-            left,
-            rightExpr,
-            inputs,
-            rem.dataType,
-            getFailOnError(rem),
-            (builder, mathExpr) => builder.setRemainder(mathExpr))
+        if (strExpr.isDefined) {
+          val builder = ExprOuterClass.Substring.newBuilder()
+          builder.setChild(strExpr.get)
+          builder.setStart(pos.asInstanceOf[Int])
+          builder.setLen(len.asInstanceOf[Int])
 
-        case rem @ Remainder(left, _, _) =>
-          if (!supportedDataType(left.dataType)) {
-            withInfo(rem, s"Unsupported datatype ${left.dataType}")
-          }
-          if (decimalBeforeSpark34(left.dataType)) {
-            withInfo(rem, "Decimal support requires Spark 3.4 or later")
-          }
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setSubstring(builder)
+              .build())
+        } else {
+          withInfo(expr, str)
           None
+        }
 
-        case EqualTo(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setEq(binaryExpr))
-
-        case Not(EqualTo(left, right)) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setNeq(binaryExpr))
-
-        case EqualNullSafe(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setEqNullSafe(binaryExpr))
-
-        case Not(EqualNullSafe(left, right)) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setNeqNullSafe(binaryExpr))
-
-        case GreaterThan(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setGt(binaryExpr))
-
-        case GreaterThanOrEqual(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setGtEq(binaryExpr))
-
-        case LessThan(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setLt(binaryExpr))
-
-        case LessThanOrEqual(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setLtEq(binaryExpr))
-
-        case Literal(value, dataType)
-            if supportedDataType(dataType, allowStruct = value == null) =>
-          val exprBuilder = ExprOuterClass.Literal.newBuilder()
-
-          if (value == null) {
-            exprBuilder.setIsNull(true)
-          } else {
-            exprBuilder.setIsNull(false)
-            dataType match {
-              case _: BooleanType => 
exprBuilder.setBoolVal(value.asInstanceOf[Boolean])
-              case _: ByteType => 
exprBuilder.setByteVal(value.asInstanceOf[Byte])
-              case _: ShortType => 
exprBuilder.setShortVal(value.asInstanceOf[Short])
-              case _: IntegerType => 
exprBuilder.setIntVal(value.asInstanceOf[Int])
-              case _: LongType => 
exprBuilder.setLongVal(value.asInstanceOf[Long])
-              case _: FloatType => 
exprBuilder.setFloatVal(value.asInstanceOf[Float])
-              case _: DoubleType => 
exprBuilder.setDoubleVal(value.asInstanceOf[Double])
-              case _: StringType =>
-                
exprBuilder.setStringVal(value.asInstanceOf[UTF8String].toString)
-              case _: TimestampType => 
exprBuilder.setLongVal(value.asInstanceOf[Long])
-              case _: DecimalType =>
-                // Pass decimal literal as bytes.
-                val unscaled = 
value.asInstanceOf[Decimal].toBigDecimal.underlying.unscaledValue
-                exprBuilder.setDecimalVal(
-                  
com.google.protobuf.ByteString.copyFrom(unscaled.toByteArray))
-              case _: BinaryType =>
-                val byteStr =
-                  
com.google.protobuf.ByteString.copyFrom(value.asInstanceOf[Array[Byte]])
-                exprBuilder.setBytesVal(byteStr)
-              case _: DateType => 
exprBuilder.setIntVal(value.asInstanceOf[Int])
-              case dt if isTimestampNTZType(dt) =>
-                exprBuilder.setLongVal(value.asInstanceOf[Long])
-              case dt =>
-                logWarning(s"Unexpected date type '$dt' for literal value 
'$value'")
-            }
-          }
-
-          val dt = serializeDataType(dataType)
-
-          if (dt.isDefined) {
-            exprBuilder.setDatatype(dt.get)
-
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setLiteral(exprBuilder)
-                .build())
-          } else {
-            withInfo(expr, s"Unsupported datatype $dataType")
-            None
-          }
-        case Literal(_, dataType) if !supportedDataType(dataType) =>
-          withInfo(expr, s"Unsupported datatype $dataType")
+      case StructsToJson(options, child, timezoneId) =>
+        if (options.nonEmpty) {
+          withInfo(expr, "StructsToJson with options is not supported")
           None
+        } else {
 
-        case Substring(str, Literal(pos, _), Literal(len, _)) =>
-          val strExpr = exprToProtoInternal(str, inputs)
-
-          if (strExpr.isDefined) {
-            val builder = ExprOuterClass.Substring.newBuilder()
-            builder.setChild(strExpr.get)
-            builder.setStart(pos.asInstanceOf[Int])
-            builder.setLen(len.asInstanceOf[Int])
-
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setSubstring(builder)
-                .build())
-          } else {
-            withInfo(expr, str)
-            None
-          }
-
-        case StructsToJson(options, child, timezoneId) =>
-          if (options.nonEmpty) {
-            withInfo(expr, "StructsToJson with options is not supported")
-            None
-          } else {
-
-            def isSupportedType(dt: DataType): Boolean = {
-              dt match {
-                case StructType(fields) =>
-                  fields.forall(f => isSupportedType(f.dataType))
-                case DataTypes.BooleanType | DataTypes.ByteType | 
DataTypes.ShortType |
-                    DataTypes.IntegerType | DataTypes.LongType | 
DataTypes.FloatType |
-                    DataTypes.DoubleType | DataTypes.StringType =>
-                  true
-                case DataTypes.DateType | DataTypes.TimestampType =>
-                  // TODO implement these types with tests for formatting 
options and timezone
-                  false
-                case _: MapType | _: ArrayType =>
-                  // Spark supports map and array in StructsToJson but this is 
not yet
-                  // implemented in Comet
-                  false
-                case _ => false
-              }
-            }
-
-            val isSupported = child.dataType match {
-              case s: StructType =>
-                s.fields.forall(f => isSupportedType(f.dataType))
+          def isSupportedType(dt: DataType): Boolean = {
+            dt match {
+              case StructType(fields) =>
+                fields.forall(f => isSupportedType(f.dataType))
+              case DataTypes.BooleanType | DataTypes.ByteType | 
DataTypes.ShortType |
+                  DataTypes.IntegerType | DataTypes.LongType | 
DataTypes.FloatType |
+                  DataTypes.DoubleType | DataTypes.StringType =>
+                true
+              case DataTypes.DateType | DataTypes.TimestampType =>
+                // TODO implement these types with tests for formatting 
options and timezone
+                false
               case _: MapType | _: ArrayType =>
                 // Spark supports map and array in StructsToJson but this is 
not yet
                 // implemented in Comet
                 false
-              case _ =>
-                false
+              case _ => false
             }
+          }
 
-            if (isSupported) {
-              exprToProto(child, input, binding) match {
-                case Some(p) =>
-                  val toJson = ExprOuterClass.ToJson
-                    .newBuilder()
-                    .setChild(p)
-                    .setTimezone(timezoneId.getOrElse("UTC"))
-                    .setIgnoreNullFields(true)
-                    .build()
-                  Some(
-                    ExprOuterClass.Expr
-                      .newBuilder()
-                      .setToJson(toJson)
-                      .build())
-                case _ =>
-                  withInfo(expr, child)
-                  None
-              }
-            } else {
-              withInfo(expr, "Unsupported data type", child)
-              None
-            }
+          val isSupported = child.dataType match {
+            case s: StructType =>
+              s.fields.forall(f => isSupportedType(f.dataType))
+            case _: MapType | _: ArrayType =>
+              // Spark supports map and array in StructsToJson but this is not 
yet
+              // implemented in Comet
+              false
+            case _ =>
+              false
           }
 
-        case Like(left, right, escapeChar) =>
-          if (escapeChar == '\\') {
-            createBinaryExpr(
-              left,
-              right,
-              inputs,
-              (builder, binaryExpr) => builder.setLike(binaryExpr))
+          if (isSupported) {
+            exprToProto(child, inputs, binding) match {
+              case Some(p) =>
+                val toJson = ExprOuterClass.ToJson
+                  .newBuilder()
+                  .setChild(p)
+                  .setTimezone(timezoneId.getOrElse("UTC"))
+                  .setIgnoreNullFields(true)
+                  .build()
+                Some(
+                  ExprOuterClass.Expr
+                    .newBuilder()
+                    .setToJson(toJson)
+                    .build())
+              case _ =>
+                withInfo(expr, child)
+                None
+            }
           } else {
-            // TODO custom escape char
-            withInfo(expr, s"custom escape character $escapeChar not supported 
in LIKE")
+            withInfo(expr, "Unsupported data type", child)
             None
           }
+        }
 
-        case RLike(left, right) =>
-          // we currently only support scalar regex patterns
-          right match {
-            case Literal(pattern, DataTypes.StringType) =>
-              if (!RegExp.isSupportedPattern(pattern.toString) &&
-                !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) {
-                withInfo(
-                  expr,
-                  s"Regexp pattern $pattern is not compatible with Spark. " +
-                    s"Set 
${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true " +
-                    "to allow it anyway.")
-                return None
-              }
-            case _ =>
-              withInfo(expr, "Only scalar regexp patterns are supported")
-              return None
-          }
-
+      case Like(left, right, escapeChar) =>
+        if (escapeChar == '\\') {
           createBinaryExpr(
+            expr,
             left,
             right,
             inputs,
-            (builder, binaryExpr) => builder.setRlike(binaryExpr))
+            binding,
+            (builder, binaryExpr) => builder.setLike(binaryExpr))
+        } else {
+          // TODO custom escape char
+          withInfo(expr, s"custom escape character $escapeChar not supported 
in LIKE")
+          None
+        }
 
-        case StartsWith(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setStartsWith(binaryExpr))
+      case RLike(left, right) =>
+        // we currently only support scalar regex patterns
+        right match {
+          case Literal(pattern, DataTypes.StringType) =>
+            if (!RegExp.isSupportedPattern(pattern.toString) &&
+              !CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.get()) {
+              withInfo(
+                expr,
+                s"Regexp pattern $pattern is not compatible with Spark. " +
+                  s"Set ${CometConf.COMET_REGEXP_ALLOW_INCOMPATIBLE.key}=true 
" +
+                  "to allow it anyway.")
+              return None
+            }
+          case _ =>
+            withInfo(expr, "Only scalar regexp patterns are supported")
+            return None
+        }
 
-        case EndsWith(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setEndsWith(binaryExpr))
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setRlike(binaryExpr))
+
+      case StartsWith(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setStartsWith(binaryExpr))
+
+      case EndsWith(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setEndsWith(binaryExpr))
+
+      case Contains(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setContains(binaryExpr))
+
+      case StringSpace(child) =>
+        createUnaryExpr(
+          expr,
+          child,
+          inputs,
+          binding,
+          (builder, unaryExpr) => builder.setStringSpace(unaryExpr))
+
+      case Hour(child, timeZoneId) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
 
-        case Contains(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setContains(binaryExpr))
+        if (childExpr.isDefined) {
+          val builder = ExprOuterClass.Hour.newBuilder()
+          builder.setChild(childExpr.get)
 
-        case StringSpace(child) =>
-          createUnaryExpr(
-            child,
-            inputs,
-            (builder, unaryExpr) => builder.setStringSpace(unaryExpr))
+          val timeZone = timeZoneId.getOrElse("UTC")
+          builder.setTimezone(timeZone)
 
-        case Hour(child, timeZoneId) =>
-          val childExpr = exprToProtoInternal(child, inputs)
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setHour(builder)
+              .build())
+        } else {
+          withInfo(expr, child)
+          None
+        }
 
-          if (childExpr.isDefined) {
-            val builder = ExprOuterClass.Hour.newBuilder()
-            builder.setChild(childExpr.get)
+      case Minute(child, timeZoneId) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
 
-            val timeZone = timeZoneId.getOrElse("UTC")
-            builder.setTimezone(timeZone)
+        if (childExpr.isDefined) {
+          val builder = ExprOuterClass.Minute.newBuilder()
+          builder.setChild(childExpr.get)
 
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setHour(builder)
-                .build())
-          } else {
-            withInfo(expr, child)
-            None
-          }
+          val timeZone = timeZoneId.getOrElse("UTC")
+          builder.setTimezone(timeZone)
 
-        case Minute(child, timeZoneId) =>
-          val childExpr = exprToProtoInternal(child, inputs)
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setMinute(builder)
+              .build())
+        } else {
+          withInfo(expr, child)
+          None
+        }
 
-          if (childExpr.isDefined) {
-            val builder = ExprOuterClass.Minute.newBuilder()
-            builder.setChild(childExpr.get)
+      case DateAdd(left, right) =>
+        val leftExpr = exprToProtoInternal(left, inputs, binding)
+        val rightExpr = exprToProtoInternal(right, inputs, binding)
+        val optExpr = scalarExprToProtoWithReturnType("date_add", DateType, 
leftExpr, rightExpr)
+        optExprWithInfo(optExpr, expr, left, right)
 
-            val timeZone = timeZoneId.getOrElse("UTC")
-            builder.setTimezone(timeZone)
+      case DateSub(left, right) =>
+        val leftExpr = exprToProtoInternal(left, inputs, binding)
+        val rightExpr = exprToProtoInternal(right, inputs, binding)
+        val optExpr = scalarExprToProtoWithReturnType("date_sub", DateType, 
leftExpr, rightExpr)
+        optExprWithInfo(optExpr, expr, left, right)
 
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setMinute(builder)
-                .build())
-          } else {
-            withInfo(expr, child)
-            None
-          }
+      case TruncDate(child, format) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val formatExpr = exprToProtoInternal(format, inputs, binding)
 
-        case DateAdd(left, right) =>
-          val leftExpr = exprToProtoInternal(left, inputs)
-          val rightExpr = exprToProtoInternal(right, inputs)
-          val optExpr = scalarExprToProtoWithReturnType("date_add", DateType, 
leftExpr, rightExpr)
-          optExprWithInfo(optExpr, expr, left, right)
+        if (childExpr.isDefined && formatExpr.isDefined) {
+          val builder = ExprOuterClass.TruncDate.newBuilder()
+          builder.setChild(childExpr.get)
+          builder.setFormat(formatExpr.get)
 
-        case DateSub(left, right) =>
-          val leftExpr = exprToProtoInternal(left, inputs)
-          val rightExpr = exprToProtoInternal(right, inputs)
-          val optExpr = scalarExprToProtoWithReturnType("date_sub", DateType, 
leftExpr, rightExpr)
-          optExprWithInfo(optExpr, expr, left, right)
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setTruncDate(builder)
+              .build())
+        } else {
+          withInfo(expr, child, format)
+          None
+        }
 
-        case TruncDate(child, format) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val formatExpr = exprToProtoInternal(format, inputs)
+      case TruncTimestamp(format, child, timeZoneId) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val formatExpr = exprToProtoInternal(format, inputs, binding)
 
-          if (childExpr.isDefined && formatExpr.isDefined) {
-            val builder = ExprOuterClass.TruncDate.newBuilder()
-            builder.setChild(childExpr.get)
-            builder.setFormat(formatExpr.get)
+        if (childExpr.isDefined && formatExpr.isDefined) {
+          val builder = ExprOuterClass.TruncTimestamp.newBuilder()
+          builder.setChild(childExpr.get)
+          builder.setFormat(formatExpr.get)
 
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setTruncDate(builder)
-                .build())
-          } else {
-            withInfo(expr, child, format)
-            None
-          }
+          val timeZone = timeZoneId.getOrElse("UTC")
+          builder.setTimezone(timeZone)
 
-        case TruncTimestamp(format, child, timeZoneId) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val formatExpr = exprToProtoInternal(format, inputs)
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setTruncTimestamp(builder)
+              .build())
+        } else {
+          withInfo(expr, child, format)
+          None
+        }
 
-          if (childExpr.isDefined && formatExpr.isDefined) {
-            val builder = ExprOuterClass.TruncTimestamp.newBuilder()
-            builder.setChild(childExpr.get)
-            builder.setFormat(formatExpr.get)
+      case Second(child, timeZoneId) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
 
-            val timeZone = timeZoneId.getOrElse("UTC")
-            builder.setTimezone(timeZone)
+        if (childExpr.isDefined) {
+          val builder = ExprOuterClass.Second.newBuilder()
+          builder.setChild(childExpr.get)
 
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setTruncTimestamp(builder)
-                .build())
-          } else {
-            withInfo(expr, child, format)
-            None
-          }
+          val timeZone = timeZoneId.getOrElse("UTC")
+          builder.setTimezone(timeZone)
+
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setSecond(builder)
+              .build())
+        } else {
+          withInfo(expr, child)
+          None
+        }
 
-        case Second(child, timeZoneId) =>
-          val childExpr = exprToProtoInternal(child, inputs)
+      case Year(child) =>
+        val periodType = exprToProtoInternal(Literal("year"), inputs, binding)
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr = scalarExprToProto("datepart", Seq(periodType, 
childExpr): _*)
+          .map(e => {
+            Expr
+              .newBuilder()
+              .setCast(
+                ExprOuterClass.Cast
+                  .newBuilder()
+                  .setChild(e)
+                  .setDatatype(serializeDataType(IntegerType).get)
+                  .setEvalMode(ExprOuterClass.EvalMode.LEGACY)
+                  .setAllowIncompat(false)
+                  .build())
+              .build()
+          })
+        optExprWithInfo(optExpr, expr, child)
+
+      case IsNull(child) =>
+        createUnaryExpr(
+          expr,
+          child,
+          inputs,
+          binding,
+          (builder, unaryExpr) => builder.setIsNull(unaryExpr))
+
+      case IsNotNull(child) =>
+        createUnaryExpr(
+          expr,
+          child,
+          inputs,
+          binding,
+          (builder, unaryExpr) => builder.setIsNotNull(unaryExpr))
+
+      case IsNaN(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr =
+          scalarExprToProtoWithReturnType("isnan", BooleanType, childExpr)
+
+        optExprWithInfo(optExpr, expr, child)
+
+      case SortOrder(child, direction, nullOrdering, _) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
 
-          if (childExpr.isDefined) {
-            val builder = ExprOuterClass.Second.newBuilder()
-            builder.setChild(childExpr.get)
+        if (childExpr.isDefined) {
+          val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder()
+          sortOrderBuilder.setChild(childExpr.get)
 
-            val timeZone = timeZoneId.getOrElse("UTC")
-            builder.setTimezone(timeZone)
+          direction match {
+            case Ascending => sortOrderBuilder.setDirectionValue(0)
+            case Descending => sortOrderBuilder.setDirectionValue(1)
+          }
 
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setSecond(builder)
-                .build())
-          } else {
-            withInfo(expr, child)
-            None
+          nullOrdering match {
+            case NullsFirst => sortOrderBuilder.setNullOrderingValue(0)
+            case NullsLast => sortOrderBuilder.setNullOrderingValue(1)
           }
 
-        case Year(child) =>
-          val periodType = exprToProtoInternal(Literal("year"), inputs)
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("datepart", Seq(periodType, 
childExpr): _*)
-            .map(e => {
-              Expr
-                .newBuilder()
-                .setCast(
-                  ExprOuterClass.Cast
-                    .newBuilder()
-                    .setChild(e)
-                    .setDatatype(serializeDataType(IntegerType).get)
-                    .setEvalMode(ExprOuterClass.EvalMode.LEGACY)
-                    .setAllowIncompat(false)
-                    .build())
-                .build()
-            })
-          optExprWithInfo(optExpr, expr, child)
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setSortOrder(sortOrderBuilder)
+              .build())
+        } else {
+          withInfo(expr, child)
+          None
+        }
 
-        case IsNull(child) =>
-          createUnaryExpr(child, inputs, (builder, unaryExpr) => 
builder.setIsNull(unaryExpr))
+      case And(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setAnd(binaryExpr))
+
+      case Or(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setOr(binaryExpr))
+
+      case UnaryExpression(child) if expr.prettyName == "promote_precision" =>
+        // `UnaryExpression` includes `PromotePrecision` for Spark 3.3
+        // `PromotePrecision` is just a wrapper, don't need to serialize it.
+        exprToProtoInternal(child, inputs, binding)
+
+      case CheckOverflow(child, dt, nullOnOverflow) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
 
-        case IsNotNull(child) =>
-          createUnaryExpr(child, inputs, (builder, unaryExpr) => 
builder.setIsNotNull(unaryExpr))
+        if (childExpr.isDefined) {
+          val builder = ExprOuterClass.CheckOverflow.newBuilder()
+          builder.setChild(childExpr.get)
+          builder.setFailOnError(!nullOnOverflow)
 
-        case IsNaN(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr =
-            scalarExprToProtoWithReturnType("isnan", BooleanType, childExpr)
+          // `dataType` must be decimal type
+          val dataType = serializeDataType(dt)
+          builder.setDatatype(dataType.get)
 
-          optExprWithInfo(optExpr, expr, child)
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setCheckOverflow(builder)
+              .build())
+        } else {
+          withInfo(expr, child)
+          None
+        }
 
-        case SortOrder(child, direction, nullOrdering, _) =>
-          val childExpr = exprToProtoInternal(child, inputs)
+      case attr: AttributeReference =>
+        val dataType = serializeDataType(attr.dataType)
 
-          if (childExpr.isDefined) {
-            val sortOrderBuilder = ExprOuterClass.SortOrder.newBuilder()
-            sortOrderBuilder.setChild(childExpr.get)
+        if (dataType.isDefined) {
+          if (binding) {
+            // Spark may produce unresolvable attributes in some cases,
+            // for example 
https://github.com/apache/datafusion-comet/issues/925.
+            // So, we allow the binding to fail.
+            val boundRef: Any = BindReferences
+              .bindReference(attr, inputs, allowFailures = true)
 
-            direction match {
-              case Ascending => sortOrderBuilder.setDirectionValue(0)
-              case Descending => sortOrderBuilder.setDirectionValue(1)
+            if (boundRef.isInstanceOf[AttributeReference]) {
+              withInfo(attr, s"cannot resolve $attr among ${inputs.mkString(", 
")}")
+              return None
             }
 
-            nullOrdering match {
-              case NullsFirst => sortOrderBuilder.setNullOrderingValue(0)
-              case NullsLast => sortOrderBuilder.setNullOrderingValue(1)
-            }
+            val boundExpr = ExprOuterClass.BoundReference
+              .newBuilder()
+              .setIndex(boundRef.asInstanceOf[BoundReference].ordinal)
+              .setDatatype(dataType.get)
+              .build()
 
             Some(
               ExprOuterClass.Expr
                 .newBuilder()
-                .setSortOrder(sortOrderBuilder)
+                .setBound(boundExpr)
                 .build())
           } else {
-            withInfo(expr, child)
-            None
-          }
-
-        case And(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setAnd(binaryExpr))
-
-        case Or(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setOr(binaryExpr))
-
-        case UnaryExpression(child) if expr.prettyName == "promote_precision" 
=>
-          // `UnaryExpression` includes `PromotePrecision` for Spark 3.3
-          // `PromotePrecision` is just a wrapper, don't need to serialize it.
-          exprToProtoInternal(child, inputs)
-
-        case CheckOverflow(child, dt, nullOnOverflow) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-
-          if (childExpr.isDefined) {
-            val builder = ExprOuterClass.CheckOverflow.newBuilder()
-            builder.setChild(childExpr.get)
-            builder.setFailOnError(!nullOnOverflow)
-
-            // `dataType` must be decimal type
-            val dataType = serializeDataType(dt)
-            builder.setDatatype(dataType.get)
+            val unboundRef = ExprOuterClass.UnboundReference
+              .newBuilder()
+              .setName(attr.name)
+              .setDatatype(dataType.get)
+              .build()
 
             Some(
               ExprOuterClass.Expr
                 .newBuilder()
-                .setCheckOverflow(builder)
+                .setUnbound(unboundRef)
                 .build())
-          } else {
-            withInfo(expr, child)
-            None
-          }
-
-        case attr: AttributeReference =>
-          val dataType = serializeDataType(attr.dataType)
-
-          if (dataType.isDefined) {
-            if (binding) {
-              // Spark may produce unresolvable attributes in some cases,
-              // for example 
https://github.com/apache/datafusion-comet/issues/925.
-              // So, we allow the binding to fail.
-              val boundRef: Any = BindReferences
-                .bindReference(attr, inputs, allowFailures = true)
-
-              if (boundRef.isInstanceOf[AttributeReference]) {
-                withInfo(attr, s"cannot resolve $attr among 
${inputs.mkString(", ")}")
-                return None
-              }
-
-              val boundExpr = ExprOuterClass.BoundReference
-                .newBuilder()
-                .setIndex(boundRef.asInstanceOf[BoundReference].ordinal)
-                .setDatatype(dataType.get)
-                .build()
-
-              Some(
-                ExprOuterClass.Expr
-                  .newBuilder()
-                  .setBound(boundExpr)
-                  .build())
-            } else {
-              val unboundRef = ExprOuterClass.UnboundReference
-                .newBuilder()
-                .setName(attr.name)
-                .setDatatype(dataType.get)
-                .build()
-
-              Some(
-                ExprOuterClass.Expr
-                  .newBuilder()
-                  .setUnbound(unboundRef)
-                  .build())
-            }
-          } else {
-            withInfo(attr, s"unsupported datatype: ${attr.dataType}")
-            None
           }
+        } else {
+          withInfo(attr, s"unsupported datatype: ${attr.dataType}")
+          None
+        }
 
-        // abs implementation is not correct
-        // https://github.com/apache/datafusion-comet/issues/666
+      // abs implementation is not correct
+      // https://github.com/apache/datafusion-comet/issues/666
 //        case Abs(child, failOnErr) =>
 //          val childExpr = exprToProtoInternal(child, inputs)
 //          if (childExpr.isDefined) {
@@ -1557,950 +1624,967 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
 //            None
 //          }
 
-        case Acos(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("acos", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Asin(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("asin", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Atan(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("atan", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Atan2(left, right) =>
-          val leftExpr = exprToProtoInternal(left, inputs)
-          val rightExpr = exprToProtoInternal(right, inputs)
-          val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr)
-          optExprWithInfo(optExpr, expr, left, right)
-
-        case Hex(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr =
-            scalarExprToProtoWithReturnType("hex", StringType, childExpr)
-
-          optExprWithInfo(optExpr, expr, child)
-
-        case e: Unhex =>
-          val unHex = unhexSerde(e)
-
-          val childExpr = exprToProtoInternal(unHex._1, inputs)
-          val failOnErrorExpr = exprToProtoInternal(unHex._2, inputs)
-
-          val optExpr =
-            scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, 
failOnErrorExpr)
-          optExprWithInfo(optExpr, expr, unHex._1)
-
-        case e @ Ceil(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          child.dataType match {
-            case t: DecimalType if t.scale == 0 => // zero scale is no-op
-              childExpr
-            case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
-              withInfo(e, s"Decimal type $t has negative scale")
-              None
-            case _ =>
-              val optExpr = scalarExprToProtoWithReturnType("ceil", 
e.dataType, childExpr)
-              optExprWithInfo(optExpr, expr, child)
-          }
-
-        case Cos(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("cos", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Exp(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("exp", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case e @ Floor(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          child.dataType match {
-            case t: DecimalType if t.scale == 0 => // zero scale is no-op
-              childExpr
-            case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
-              withInfo(e, s"Decimal type $t has negative scale")
-              None
-            case _ =>
-              val optExpr = scalarExprToProtoWithReturnType("floor", 
e.dataType, childExpr)
-              optExprWithInfo(optExpr, expr, child)
-          }
-
-        // The expression for `log` functions is defined as null on numbers 
less than or equal
-        // to 0. This matches Spark and Hive behavior, where non positive 
values eval to null
-        // instead of NaN or -Infinity.
-        case Log(child) =>
-          val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
-          val optExpr = scalarExprToProto("ln", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Log10(child) =>
-          val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
-          val optExpr = scalarExprToProto("log10", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Log2(child) =>
-          val childExpr = exprToProtoInternal(nullIfNegative(child), inputs)
-          val optExpr = scalarExprToProto("log2", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Pow(left, right) =>
-          val leftExpr = exprToProtoInternal(left, inputs)
-          val rightExpr = exprToProtoInternal(right, inputs)
-          val optExpr = scalarExprToProto("pow", leftExpr, rightExpr)
-          optExprWithInfo(optExpr, expr, left, right)
-
-        case r: Round =>
-          // _scale s a constant, copied from Spark's RoundBase because it is 
a protected val
-          val scaleV: Any = r.scale.eval(EmptyRow)
-          val _scale: Int = scaleV.asInstanceOf[Int]
-
-          lazy val childExpr = exprToProtoInternal(r.child, inputs)
-          r.child.dataType match {
-            case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
-              withInfo(r, "Decimal type has negative scale")
-              None
-            case _ if scaleV == null =>
-              exprToProtoInternal(Literal(null), inputs)
-            case _: ByteType | ShortType | IntegerType | LongType if _scale >= 
0 =>
-              childExpr // _scale(I.e. decimal place) >= 0 is a no-op for 
integer types in Spark
-            case _: FloatType | DoubleType =>
-              // We cannot properly match with the Spark behavior for 
floating-point numbers.
-              // Spark uses BigDecimal for rounding float/double, and 
BigDecimal fist converts a
-              // double to string internally in order to create its own 
internal representation.
-              // The problem is BigDecimal uses java.lang.Double.toString() 
and it has complicated
-              // rounding algorithm. E.g. -5.81855622136895E8 is actually
-              // -581855622.13689494132995605468750. Note the 5th fractional 
digit is 4 instead of
-              // 5. Java(Scala)'s toString() rounds it up to 
-581855622.136895. This makes a
-              // difference when rounding at 5th digit, I.e. 
round(-5.81855622136895E8, 5) should be
-              // -5.818556221369E8, instead of -5.8185562213689E8. There is 
also an example that
-              // toString() does NOT round up. 6.1317116247283497E18 is 
6131711624728349696. It can
-              // be rounded up to 6.13171162472835E18 that still represents 
the same double number.
-              // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, 
toString() does not.
-              // That results in round(6.1317116247283497E18, -5) == 
6.1317116247282995E18 instead
-              // of 6.1317116247283999E18.
-              withInfo(r, "Comet does not support Spark's BigDecimal rounding")
-              None
-            case _ =>
-              // `scale` must be Int64 type in DataFusion
-              val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, 
LongType), inputs)
-              val optExpr =
-                scalarExprToProtoWithReturnType("round", r.dataType, 
childExpr, scaleExpr)
-              optExprWithInfo(optExpr, expr, r.child)
-          }
+      case Acos(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr = scalarExprToProto("acos", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Asin(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr = scalarExprToProto("asin", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Atan(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr = scalarExprToProto("atan", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Atan2(left, right) =>
+        val leftExpr = exprToProtoInternal(left, inputs, binding)
+        val rightExpr = exprToProtoInternal(right, inputs, binding)
+        val optExpr = scalarExprToProto("atan2", leftExpr, rightExpr)
+        optExprWithInfo(optExpr, expr, left, right)
+
+      case Hex(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr =
+          scalarExprToProtoWithReturnType("hex", StringType, childExpr)
+
+        optExprWithInfo(optExpr, expr, child)
+
+      case e: Unhex =>
+        val unHex = unhexSerde(e)
+
+        val childExpr = exprToProtoInternal(unHex._1, inputs, binding)
+        val failOnErrorExpr = exprToProtoInternal(unHex._2, inputs, binding)
+
+        val optExpr =
+          scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, 
failOnErrorExpr)
+        optExprWithInfo(optExpr, expr, unHex._1)
+
+      case e @ Ceil(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        child.dataType match {
+          case t: DecimalType if t.scale == 0 => // zero scale is no-op
+            childExpr
+          case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
+            withInfo(e, s"Decimal type $t has negative scale")
+            None
+          case _ =>
+            val optExpr = scalarExprToProtoWithReturnType("ceil", e.dataType, 
childExpr)
+            optExprWithInfo(optExpr, expr, child)
+        }
+
+      case Cos(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr = scalarExprToProto("cos", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Exp(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr = scalarExprToProto("exp", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case e @ Floor(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        child.dataType match {
+          case t: DecimalType if t.scale == 0 => // zero scale is no-op
+            childExpr
+          case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
+            withInfo(e, s"Decimal type $t has negative scale")
+            None
+          case _ =>
+            val optExpr = scalarExprToProtoWithReturnType("floor", e.dataType, 
childExpr)
+            optExprWithInfo(optExpr, expr, child)
+        }
+
+      // The expression for `log` functions is defined as null on numbers less 
than or equal
+      // to 0. This matches Spark and Hive behavior, where non positive values 
eval to null
+      // instead of NaN or -Infinity.
+      case Log(child) =>
+        val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, 
binding)
+        val optExpr = scalarExprToProto("ln", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Log10(child) =>
+        val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, 
binding)
+        val optExpr = scalarExprToProto("log10", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Log2(child) =>
+        val childExpr = exprToProtoInternal(nullIfNegative(child), inputs, 
binding)
+        val optExpr = scalarExprToProto("log2", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Pow(left, right) =>
+        val leftExpr = exprToProtoInternal(left, inputs, binding)
+        val rightExpr = exprToProtoInternal(right, inputs, binding)
+        val optExpr = scalarExprToProto("pow", leftExpr, rightExpr)
+        optExprWithInfo(optExpr, expr, left, right)
+
+      case r: Round =>
+        // _scale s a constant, copied from Spark's RoundBase because it is a 
protected val
+        val scaleV: Any = r.scale.eval(EmptyRow)
+        val _scale: Int = scaleV.asInstanceOf[Int]
+
+        lazy val childExpr = exprToProtoInternal(r.child, inputs, binding)
+        r.child.dataType match {
+          case t: DecimalType if t.scale < 0 => // Spark disallows negative 
scale SPARK-30252
+            withInfo(r, "Decimal type has negative scale")
+            None
+          case _ if scaleV == null =>
+            exprToProtoInternal(Literal(null), inputs, binding)
+          case _: ByteType | ShortType | IntegerType | LongType if _scale >= 0 
=>
+            childExpr // _scale(I.e. decimal place) >= 0 is a no-op for 
integer types in Spark
+          case _: FloatType | DoubleType =>
+            // We cannot properly match with the Spark behavior for 
floating-point numbers.
+            // Spark uses BigDecimal for rounding float/double, and BigDecimal 
fist converts a
+            // double to string internally in order to create its own internal 
representation.
+            // The problem is BigDecimal uses java.lang.Double.toString() and 
it has complicated
+            // rounding algorithm. E.g. -5.81855622136895E8 is actually
+            // -581855622.13689494132995605468750. Note the 5th fractional 
digit is 4 instead of
+            // 5. Java(Scala)'s toString() rounds it up to -581855622.136895. 
This makes a
+            // difference when rounding at 5th digit, I.e. 
round(-5.81855622136895E8, 5) should be
+            // -5.818556221369E8, instead of -5.8185562213689E8. There is also 
an example that
+            // toString() does NOT round up. 6.1317116247283497E18 is 
6131711624728349696. It can
+            // be rounded up to 6.13171162472835E18 that still represents the 
same double number.
+            // I.e. 6.13171162472835E18 == 6.1317116247283497E18. However, 
toString() does not.
+            // That results in round(6.1317116247283497E18, -5) == 
6.1317116247282995E18 instead
+            // of 6.1317116247283999E18.
+            withInfo(r, "Comet does not support Spark's BigDecimal rounding")
+            None
+          case _ =>
+            // `scale` must be Int64 type in DataFusion
+            val scaleExpr = exprToProtoInternal(Literal(_scale.toLong, 
LongType), inputs, binding)
+            val optExpr =
+              scalarExprToProtoWithReturnType("round", r.dataType, childExpr, 
scaleExpr)
+            optExprWithInfo(optExpr, expr, r.child)
+        }
 
-        // TODO enable once https://github.com/apache/datafusion/issues/11557 
is fixed or
-        // when we have a Spark-compatible version implemented in Comet
+      // TODO enable once https://github.com/apache/datafusion/issues/11557 is 
fixed or
+      // when we have a Spark-compatible version implemented in Comet
 //        case Signum(child) =>
 //          val childExpr = exprToProtoInternal(child, inputs)
 //          val optExpr = scalarExprToProto("signum", childExpr)
 //          optExprWithInfo(optExpr, expr, child)
 
-        case Sin(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("sin", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Sqrt(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("sqrt", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Tan(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("tan", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case Ascii(child) =>
-          val castExpr = Cast(child, StringType)
-          val childExpr = exprToProtoInternal(castExpr, inputs)
-          val optExpr = scalarExprToProto("ascii", childExpr)
-          optExprWithInfo(optExpr, expr, castExpr)
-
-        case BitLength(child) =>
-          val castExpr = Cast(child, StringType)
-          val childExpr = exprToProtoInternal(castExpr, inputs)
-          val optExpr = scalarExprToProto("bit_length", childExpr)
-          optExprWithInfo(optExpr, expr, castExpr)
-
-        case If(predicate, trueValue, falseValue) =>
-          val predicateExpr = exprToProtoInternal(predicate, inputs)
-          val trueExpr = exprToProtoInternal(trueValue, inputs)
-          val falseExpr = exprToProtoInternal(falseValue, inputs)
-          if (predicateExpr.isDefined && trueExpr.isDefined && 
falseExpr.isDefined) {
-            val builder = ExprOuterClass.IfExpr.newBuilder()
-            builder.setIfExpr(predicateExpr.get)
-            builder.setTrueExpr(trueExpr.get)
-            builder.setFalseExpr(falseExpr.get)
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setIf(builder)
-                .build())
-          } else {
-            withInfo(expr, predicate, trueValue, falseValue)
-            None
-          }
+      case Sin(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr = scalarExprToProto("sin", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Sqrt(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr = scalarExprToProto("sqrt", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Tan(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr = scalarExprToProto("tan", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case Ascii(child) =>
+        val castExpr = Cast(child, StringType)
+        val childExpr = exprToProtoInternal(castExpr, inputs, binding)
+        val optExpr = scalarExprToProto("ascii", childExpr)
+        optExprWithInfo(optExpr, expr, castExpr)
+
+      case BitLength(child) =>
+        val castExpr = Cast(child, StringType)
+        val childExpr = exprToProtoInternal(castExpr, inputs, binding)
+        val optExpr = scalarExprToProto("bit_length", childExpr)
+        optExprWithInfo(optExpr, expr, castExpr)
+
+      case If(predicate, trueValue, falseValue) =>
+        val predicateExpr = exprToProtoInternal(predicate, inputs, binding)
+        val trueExpr = exprToProtoInternal(trueValue, inputs, binding)
+        val falseExpr = exprToProtoInternal(falseValue, inputs, binding)
+        if (predicateExpr.isDefined && trueExpr.isDefined && 
falseExpr.isDefined) {
+          val builder = ExprOuterClass.IfExpr.newBuilder()
+          builder.setIfExpr(predicateExpr.get)
+          builder.setTrueExpr(trueExpr.get)
+          builder.setFalseExpr(falseExpr.get)
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setIf(builder)
+              .build())
+        } else {
+          withInfo(expr, predicate, trueValue, falseValue)
+          None
+        }
 
-        case CaseWhen(branches, elseValue) =>
-          var allBranches: Seq[Expression] = Seq()
-          val whenSeq = branches.map(elements => {
-            allBranches = allBranches :+ elements._1
-            exprToProtoInternal(elements._1, inputs)
-          })
-          val thenSeq = branches.map(elements => {
-            allBranches = allBranches :+ elements._2
-            exprToProtoInternal(elements._2, inputs)
-          })
-          assert(whenSeq.length == thenSeq.length)
-          if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) {
-            val builder = ExprOuterClass.CaseWhen.newBuilder()
-            builder.addAllWhen(whenSeq.map(_.get).asJava)
-            builder.addAllThen(thenSeq.map(_.get).asJava)
-            if (elseValue.isDefined) {
-              val elseValueExpr =
-                exprToProtoInternal(elseValue.get, inputs)
-              if (elseValueExpr.isDefined) {
-                builder.setElseExpr(elseValueExpr.get)
-              } else {
-                withInfo(expr, elseValue.get)
-                return None
-              }
+      case CaseWhen(branches, elseValue) =>
+        var allBranches: Seq[Expression] = Seq()
+        val whenSeq = branches.map(elements => {
+          allBranches = allBranches :+ elements._1
+          exprToProtoInternal(elements._1, inputs, binding)
+        })
+        val thenSeq = branches.map(elements => {
+          allBranches = allBranches :+ elements._2
+          exprToProtoInternal(elements._2, inputs, binding)
+        })
+        assert(whenSeq.length == thenSeq.length)
+        if (whenSeq.forall(_.isDefined) && thenSeq.forall(_.isDefined)) {
+          val builder = ExprOuterClass.CaseWhen.newBuilder()
+          builder.addAllWhen(whenSeq.map(_.get).asJava)
+          builder.addAllThen(thenSeq.map(_.get).asJava)
+          if (elseValue.isDefined) {
+            val elseValueExpr =
+              exprToProtoInternal(elseValue.get, inputs, binding)
+            if (elseValueExpr.isDefined) {
+              builder.setElseExpr(elseValueExpr.get)
+            } else {
+              withInfo(expr, elseValue.get)
+              return None
             }
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setCaseWhen(builder)
-                .build())
-          } else {
-            withInfo(expr, allBranches: _*)
-            None
-          }
-        case ConcatWs(children) =>
-          var childExprs: Seq[Expression] = Seq()
-          val exprs = children.map(e => {
-            val castExpr = Cast(e, StringType)
-            childExprs = childExprs :+ castExpr
-            exprToProtoInternal(castExpr, inputs)
-          })
-          val optExpr = scalarExprToProto("concat_ws", exprs: _*)
-          optExprWithInfo(optExpr, expr, childExprs: _*)
-
-        case Chr(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("chr", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case InitCap(child) =>
-          if (CometConf.COMET_EXEC_INITCAP_ENABLED.get()) {
-            val castExpr = Cast(child, StringType)
-            val childExpr = exprToProtoInternal(castExpr, inputs)
-            val optExpr = scalarExprToProto("initcap", childExpr)
-            optExprWithInfo(optExpr, expr, castExpr)
-          } else {
-            withInfo(
-              expr,
-              "Comet initCap is not compatible with Spark yet. " +
-                "See https://github.com/apache/datafusion-comet/issues/1052 ." 
+
-                s"Set ${CometConf.COMET_EXEC_INITCAP_ENABLED.key}=true to 
enable it anyway.")
-            None
           }
-
-        case Length(child) =>
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setCaseWhen(builder)
+              .build())
+        } else {
+          withInfo(expr, allBranches: _*)
+          None
+        }
+      case ConcatWs(children) =>
+        var childExprs: Seq[Expression] = Seq()
+        val exprs = children.map(e => {
+          val castExpr = Cast(e, StringType)
+          childExprs = childExprs :+ castExpr
+          exprToProtoInternal(castExpr, inputs, binding)
+        })
+        val optExpr = scalarExprToProto("concat_ws", exprs: _*)
+        optExprWithInfo(optExpr, expr, childExprs: _*)
+
+      case Chr(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr = scalarExprToProto("chr", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case InitCap(child) =>
+        if (CometConf.COMET_EXEC_INITCAP_ENABLED.get()) {
           val castExpr = Cast(child, StringType)
-          val childExpr = exprToProtoInternal(castExpr, inputs)
-          val optExpr = scalarExprToProto("length", childExpr)
+          val childExpr = exprToProtoInternal(castExpr, inputs, binding)
+          val optExpr = scalarExprToProto("initcap", childExpr)
           optExprWithInfo(optExpr, expr, castExpr)
+        } else {
+          withInfo(
+            expr,
+            "Comet initCap is not compatible with Spark yet. " +
+              "See https://github.com/apache/datafusion-comet/issues/1052 ." +
+              s"Set ${CometConf.COMET_EXEC_INITCAP_ENABLED.key}=true to enable 
it anyway.")
+          None
+        }
 
-        case Md5(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProto("md5", childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case OctetLength(child) =>
+      case Length(child) =>
+        val castExpr = Cast(child, StringType)
+        val childExpr = exprToProtoInternal(castExpr, inputs, binding)
+        val optExpr = scalarExprToProto("length", childExpr)
+        optExprWithInfo(optExpr, expr, castExpr)
+
+      case Md5(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr = scalarExprToProto("md5", childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case OctetLength(child) =>
+        val castExpr = Cast(child, StringType)
+        val childExpr = exprToProtoInternal(castExpr, inputs, binding)
+        val optExpr = scalarExprToProto("octet_length", childExpr)
+        optExprWithInfo(optExpr, expr, castExpr)
+
+      case Reverse(child) =>
+        val castExpr = Cast(child, StringType)
+        val childExpr = exprToProtoInternal(castExpr, inputs, binding)
+        val optExpr = scalarExprToProto("reverse", childExpr)
+        optExprWithInfo(optExpr, expr, castExpr)
+
+      case StringInstr(str, substr) =>
+        val leftCast = Cast(str, StringType)
+        val rightCast = Cast(substr, StringType)
+        val leftExpr = exprToProtoInternal(leftCast, inputs, binding)
+        val rightExpr = exprToProtoInternal(rightCast, inputs, binding)
+        val optExpr = scalarExprToProto("strpos", leftExpr, rightExpr)
+        optExprWithInfo(optExpr, expr, leftCast, rightCast)
+
+      case StringRepeat(str, times) =>
+        val leftCast = Cast(str, StringType)
+        val rightCast = Cast(times, LongType)
+        val leftExpr = exprToProtoInternal(leftCast, inputs, binding)
+        val rightExpr = exprToProtoInternal(rightCast, inputs, binding)
+        val optExpr = scalarExprToProto("repeat", leftExpr, rightExpr)
+        optExprWithInfo(optExpr, expr, leftCast, rightCast)
+
+      case StringReplace(src, search, replace) =>
+        val srcCast = Cast(src, StringType)
+        val searchCast = Cast(search, StringType)
+        val replaceCast = Cast(replace, StringType)
+        val srcExpr = exprToProtoInternal(srcCast, inputs, binding)
+        val searchExpr = exprToProtoInternal(searchCast, inputs, binding)
+        val replaceExpr = exprToProtoInternal(replaceCast, inputs, binding)
+        val optExpr = scalarExprToProto("replace", srcExpr, searchExpr, 
replaceExpr)
+        optExprWithInfo(optExpr, expr, srcCast, searchCast, replaceCast)
+
+      case StringTranslate(src, matching, replace) =>
+        val srcCast = Cast(src, StringType)
+        val matchingCast = Cast(matching, StringType)
+        val replaceCast = Cast(replace, StringType)
+        val srcExpr = exprToProtoInternal(srcCast, inputs, binding)
+        val matchingExpr = exprToProtoInternal(matchingCast, inputs, binding)
+        val replaceExpr = exprToProtoInternal(replaceCast, inputs, binding)
+        val optExpr = scalarExprToProto("translate", srcExpr, matchingExpr, 
replaceExpr)
+        optExprWithInfo(optExpr, expr, srcCast, matchingCast, replaceCast)
+
+      case StringTrim(srcStr, trimStr) =>
+        trim(expr, srcStr, trimStr, inputs, binding, "trim")
+
+      case StringTrimLeft(srcStr, trimStr) =>
+        trim(expr, srcStr, trimStr, inputs, binding, "ltrim")
+
+      case StringTrimRight(srcStr, trimStr) =>
+        trim(expr, srcStr, trimStr, inputs, binding, "rtrim")
+
+      case StringTrimBoth(srcStr, trimStr, _) =>
+        trim(expr, srcStr, trimStr, inputs, binding, "btrim")
+
+      case Upper(child) =>
+        if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) {
           val castExpr = Cast(child, StringType)
-          val childExpr = exprToProtoInternal(castExpr, inputs)
-          val optExpr = scalarExprToProto("octet_length", childExpr)
+          val childExpr = exprToProtoInternal(castExpr, inputs, binding)
+          val optExpr = scalarExprToProto("upper", childExpr)
           optExprWithInfo(optExpr, expr, castExpr)
+        } else {
+          withInfo(
+            expr,
+            "Comet is not compatible with Spark for case conversion in " +
+              s"locale-specific cases. Set 
${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " +
+              "to enable it anyway.")
+          None
+        }
 
-        case Reverse(child) =>
+      case Lower(child) =>
+        if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) {
           val castExpr = Cast(child, StringType)
-          val childExpr = exprToProtoInternal(castExpr, inputs)
-          val optExpr = scalarExprToProto("reverse", childExpr)
+          val childExpr = exprToProtoInternal(castExpr, inputs, binding)
+          val optExpr = scalarExprToProto("lower", childExpr)
           optExprWithInfo(optExpr, expr, castExpr)
+        } else {
+          withInfo(
+            expr,
+            "Comet is not compatible with Spark for case conversion in " +
+              s"locale-specific cases. Set 
${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " +
+              "to enable it anyway.")
+          None
+        }
 
-        case StringInstr(str, substr) =>
-          val leftCast = Cast(str, StringType)
-          val rightCast = Cast(substr, StringType)
-          val leftExpr = exprToProtoInternal(leftCast, inputs)
-          val rightExpr = exprToProtoInternal(rightCast, inputs)
-          val optExpr = scalarExprToProto("strpos", leftExpr, rightExpr)
-          optExprWithInfo(optExpr, expr, leftCast, rightCast)
-
-        case StringRepeat(str, times) =>
-          val leftCast = Cast(str, StringType)
-          val rightCast = Cast(times, LongType)
-          val leftExpr = exprToProtoInternal(leftCast, inputs)
-          val rightExpr = exprToProtoInternal(rightCast, inputs)
-          val optExpr = scalarExprToProto("repeat", leftExpr, rightExpr)
-          optExprWithInfo(optExpr, expr, leftCast, rightCast)
-
-        case StringReplace(src, search, replace) =>
-          val srcCast = Cast(src, StringType)
-          val searchCast = Cast(search, StringType)
-          val replaceCast = Cast(replace, StringType)
-          val srcExpr = exprToProtoInternal(srcCast, inputs)
-          val searchExpr = exprToProtoInternal(searchCast, inputs)
-          val replaceExpr = exprToProtoInternal(replaceCast, inputs)
-          val optExpr = scalarExprToProto("replace", srcExpr, searchExpr, 
replaceExpr)
-          optExprWithInfo(optExpr, expr, srcCast, searchCast, replaceCast)
-
-        case StringTranslate(src, matching, replace) =>
-          val srcCast = Cast(src, StringType)
-          val matchingCast = Cast(matching, StringType)
-          val replaceCast = Cast(replace, StringType)
-          val srcExpr = exprToProtoInternal(srcCast, inputs)
-          val matchingExpr = exprToProtoInternal(matchingCast, inputs)
-          val replaceExpr = exprToProtoInternal(replaceCast, inputs)
-          val optExpr = scalarExprToProto("translate", srcExpr, matchingExpr, 
replaceExpr)
-          optExprWithInfo(optExpr, expr, srcCast, matchingCast, replaceCast)
-
-        case StringTrim(srcStr, trimStr) =>
-          trim(expr, srcStr, trimStr, inputs, "trim")
-
-        case StringTrimLeft(srcStr, trimStr) =>
-          trim(expr, srcStr, trimStr, inputs, "ltrim")
-
-        case StringTrimRight(srcStr, trimStr) =>
-          trim(expr, srcStr, trimStr, inputs, "rtrim")
-
-        case StringTrimBoth(srcStr, trimStr, _) =>
-          trim(expr, srcStr, trimStr, inputs, "btrim")
-
-        case Upper(child) =>
-          if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) {
-            val castExpr = Cast(child, StringType)
-            val childExpr = exprToProtoInternal(castExpr, inputs)
-            val optExpr = scalarExprToProto("upper", childExpr)
-            optExprWithInfo(optExpr, expr, castExpr)
-          } else {
-            withInfo(
-              expr,
-              "Comet is not compatible with Spark for case conversion in " +
-                s"locale-specific cases. Set 
${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " +
-                "to enable it anyway.")
-            None
-          }
-
-        case Lower(child) =>
-          if (CometConf.COMET_CASE_CONVERSION_ENABLED.get()) {
-            val castExpr = Cast(child, StringType)
-            val childExpr = exprToProtoInternal(castExpr, inputs)
-            val optExpr = scalarExprToProto("lower", childExpr)
-            optExprWithInfo(optExpr, expr, castExpr)
-          } else {
-            withInfo(
-              expr,
-              "Comet is not compatible with Spark for case conversion in " +
-                s"locale-specific cases. Set 
${CometConf.COMET_CASE_CONVERSION_ENABLED.key}=true " +
-                "to enable it anyway.")
-            None
-          }
-
-        case BitwiseAnd(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr))
-
-        case BitwiseNot(child) =>
-          createUnaryExpr(child, inputs, (builder, unaryExpr) => 
builder.setBitwiseNot(unaryExpr))
-
-        case BitwiseOr(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setBitwiseOr(binaryExpr))
+      case BitwiseAnd(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setBitwiseAnd(binaryExpr))
+
+      case BitwiseNot(child) =>
+        createUnaryExpr(
+          expr,
+          child,
+          inputs,
+          binding,
+          (builder, unaryExpr) => builder.setBitwiseNot(unaryExpr))
+
+      case BitwiseOr(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setBitwiseOr(binaryExpr))
+
+      case BitwiseXor(left, right) =>
+        createBinaryExpr(
+          expr,
+          left,
+          right,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr))
+
+      case ShiftRight(left, right) =>
+        // DataFusion bitwise shift right expression requires
+        // same data type between left and right side
+        val rightExpression = if (left.dataType == LongType) {
+          Cast(right, LongType)
+        } else {
+          right
+        }
+
+        createBinaryExpr(
+          expr,
+          left,
+          rightExpression,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setBitwiseShiftRight(binaryExpr))
+
+      case ShiftLeft(left, right) =>
+        // DataFusion bitwise shift right expression requires
+        // same data type between left and right side
+        val rightExpression = if (left.dataType == LongType) {
+          Cast(right, LongType)
+        } else {
+          right
+        }
+
+        createBinaryExpr(
+          expr,
+          left,
+          rightExpression,
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setBitwiseShiftLeft(binaryExpr))
+      case In(value, list) =>
+        in(expr, value, list, inputs, binding, negate = false)
+
+      case InSet(value, hset) =>
+        val valueDataType = value.dataType
+        val list = hset.map { setVal =>
+          Literal(setVal, valueDataType)
+        }.toSeq
+        // Change `InSet` to `In` expression
+        // We do Spark `InSet` optimization in native (DataFusion) side.
+        in(expr, value, list, inputs, binding, negate = false)
+
+      case Not(In(value, list)) =>
+        in(expr, value, list, inputs, binding, negate = true)
+
+      case Not(child) =>
+        createUnaryExpr(
+          expr,
+          child,
+          inputs,
+          binding,
+          (builder, unaryExpr) => builder.setNot(unaryExpr))
+
+      case UnaryMinus(child, failOnError) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        if (childExpr.isDefined) {
+          val builder = ExprOuterClass.UnaryMinus.newBuilder()
+          builder.setChild(childExpr.get)
+          builder.setFailOnError(failOnError)
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setUnaryMinus(builder)
+              .build())
+        } else {
+          withInfo(expr, child)
+          None
+        }
 
-        case BitwiseXor(left, right) =>
-          createBinaryExpr(
-            left,
-            right,
-            inputs,
-            (builder, binaryExpr) => builder.setBitwiseXor(binaryExpr))
+      case a @ Coalesce(_) =>
+        val exprChildren = a.children.map(exprToProtoInternal(_, inputs, 
binding))
+        scalarExprToProto("coalesce", exprChildren: _*)
+
+      // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called 
to pad spaces for
+      // char types.
+      // See https://github.com/apache/spark/pull/38151
+      case s: StaticInvoke
+          if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] &&
+            s.dataType.isInstanceOf[StringType] &&
+            s.functionName == "readSidePadding" &&
+            s.arguments.size == 2 &&
+            s.propagateNull &&
+            !s.returnNullable &&
+            s.isDeterministic =>
+        val argsExpr = Seq(
+          exprToProtoInternal(Cast(s.arguments(0), StringType), inputs, 
binding),
+          exprToProtoInternal(s.arguments(1), inputs, binding))
+
+        if (argsExpr.forall(_.isDefined)) {
+          val builder = ExprOuterClass.ScalarFunc.newBuilder()
+          builder.setFunc("read_side_padding")
+          argsExpr.foreach(arg => builder.addArgs(arg.get))
+
+          Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
+        } else {
+          withInfo(expr, s.arguments: _*)
+          None
+        }
 
-        case ShiftRight(left, right) =>
-          // DataFusion bitwise shift right expression requires
-          // same data type between left and right side
-          val rightExpression = if (left.dataType == LongType) {
-            Cast(right, LongType)
-          } else {
-            right
-          }
+      case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) =>
+        val dataType = serializeDataType(expr.dataType)
+        if (dataType.isEmpty) {
+          withInfo(expr, s"Unsupported datatype ${expr.dataType}")
+          return None
+        }
+        val ex = exprToProtoInternal(expr, inputs, binding)
+        ex.map { child =>
+          val builder = ExprOuterClass.NormalizeNaNAndZero
+            .newBuilder()
+            .setChild(child)
+            .setDatatype(dataType.get)
+          
ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build()
+        }
 
-          createBinaryExpr(
-            left,
-            rightExpression,
-            inputs,
-            (builder, binaryExpr) => builder.setBitwiseShiftRight(binaryExpr))
+      case s @ execution.ScalarSubquery(_, _) if supportedDataType(s.dataType) 
=>
+        val dataType = serializeDataType(s.dataType)
+        if (dataType.isEmpty) {
+          withInfo(s, s"Scalar subquery returns unsupported datatype 
${s.dataType}")
+          return None
+        }
 
-        case ShiftLeft(left, right) =>
-          // DataFusion bitwise shift right expression requires
-          // same data type between left and right side
-          val rightExpression = if (left.dataType == LongType) {
-            Cast(right, LongType)
-          } else {
-            right
-          }
+        val builder = ExprOuterClass.Subquery
+          .newBuilder()
+          .setId(s.exprId.id)
+          .setDatatype(dataType.get)
+        Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build())
+
+      case UnscaledValue(child) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr = scalarExprToProtoWithReturnType("unscaled_value", 
LongType, childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case MakeDecimal(child, precision, scale, true) =>
+        val childExpr = exprToProtoInternal(child, inputs, binding)
+        val optExpr = scalarExprToProtoWithReturnType(
+          "make_decimal",
+          DecimalType(precision, scale),
+          childExpr)
+        optExprWithInfo(optExpr, expr, child)
+
+      case b @ BloomFilterMightContain(_, _) =>
+        val bloomFilter = b.left
+        val value = b.right
+        val bloomFilterExpr = exprToProtoInternal(bloomFilter, inputs, binding)
+        val valueExpr = exprToProtoInternal(value, inputs, binding)
+        if (bloomFilterExpr.isDefined && valueExpr.isDefined) {
+          val builder = ExprOuterClass.BloomFilterMightContain.newBuilder()
+          builder.setBloomFilter(bloomFilterExpr.get)
+          builder.setValue(valueExpr.get)
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setBloomFilterMightContain(builder)
+              .build())
+        } else {
+          withInfo(expr, bloomFilter, value)
+          None
+        }
 
-          createBinaryExpr(
-            left,
-            rightExpression,
-            inputs,
-            (builder, binaryExpr) => builder.setBitwiseShiftLeft(binaryExpr))
-        case In(value, list) =>
-          in(expr, value, list, inputs, false)
-
-        case InSet(value, hset) =>
-          val valueDataType = value.dataType
-          val list = hset.map { setVal =>
-            Literal(setVal, valueDataType)
-          }.toSeq
-          // Change `InSet` to `In` expression
-          // We do Spark `InSet` optimization in native (DataFusion) side.
-          in(expr, value, list, inputs, false)
-
-        case Not(In(value, list)) =>
-          in(expr, value, list, inputs, true)
-
-        case Not(child) =>
-          createUnaryExpr(child, inputs, (builder, unaryExpr) => 
builder.setNot(unaryExpr))
-
-        case UnaryMinus(child, failOnError) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          if (childExpr.isDefined) {
-            val builder = ExprOuterClass.UnaryMinus.newBuilder()
-            builder.setChild(childExpr.get)
-            builder.setFailOnError(failOnError)
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setUnaryMinus(builder)
-                .build())
-          } else {
-            withInfo(expr, child)
-            None
-          }
+      case Murmur3Hash(children, seed) =>
+        val firstUnSupportedInput = children.find(c => 
!supportedDataType(c.dataType))
+        if (firstUnSupportedInput.isDefined) {
+          withInfo(expr, s"Unsupported datatype 
${firstUnSupportedInput.get.dataType}")
+          return None
+        }
+        val exprs = children.map(exprToProtoInternal(_, inputs, binding))
+        val seedBuilder = ExprOuterClass.Literal
+          .newBuilder()
+          .setDatatype(serializeDataType(IntegerType).get)
+          .setIntVal(seed)
+        val seedExpr = 
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
+        // the seed is put at the end of the arguments
+        scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ 
seedExpr: _*)
+
+      case XxHash64(children, seed) =>
+        val firstUnSupportedInput = children.find(c => 
!supportedDataType(c.dataType))
+        if (firstUnSupportedInput.isDefined) {
+          withInfo(expr, s"Unsupported datatype 
${firstUnSupportedInput.get.dataType}")
+          return None
+        }
+        val exprs = children.map(exprToProtoInternal(_, inputs, binding))
+        val seedBuilder = ExprOuterClass.Literal
+          .newBuilder()
+          .setDatatype(serializeDataType(LongType).get)
+          .setLongVal(seed)
+        val seedExpr = 
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
+        // the seed is put at the end of the arguments
+        scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ 
seedExpr: _*)
+
+      case Sha2(left, numBits) =>
+        if (!numBits.foldable) {
+          withInfo(expr, "non literal numBits is not supported")
+          return None
+        }
+        // it's possible for spark to dynamically compute the number of bits 
from input
+        // expression, however DataFusion does not support that yet.
+        val childExpr = exprToProtoInternal(left, inputs, binding)
+        val bits = numBits.eval().asInstanceOf[Int]
+        val algorithm = bits match {
+          case 224 => "sha224"
+          case 256 | 0 => "sha256"
+          case 384 => "sha384"
+          case 512 => "sha512"
+          case _ =>
+            null
+        }
+        if (algorithm == null) {
+          exprToProtoInternal(Literal(null, StringType), inputs, binding)
+        } else {
+          scalarExprToProtoWithReturnType(algorithm, StringType, childExpr)
+        }
 
-        case a @ Coalesce(_) =>
-          val exprChildren = a.children.map(exprToProtoInternal(_, inputs))
-          scalarExprToProto("coalesce", exprChildren: _*)
-
-        // With Spark 3.4, CharVarcharCodegenUtils.readSidePadding gets called 
to pad spaces for
-        // char types.
-        // See https://github.com/apache/spark/pull/38151
-        case s: StaticInvoke
-            if s.staticObject.isInstanceOf[Class[CharVarcharCodegenUtils]] &&
-              s.dataType.isInstanceOf[StringType] &&
-              s.functionName == "readSidePadding" &&
-              s.arguments.size == 2 &&
-              s.propagateNull &&
-              !s.returnNullable &&
-              s.isDeterministic =>
-          val argsExpr = Seq(
-            exprToProtoInternal(Cast(s.arguments(0), StringType), inputs),
-            exprToProtoInternal(s.arguments(1), inputs))
-
-          if (argsExpr.forall(_.isDefined)) {
-            val builder = ExprOuterClass.ScalarFunc.newBuilder()
-            builder.setFunc("read_side_padding")
-            argsExpr.foreach(arg => builder.addArgs(arg.get))
-
-            
Some(ExprOuterClass.Expr.newBuilder().setScalarFunc(builder).build())
-          } else {
-            withInfo(expr, s.arguments: _*)
-            None
-          }
+      case struct @ CreateNamedStruct(_) =>
+        if (struct.names.length != struct.names.distinct.length) {
+          withInfo(expr, "CreateNamedStruct with duplicate field names are not 
supported")
+          return None
+        }
 
-        case KnownFloatingPointNormalized(NormalizeNaNAndZero(expr)) =>
-          val dataType = serializeDataType(expr.dataType)
-          if (dataType.isEmpty) {
-            withInfo(expr, s"Unsupported datatype ${expr.dataType}")
-            return None
-          }
-          val ex = exprToProtoInternal(expr, inputs)
-          ex.map { child =>
-            val builder = ExprOuterClass.NormalizeNaNAndZero
-              .newBuilder()
-              .setChild(child)
-              .setDatatype(dataType.get)
-            
ExprOuterClass.Expr.newBuilder().setNormalizeNanAndZero(builder).build()
-          }
+        val valExprs = struct.valExprs.map(exprToProto(_, inputs, binding))
 
-        case s @ execution.ScalarSubquery(_, _) if 
supportedDataType(s.dataType) =>
-          val dataType = serializeDataType(s.dataType)
-          if (dataType.isEmpty) {
-            withInfo(s, s"Scalar subquery returns unsupported datatype 
${s.dataType}")
-            return None
-          }
+        if (valExprs.forall(_.isDefined)) {
+          val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder()
+          structBuilder.addAllValues(valExprs.map(_.get).asJava)
+          structBuilder.addAllNames(struct.names.map(_.toString).asJava)
 
-          val builder = ExprOuterClass.Subquery
-            .newBuilder()
-            .setId(s.exprId.id)
-            .setDatatype(dataType.get)
-          Some(ExprOuterClass.Expr.newBuilder().setSubquery(builder).build())
-
-        case UnscaledValue(child) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProtoWithReturnType("unscaled_value", 
LongType, childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case MakeDecimal(child, precision, scale, true) =>
-          val childExpr = exprToProtoInternal(child, inputs)
-          val optExpr = scalarExprToProtoWithReturnType(
-            "make_decimal",
-            DecimalType(precision, scale),
-            childExpr)
-          optExprWithInfo(optExpr, expr, child)
-
-        case b @ BloomFilterMightContain(_, _) =>
-          val bloomFilter = b.left
-          val value = b.right
-          val bloomFilterExpr = exprToProtoInternal(bloomFilter, inputs)
-          val valueExpr = exprToProtoInternal(value, inputs)
-          if (bloomFilterExpr.isDefined && valueExpr.isDefined) {
-            val builder = ExprOuterClass.BloomFilterMightContain.newBuilder()
-            builder.setBloomFilter(bloomFilterExpr.get)
-            builder.setValue(valueExpr.get)
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setBloomFilterMightContain(builder)
-                .build())
-          } else {
-            withInfo(expr, bloomFilter, value)
-            None
-          }
+          Some(
+            ExprOuterClass.Expr
+              .newBuilder()
+              .setCreateNamedStruct(structBuilder)
+              .build())
+        } else {
+          withInfo(expr, "unsupported arguments for CreateNamedStruct", 
struct.valExprs: _*)
+          None
+        }
 
-        case Murmur3Hash(children, seed) =>
-          val firstUnSupportedInput = children.find(c => 
!supportedDataType(c.dataType))
-          if (firstUnSupportedInput.isDefined) {
-            withInfo(expr, s"Unsupported datatype 
${firstUnSupportedInput.get.dataType}")
-            return None
-          }
-          val exprs = children.map(exprToProtoInternal(_, inputs))
-          val seedBuilder = ExprOuterClass.Literal
-            .newBuilder()
-            .setDatatype(serializeDataType(IntegerType).get)
-            .setIntVal(seed)
-          val seedExpr = 
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
-          // the seed is put at the end of the arguments
-          scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs 
:+ seedExpr: _*)
-
-        case XxHash64(children, seed) =>
-          val firstUnSupportedInput = children.find(c => 
!supportedDataType(c.dataType))
-          if (firstUnSupportedInput.isDefined) {
-            withInfo(expr, s"Unsupported datatype 
${firstUnSupportedInput.get.dataType}")
-            return None
-          }
-          val exprs = children.map(exprToProtoInternal(_, inputs))
-          val seedBuilder = ExprOuterClass.Literal
+      case GetStructField(child, ordinal, _) =>
+        exprToProto(child, inputs, binding).map { childExpr =>
+          val getStructFieldBuilder = ExprOuterClass.GetStructField
             .newBuilder()
-            .setDatatype(serializeDataType(LongType).get)
-            .setLongVal(seed)
-          val seedExpr = 
Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build())
-          // the seed is put at the end of the arguments
-          scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ 
seedExpr: _*)
-
-        case Sha2(left, numBits) =>
-          if (!numBits.foldable) {
-            withInfo(expr, "non literal numBits is not supported")
-            return None
-          }
-          // it's possible for spark to dynamically compute the number of bits 
from input
-          // expression, however DataFusion does not support that yet.
-          val childExpr = exprToProtoInternal(left, inputs)
-          val bits = numBits.eval().asInstanceOf[Int]
-          val algorithm = bits match {
-            case 224 => "sha224"
-            case 256 | 0 => "sha256"
-            case 384 => "sha384"
-            case 512 => "sha512"
-            case _ =>
-              null
-          }
-          if (algorithm == null) {
-            exprToProtoInternal(Literal(null, StringType), inputs)
-          } else {
-            scalarExprToProtoWithReturnType(algorithm, StringType, childExpr)
-          }
+            .setChild(childExpr)
+            .setOrdinal(ordinal)
 
-        case struct @ CreateNamedStruct(_) =>
-          if (struct.names.length != struct.names.distinct.length) {
-            withInfo(expr, "CreateNamedStruct with duplicate field names are 
not supported")
-            return None
-          }
+          ExprOuterClass.Expr
+            .newBuilder()
+            .setGetStructField(getStructFieldBuilder)
+            .build()
+        }
 
-          val valExprs = struct.valExprs.map(exprToProto(_, inputs, binding))
+      case CreateArray(children, _) =>
+        val childExprs = children.map(exprToProto(_, inputs, binding))
 
-          if (valExprs.forall(_.isDefined)) {
-            val structBuilder = ExprOuterClass.CreateNamedStruct.newBuilder()
-            structBuilder.addAllValues(valExprs.map(_.get).asJava)
-            structBuilder.addAllNames(struct.names.map(_.toString).asJava)
+        if (childExprs.forall(_.isDefined)) {
+          scalarExprToProto("make_array", childExprs: _*)
+        } else {
+          withInfo(expr, "unsupported arguments for CreateArray", children: _*)
+          None
+        }
 
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setCreateNamedStruct(structBuilder)
-                .build())
-          } else {
-            withInfo(expr, "unsupported arguments for CreateNamedStruct", 
struct.valExprs: _*)
-            None
-          }
+      case GetArrayItem(child, ordinal, failOnError) =>
+        val childExpr = exprToProto(child, inputs, binding)
+        val ordinalExpr = exprToProto(ordinal, inputs, binding)
 
-        case GetStructField(child, ordinal, _) =>
-          exprToProto(child, inputs, binding).map { childExpr =>
-            val getStructFieldBuilder = ExprOuterClass.GetStructField
-              .newBuilder()
-              .setChild(childExpr)
-              .setOrdinal(ordinal)
+        if (childExpr.isDefined && ordinalExpr.isDefined) {
+          val listExtractBuilder = ExprOuterClass.ListExtract
+            .newBuilder()
+            .setChild(childExpr.get)
+            .setOrdinal(ordinalExpr.get)
+            .setOneBased(false)
+            .setFailOnError(failOnError)
 
+          Some(
             ExprOuterClass.Expr
               .newBuilder()
-              .setGetStructField(getStructFieldBuilder)
-              .build()
-          }
-
-        case CreateArray(children, _) =>
-          val childExprs = children.map(exprToProto(_, inputs, binding))
-
-          if (childExprs.forall(_.isDefined)) {
-            scalarExprToProto("make_array", childExprs: _*)
-          } else {
-            withInfo(expr, "unsupported arguments for CreateArray", children: 
_*)
-            None
-          }
+              .setListExtract(listExtractBuilder)
+              .build())
+        } else {
+          withInfo(expr, "unsupported arguments for GetArrayItem", child, 
ordinal)
+          None
+        }
 
-        case GetArrayItem(child, ordinal, failOnError) =>
-          val childExpr = exprToProto(child, inputs, binding)
-          val ordinalExpr = exprToProto(ordinal, inputs, binding)
+      case expr if expr.prettyName == "array_insert" =>
+        val srcExprProto = exprToProto(expr.children(0), inputs, binding)
+        val posExprProto = exprToProto(expr.children(1), inputs, binding)
+        val itemExprProto = exprToProto(expr.children(2), inputs, binding)
+        val legacyNegativeIndex =
+          
SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean
+        if (srcExprProto.isDefined && posExprProto.isDefined && 
itemExprProto.isDefined) {
+          val arrayInsertBuilder = ExprOuterClass.ArrayInsert
+            .newBuilder()
+            .setSrcArrayExpr(srcExprProto.get)
+            .setPosExpr(posExprProto.get)
+            .setItemExpr(itemExprProto.get)
+            .setLegacyNegativeIndex(legacyNegativeIndex)
 
-          if (childExpr.isDefined && ordinalExpr.isDefined) {
-            val listExtractBuilder = ExprOuterClass.ListExtract
+          Some(
+            ExprOuterClass.Expr
               .newBuilder()
-              .setChild(childExpr.get)
-              .setOrdinal(ordinalExpr.get)
-              .setOneBased(false)
-              .setFailOnError(failOnError)
-
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setListExtract(listExtractBuilder)
-                .build())
-          } else {
-            withInfo(expr, "unsupported arguments for GetArrayItem", child, 
ordinal)
-            None
-          }
+              .setArrayInsert(arrayInsertBuilder)
+              .build())
+        } else {
+          withInfo(
+            expr,
+            "unsupported arguments for ArrayInsert",
+            expr.children(0),
+            expr.children(1),
+            expr.children(2))
+          None
+        }
 
-        case expr if expr.prettyName == "array_insert" =>
-          val srcExprProto = exprToProto(expr.children(0), inputs, binding)
-          val posExprProto = exprToProto(expr.children(1), inputs, binding)
-          val itemExprProto = exprToProto(expr.children(2), inputs, binding)
-          val legacyNegativeIndex =
-            
SQLConf.get.getConfString("spark.sql.legacy.negativeIndexInArrayInsert").toBoolean
-          if (srcExprProto.isDefined && posExprProto.isDefined && 
itemExprProto.isDefined) {
-            val arrayInsertBuilder = ExprOuterClass.ArrayInsert
-              .newBuilder()
-              .setSrcArrayExpr(srcExprProto.get)
-              .setPosExpr(posExprProto.get)
-              .setItemExpr(itemExprProto.get)
-              .setLegacyNegativeIndex(legacyNegativeIndex)
+      case ElementAt(child, ordinal, defaultValue, failOnError)
+          if child.dataType.isInstanceOf[ArrayType] =>
+        val childExpr = exprToProto(child, inputs, binding)
+        val ordinalExpr = exprToProto(ordinal, inputs, binding)
+        val defaultExpr = defaultValue.flatMap(exprToProto(_, inputs, binding))
 
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setArrayInsert(arrayInsertBuilder)
-                .build())
-          } else {
-            withInfo(
-              expr,
-              "unsupported arguments for ArrayInsert",
-              expr.children(0),
-              expr.children(1),
-              expr.children(2))
-            None
-          }
+        if (childExpr.isDefined && ordinalExpr.isDefined &&
+          defaultExpr.isDefined == defaultValue.isDefined) {
+          val arrayExtractBuilder = ExprOuterClass.ListExtract
+            .newBuilder()
+            .setChild(childExpr.get)
+            .setOrdinal(ordinalExpr.get)
+            .setOneBased(true)
+            .setFailOnError(failOnError)
 
-        case ElementAt(child, ordinal, defaultValue, failOnError)
-            if child.dataType.isInstanceOf[ArrayType] =>
-          val childExpr = exprToProto(child, inputs, binding)
-          val ordinalExpr = exprToProto(ordinal, inputs, binding)
-          val defaultExpr = defaultValue.flatMap(exprToProto(_, inputs, 
binding))
+          defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_))
 
-          if (childExpr.isDefined && ordinalExpr.isDefined &&
-            defaultExpr.isDefined == defaultValue.isDefined) {
-            val arrayExtractBuilder = ExprOuterClass.ListExtract
+          Some(
+            ExprOuterClass.Expr
               .newBuilder()
-              .setChild(childExpr.get)
-              .setOrdinal(ordinalExpr.get)
-              .setOneBased(true)
-              .setFailOnError(failOnError)
-
-            defaultExpr.foreach(arrayExtractBuilder.setDefaultValue(_))
+              .setListExtract(arrayExtractBuilder)
+              .build())
+        } else {
+          withInfo(expr, "unsupported arguments for ElementAt", child, ordinal)
+          None
+        }
 
-            Some(
-              ExprOuterClass.Expr
-                .newBuilder()
-                .setListExtract(arrayExtractBuilder)
-                .build())
-          } else {
-            withInfo(expr, "unsupported arguments for ElementAt", child, 
ordinal)
-            None
-          }
+      case GetArrayStructFields(child, _, ordinal, _, _) =>
+        val childExpr = exprToProto(child, inputs, binding)
 
-        case GetArrayStructFields(child, _, ordinal, _, _) =>
-          val childExpr = exprToProto(child, inputs, binding)
+        if (childExpr.isDefined) {
+          val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
+            .newBuilder()
+            .setChild(childExpr.get)
+            .setOrdinal(ordinal)
 
-          if (childExpr.isDefined) {
-            val arrayStructFieldsBuilder = ExprOuterClass.GetArrayStructFields
+          Some(
+            ExprOuterClass.Expr
               .newBuilder()
-              .setChild(childExpr.get)
-              .setOrdinal(ordinal)
-
-            Some(
-              ExprOuterClass.Expr
+              .setGetArrayStructFields(arrayStructFieldsBuilder)
+              .build())
+        } else {
+          withInfo(expr, "unsupported arguments for GetArrayStructFields", 
child)
+          None
+        }
+      case expr: ArrayRemove => CometArrayRemove.convert(expr, inputs, binding)
+      case expr if expr.prettyName == "array_contains" =>
+        createBinaryExpr(
+          expr,
+          expr.children(0),
+          expr.children(1),
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setArrayContains(binaryExpr))
+      case _ if expr.prettyName == "array_append" =>
+        createBinaryExpr(
+          expr,
+          expr.children(0),
+          expr.children(1),
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setArrayAppend(binaryExpr))
+      case _ if expr.prettyName == "array_intersect" =>
+        createBinaryExpr(
+          expr,
+          expr.children(0),
+          expr.children(1),
+          inputs,
+          binding,
+          (builder, binaryExpr) => builder.setArrayIntersect(binaryExpr))
+      case ArrayJoin(arrayExpr, delimiterExpr, nullReplacementExpr) =>
+        val arrayExprProto = exprToProto(arrayExpr, inputs, binding)
+        val delimiterExprProto = exprToProto(delimiterExpr, inputs, binding)
+
+        if (arrayExprProto.isDefined && delimiterExprProto.isDefined) {
+          val arrayJoinBuilder = nullReplacementExpr match {
+            case Some(nrExpr) =>
+              val nullReplacementExprProto = exprToProto(nrExpr, inputs, 
binding)
+              ExprOuterClass.ArrayJoin
                 .newBuilder()
-                .setGetArrayStructFields(arrayStructFieldsBuilder)
-                .build())
-          } else {
-            withInfo(expr, "unsupported arguments for GetArrayStructFields", 
child)
-            None
-          }
-        case expr: ArrayRemove =>
-          if (CometArrayRemove.checkSupport(expr)) {
-            createBinaryExpr(
-              expr.children(0),
-              expr.children(1),
-              inputs,
-              (builder, binaryExpr) => builder.setArrayRemove(binaryExpr))
-          } else {
-            None
-          }
-        case expr if expr.prettyName == "array_contains" =>
-          createBinaryExpr(
-            expr.children(0),
-            expr.children(1),
-            inputs,
-            (builder, binaryExpr) => builder.setArrayContains(binaryExpr))
-        case _ if expr.prettyName == "array_append" =>
-          createBinaryExpr(
-            expr.children(0),
-            expr.children(1),
-            inputs,
-            (builder, binaryExpr) => builder.setArrayAppend(binaryExpr))
-        case _ if expr.prettyName == "array_intersect" =>
-          createBinaryExpr(
-            expr.children(0),
-            expr.children(1),
-            inputs,
-            (builder, binaryExpr) => builder.setArrayIntersect(binaryExpr))
-        case ArrayJoin(arrayExpr, delimiterExpr, nullReplacementExpr) =>
-          val arrayExprProto = exprToProto(arrayExpr, inputs, binding)
-          val delimiterExprProto = exprToProto(delimiterExpr, inputs, binding)
-
-          if (arrayExprProto.isDefined && delimiterExprProto.isDefined) {
-            val arrayJoinBuilder = nullReplacementExpr match {
-              case Some(nrExpr) =>
-                val nullReplacementExprProto = exprToProto(nrExpr, inputs, 
binding)
-                ExprOuterClass.ArrayJoin
-                  .newBuilder()
-                  .setArrayExpr(arrayExprProto.get)
-                  .setDelimiterExpr(delimiterExprProto.get)
-                  .setNullReplacementExpr(nullReplacementExprProto.get)
-              case None =>
-                ExprOuterClass.ArrayJoin
-                  .newBuilder()
-                  .setArrayExpr(arrayExprProto.get)
-                  .setDelimiterExpr(delimiterExprProto.get)
-            }
-            Some(
-              ExprOuterClass.Expr
+                .setArrayExpr(arrayExprProto.get)
+                .setDelimiterExpr(delimiterExprProto.get)
+                .setNullReplacementExpr(nullReplacementExprProto.get)
+            case None =>
+              ExprOuterClass.ArrayJoin
                 .newBuilder()
-                .setArrayJoin(arrayJoinBuilder)
-                .build())
-          } else {
-            val exprs: List[Expression] = nullReplacementExpr match {
-              case Some(nrExpr) => List(arrayExpr, delimiterExpr, nrExpr)
-              case None => List(arrayExpr, delimiterExpr)
-            }
-            withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*)
-            None
+                .setArrayExpr(arrayExprProto.get)
+                .setDelimiterExpr(delimiterExprProto.get)
           }
-        case _ =>
-          withInfo(expr, s"${expr.prettyName} is not supported", 
expr.children: _*)
-          None
-      }
-    }
-
-    /**
-     * Creates a UnaryExpr by calling exprToProtoInternal for the provided 
child expression and
-     * then invokes the supplied function to wrap this UnaryExpr in a 
top-level Expr.
-     *
-     * @param child
-     *   Spark expression
-     * @param inputs
-     *   Inputs to the expression
-     * @param f
-     *   Function that accepts an Expr.Builder and a UnaryExpr and builds the 
specific top-level
-     *   Expr
-     * @return
-     *   Some(Expr) or None if not supported
-     */
-    def createUnaryExpr(
-        child: Expression,
-        inputs: Seq[Attribute],
-        f: (ExprOuterClass.Expr.Builder, ExprOuterClass.UnaryExpr) => 
ExprOuterClass.Expr.Builder)
-        : Option[ExprOuterClass.Expr] = {
-      val childExpr = exprToProtoInternal(child, inputs)
-      if (childExpr.isDefined) {
-        // create the generic UnaryExpr message
-        val inner = ExprOuterClass.UnaryExpr
-          .newBuilder()
-          .setChild(childExpr.get)
-          .build()
-        // call the user-supplied function to wrap UnaryExpr in a top-level 
Expr
-        // such as Expr.IsNull or Expr.IsNotNull
-        Some(
-          f(
+          Some(
             ExprOuterClass.Expr
-              .newBuilder(),
-            inner).build())
-      } else {
-        withInfo(expr, child)
+              .newBuilder()
+              .setArrayJoin(arrayJoinBuilder)
+              .build())
+        } else {
+          val exprs: List[Expression] = nullReplacementExpr match {
+            case Some(nrExpr) => List(arrayExpr, delimiterExpr, nrExpr)
+            case None => List(arrayExpr, delimiterExpr)
+          }
+          withInfo(expr, "unsupported arguments for ArrayJoin", exprs: _*)
+          None
+        }
+      case _ =>
+        withInfo(expr, s"${expr.prettyName} is not supported", expr.children: 
_*)
         None
-      }
     }
+  }
 
-    def createBinaryExpr(
-        left: Expression,
-        right: Expression,
-        inputs: Seq[Attribute],
-        f: (
-            ExprOuterClass.Expr.Builder,
-            ExprOuterClass.BinaryExpr) => ExprOuterClass.Expr.Builder)
-        : Option[ExprOuterClass.Expr] = {
-      val leftExpr = exprToProtoInternal(left, inputs)
-      val rightExpr = exprToProtoInternal(right, inputs)
-      if (leftExpr.isDefined && rightExpr.isDefined) {
-        // create the generic BinaryExpr message
-        val inner = ExprOuterClass.BinaryExpr
-          .newBuilder()
-          .setLeft(leftExpr.get)
-          .setRight(rightExpr.get)
-          .build()
-        // call the user-supplied function to wrap BinaryExpr in a top-level 
Expr
-        // such as Expr.And or Expr.Or
-        Some(
-          f(
-            ExprOuterClass.Expr
-              .newBuilder(),
-            inner).build())
-      } else {
-        withInfo(expr, left, right)
-        None
-      }
+  /**
+   * Creates a UnaryExpr by calling exprToProtoInternal for the provided child 
expression and then
+   * invokes the supplied function to wrap this UnaryExpr in a top-level Expr.
+   *
+   * @param child
+   *   Spark expression
+   * @param inputs
+   *   Inputs to the expression
+   * @param f
+   *   Function that accepts an Expr.Builder and a UnaryExpr and builds the 
specific top-level
+   *   Expr
+   * @return
+   *   Some(Expr) or None if not supported
+   */
+  def createUnaryExpr(
+      expr: Expression,
+      child: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      f: (ExprOuterClass.Expr.Builder, ExprOuterClass.UnaryExpr) => 
ExprOuterClass.Expr.Builder)
+      : Option[ExprOuterClass.Expr] = {
+    val childExpr = exprToProtoInternal(child, inputs, binding) // TODO review
+    if (childExpr.isDefined) {
+      // create the generic UnaryExpr message
+      val inner = ExprOuterClass.UnaryExpr
+        .newBuilder()
+        .setChild(childExpr.get)
+        .build()
+      // call the user-supplied function to wrap UnaryExpr in a top-level Expr
+      // such as Expr.IsNull or Expr.IsNotNull
+      Some(
+        f(
+          ExprOuterClass.Expr
+            .newBuilder(),
+          inner).build())
+    } else {
+      withInfo(expr, child)
+      None
     }
+  }
 
-    def createMathExpression(
-        left: Expression,
-        right: Expression,
-        inputs: Seq[Attribute],
-        dataType: DataType,
-        failOnError: Boolean,
-        f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => 
ExprOuterClass.Expr.Builder)
-        : Option[ExprOuterClass.Expr] = {
-      val leftExpr = exprToProtoInternal(left, inputs)
-      val rightExpr = exprToProtoInternal(right, inputs)
-
-      if (leftExpr.isDefined && rightExpr.isDefined) {
-        // create the generic MathExpr message
-        val builder = ExprOuterClass.MathExpr.newBuilder()
-        builder.setLeft(leftExpr.get)
-        builder.setRight(rightExpr.get)
-        builder.setFailOnError(failOnError)
-        serializeDataType(dataType).foreach { t =>
-          builder.setReturnType(t)
-        }
-        val inner = builder.build()
-        // call the user-supplied function to wrap MathExpr in a top-level Expr
-        // such as Expr.Add or Expr.Divide
-        Some(
-          f(
-            ExprOuterClass.Expr
-              .newBuilder(),
-            inner).build())
-      } else {
-        withInfo(expr, left, right)
-        None
-      }
+  def createBinaryExpr(
+      expr: Expression,
+      left: Expression,
+      right: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      f: (ExprOuterClass.Expr.Builder, ExprOuterClass.BinaryExpr) => 
ExprOuterClass.Expr.Builder)
+      : Option[ExprOuterClass.Expr] = {
+    val leftExpr = exprToProtoInternal(left, inputs, binding)
+    val rightExpr = exprToProtoInternal(right, inputs, binding)
+    if (leftExpr.isDefined && rightExpr.isDefined) {
+      // create the generic BinaryExpr message
+      val inner = ExprOuterClass.BinaryExpr
+        .newBuilder()
+        .setLeft(leftExpr.get)
+        .setRight(rightExpr.get)
+        .build()
+      // call the user-supplied function to wrap BinaryExpr in a top-level Expr
+      // such as Expr.And or Expr.Or
+      Some(
+        f(
+          ExprOuterClass.Expr
+            .newBuilder(),
+          inner).build())
+    } else {
+      withInfo(expr, left, right)
+      None
     }
+  }
 
-    def trim(
-        expr: Expression, // parent expression
-        srcStr: Expression,
-        trimStr: Option[Expression],
-        inputs: Seq[Attribute],
-        trimType: String): Option[Expr] = {
-      val srcCast = Cast(srcStr, StringType)
-      val srcExpr = exprToProtoInternal(srcCast, inputs)
-      if (trimStr.isDefined) {
-        val trimCast = Cast(trimStr.get, StringType)
-        val trimExpr = exprToProtoInternal(trimCast, inputs)
-        val optExpr = scalarExprToProto(trimType, srcExpr, trimExpr)
-        optExprWithInfo(optExpr, expr, srcCast, trimCast)
-      } else {
-        val optExpr = scalarExprToProto(trimType, srcExpr)
-        optExprWithInfo(optExpr, expr, srcCast)
+  def createMathExpression(
+      expr: Expression,
+      left: Expression,
+      right: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      dataType: DataType,
+      failOnError: Boolean,
+      f: (ExprOuterClass.Expr.Builder, ExprOuterClass.MathExpr) => 
ExprOuterClass.Expr.Builder)
+      : Option[ExprOuterClass.Expr] = {
+    val leftExpr = exprToProtoInternal(left, inputs, binding)
+    val rightExpr = exprToProtoInternal(right, inputs, binding)
+
+    if (leftExpr.isDefined && rightExpr.isDefined) {
+      // create the generic MathExpr message
+      val builder = ExprOuterClass.MathExpr.newBuilder()
+      builder.setLeft(leftExpr.get)
+      builder.setRight(rightExpr.get)
+      builder.setFailOnError(failOnError)
+      serializeDataType(dataType).foreach { t =>
+        builder.setReturnType(t)
       }
+      val inner = builder.build()
+      // call the user-supplied function to wrap MathExpr in a top-level Expr
+      // such as Expr.Add or Expr.Divide
+      Some(
+        f(
+          ExprOuterClass.Expr
+            .newBuilder(),
+          inner).build())
+    } else {
+      withInfo(expr, left, right)
+      None
     }
+  }
 
-    def in(
-        expr: Expression,
-        value: Expression,
-        list: Seq[Expression],
-        inputs: Seq[Attribute],
-        negate: Boolean): Option[Expr] = {
-      val valueExpr = exprToProtoInternal(value, inputs)
-      val listExprs = list.map(exprToProtoInternal(_, inputs))
-      if (valueExpr.isDefined && listExprs.forall(_.isDefined)) {
-        val builder = ExprOuterClass.In.newBuilder()
-        builder.setInValue(valueExpr.get)
-        builder.addAllLists(listExprs.map(_.get).asJava)
-        builder.setNegated(negate)
-        Some(
-          ExprOuterClass.Expr
-            .newBuilder()
-            .setIn(builder)
-            .build())
-      } else {
-        val allExprs = list ++ Seq(value)
-        withInfo(expr, allExprs: _*)
-        None
-      }
+  def trim(
+      expr: Expression, // parent expression
+      srcStr: Expression,
+      trimStr: Option[Expression],
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      trimType: String): Option[Expr] = {
+    val srcCast = Cast(srcStr, StringType)
+    val srcExpr = exprToProtoInternal(srcCast, inputs, binding)
+    if (trimStr.isDefined) {
+      val trimCast = Cast(trimStr.get, StringType)
+      val trimExpr = exprToProtoInternal(trimCast, inputs, binding)
+      val optExpr = scalarExprToProto(trimType, srcExpr, trimExpr)
+      optExprWithInfo(optExpr, expr, srcCast, trimCast)
+    } else {
+      val optExpr = scalarExprToProto(trimType, srcExpr)
+      optExprWithInfo(optExpr, expr, srcCast)
     }
+  }
 
-    val conf = SQLConf.get
-    val newExpr =
-      DecimalPrecision.promote(conf.decimalOperationsAllowPrecisionLoss, expr, 
!conf.ansiEnabled)
-    exprToProtoInternal(newExpr, input)
+  def in(
+      expr: Expression,
+      value: Expression,
+      list: Seq[Expression],
+      inputs: Seq[Attribute],
+      binding: Boolean,
+      negate: Boolean): Option[Expr] = {
+    val valueExpr = exprToProtoInternal(value, inputs, binding)
+    val listExprs = list.map(exprToProtoInternal(_, inputs, binding))
+    if (valueExpr.isDefined && listExprs.forall(_.isDefined)) {
+      val builder = ExprOuterClass.In.newBuilder()
+      builder.setInValue(valueExpr.get)
+      builder.addAllLists(listExprs.map(_.get).asJava)
+      builder.setNegated(negate)
+      Some(
+        ExprOuterClass.Expr
+          .newBuilder()
+          .setIn(builder)
+          .build())
+    } else {
+      val allExprs = list ++ Seq(value)
+      withInfo(expr, allExprs: _*)
+      None
+    }
   }
 
   def scalarExprToProtoWithReturnType(
@@ -3364,3 +3448,29 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde with CometExprShim
     nativeScanBuilder.addFilePartitions(partitionBuilder.build())
   }
 }
+
+/**
+ * Trait for providing serialization logic for expressions.
+ */
+trait CometExpressionSerde {
+
+  /**
+   * Convert a Spark expression into a protocol buffer representation that can 
be passed into
+   * native code.
+   *
+   * @param expr
+   *   The Spark expression.
+   * @param inputs
+   *   The input attributes.
+   * @param binding
+   *   Whether the attributes are bound (this is only relevant in aggregate 
expressions).
+   * @return
+   *   Protocol buffer representation, or None if the expression could not be 
converted. In this
+   *   case it is expected that the input expression will have been tagged 
with reasons why it
+   *   could not be converted.
+   */
+  def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr]
+}
diff --git a/spark/src/main/scala/org/apache/comet/serde/arrays.scala 
b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
index ecb4f5a13..9058a641e 100644
--- a/spark/src/main/scala/org/apache/comet/serde/arrays.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/arrays.scala
@@ -19,18 +19,16 @@
 
 package org.apache.comet.serde
 
-import org.apache.spark.sql.catalyst.expressions.{ArrayRemove, Expression}
+import org.apache.spark.sql.catalyst.expressions.{ArrayRemove, Attribute, 
Expression}
 import org.apache.spark.sql.types.{ArrayType, DataType, DataTypes, 
DecimalType, StructType}
 
 import org.apache.comet.CometSparkSessionExtensions.withInfo
+import org.apache.comet.serde.QueryPlanSerde.createBinaryExpr
 import org.apache.comet.shims.CometExprShim
 
-trait CometExpression {
-  def checkSupport(expr: Expression): Boolean
-}
-
-object CometArrayRemove extends CometExpression with CometExprShim {
+object CometArrayRemove extends CometExpressionSerde with CometExprShim {
 
+  /** Exposed for unit testing */
   def isTypeSupported(dt: DataType): Boolean = {
     import DataTypes._
     dt match {
@@ -46,15 +44,24 @@ object CometArrayRemove extends CometExpression with 
CometExprShim {
     }
   }
 
-  override def checkSupport(expr: Expression): Boolean = {
+  override def convert(
+      expr: Expression,
+      inputs: Seq[Attribute],
+      binding: Boolean): Option[ExprOuterClass.Expr] = {
     val ar = expr.asInstanceOf[ArrayRemove]
     val inputTypes: Set[DataType] = ar.children.map(_.dataType).toSet
     for (dt <- inputTypes) {
       if (!isTypeSupported(dt)) {
         withInfo(expr, s"data type not supported: $dt")
-        return false
+        return None
       }
     }
-    true
+    createBinaryExpr(
+      expr,
+      expr.children(0),
+      expr.children(1),
+      inputs,
+      binding,
+      (builder, binaryExpr) => builder.setArrayRemove(binaryExpr))
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@datafusion.apache.org
For additional commands, e-mail: commits-h...@datafusion.apache.org

Reply via email to