This is an automated email from the ASF dual-hosted git repository.

hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new 8cb3e70b9d [GLUTEN-7313][VL] Explicit Arrow transitions, part 4: 
explicit Arrow-to-Velox transition (#7392)
8cb3e70b9d is described below

commit 8cb3e70b9d835a6ee06b0bb8c83679711dc7d99a
Author: Hongze Zhang <[email protected]>
AuthorDate: Tue Oct 8 11:01:53 2024 +0800

    [GLUTEN-7313][VL] Explicit Arrow transitions, part 4: explicit 
Arrow-to-Velox transition (#7392)
---
 .../VeloxColumnarBatchJniWrapper.java              | 26 ++-----
 .../gluten/columnarbatch/VeloxColumnarBatches.java | 85 +++++++++++++++++++++
 .../apache/gluten/columnarbatch/VeloxBatch.scala   |  6 +-
 .../ArrowColumnarToVeloxColumnarExec.scala}        | 27 ++++---
 .../spark/sql/execution/utils/ExecUtil.scala       | 13 ++--
 .../gluten/columnarbatch/ColumnarBatchTest.java    | 39 +++++++++-
 .../columnar/transition/VeloxTransitionSuite.scala | 28 ++++---
 cpp/core/compute/Runtime.h                         |  2 +-
 cpp/core/jni/JniWrapper.cc                         | 21 ------
 cpp/core/memory/ColumnarBatch.cc                   | 86 ----------------------
 cpp/core/memory/ColumnarBatch.h                    | 33 ---------
 cpp/velox/benchmarks/common/BenchmarkUtils.h       |  8 +-
 cpp/velox/compute/VeloxRuntime.cc                  |  2 +-
 cpp/velox/compute/VeloxRuntime.h                   |  2 +-
 cpp/velox/jni/VeloxJniWrapper.cc                   | 38 ++++++++++
 cpp/velox/memory/VeloxColumnarBatch.cc             | 72 +++++++++++-------
 cpp/velox/memory/VeloxColumnarBatch.h              | 12 ++-
 cpp/velox/shuffle/VeloxHashShuffleWriter.cc        | 16 ++--
 cpp/velox/shuffle/VeloxRssSortShuffleWriter.cc     | 16 ++--
 cpp/velox/shuffle/VeloxSortShuffleWriter.cc        | 17 +++--
 cpp/velox/tests/RuntimeTest.cc                     |  2 +-
 cpp/velox/utils/tests/VeloxShuffleWriterTestBase.h |  8 +-
 .../columnarbatch/ColumnarBatchJniWrapper.java     |  2 -
 .../gluten/columnarbatch/ColumnarBatches.java      | 71 ++++++++----------
 ...rBatchUtil.java => SparkColumnarBatchUtil.java} |  8 +-
 .../apache/gluten/columnarbatch/ArrowBatches.scala | 14 ++--
 .../extension/columnar/transition/Transition.scala | 12 +--
 .../columnar/transition/TransitionGraph.scala      | 49 +++++++++---
 .../extension/columnar/transition/package.scala    |  9 +++
 .../columnar/transition/TransitionSuite.scala      | 10 +--
 30 files changed, 398 insertions(+), 336 deletions(-)

diff --git 
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatchJniWrapper.java
 
b/backends-velox/src/main/java/org/apache/gluten/columnarbatch/VeloxColumnarBatchJniWrapper.java
similarity index 61%
copy from 
gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatchJniWrapper.java
copy to 
backends-velox/src/main/java/org/apache/gluten/columnarbatch/VeloxColumnarBatchJniWrapper.java
index 464ac1aecb..4a86012b55 100644
--- 
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatchJniWrapper.java
+++ 
b/backends-velox/src/main/java/org/apache/gluten/columnarbatch/VeloxColumnarBatchJniWrapper.java
@@ -19,37 +19,21 @@ package org.apache.gluten.columnarbatch;
 import org.apache.gluten.runtime.Runtime;
 import org.apache.gluten.runtime.RuntimeAware;
 
-public class ColumnarBatchJniWrapper implements RuntimeAware {
+public class VeloxColumnarBatchJniWrapper implements RuntimeAware {
   private final Runtime runtime;
 
-  private ColumnarBatchJniWrapper(Runtime runtime) {
+  private VeloxColumnarBatchJniWrapper(Runtime runtime) {
     this.runtime = runtime;
   }
 
-  public static ColumnarBatchJniWrapper create(Runtime runtime) {
-    return new ColumnarBatchJniWrapper(runtime);
+  public static VeloxColumnarBatchJniWrapper create(Runtime runtime) {
+    return new VeloxColumnarBatchJniWrapper(runtime);
   }
 
-  public native long createWithArrowArray(long cSchema, long cArray);
-
-  public native long getForEmptySchema(int numRows);
-
-  public native String getType(long batch);
-
-  public native long numColumns(long batch);
-
-  public native long numRows(long batch);
-
-  public native long numBytes(long batch);
+  public native long from(long batch);
 
   public native long compose(long[] batches);
 
-  public native void exportToArrow(long batch, long cSchema, long cArray);
-
-  public native long select(long batch, int[] columnIndices);
-
-  public native void close(long batch);
-
   @Override
   public long rtHandle() {
     return runtime.getHandle();
diff --git 
a/backends-velox/src/main/java/org/apache/gluten/columnarbatch/VeloxColumnarBatches.java
 
b/backends-velox/src/main/java/org/apache/gluten/columnarbatch/VeloxColumnarBatches.java
new file mode 100644
index 0000000000..36d5a360d0
--- /dev/null
+++ 
b/backends-velox/src/main/java/org/apache/gluten/columnarbatch/VeloxColumnarBatches.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.columnarbatch;
+
+import org.apache.gluten.runtime.Runtime;
+import org.apache.gluten.runtime.Runtimes;
+
+import com.google.common.base.Preconditions;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
+import org.apache.spark.sql.vectorized.SparkColumnarBatchUtil;
+
+import java.util.Arrays;
+import java.util.Objects;
+
+public final class VeloxColumnarBatches {
+  public static final String COMPREHENSIVE_TYPE_VELOX = "velox";
+
+  public static void checkVeloxBatch(ColumnarBatch batch) {
+    final String comprehensiveType = 
ColumnarBatches.getComprehensiveLightBatchType(batch);
+    Preconditions.checkArgument(
+        Objects.equals(comprehensiveType, COMPREHENSIVE_TYPE_VELOX),
+        String.format(
+            "Expected comprehensive batch type %s, but got %s",
+            COMPREHENSIVE_TYPE_VELOX, comprehensiveType));
+  }
+
+  public static void checkNonVeloxBatch(ColumnarBatch batch) {
+    final String comprehensiveType = 
ColumnarBatches.getComprehensiveLightBatchType(batch);
+    Preconditions.checkArgument(
+        !Objects.equals(comprehensiveType, COMPREHENSIVE_TYPE_VELOX),
+        String.format("Comprehensive batch type is already %s", 
COMPREHENSIVE_TYPE_VELOX));
+  }
+
+  public static ColumnarBatch toVeloxBatch(ColumnarBatch input) {
+    checkNonVeloxBatch(input);
+    final Runtime runtime = 
Runtimes.contextInstance("VeloxColumnarBatches#toVeloxBatch");
+    final long handle = ColumnarBatches.getNativeHandle(input);
+    final long outHandle = 
VeloxColumnarBatchJniWrapper.create(runtime).from(handle);
+    final ColumnarBatch output = ColumnarBatches.create(outHandle);
+
+    // Follow input's reference count. This might be optimized using
+    // automatic clean-up or once the extensibility of ColumnarBatch is 
enriched
+    final long refCnt = ColumnarBatches.getRefCntLight(input);
+    final IndicatorVector giv = (IndicatorVector) output.column(0);
+    for (long i = 0; i < (refCnt - 1); i++) {
+      giv.retain();
+    }
+
+    // close the input one
+    for (long i = 0; i < refCnt; i++) {
+      input.close();
+    }
+
+    // Populate new vectors to input.
+    SparkColumnarBatchUtil.transferVectors(output, input);
+
+    return input;
+  }
+
+  /**
+   * Combine multiple columnar batches horizontally, assuming each of them is 
already offloaded.
+   * Otherwise {@link UnsupportedOperationException} will be thrown.
+   */
+  public static ColumnarBatch compose(ColumnarBatch... batches) {
+    final Runtime runtime = 
Runtimes.contextInstance("VeloxColumnarBatches#compose");
+    final long[] handles =
+        
Arrays.stream(batches).mapToLong(ColumnarBatches::getNativeHandle).toArray();
+    final long handle = 
VeloxColumnarBatchJniWrapper.create(runtime).compose(handles);
+    return ColumnarBatches.create(handle);
+  }
+}
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/columnarbatch/VeloxBatch.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/columnarbatch/VeloxBatch.scala
index 0fdf0dc532..0c7600c856 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/columnarbatch/VeloxBatch.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/columnarbatch/VeloxBatch.scala
@@ -16,14 +16,12 @@
  */
 package org.apache.gluten.columnarbatch
 
-import org.apache.gluten.execution.{RowToVeloxColumnarExec, 
VeloxColumnarToRowExec}
+import org.apache.gluten.execution.{ArrowColumnarToVeloxColumnarExec, 
RowToVeloxColumnarExec, VeloxColumnarToRowExec}
 import org.apache.gluten.extension.columnar.transition.{Convention, Transition}
 
 object VeloxBatch extends Convention.BatchType {
   fromRow(RowToVeloxColumnarExec.apply)
   toRow(VeloxColumnarToRowExec.apply)
-  // TODO: Add explicit transitions between Arrow native batch and Velox batch.
-  //  See https://github.com/apache/incubator-gluten/issues/7313.
-  fromBatch(ArrowBatches.ArrowNativeBatch, Transition.empty)
+  fromBatch(ArrowBatches.ArrowNativeBatch, 
ArrowColumnarToVeloxColumnarExec.apply)
   toBatch(ArrowBatches.ArrowNativeBatch, Transition.empty)
 }
diff --git 
a/backends-velox/src/main/scala/org/apache/gluten/columnarbatch/VeloxBatch.scala
 
b/backends-velox/src/main/scala/org/apache/gluten/execution/ArrowColumnarToVeloxColumnarExec.scala
similarity index 51%
copy from 
backends-velox/src/main/scala/org/apache/gluten/columnarbatch/VeloxBatch.scala
copy to 
backends-velox/src/main/scala/org/apache/gluten/execution/ArrowColumnarToVeloxColumnarExec.scala
index 0fdf0dc532..0ab51772e1 100644
--- 
a/backends-velox/src/main/scala/org/apache/gluten/columnarbatch/VeloxBatch.scala
+++ 
b/backends-velox/src/main/scala/org/apache/gluten/execution/ArrowColumnarToVeloxColumnarExec.scala
@@ -14,16 +14,23 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
-package org.apache.gluten.columnarbatch
+package org.apache.gluten.execution
 
-import org.apache.gluten.execution.{RowToVeloxColumnarExec, 
VeloxColumnarToRowExec}
-import org.apache.gluten.extension.columnar.transition.{Convention, Transition}
+import org.apache.gluten.columnarbatch.{VeloxBatch, VeloxColumnarBatches}
+import org.apache.gluten.columnarbatch.ArrowBatches.ArrowNativeBatch
 
-object VeloxBatch extends Convention.BatchType {
-  fromRow(RowToVeloxColumnarExec.apply)
-  toRow(VeloxColumnarToRowExec.apply)
-  // TODO: Add explicit transitions between Arrow native batch and Velox batch.
-  //  See https://github.com/apache/incubator-gluten/issues/7313.
-  fromBatch(ArrowBatches.ArrowNativeBatch, Transition.empty)
-  toBatch(ArrowBatches.ArrowNativeBatch, Transition.empty)
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+case class ArrowColumnarToVeloxColumnarExec(override val child: SparkPlan)
+  extends ColumnarToColumnarExec(ArrowNativeBatch, VeloxBatch) {
+  override protected def mapIterator(in: Iterator[ColumnarBatch]): 
Iterator[ColumnarBatch] = {
+    in.map {
+      b =>
+        val out = VeloxColumnarBatches.toVeloxBatch(b)
+        out
+    }
+  }
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    ArrowColumnarToVeloxColumnarExec(child = newChild)
 }
diff --git 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala
 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala
index 2e8a6a479a..32bac02045 100644
--- 
a/backends-velox/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala
+++ 
b/backends-velox/src/main/scala/org/apache/spark/sql/execution/utils/ExecUtil.scala
@@ -16,7 +16,7 @@
  */
 package org.apache.spark.sql.execution.utils
 
-import org.apache.gluten.columnarbatch.ColumnarBatches
+import org.apache.gluten.columnarbatch.{ColumnarBatches, VeloxColumnarBatches}
 import org.apache.gluten.iterator.Iterators
 import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
 import org.apache.gluten.runtime.Runtimes
@@ -145,13 +145,14 @@ object ExecUtil {
                     val pid = 
rangePartitioner.get.getPartition(partitionKeyExtractor(row))
                     pidVec.putInt(i, pid)
                 }
-                val pidBatch = ColumnarBatches.offload(
-                  ArrowBufferAllocators.contextInstance(),
-                  new ColumnarBatch(Array[ColumnVector](pidVec), cb.numRows))
-                val newHandle = ColumnarBatches.compose(pidBatch, cb)
+                val pidBatch = VeloxColumnarBatches.toVeloxBatch(
+                  ColumnarBatches.offload(
+                    ArrowBufferAllocators.contextInstance(),
+                    new ColumnarBatch(Array[ColumnVector](pidVec), 
cb.numRows)))
+                val newBatch = VeloxColumnarBatches.compose(pidBatch, cb)
                 // Composed batch already hold pidBatch's shared ref, so close 
is safe.
                 ColumnarBatches.forceClose(pidBatch)
-                (0, ColumnarBatches.create(newHandle))
+                (0, newBatch)
             })
         .recyclePayload(p => ColumnarBatches.forceClose(p._2)) // FIXME why 
force close?
         .create()
diff --git 
a/backends-velox/src/test/java/org/apache/gluten/columnarbatch/ColumnarBatchTest.java
 
b/backends-velox/src/test/java/org/apache/gluten/columnarbatch/ColumnarBatchTest.java
index a207a4b326..04221ec3ad 100644
--- 
a/backends-velox/src/test/java/org/apache/gluten/columnarbatch/ColumnarBatchTest.java
+++ 
b/backends-velox/src/test/java/org/apache/gluten/columnarbatch/ColumnarBatchTest.java
@@ -21,6 +21,7 @@ import 
org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators;
 import org.apache.gluten.test.VeloxBackendTestBase;
 import org.apache.gluten.vectorized.ArrowWritableColumnVector;
 
+import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.types.DataTypes;
 import org.apache.spark.sql.types.StructType;
 import org.apache.spark.sql.vectorized.ColumnarBatch;
@@ -128,6 +129,40 @@ public class ColumnarBatchTest extends 
VeloxBackendTestBase {
         });
   }
 
+  @Test
+  public void testCompose() {
+    TaskResources$.MODULE$.runUnsafe(
+        () -> {
+          final int numRows = 20;
+          final ColumnarBatch batch1 = newArrowBatch("a boolean, b int", 
numRows);
+          final ColumnarBatch batch2 = newArrowBatch("b int, a boolean", 
numRows);
+          final ArrowWritableColumnVector col0 = (ArrowWritableColumnVector) 
batch1.column(0);
+          final ArrowWritableColumnVector col1 = (ArrowWritableColumnVector) 
batch1.column(1);
+          final ArrowWritableColumnVector col2 = (ArrowWritableColumnVector) 
batch2.column(0);
+          final ArrowWritableColumnVector col3 = (ArrowWritableColumnVector) 
batch2.column(1);
+          for (int j = 0; j < numRows; j++) {
+            col0.putBoolean(j, j % 2 == 0);
+            col1.putInt(j, 15 - j);
+            col2.putInt(j, 15 - j);
+            col3.putBoolean(j, j % 2 == 0);
+          }
+          ColumnarBatches.offload(ArrowBufferAllocators.contextInstance(), 
batch1);
+          ColumnarBatches.offload(ArrowBufferAllocators.contextInstance(), 
batch2);
+          VeloxColumnarBatches.toVeloxBatch(batch1);
+          VeloxColumnarBatches.toVeloxBatch(batch2);
+          final ColumnarBatch batch3 = VeloxColumnarBatches.compose(batch1, 
batch2);
+          Assert.assertEquals(
+              VeloxColumnarBatches.COMPREHENSIVE_TYPE_VELOX,
+              ColumnarBatches.getComprehensiveLightBatchType(batch3));
+
+          Assert.assertEquals(numRows, batch3.numRows());
+          Assert.assertEquals(4, batch3.numCols());
+          Assert.assertEquals(
+              "[false,14,14,false]\n[true,13,13,true]", 
ColumnarBatches.toString(batch3, 1, 2));
+          return null;
+        });
+  }
+
   @Test
   public void testToString() {
     TaskResources$.MODULE$.runUnsafe(
@@ -146,7 +181,9 @@ public class ColumnarBatchTest extends VeloxBackendTestBase 
{
           structType = structType.add("b", DataTypes.IntegerType, true);
           ColumnarBatch veloxBatch =
               RowToVeloxColumnarExec.toColumnarBatchIterator(
-                      JavaConverters.asScalaIterator(batch.rowIterator()), 
structType, numRows)
+                      
JavaConverters.<InternalRow>asScalaIterator(batch.rowIterator()),
+                      structType,
+                      numRows)
                   .next();
           Assert.assertEquals("[true,15]\n[false,14]", 
ColumnarBatches.toString(veloxBatch, 0, 2));
           Assert.assertEquals(
diff --git 
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala
 
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala
index 24540743e1..d12faae0f7 100644
--- 
a/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala
+++ 
b/backends-velox/src/test/scala/org/apache/gluten/extension/columnar/transition/VeloxTransitionSuite.scala
@@ -19,7 +19,7 @@ package org.apache.gluten.extension.columnar.transition
 import org.apache.gluten.backendsapi.velox.VeloxListenerApi
 import org.apache.gluten.columnarbatch.ArrowBatches.{ArrowJavaBatch, 
ArrowNativeBatch}
 import org.apache.gluten.columnarbatch.VeloxBatch
-import org.apache.gluten.execution.{LoadArrowDataExec, OffloadArrowDataExec, 
RowToVeloxColumnarExec, VeloxColumnarToRowExec}
+import org.apache.gluten.execution.{ArrowColumnarToVeloxColumnarExec, 
LoadArrowDataExec, OffloadArrowDataExec, RowToVeloxColumnarExec, 
VeloxColumnarToRowExec}
 import 
org.apache.gluten.extension.columnar.transition.Convention.BatchType.VanillaBatch
 import org.apache.gluten.test.MockVeloxBackend
 
@@ -52,21 +52,21 @@ class VeloxTransitionSuite extends SharedSparkSession {
   test("ArrowNative C2R - outputs row") {
     val in = BatchLeaf(ArrowNativeBatch)
     val out = Transitions.insertTransitions(in, outputsColumnar = false)
-    assert(out == VeloxColumnarToRowExec(BatchLeaf(ArrowNativeBatch)))
+    assert(out == 
ColumnarToRowExec(LoadArrowDataExec(BatchLeaf(ArrowNativeBatch))))
   }
 
   test("ArrowNative C2R - requires row input") {
     val in = RowUnary(BatchLeaf(ArrowNativeBatch))
     val out = Transitions.insertTransitions(in, outputsColumnar = false)
-    assert(out == 
RowUnary(VeloxColumnarToRowExec(BatchLeaf(ArrowNativeBatch))))
+    assert(out == 
RowUnary(ColumnarToRowExec(LoadArrowDataExec(BatchLeaf(ArrowNativeBatch)))))
   }
 
   test("ArrowNative R2C - requires Arrow input") {
     val in = BatchUnary(ArrowNativeBatch, RowLeaf())
     val out = Transitions.insertTransitions(in, outputsColumnar = false)
     assert(
-      out == VeloxColumnarToRowExec(
-        BatchUnary(ArrowNativeBatch, RowToVeloxColumnarExec(RowLeaf()))))
+      out == ColumnarToRowExec(
+        LoadArrowDataExec(BatchUnary(ArrowNativeBatch, 
RowToVeloxColumnarExec(RowLeaf())))))
   }
 
   test("ArrowNative-to-Velox C2C") {
@@ -75,23 +75,27 @@ class VeloxTransitionSuite extends SharedSparkSession {
     // No explicit transition needed for ArrowNative-to-Velox.
     // FIXME: Add explicit transitions.
     //  See https://github.com/apache/incubator-gluten/issues/7313.
-    assert(out == VeloxColumnarToRowExec(BatchUnary(VeloxBatch, 
BatchLeaf(ArrowNativeBatch))))
+    assert(
+      out == VeloxColumnarToRowExec(
+        BatchUnary(VeloxBatch, 
ArrowColumnarToVeloxColumnarExec(BatchLeaf(ArrowNativeBatch)))))
   }
 
   test("Velox-to-ArrowNative C2C") {
     val in = BatchUnary(ArrowNativeBatch, BatchLeaf(VeloxBatch))
     val out = Transitions.insertTransitions(in, outputsColumnar = false)
-    assert(out == VeloxColumnarToRowExec(BatchUnary(ArrowNativeBatch, 
BatchLeaf(VeloxBatch))))
+    assert(
+      out == ColumnarToRowExec(
+        LoadArrowDataExec(BatchUnary(ArrowNativeBatch, 
BatchLeaf(VeloxBatch)))))
   }
 
   test("Vanilla-to-ArrowNative C2C") {
     val in = BatchUnary(ArrowNativeBatch, BatchLeaf(VanillaBatch))
     val out = Transitions.insertTransitions(in, outputsColumnar = false)
     assert(
-      out == VeloxColumnarToRowExec(
-        BatchUnary(
+      out == ColumnarToRowExec(
+        LoadArrowDataExec(BatchUnary(
           ArrowNativeBatch,
-          RowToVeloxColumnarExec(ColumnarToRowExec(BatchLeaf(VanillaBatch))))))
+          
RowToVeloxColumnarExec(ColumnarToRowExec(BatchLeaf(VanillaBatch)))))))
   }
 
   test("ArrowNative-to-Vanilla C2C") {
@@ -127,7 +131,9 @@ class VeloxTransitionSuite extends SharedSparkSession {
     val out = Transitions.insertTransitions(in, outputsColumnar = false)
     assert(
       out == VeloxColumnarToRowExec(
-        BatchUnary(VeloxBatch, 
OffloadArrowDataExec(BatchLeaf(ArrowJavaBatch)))))
+        BatchUnary(
+          VeloxBatch,
+          
ArrowColumnarToVeloxColumnarExec(OffloadArrowDataExec(BatchLeaf(ArrowJavaBatch))))))
   }
 
   test("Velox-to-ArrowJava C2C") {
diff --git a/cpp/core/compute/Runtime.h b/cpp/core/compute/Runtime.h
index 8bdf95cd73..60be778ef1 100644
--- a/cpp/core/compute/Runtime.h
+++ b/cpp/core/compute/Runtime.h
@@ -89,7 +89,7 @@ class Runtime : public std::enable_shared_from_this<Runtime> {
 
   virtual std::shared_ptr<ColumnarBatch> createOrGetEmptySchemaBatch(int32_t 
numRows) = 0;
 
-  virtual std::shared_ptr<ColumnarBatch> 
select(std::shared_ptr<ColumnarBatch>, std::vector<int32_t>) = 0;
+  virtual std::shared_ptr<ColumnarBatch> 
select(std::shared_ptr<ColumnarBatch>, const std::vector<int32_t>&) = 0;
 
   virtual MemoryManager* memoryManager() {
     return memoryManager_.get();
diff --git a/cpp/core/jni/JniWrapper.cc b/cpp/core/jni/JniWrapper.cc
index 8409001cee..13ea8492cb 100644
--- a/cpp/core/jni/JniWrapper.cc
+++ b/cpp/core/jni/JniWrapper.cc
@@ -667,27 +667,6 @@ JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_columnarbatch_ColumnarBatchJniWra
   JNI_METHOD_END(kInvalidObjectHandle)
 }
 
-JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_columnarbatch_ColumnarBatchJniWrapper_compose( // NOLINT
-    JNIEnv* env,
-    jobject wrapper,
-    jlongArray batchHandles) {
-  JNI_METHOD_START
-  auto ctx = gluten::getRuntime(env, wrapper);
-
-  int handleCount = env->GetArrayLength(batchHandles);
-  auto safeArray = gluten::getLongArrayElementsSafe(env, batchHandles);
-
-  std::vector<std::shared_ptr<ColumnarBatch>> batches;
-  for (int i = 0; i < handleCount; ++i) {
-    int64_t handle = safeArray.elems()[i];
-    auto batch = ObjectStore::retrieve<ColumnarBatch>(handle);
-    batches.push_back(batch);
-  }
-  auto newBatch = CompositeColumnarBatch::create(std::move(batches));
-  return ctx->saveObject(newBatch);
-  JNI_METHOD_END(kInvalidObjectHandle)
-}
-
 JNIEXPORT void JNICALL 
Java_org_apache_gluten_columnarbatch_ColumnarBatchJniWrapper_exportToArrow( // 
NOLINT
     JNIEnv* env,
     jobject wrapper,
diff --git a/cpp/core/memory/ColumnarBatch.cc b/cpp/core/memory/ColumnarBatch.cc
index 66834c7e39..3662e08605 100644
--- a/cpp/core/memory/ColumnarBatch.cc
+++ b/cpp/core/memory/ColumnarBatch.cc
@@ -127,90 +127,4 @@ std::vector<char> 
ArrowCStructColumnarBatch::toUnsafeRow(int32_t rowId) const {
   throw gluten::GlutenException("#toUnsafeRow of ArrowCStructColumnarBatch is 
not implemented");
 }
 
-std::shared_ptr<ColumnarBatch> 
CompositeColumnarBatch::create(std::vector<std::shared_ptr<ColumnarBatch>> 
batches) {
-  int32_t numRows = -1;
-  int32_t numColumns = 0;
-  for (const auto& batch : batches) {
-    if (numRows == -1) {
-      numRows = batch->numRows();
-    } else if (batch->numRows() != numRows) {
-      throw GlutenException("Mismatched row counts among the input batches 
during creating CompositeColumnarBatch");
-    }
-    numColumns += batch->numColumns();
-  }
-  return std::shared_ptr<ColumnarBatch>(new CompositeColumnarBatch(numColumns, 
numRows, std::move(batches)));
-}
-
-std::string CompositeColumnarBatch::getType() const {
-  return "composite";
-}
-
-int64_t CompositeColumnarBatch::numBytes() {
-  if (compositeBatch_) {
-    return compositeBatch_->numBytes();
-  } else {
-    int64_t numBytes = 0L;
-    for (const auto& batch : batches_) {
-      numBytes += batch->numBytes();
-    }
-    return numBytes;
-  }
-}
-
-std::shared_ptr<ArrowArray> CompositeColumnarBatch::exportArrowArray() {
-  ensureUnderlyingBatchCreated();
-  return compositeBatch_->exportArrowArray();
-}
-
-std::shared_ptr<ArrowSchema> CompositeColumnarBatch::exportArrowSchema() {
-  ensureUnderlyingBatchCreated();
-  return compositeBatch_->exportArrowSchema();
-}
-
-const std::vector<std::shared_ptr<ColumnarBatch>>& 
CompositeColumnarBatch::getBatches() const {
-  return batches_;
-}
-
-std::vector<char> CompositeColumnarBatch::toUnsafeRow(int32_t rowId) const {
-  throw gluten::GlutenException("#toUnsafeRow of CompositeColumnarBatch is not 
implemented");
-}
-
-CompositeColumnarBatch::CompositeColumnarBatch(
-    int32_t numColumns,
-    int32_t numRows,
-    std::vector<std::shared_ptr<ColumnarBatch>> batches)
-    : ColumnarBatch(numColumns, numRows) {
-  this->batches_ = std::move(batches);
-}
-
-void CompositeColumnarBatch::ensureUnderlyingBatchCreated() {
-  if (compositeBatch_ != nullptr) {
-    return;
-  }
-  std::vector<std::shared_ptr<arrow::RecordBatch>> arrowBatches;
-  for (const auto& batch : batches_) {
-    auto cSchema = batch->exportArrowSchema();
-    auto cArray = batch->exportArrowArray();
-    auto arrowBatch = 
gluten::arrowGetOrThrow(arrow::ImportRecordBatch(cArray.get(), cSchema.get()));
-    arrowBatches.push_back(std::move(arrowBatch));
-  }
-
-  std::vector<std::shared_ptr<arrow::Field>> fields;
-  std::vector<std::shared_ptr<arrow::ArrayData>> arrays;
-
-  for (const auto& batch : arrowBatches) {
-    if (batch->schema()->metadata() != nullptr) {
-      throw gluten::GlutenException("Schema metadata not allowed");
-    }
-    for (const auto& field : batch->schema()->fields()) {
-      fields.push_back(field);
-    }
-    for (const auto& col : batch->column_data()) {
-      arrays.push_back(col);
-    }
-  }
-  compositeBatch_ = std::make_shared<ArrowColumnarBatch>(
-      arrow::RecordBatch::Make(std::make_shared<arrow::Schema>(fields), 
numRows(), arrays));
-}
-
 } // namespace gluten
diff --git a/cpp/core/memory/ColumnarBatch.h b/cpp/core/memory/ColumnarBatch.h
index 2495de2485..e0bab25418 100644
--- a/cpp/core/memory/ColumnarBatch.h
+++ b/cpp/core/memory/ColumnarBatch.h
@@ -103,39 +103,6 @@ class ArrowCStructColumnarBatch final : public 
ColumnarBatch {
   std::shared_ptr<ArrowArray> cArray_ = std::make_shared<ArrowArray>();
 };
 
-/**
- * A columnar batch implementations that creates a view on top of a couple of 
child batches.
- * Fields in the child batches will be organized horizontally in the parent 
batch.
- */
-class CompositeColumnarBatch final : public ColumnarBatch {
- public:
-  static std::shared_ptr<ColumnarBatch> 
create(std::vector<std::shared_ptr<ColumnarBatch>> batches);
-
-  std::string getType() const override;
-
-  int64_t numBytes() override;
-
-  std::shared_ptr<ArrowArray> exportArrowArray() override;
-
-  std::shared_ptr<ArrowSchema> exportArrowSchema() override;
-
-  const std::vector<std::shared_ptr<ColumnarBatch>>& getBatches() const;
-
-  std::vector<char> toUnsafeRow(int32_t rowId) const override;
-
- private:
-  explicit CompositeColumnarBatch(
-      int32_t numColumns,
-      int32_t numRows,
-      std::vector<std::shared_ptr<ColumnarBatch>> batches);
-
-  // We use ArrowColumnarBatch as the way to compose columnar batches
-  void ensureUnderlyingBatchCreated();
-
-  std::vector<std::shared_ptr<ColumnarBatch>> batches_;
-  std::shared_ptr<ColumnarBatch> compositeBatch_ = nullptr;
-};
-
 std::shared_ptr<ColumnarBatch> createZeroColumnBatch(int32_t numRows);
 
 } // namespace gluten
diff --git a/cpp/velox/benchmarks/common/BenchmarkUtils.h 
b/cpp/velox/benchmarks/common/BenchmarkUtils.h
index 4f4f5b75df..0108f1d448 100644
--- a/cpp/velox/benchmarks/common/BenchmarkUtils.h
+++ b/cpp/velox/benchmarks/common/BenchmarkUtils.h
@@ -89,13 +89,7 @@ bool checkPathExists(const std::string& filepath);
 void abortIfFileNotExists(const std::string& filepath);
 
 inline std::shared_ptr<gluten::ColumnarBatch> 
convertBatch(std::shared_ptr<gluten::ColumnarBatch> cb) {
-  if (cb->getType() != "velox") {
-    auto vp = facebook::velox::importFromArrowAsOwner(
-        *cb->exportArrowSchema(), *cb->exportArrowArray(), 
gluten::defaultLeafVeloxMemoryPool().get());
-    return 
std::make_shared<gluten::VeloxColumnarBatch>(std::dynamic_pointer_cast<facebook::velox::RowVector>(vp));
-  } else {
-    return cb;
-  }
+  return 
gluten::VeloxColumnarBatch::from(gluten::defaultLeafVeloxMemoryPool().get(), 
cb);
 }
 
 /// Return whether the data ends with suffix.
diff --git a/cpp/velox/compute/VeloxRuntime.cc 
b/cpp/velox/compute/VeloxRuntime.cc
index 996fcb8507..5b93fb7d53 100644
--- a/cpp/velox/compute/VeloxRuntime.cc
+++ b/cpp/velox/compute/VeloxRuntime.cc
@@ -174,7 +174,7 @@ std::shared_ptr<ColumnarBatch> 
VeloxRuntime::createOrGetEmptySchemaBatch(int32_t
 
 std::shared_ptr<ColumnarBatch> VeloxRuntime::select(
     std::shared_ptr<ColumnarBatch> batch,
-    std::vector<int32_t> columnIndices) {
+    const std::vector<int32_t>& columnIndices) {
   auto veloxPool = vmm_->getLeafMemoryPool();
   auto veloxBatch = gluten::VeloxColumnarBatch::from(veloxPool.get(), batch);
   auto outputBatch = veloxBatch->select(veloxPool.get(), 
std::move(columnIndices));
diff --git a/cpp/velox/compute/VeloxRuntime.h b/cpp/velox/compute/VeloxRuntime.h
index 3460677d91..74f59639d7 100644
--- a/cpp/velox/compute/VeloxRuntime.h
+++ b/cpp/velox/compute/VeloxRuntime.h
@@ -53,7 +53,7 @@ class VeloxRuntime final : public Runtime {
 
   std::shared_ptr<ColumnarBatch> createOrGetEmptySchemaBatch(int32_t numRows) 
override;
 
-  std::shared_ptr<ColumnarBatch> select(std::shared_ptr<ColumnarBatch> batch, 
std::vector<int32_t> columnIndices)
+  std::shared_ptr<ColumnarBatch> select(std::shared_ptr<ColumnarBatch> batch, 
const std::vector<int32_t>& columnIndices)
       override;
 
   std::shared_ptr<RowToColumnarConverter> createRow2ColumnarConverter(struct 
ArrowSchema* cSchema) override;
diff --git a/cpp/velox/jni/VeloxJniWrapper.cc b/cpp/velox/jni/VeloxJniWrapper.cc
index 37c90643a5..22136ad297 100644
--- a/cpp/velox/jni/VeloxJniWrapper.cc
+++ b/cpp/velox/jni/VeloxJniWrapper.cc
@@ -28,6 +28,7 @@
 #include "config/GlutenConfig.h"
 #include "jni/JniError.h"
 #include "jni/JniFileSystem.h"
+#include "memory/VeloxColumnarBatch.h"
 #include "memory/VeloxMemoryManager.h"
 #include "substrait/SubstraitToVeloxPlanValidator.h"
 #include "utils/ObjectStore.h"
@@ -152,6 +153,43 @@ 
Java_org_apache_gluten_vectorized_PlanEvaluatorJniWrapper_nativeValidateWithFail
   JNI_METHOD_END(nullptr)
 }
 
+JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_columnarbatch_VeloxColumnarBatchJniWrapper_from( // 
NOLINT
+    JNIEnv* env,
+    jobject wrapper,
+    jlong handle) {
+  JNI_METHOD_START
+  auto ctx = gluten::getRuntime(env, wrapper);
+  auto runtime = dynamic_cast<gluten::VeloxRuntime*>(ctx);
+
+  auto batch = gluten::ObjectStore::retrieve<gluten::ColumnarBatch>(handle);
+  auto newBatch = 
gluten::VeloxColumnarBatch::from(runtime->memoryManager()->getLeafMemoryPool().get(),
 batch);
+  return ctx->saveObject(newBatch);
+  JNI_METHOD_END(gluten::kInvalidObjectHandle)
+}
+
+JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_columnarbatch_VeloxColumnarBatchJniWrapper_compose( // 
NOLINT
+    JNIEnv* env,
+    jobject wrapper,
+    jlongArray batchHandles) {
+  JNI_METHOD_START
+  auto ctx = gluten::getRuntime(env, wrapper);
+  auto runtime = dynamic_cast<gluten::VeloxRuntime*>(ctx);
+
+  int handleCount = env->GetArrayLength(batchHandles);
+  auto safeArray = gluten::getLongArrayElementsSafe(env, batchHandles);
+
+  std::vector<std::shared_ptr<gluten::ColumnarBatch>> batches;
+  for (int i = 0; i < handleCount; ++i) {
+    int64_t handle = safeArray.elems()[i];
+    auto batch = gluten::ObjectStore::retrieve<gluten::ColumnarBatch>(handle);
+    batches.push_back(batch);
+  }
+  auto newBatch =
+      
gluten::VeloxColumnarBatch::compose(runtime->memoryManager()->getLeafMemoryPool().get(),
 std::move(batches));
+  return ctx->saveObject(newBatch);
+  JNI_METHOD_END(gluten::kInvalidObjectHandle)
+}
+
 JNIEXPORT jlong JNICALL 
Java_org_apache_gluten_utils_VeloxBloomFilterJniWrapper_empty( // NOLINT
     JNIEnv* env,
     jobject wrapper,
diff --git a/cpp/velox/memory/VeloxColumnarBatch.cc 
b/cpp/velox/memory/VeloxColumnarBatch.cc
index 0d8db31272..5b7fba9796 100644
--- a/cpp/velox/memory/VeloxColumnarBatch.cc
+++ b/cpp/velox/memory/VeloxColumnarBatch.cc
@@ -30,17 +30,18 @@ using namespace facebook::velox;
 namespace {
 
 RowVectorPtr makeRowVector(
-    std::vector<std::string> childNames,
-    const std::vector<VectorPtr>& children,
+    velox::memory::MemoryPool* pool,
     int32_t numRows,
-    velox::memory::MemoryPool* pool) {
+    std::vector<std::string> childNames,
+    BufferPtr nulls,
+    const std::vector<VectorPtr>& children) {
   std::vector<std::shared_ptr<const Type>> childTypes;
   childTypes.resize(children.size());
   for (int i = 0; i < children.size(); i++) {
     childTypes[i] = children[i]->type();
   }
   auto rowType = ROW(std::move(childNames), std::move(childTypes));
-  return std::make_shared<RowVector>(pool, rowType, BufferPtr(nullptr), 
numRows, std::move(children));
+  return std::make_shared<RowVector>(pool, rowType, nulls, numRows, 
std::move(children));
 }
 } // namespace
 
@@ -97,34 +98,51 @@ std::shared_ptr<VeloxColumnarBatch> 
VeloxColumnarBatch::from(
   if (cb->getType() == "velox") {
     return std::dynamic_pointer_cast<VeloxColumnarBatch>(cb);
   }
-  if (cb->getType() == "composite") {
-    auto composite = 
std::dynamic_pointer_cast<gluten::CompositeColumnarBatch>(cb);
-    auto children = composite->getBatches();
-    std::vector<std::string> childNames;
-    std::vector<VectorPtr> childVectors;
-
-    for (const auto& child : children) {
-      auto asVelox = from(pool, child);
-      auto names = 
facebook::velox::asRowType(asVelox->getRowVector()->type())->names();
-      for (const auto& name : names) {
-        childNames.push_back(name);
-      }
-      auto vectors = asVelox->getRowVector()->children();
-      for (const auto& vector : vectors) {
-        childVectors.push_back(vector);
-      }
+  auto vp = velox::importFromArrowAsOwner(*cb->exportArrowSchema(), 
*cb->exportArrowArray(), pool);
+  return 
std::make_shared<VeloxColumnarBatch>(std::dynamic_pointer_cast<velox::RowVector>(vp));
+}
+
+std::shared_ptr<VeloxColumnarBatch> VeloxColumnarBatch::compose(
+    facebook::velox::memory::MemoryPool* pool,
+    const std::vector<std::shared_ptr<ColumnarBatch>>& batches) {
+  GLUTEN_CHECK(!batches.empty(), "No batches to compose");
+
+  int32_t numRows = -1;
+  for (const auto& batch : batches) {
+    GLUTEN_CHECK(batch->getType() == kType, "At least one of the input batches 
is not in Velox format");
+    if (numRows == -1) {
+      numRows = batch->numRows();
+      continue;
+    }
+    if (batch->numRows() != numRows) {
+      throw GlutenException("Mismatched row counts among the input batches 
during composing columnar batches");
     }
+    auto vb = std::dynamic_pointer_cast<VeloxColumnarBatch>(batch);
+    auto rv = vb->getRowVector();
+    GLUTEN_CHECK(rv->nulls() == nullptr, "Vectors to compose contain null 
bits");
+  }
+
+  GLUTEN_CHECK(numRows > 0, "Illegal state");
 
-    auto compositeVeloxVector = makeRowVector(childNames, childVectors, 
cb->numRows(), pool);
-    return std::make_shared<VeloxColumnarBatch>(compositeVeloxVector);
+  std::vector<std::string> childNames;
+  std::vector<VectorPtr> children;
+  for (const auto& batch : batches) {
+    auto vb = std::dynamic_pointer_cast<VeloxColumnarBatch>(batch);
+    auto rv = vb->getRowVector();
+    for (const std::string& name : rv->type()->asRow().names()) {
+      childNames.push_back(name);
+    }
+    for (const VectorPtr& vec : rv->children()) {
+      children.push_back(vec);
+    }
   }
-  auto vp = velox::importFromArrowAsOwner(*cb->exportArrowSchema(), 
*cb->exportArrowArray(), pool);
-  return 
std::make_shared<VeloxColumnarBatch>(std::dynamic_pointer_cast<velox::RowVector>(vp));
+  RowVectorPtr outRv = makeRowVector(pool, numRows, std::move(childNames), 
nullptr, std::move(children));
+  return std::make_shared<VeloxColumnarBatch>(outRv);
 }
 
-std::shared_ptr<ColumnarBatch> VeloxColumnarBatch::select(
+std::shared_ptr<VeloxColumnarBatch> VeloxColumnarBatch::select(
     facebook::velox::memory::MemoryPool* pool,
-    std::vector<int32_t> columnIndices) {
+    const std::vector<int32_t>& columnIndices) {
   std::vector<std::string> childNames;
   std::vector<VectorPtr> childVectors;
   childNames.reserve(columnIndices.size());
@@ -139,7 +157,7 @@ std::shared_ptr<ColumnarBatch> VeloxColumnarBatch::select(
     childVectors.push_back(child);
   }
 
-  auto rowVector = makeRowVector(std::move(childNames), 
std::move(childVectors), numRows(), pool);
+  auto rowVector = makeRowVector(pool, numRows(), std::move(childNames), 
vector->nulls(), std::move(childVectors));
   return std::make_shared<VeloxColumnarBatch>(rowVector);
 }
 
diff --git a/cpp/velox/memory/VeloxColumnarBatch.h 
b/cpp/velox/memory/VeloxColumnarBatch.h
index 6c79f2772d..cecadd3cdf 100644
--- a/cpp/velox/memory/VeloxColumnarBatch.h
+++ b/cpp/velox/memory/VeloxColumnarBatch.h
@@ -30,19 +30,25 @@ class VeloxColumnarBatch final : public ColumnarBatch {
       : ColumnarBatch(rowVector->childrenSize(), rowVector->size()), 
rowVector_(rowVector) {}
 
   std::string getType() const override {
-    return "velox";
+    return kType;
   }
 
   static std::shared_ptr<VeloxColumnarBatch> from(
       facebook::velox::memory::MemoryPool* pool,
       std::shared_ptr<ColumnarBatch> cb);
 
+  static std::shared_ptr<VeloxColumnarBatch> compose(
+      facebook::velox::memory::MemoryPool* pool,
+      const std::vector<std::shared_ptr<ColumnarBatch>>& batches);
+
   int64_t numBytes() override;
 
   std::shared_ptr<ArrowSchema> exportArrowSchema() override;
   std::shared_ptr<ArrowArray> exportArrowArray() override;
   std::vector<char> toUnsafeRow(int32_t rowId) const override;
-  std::shared_ptr<ColumnarBatch> select(facebook::velox::memory::MemoryPool* 
pool, std::vector<int32_t> columnIndices);
+  std::shared_ptr<VeloxColumnarBatch> select(
+      facebook::velox::memory::MemoryPool* pool,
+      const std::vector<int32_t>& columnIndices);
   facebook::velox::RowVectorPtr getRowVector() const;
   facebook::velox::RowVectorPtr getFlattenedRowVector();
 
@@ -51,6 +57,8 @@ class VeloxColumnarBatch final : public ColumnarBatch {
 
   facebook::velox::RowVectorPtr rowVector_ = nullptr;
   bool flattened_ = false;
+
+  inline static const std::string kType{"velox"};
 };
 
 } // namespace gluten
diff --git a/cpp/velox/shuffle/VeloxHashShuffleWriter.cc 
b/cpp/velox/shuffle/VeloxHashShuffleWriter.cc
index 4001a91dec..f044736142 100644
--- a/cpp/velox/shuffle/VeloxHashShuffleWriter.cc
+++ b/cpp/velox/shuffle/VeloxHashShuffleWriter.cc
@@ -262,11 +262,11 @@ arrow::Status 
VeloxHashShuffleWriter::write(std::shared_ptr<ColumnarBatch> cb, i
     }
     RETURN_NOT_OK(evictBuffers(0, rv.size(), std::move(buffers), false));
   } else if (options_.partitioning == Partitioning::kRange) {
-    auto compositeBatch = 
std::dynamic_pointer_cast<CompositeColumnarBatch>(cb);
-    VELOX_CHECK_NOT_NULL(compositeBatch);
-    auto batches = compositeBatch->getBatches();
-    VELOX_CHECK_EQ(batches.size(), 2);
-    auto pidBatch = VeloxColumnarBatch::from(veloxPool_.get(), batches[0]);
+    auto veloxColumnBatch = VeloxColumnarBatch::from(veloxPool_.get(), cb);
+    VELOX_CHECK_NOT_NULL(veloxColumnBatch);
+    const int32_t numColumns = veloxColumnBatch->numColumns();
+    VELOX_CHECK(numColumns >= 2);
+    auto pidBatch = veloxColumnBatch->select(veloxPool_.get(), {0});
     auto pidArr = getFirstColumn(*(pidBatch->getRowVector()));
     START_TIMING(cpuWallTimingList_[CpuWallTimingCompute]);
     std::fill(std::begin(partition2RowCount_), std::end(partition2RowCount_), 
0);
@@ -275,7 +275,11 @@ arrow::Status 
VeloxHashShuffleWriter::write(std::shared_ptr<ColumnarBatch> cb, i
       partition2RowCount_[pid]++;
     }
     END_TIMING();
-    auto rvBatch = VeloxColumnarBatch::from(veloxPool_.get(), batches[1]);
+    std::vector<int32_t> range;
+    for (int32_t i = 1; i < numColumns; i++) {
+      range.push_back(i);
+    }
+    auto rvBatch = veloxColumnBatch->select(veloxPool_.get(), range);
     auto& rv = *rvBatch->getFlattenedRowVector();
     RETURN_NOT_OK(initFromRowVector(rv));
     RETURN_NOT_OK(doSplit(rv, memLimit));
diff --git a/cpp/velox/shuffle/VeloxRssSortShuffleWriter.cc 
b/cpp/velox/shuffle/VeloxRssSortShuffleWriter.cc
index a3fdf5e273..34796e378e 100644
--- a/cpp/velox/shuffle/VeloxRssSortShuffleWriter.cc
+++ b/cpp/velox/shuffle/VeloxRssSortShuffleWriter.cc
@@ -70,17 +70,21 @@ arrow::Status 
VeloxRssSortShuffleWriter::write(std::shared_ptr<ColumnarBatch> cb
     RETURN_NOT_OK(initFromRowVector(*rv.get()));
     RETURN_NOT_OK(doSort(rv, 
partitionWriter_.get()->options().sortBufferMaxSize));
   } else if (options_.partitioning == Partitioning::kRange) {
-    auto compositeBatch = 
std::dynamic_pointer_cast<CompositeColumnarBatch>(cb);
-    VELOX_CHECK_NOT_NULL(compositeBatch);
-    auto batches = compositeBatch->getBatches();
-    VELOX_CHECK_EQ(batches.size(), 2);
-    auto pidBatch = VeloxColumnarBatch::from(veloxPool_.get(), batches[0]);
+    auto veloxColumnBatch = VeloxColumnarBatch::from(veloxPool_.get(), cb);
+    VELOX_CHECK_NOT_NULL(veloxColumnBatch);
+    const int32_t numColumns = veloxColumnBatch->numColumns();
+    VELOX_CHECK(numColumns >= 2);
+    auto pidBatch = veloxColumnBatch->select(veloxPool_.get(), {0});
     auto pidArr = getFirstColumn(*(pidBatch->getRowVector()));
     START_TIMING(cpuWallTimingList_[CpuWallTimingCompute]);
     setSortState(SortState::kSort);
     RETURN_NOT_OK(partitioner_->compute(pidArr, pidBatch->numRows(), 
batches_.size(), rowVectorIndexMap_));
     END_TIMING();
-    auto rvBatch = VeloxColumnarBatch::from(veloxPool_.get(), batches[1]);
+    std::vector<int32_t> range;
+    for (int32_t i = 1; i < numColumns; i++) {
+      range.push_back(i);
+    }
+    auto rvBatch = veloxColumnBatch->select(veloxPool_.get(), range);
     auto rv = rvBatch->getFlattenedRowVector();
     RETURN_NOT_OK(initFromRowVector(*rv.get()));
     RETURN_NOT_OK(doSort(rv, 
partitionWriter_.get()->options().sortBufferMaxSize));
diff --git a/cpp/velox/shuffle/VeloxSortShuffleWriter.cc 
b/cpp/velox/shuffle/VeloxSortShuffleWriter.cc
index 55aa739e7a..f87eaabb56 100644
--- a/cpp/velox/shuffle/VeloxSortShuffleWriter.cc
+++ b/cpp/velox/shuffle/VeloxSortShuffleWriter.cc
@@ -131,16 +131,19 @@ void VeloxSortShuffleWriter::initRowType(const 
facebook::velox::RowVectorPtr& rv
 arrow::Result<facebook::velox::RowVectorPtr> 
VeloxSortShuffleWriter::getPeeledRowVector(
     const std::shared_ptr<ColumnarBatch>& cb) {
   if (options_.partitioning == Partitioning::kRange) {
-    auto compositeBatch = 
std::dynamic_pointer_cast<CompositeColumnarBatch>(cb);
-    VELOX_CHECK_NOT_NULL(compositeBatch);
-    auto batches = compositeBatch->getBatches();
-    VELOX_CHECK_EQ(batches.size(), 2);
-
-    auto pidBatch = VeloxColumnarBatch::from(veloxPool_.get(), batches[0]);
+    auto veloxColumnBatch = VeloxColumnarBatch::from(veloxPool_.get(), cb);
+    VELOX_CHECK_NOT_NULL(veloxColumnBatch);
+    const int32_t numColumns = veloxColumnBatch->numColumns();
+    VELOX_CHECK(numColumns >= 2);
+    auto pidBatch = veloxColumnBatch->select(veloxPool_.get(), {0});
     auto pidArr = getFirstColumn(*(pidBatch->getRowVector()));
     RETURN_NOT_OK(partitioner_->compute(pidArr, pidBatch->numRows(), 
row2Partition_));
 
-    auto rvBatch = VeloxColumnarBatch::from(veloxPool_.get(), batches[1]);
+    std::vector<int32_t> range;
+    for (int32_t i = 1; i < numColumns; i++) {
+      range.push_back(i);
+    }
+    auto rvBatch = veloxColumnBatch->select(veloxPool_.get(), range);
     return rvBatch->getFlattenedRowVector();
   } else {
     auto veloxColumnBatch = VeloxColumnarBatch::from(veloxPool_.get(), cb);
diff --git a/cpp/velox/tests/RuntimeTest.cc b/cpp/velox/tests/RuntimeTest.cc
index 1a353816f7..b2aa3e1a86 100644
--- a/cpp/velox/tests/RuntimeTest.cc
+++ b/cpp/velox/tests/RuntimeTest.cc
@@ -89,7 +89,7 @@ class DummyRuntime final : public Runtime {
   std::unique_ptr<ColumnarBatchSerializer> 
createColumnarBatchSerializer(struct ArrowSchema* cSchema) override {
     throw GlutenException("Not yet implemented");
   }
-  std::shared_ptr<ColumnarBatch> select(std::shared_ptr<ColumnarBatch>, 
std::vector<int32_t>) override {
+  std::shared_ptr<ColumnarBatch> select(std::shared_ptr<ColumnarBatch>, const 
std::vector<int32_t>&) override {
     throw GlutenException("Not yet implemented");
   }
   std::string planString(bool details, const std::unordered_map<std::string, 
std::string>& sessionConf) override {
diff --git a/cpp/velox/utils/tests/VeloxShuffleWriterTestBase.h 
b/cpp/velox/utils/tests/VeloxShuffleWriterTestBase.h
index d9a2c1e2ea..f5dd6f4f38 100644
--- a/cpp/velox/utils/tests/VeloxShuffleWriterTestBase.h
+++ b/cpp/velox/utils/tests/VeloxShuffleWriterTestBase.h
@@ -492,13 +492,13 @@ class RangePartitioningShuffleWriter : public 
MultiplePartitioningShuffleWriter
 
     auto pid1 = makeRowVector({makeFlatVector<int32_t>({0, 1, 0, 1, 0, 1, 0, 
1, 0, 1})});
     auto rangeVector1 = makeRowVector(inputVector1_->children());
-    compositeBatch1_ = CompositeColumnarBatch::create(
-        {std::make_shared<VeloxColumnarBatch>(pid1), 
std::make_shared<VeloxColumnarBatch>(rangeVector1)});
+    compositeBatch1_ = VeloxColumnarBatch::compose(
+        pool(), {std::make_shared<VeloxColumnarBatch>(pid1), 
std::make_shared<VeloxColumnarBatch>(rangeVector1)});
 
     auto pid2 = makeRowVector({makeFlatVector<int32_t>({0, 1})});
     auto rangeVector2 = makeRowVector(inputVector2_->children());
-    compositeBatch2_ = CompositeColumnarBatch::create(
-        {std::make_shared<VeloxColumnarBatch>(pid2), 
std::make_shared<VeloxColumnarBatch>(rangeVector2)});
+    compositeBatch2_ = VeloxColumnarBatch::compose(
+        pool(), {std::make_shared<VeloxColumnarBatch>(pid2), 
std::make_shared<VeloxColumnarBatch>(rangeVector2)});
   }
 
   std::shared_ptr<VeloxShuffleWriter> createShuffleWriter(arrow::MemoryPool* 
arrowPool) override {
diff --git 
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatchJniWrapper.java
 
b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatchJniWrapper.java
index 464ac1aecb..94312a2cf5 100644
--- 
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatchJniWrapper.java
+++ 
b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatchJniWrapper.java
@@ -42,8 +42,6 @@ public class ColumnarBatchJniWrapper implements RuntimeAware {
 
   public native long numBytes(long batch);
 
-  public native long compose(long[] batches);
-
   public native void exportToArrow(long batch, long cSchema, long cArray);
 
   public native long select(long batch, int[] columnIndices);
diff --git 
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java
 
b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java
index 3f48d8f293..6b5376d9b2 100644
--- 
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java
+++ 
b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java
@@ -37,7 +37,7 @@ import org.apache.spark.sql.types.StructType;
 import org.apache.spark.sql.utils.SparkArrowUtil;
 import org.apache.spark.sql.vectorized.ColumnVector;
 import org.apache.spark.sql.vectorized.ColumnarBatch;
-import org.apache.spark.sql.vectorized.ColumnarBatchUtil;
+import org.apache.spark.sql.vectorized.SparkColumnarBatchUtil;
 
 import java.util.Arrays;
 import java.util.Iterator;
@@ -45,11 +45,11 @@ import java.util.NoSuchElementException;
 
 import scala.collection.JavaConverters;
 
-public class ColumnarBatches {
+public final class ColumnarBatches {
 
   private ColumnarBatches() {}
 
-  enum BatchType {
+  private enum BatchType {
     LIGHT,
     HEAVY
   }
@@ -173,23 +173,23 @@ public class ColumnarBatches {
       ColumnarBatch output =
           ArrowAbiUtil.importToSparkColumnarBatch(allocator, arrowSchema, 
cArray);
 
-      // Follow gluten input's reference count. This might be optimized using
-      // automatic clean-up or once the extensibility of ColumnarBatch is 
enriched
-      IndicatorVector giv = (IndicatorVector) input.column(0);
-      for (long i = 0; i < (giv.refCnt() - 1); i++) {
+      // Follow input's reference count. This might be optimized using
+      // automatic clean-up or once the extensibility of ColumnarBatch is 
enriched.
+      long refCnt = getRefCntLight(input);
+      for (long i = 0; i < (refCnt - 1); i++) {
         for (int j = 0; j < output.numCols(); j++) {
           final ArrowWritableColumnVector col = (ArrowWritableColumnVector) 
output.column(j);
           col.retain();
         }
       }
 
-      // close the input one
-      for (long i = 0; i < giv.refCnt(); i++) {
+      // Close the input one.
+      for (long i = 0; i < refCnt; i++) {
         input.close();
       }
 
-      // populate new vectors to input
-      ColumnarBatchUtil.transferVectors(output, input);
+      // Populate new vectors to input.
+      SparkColumnarBatchUtil.transferVectors(output, input);
 
       return output;
     }
@@ -212,20 +212,20 @@ public class ColumnarBatches {
       ColumnarBatch output = ColumnarBatches.create(handle);
 
       // Follow input's reference count. This might be optimized using
-      // automatic clean-up or once the extensibility of ColumnarBatch is 
enriched
+      // automatic clean-up or once the extensibility of ColumnarBatch is 
enriched.
       long refCnt = getRefCntHeavy(input);
       final IndicatorVector giv = (IndicatorVector) output.column(0);
       for (long i = 0; i < (refCnt - 1); i++) {
         giv.retain();
       }
 
-      // close the input one
+      // Close the input one.
       for (long i = 0; i < refCnt; i++) {
         input.close();
       }
 
-      // populate new vectors to input
-      ColumnarBatchUtil.transferVectors(output, input);
+      // Populate new vectors to input.
+      SparkColumnarBatchUtil.transferVectors(output, input);
       return input;
     }
   }
@@ -251,7 +251,7 @@ public class ColumnarBatches {
     };
   }
 
-  private static long getRefCntLight(ColumnarBatch input) {
+  static long getRefCntLight(ColumnarBatch input) {
     if (!isLightBatch(input)) {
       throw new UnsupportedOperationException("Input batch is not light 
batch");
     }
@@ -259,7 +259,7 @@ public class ColumnarBatches {
     return iv.refCnt();
   }
 
-  private static long getRefCntHeavy(ColumnarBatch input) {
+  static long getRefCntHeavy(ColumnarBatch input) {
     if (!isHeavyBatch(input)) {
       throw new UnsupportedOperationException("Input batch is not heavy 
batch");
     }
@@ -303,27 +303,6 @@ public class ColumnarBatches {
     }
   }
 
-  private static IndicatorVector getIndicatorVector(ColumnarBatch input) {
-    if (!isLightBatch(input)) {
-      throw new UnsupportedOperationException("Input batch is not light 
batch");
-    }
-    return (IndicatorVector) input.column(0);
-  }
-
-  /**
-   * Combine multiple columnar batches horizontally, assuming each of them is 
already offloaded.
-   * Otherwise {@link UnsupportedOperationException} will be thrown.
-   */
-  public static long compose(ColumnarBatch... batches) {
-    IndicatorVector[] ivs =
-        Arrays.stream(batches)
-            .map(ColumnarBatches::getIndicatorVector)
-            .toArray(IndicatorVector[]::new);
-    final long[] handles = 
Arrays.stream(ivs).mapToLong(IndicatorVector::handle).toArray();
-    return 
ColumnarBatchJniWrapper.create(Runtimes.contextInstance("ColumnarBatches#compose"))
-        .compose(handles);
-  }
-
   private static ColumnarBatch create(IndicatorVector iv) {
     int numColumns = Math.toIntExact(iv.getNumColumns());
     int numRows = Math.toIntExact(iv.getNumRows());
@@ -365,14 +344,28 @@ public class ColumnarBatches {
     b.close();
   }
 
+  private static IndicatorVector getIndicatorVector(ColumnarBatch input) {
+    if (!isLightBatch(input)) {
+      throw new UnsupportedOperationException("Input batch is not light 
batch");
+    }
+    return (IndicatorVector) input.column(0);
+  }
+
   public static long getNativeHandle(ColumnarBatch batch) {
     return getIndicatorVector(batch).handle();
   }
 
+  static String getComprehensiveLightBatchType(ColumnarBatch batch) {
+    return getIndicatorVector(batch).getType();
+  }
+
   public static String toString(ColumnarBatch batch, int start, int length) {
     ColumnarBatch loadedBatch = 
ensureLoaded(ArrowBufferAllocators.contextInstance(), batch);
     StructType type = 
SparkArrowUtil.fromArrowSchema(ArrowUtil.toSchema(loadedBatch));
     return InternalRowUtl.toString(
-        type, JavaConverters.asScalaIterator(loadedBatch.rowIterator()), 
start, length);
+        type,
+        JavaConverters.<InternalRow>asScalaIterator(loadedBatch.rowIterator()),
+        start,
+        length);
   }
 }
diff --git 
a/gluten-arrow/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchUtil.java
 
b/gluten-arrow/src/main/java/org/apache/spark/sql/vectorized/SparkColumnarBatchUtil.java
similarity index 90%
rename from 
gluten-arrow/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchUtil.java
rename to 
gluten-arrow/src/main/java/org/apache/spark/sql/vectorized/SparkColumnarBatchUtil.java
index 0e2c748130..9ba743d7ea 100644
--- 
a/gluten-arrow/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatchUtil.java
+++ 
b/gluten-arrow/src/main/java/org/apache/spark/sql/vectorized/SparkColumnarBatchUtil.java
@@ -16,12 +16,11 @@
  */
 package org.apache.spark.sql.vectorized;
 
-import org.apache.gluten.columnarbatch.ColumnarBatches;
 import org.apache.gluten.exception.GlutenException;
 
 import java.lang.reflect.Field;
 
-public class ColumnarBatchUtil {
+public class SparkColumnarBatchUtil {
 
   private static final Field FIELD_COLUMNS;
   private static final Field FIELD_COLUMNAR_BATCH_ROW;
@@ -61,10 +60,7 @@ public class ColumnarBatchUtil {
         newVectors[i] = from.column(i);
       }
       FIELD_COLUMNS.set(target, newVectors);
-      // Light batch does not need the row.
-      if (ColumnarBatches.isHeavyBatch(target)) {
-        setColumnarBatchRow(from, newVectors, target);
-      }
+      setColumnarBatchRow(from, newVectors, target);
     } catch (IllegalAccessException e) {
       throw new GlutenException(e);
     }
diff --git 
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ArrowBatches.scala 
b/gluten-arrow/src/main/scala/org/apache/gluten/columnarbatch/ArrowBatches.scala
similarity index 77%
rename from 
gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ArrowBatches.scala
rename to 
gluten-arrow/src/main/scala/org/apache/gluten/columnarbatch/ArrowBatches.scala
index 0b387a02c5..5ae3863c57 100644
--- 
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ArrowBatches.scala
+++ 
b/gluten-arrow/src/main/scala/org/apache/gluten/columnarbatch/ArrowBatches.scala
@@ -23,13 +23,13 @@ import 
org.apache.gluten.extension.columnar.transition.Convention.BatchType.Vani
 object ArrowBatches {
 
   /**
-   * JavaArrowBatch stands for Gluten's Java Arrow-based columnar batch 
implementation.
+   * ArrowJavaBatch stands for Gluten's Java Arrow-based columnar batch 
implementation.
    *
-   * JavaArrowBatch should have 
[[org.apache.gluten.vectorized.ArrowWritableColumnVector]]s
-   * populated in it. JavaArrowBatch can be offloaded to NativeArrowBatch 
through API in
+   * ArrowJavaBatch should have 
[[org.apache.gluten.vectorized.ArrowWritableColumnVector]]s
+   * populated in it. ArrowJavaBatch can be offloaded to ArrowNativeBatch 
through API in
    * [[ColumnarBatches]].
    *
-   * JavaArrowBatch is compatible with vanilla batch since it provides valid 
#get<type>(...)
+   * ArrowJavaBatch is compatible with vanilla batch since it provides valid 
#get<type>(...)
    * implementations.
    */
   object ArrowJavaBatch extends Convention.BatchType {
@@ -37,10 +37,10 @@ object ArrowBatches {
   }
 
   /**
-   * NativeArrowBatch stands for Gluten's native Arrow-based columnar batch 
implementation.
+   * ArrowNativeBatch stands for Gluten's native Arrow-based columnar batch 
implementation.
    *
-   * NativeArrowBatch should have 
[[org.apache.gluten.columnarbatch.IndicatorVector]] set as the
-   * first vector. NativeArrowBatch can be loaded to JavaArrowBatch through 
API in
+   * ArrowNativeBatch should have 
[[org.apache.gluten.columnarbatch.IndicatorVector]] set as the
+   * first vector. ArrowNativeBatch can be loaded to ArrowJavaBatch through 
API in
    * [[ColumnarBatches]].
    */
   object ArrowNativeBatch extends Convention.BatchType {
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
index 503f4145f9..ced9378ad6 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/Transition.scala
@@ -18,10 +18,7 @@ package org.apache.gluten.extension.columnar.transition
 
 import org.apache.gluten.exception.GlutenException
 
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan}
+import org.apache.spark.sql.execution.SparkPlan
 
 /**
  * Transition is a simple function to convert a query plan to interested 
[[ConventionReq]].
@@ -30,7 +27,6 @@ import org.apache.spark.sql.execution.{LeafExecNode, 
SparkPlan}
  * [[org.apache.gluten.extension.columnar.transition.Convention.BatchType]]'s 
definition.
  */
 trait Transition {
-  import Transition._
   final def apply(plan: SparkPlan): SparkPlan = {
     val out = apply0(plan)
     out
@@ -132,10 +128,4 @@ object Transition {
       }
     }
   }
-
-  private case class DummySparkPlan() extends LeafExecNode {
-    override def supportsColumnar: Boolean = true // To bypass the assertion 
in ColumnarToRowExec.
-    override protected def doExecute(): RDD[InternalRow] = throw new 
UnsupportedOperationException()
-    override def output: Seq[Attribute] = Nil
-  }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
index 9cafcae8b5..23fe9ed57e 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/TransitionGraph.scala
@@ -67,24 +67,53 @@ object TransitionGraph {
     }
   }
 
-  private case class TransitionCost(count: Int) extends 
FloydWarshallGraph.Cost {
+  private case class TransitionCost(count: Int, nodeNames: Seq[String])
+    extends FloydWarshallGraph.Cost {
     override def +(other: FloydWarshallGraph.Cost): TransitionCost = {
       other match {
-        case TransitionCost(otherCount) => TransitionCost(count + otherCount)
+        case TransitionCost(otherCount, otherNodeNames) =>
+          TransitionCost(count + otherCount, nodeNames ++ otherNodeNames)
       }
     }
   }
 
   private object TransitionCostModel extends 
FloydWarshallGraph.CostModel[Transition] {
-    override def zero(): FloydWarshallGraph.Cost = TransitionCost(0)
-    override def costOf(transition: Transition): FloydWarshallGraph.Cost = 
costOf0(transition)
-    override def costComparator(): Ordering[FloydWarshallGraph.Cost] = 
Ordering.Int.on {
-      case TransitionCost(c) => c
+    override def zero(): TransitionCost = TransitionCost(0, Nil)
+    override def costOf(transition: Transition): TransitionCost = {
+      costOf0(transition)
     }
-    private def costOf0(transition: Transition): TransitionCost = transition 
match {
-      case t if t.isEmpty => TransitionCost(0)
-      case ChainedTransition(f, s) => costOf0(f) + costOf0(s)
-      case _ => TransitionCost(1)
+    override def costComparator(): Ordering[FloydWarshallGraph.Cost] = {
+      (x: FloydWarshallGraph.Cost, y: FloydWarshallGraph.Cost) =>
+        (x, y) match {
+          case (TransitionCost(count, nodeNames), TransitionCost(otherCount, 
otherNodeNames)) =>
+            if (count != otherCount) {
+              count - otherCount
+            } else {
+              // To make the output order stable.
+              nodeNames.hashCode() - otherNodeNames.hashCode()
+            }
+        }
+    }
+
+    private def costOf0(transition: Transition): TransitionCost = {
+      val leaf = DummySparkPlan()
+
+      /**
+       * The calculation considers C2C's cost as half of C2R / R2C's cost. So 
query planner prefers
+       * C2C than C2R / R2C.
+       */
+      def costOfPlan(plan: SparkPlan): TransitionCost = plan
+        .map {
+          case p if p == leaf => TransitionCost(0, Nil)
+          case node @ RowToColumnarLike(_) => TransitionCost(2, 
Seq(node.nodeName))
+          case node @ ColumnarToRowLike(_) => TransitionCost(2, 
Seq(node.nodeName))
+          case node @ ColumnarToColumnarLike(_) => TransitionCost(1, 
Seq(node.nodeName))
+        }
+        .reduce((l, r) => l + r)
+
+      val plan = transition.apply(leaf)
+      val cost = costOfPlan(plan)
+      cost
     }
   }
 }
diff --git 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala
 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala
index 0a0deb17de..f02a513362 100644
--- 
a/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala
+++ 
b/gluten-core/src/main/scala/org/apache/gluten/extension/columnar/transition/package.scala
@@ -18,6 +18,9 @@ package org.apache.gluten.extension.columnar
 
 import org.apache.gluten.execution.ColumnarToColumnarExec
 
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.AQEShuffleReadExec
 import org.apache.spark.sql.execution.debug.DebugExec
@@ -69,4 +72,10 @@ package object transition {
       }
     }
   }
+
+  case class DummySparkPlan() extends LeafExecNode {
+    override def supportsColumnar: Boolean = true // To bypass the assertion 
in ColumnarToRowExec.
+    override protected def doExecute(): RDD[InternalRow] = throw new 
UnsupportedOperationException()
+    override def output: Seq[Attribute] = Nil
+  }
 }
diff --git 
a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
 
b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
index ff43dbe190..5c35cb5020 100644
--- 
a/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
+++ 
b/gluten-substrait/src/test/scala/org/apache/gluten/extension/columnar/transition/TransitionSuite.scala
@@ -17,6 +17,7 @@
 package org.apache.gluten.extension.columnar.transition
 
 import org.apache.gluten.exception.GlutenException
+import org.apache.gluten.execution.ColumnarToColumnarExec
 import org.apache.gluten.extension.GlutenPlan
 
 import org.apache.spark.rdd.RDD
@@ -24,6 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.vectorized.ColumnarBatch
 
 class TransitionSuite extends SharedSparkSession {
   import TransitionSuite._
@@ -128,14 +130,12 @@ object TransitionSuite extends TransitionSuiteBase {
       from: Convention.BatchType,
       to: Convention.BatchType,
       override val child: SparkPlan)
-    extends UnaryExecNode
-    with GlutenPlan {
-    override def supportsColumnar: Boolean = true
-    override protected def batchType0(): Convention.BatchType = to
+    extends ColumnarToColumnarExec(from, to) {
     override protected def withNewChildInternal(newChild: SparkPlan): 
SparkPlan =
       copy(child = newChild)
     override protected def doExecute(): RDD[InternalRow] = throw new 
UnsupportedOperationException()
-    override def output: Seq[Attribute] = child.output
+    override protected def mapIterator(in: Iterator[ColumnarBatch]): 
Iterator[ColumnarBatch] =
+      throw new UnsupportedOperationException()
   }
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to