This is an automated email from the ASF dual-hosted git repository.

chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 6c86537532 [GLUTEN-10966][VL] Support codegen for ArrowProjection 
(#10968)
6c86537532 is described below

commit 6c86537532362c8eaa0ae23fa8e065e76bd652c0
Author: Zhen Wang <[email protected]>
AuthorDate: Wed Oct 29 19:00:37 2025 +0800

    [GLUTEN-10966][VL] Support codegen for ArrowProjection (#10968)
---
 .../apache/gluten/expression/ArrowProjection.scala |  16 ++-
 .../expression/GenerateArrowProjection.scala       | 158 +++++++++++++++++++++
 2 files changed, 171 insertions(+), 3 deletions(-)

diff --git 
a/gluten-arrow/src/main/scala/org/apache/gluten/expression/ArrowProjection.scala
 
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/ArrowProjection.scala
index 3216d4f3f9..e3bd6e35db 100644
--- 
a/gluten-arrow/src/main/scala/org/apache/gluten/expression/ArrowProjection.scala
+++ 
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/ArrowProjection.scala
@@ -19,8 +19,9 @@ package org.apache.gluten.expression
 import org.apache.gluten.vectorized.ArrowColumnarRow
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
Expression, ExpressionsEvaluator}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
CodeGeneratorWithInterpretedFallback, Expression, ExpressionsEvaluator}
 import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{DataType, StructType}
 
 // Not thread safe.
@@ -32,7 +33,16 @@ abstract class ArrowProjection extends (InternalRow => 
ArrowColumnarRow) with Ex
 }
 
 /** The factory object for `ArrowProjection`. */
-object ArrowProjection {
+object ArrowProjection
+  extends CodeGeneratorWithInterpretedFallback[Seq[Expression], 
ArrowProjection] {
+
+  override protected def createCodeGeneratedObject(in: Seq[Expression]): 
ArrowProjection = {
+    GenerateArrowProjection.generate(in, 
SQLConf.get.subexpressionEliminationEnabled)
+  }
+
+  override protected def createInterpretedObject(in: Seq[Expression]): 
ArrowProjection = {
+    InterpretedArrowProjection.createProjection(in)
+  }
 
   /**
    * Returns an ArrowProjection for given StructType.
@@ -52,7 +62,7 @@ object ArrowProjection {
 
   /** Returns an ArrowProjection for given sequence of bound Expressions. */
   def create(exprs: Seq[Expression]): ArrowProjection = {
-    InterpretedArrowProjection.createProjection(exprs)
+    createObject(exprs)
   }
 
   def create(expr: Expression): ArrowProjection = create(Seq(expr))
diff --git 
a/gluten-arrow/src/main/scala/org/apache/gluten/expression/GenerateArrowProjection.scala
 
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/GenerateArrowProjection.scala
new file mode 100644
index 0000000000..71a2c4db08
--- /dev/null
+++ 
b/gluten-arrow/src/main/scala/org/apache/gluten/expression/GenerateArrowProjection.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.expression
+
+import org.apache.gluten.vectorized.ArrowColumnarRow
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
GenericInternalRow}
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
+import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp
+import org.apache.spark.sql.catalyst.expressions.codegen._
+
+// ArrowProjection is not accessible in Java
+abstract class BaseArrowProjection extends ArrowProjection
+
+/**
+ * Generates byte code that produces a [[InternalRow]] object that can update 
itself based on a new
+ * input [[InternalRow]] for a fixed set of [[Expression Expressions]]. It 
exposes a `target`
+ * method, which is used to set the row that will be updated. The internal 
[[InternalRow]] object
+ * created internally is used only when `target` is not used.
+ */
+object GenerateArrowProjection extends CodeGenerator[Seq[Expression], 
ArrowProjection] {
+
+  protected def canonicalize(in: Seq[Expression]): Seq[Expression] =
+    in.map(ExpressionCanonicalizer.execute)
+
+  protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): 
Seq[Expression] =
+    bindReferences(in, inputSchema)
+
+  def generate(
+      expressions: Seq[Expression],
+      inputSchema: Seq[Attribute],
+      useSubexprElimination: Boolean): ArrowProjection = {
+    create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination)
+  }
+
+  def generate(expressions: Seq[Expression], useSubexprElimination: Boolean): 
ArrowProjection = {
+    create(canonicalize(expressions), useSubexprElimination)
+  }
+
+  protected def create(expressions: Seq[Expression]): ArrowProjection = {
+    create(expressions, false)
+  }
+
+  private def create(
+      expressions: Seq[Expression],
+      useSubexprElimination: Boolean): ArrowProjection = {
+    val ctx = newCodeGenContext()
+    val validExpr = expressions.zipWithIndex.filter {
+      case (NoOp, _) => false
+      case _ => true
+    }
+    val exprVals = ctx.generateExpressions(validExpr.map(_._1), 
useSubexprElimination)
+
+    // 4-tuples: (code for projection, isNull variable name, value variable 
name, column index)
+    val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map {
+      case ((e, i), ev) =>
+        val value = JavaCode
+          .global(ctx.addMutableState(CodeGenerator.javaType(e.dataType), 
"value"), e.dataType)
+        val (code, isNull) = if (e.nullable) {
+          val isNull = ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, 
"isNull")
+          (
+            s"""
+               |${ev.code}
+               |$isNull = ${ev.isNull};
+               |$value = ${ev.value};
+            """.stripMargin,
+            JavaCode.isNullGlobal(isNull))
+        } else {
+          (
+            s"""
+               |${ev.code}
+               |$value = ${ev.value};
+            """.stripMargin,
+            FalseLiteral)
+        }
+        // update value into intermediate
+        val update = CodeGenerator
+          .updateColumn("intermediate", e.dataType, i, ExprCode(isNull, 
value), e.nullable)
+        (code, update)
+    }
+
+    // Evaluate all the subexpressions.
+    val evalSubexpr = ctx.subexprFunctionsCode
+
+    val allProjections = 
ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._1))
+    val allUpdates = 
ctx.splitExpressionsWithCurrentInputs(projectionCodes.map(_._2))
+
+    val codeBody = s"""
+      public java.lang.Object generate(Object[] references) {
+        return new SpecificArrowProjection(references);
+      }
+
+      class SpecificArrowProjection extends 
${classOf[BaseArrowProjection].getName} {
+
+        private Object[] references;
+        private ${classOf[ArrowColumnarRow].getName} mutableRow;
+        private ${classOf[GenericInternalRow].getName} intermediate;
+        ${ctx.declareMutableStates()}
+
+        public SpecificArrowProjection(Object[] references) {
+          this.references = references;
+          mutableRow = null;
+          intermediate = new 
${classOf[GenericInternalRow].getName}(${expressions.size});
+          ${ctx.initMutableStates()}
+        }
+
+        public void initialize(int partitionIndex) {
+          ${ctx.initPartition()}
+        }
+
+        public ${classOf[BaseArrowProjection].getName} target(
+            ${classOf[ArrowColumnarRow].getName} row) {
+          mutableRow = row;
+          return this;
+        }
+
+        /* Provide immutable access to the last projected row. */
+        public ${classOf[ArrowColumnarRow].getName} currentValue() {
+          return (${classOf[ArrowColumnarRow].getName}) mutableRow;
+        }
+
+        public java.lang.Object apply(java.lang.Object _i) {
+          InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
+          $evalSubexpr
+          $allProjections
+          // copy all the results into intermediate
+          $allUpdates
+          // write intermediate to mutableRow
+          mutableRow.writeRow(intermediate);
+          return mutableRow;
+        }
+
+        ${ctx.declareAddedFunctions()}
+      }
+    """
+
+    val code = CodeFormatter.stripOverlappingComments(
+      new CodeAndComment(codeBody, ctx.getPlaceHolderToComments()))
+    logDebug(s"code for 
${expressions.mkString(",")}:\n${CodeFormatter.format(code)}")
+
+    val (clazz, _) = CodeGenerator.compile(code)
+    clazz.generate(ctx.references.toArray).asInstanceOf[ArrowProjection]
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to