This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new fe22ec74db78 [SPARK-46598][SQL] OrcColumnarBatchReader should respect 
the memory mode when creating column vectors for the missing column
fe22ec74db78 is described below

commit fe22ec74db7895c6ea1f39236162ae39027111f4
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Sat Jan 6 12:38:35 2024 -0800

    [SPARK-46598][SQL] OrcColumnarBatchReader should respect the memory mode 
when creating column vectors for the missing column
    
    This PR fixes a long-standing bug that `OrcColumnarBatchReader` does not 
respect the memory mode when creating column vectors for missing columbs. This 
PR fixes it.
    
    To not violate the memory mode requirement
    
    No
    
    new test
    
    no
    
    Closes #44598 from cloud-fan/orc.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
    (cherry picked from commit 0c1c5e93e376b97a6d2dae99e973b9385155727a)
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../datasources/orc/OrcColumnarBatchReader.java        | 17 ++++++++++++-----
 .../sql/execution/datasources/orc/OrcFileFormat.scala  |  9 ++++++++-
 .../datasources/v2/orc/OrcPartitionReaderFactory.scala |  6 ++++--
 .../sql/execution/datasources/v2/orc/OrcScan.scala     |  8 +++++++-
 .../datasources/orc/OrcColumnarBatchReaderSuite.scala  | 18 +++++++++++++++---
 5 files changed, 46 insertions(+), 12 deletions(-)

diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
index b6184baa2e0e..5bfe22450f36 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
@@ -31,12 +31,11 @@ import org.apache.orc.Reader;
 import org.apache.orc.TypeDescription;
 import org.apache.orc.mapred.OrcInputFormat;
 
+import org.apache.spark.memory.MemoryMode;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns;
 import 
org.apache.spark.sql.execution.datasources.orc.OrcShimUtils.VectorizedRowBatchWrap;
-import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
-import org.apache.spark.sql.execution.vectorized.ConstantColumnVector;
-import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
+import org.apache.spark.sql.execution.vectorized.*;
 import org.apache.spark.sql.types.*;
 import org.apache.spark.sql.vectorized.ColumnarBatch;
 
@@ -73,11 +72,14 @@ public class OrcColumnarBatchReader extends 
RecordReader<Void, ColumnarBatch> {
   @VisibleForTesting
   public ColumnarBatch columnarBatch;
 
+  private final MemoryMode memoryMode;
+
   // The wrapped ORC column vectors.
   private org.apache.spark.sql.vectorized.ColumnVector[] orcVectorWrappers;
 
-  public OrcColumnarBatchReader(int capacity) {
+  public OrcColumnarBatchReader(int capacity, MemoryMode memoryMode) {
     this.capacity = capacity;
+    this.memoryMode = memoryMode;
   }
 
 
@@ -177,7 +179,12 @@ public class OrcColumnarBatchReader extends 
RecordReader<Void, ColumnarBatch> {
         int colId = requestedDataColIds[i];
         // Initialize the missing columns once.
         if (colId == -1) {
-          OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt);
+          final WritableColumnVector missingCol;
+          if (memoryMode == MemoryMode.OFF_HEAP) {
+            missingCol = new OffHeapColumnVector(capacity, dt);
+          } else {
+            missingCol = new OnHeapColumnVector(capacity, dt);
+          }
           // Check if the missing column has an associated default value in 
the schema metadata.
           // If so, fill the corresponding column vector with the value.
           Object defaultValue = 
ResolveDefaultColumns.existenceDefaultValues(requiredSchema)[i];
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
index b7e6f11f67d6..53d2b08431f8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala
@@ -31,6 +31,7 @@ import org.apache.orc.mapred.OrcStruct
 import org.apache.orc.mapreduce._
 
 import org.apache.spark.TaskContext
+import org.apache.spark.memory.MemoryMode
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
@@ -152,6 +153,12 @@ class OrcFileFormat
       assert(supportBatch(sparkSession, resultSchema))
     }
 
+    val memoryMode = if (sqlConf.offHeapColumnVectorEnabled) {
+      MemoryMode.OFF_HEAP
+    } else {
+      MemoryMode.ON_HEAP
+    }
+
     OrcConf.IS_SCHEMA_EVOLUTION_CASE_SENSITIVE.setBoolean(hadoopConf, 
sqlConf.caseSensitiveAnalysis)
 
     val broadcastedConf =
@@ -196,7 +203,7 @@ class OrcFileFormat
         val taskAttemptContext = new TaskAttemptContextImpl(taskConf, 
attemptId)
 
         if (enableVectorizedReader) {
-          val batchReader = new OrcColumnarBatchReader(capacity)
+          val batchReader = new OrcColumnarBatchReader(capacity, memoryMode)
           // SPARK-23399 Register a task completion listener first to call 
`close()` in all cases.
           // There is a possibility that `initialize` and `initBatch` hit some 
errors (like OOM)
           // after opening a file.
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
index 2b7bdae6b31b..b23071e50cbe 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala
@@ -26,6 +26,7 @@ import org.apache.orc.mapred.OrcStruct
 import org.apache.orc.mapreduce.OrcInputFormat
 
 import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.memory.MemoryMode
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
 import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
@@ -57,7 +58,8 @@ case class OrcPartitionReaderFactory(
     partitionSchema: StructType,
     filters: Array[Filter],
     aggregation: Option[Aggregation],
-    options: OrcOptions) extends FilePartitionReaderFactory {
+    options: OrcOptions,
+    memoryMode: MemoryMode) extends FilePartitionReaderFactory {
   private val resultSchema = StructType(readDataSchema.fields ++ 
partitionSchema.fields)
   private val isCaseSensitive = sqlConf.caseSensitiveAnalysis
   private val capacity = sqlConf.orcVectorizedReaderBatchSize
@@ -146,7 +148,7 @@ case class OrcPartitionReaderFactory(
       val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 
0), 0)
       val taskAttemptContext = new TaskAttemptContextImpl(taskConf, attemptId)
 
-      val batchReader = new OrcColumnarBatchReader(capacity)
+      val batchReader = new OrcColumnarBatchReader(capacity, memoryMode)
       batchReader.initialize(fileSplit, taskAttemptContext)
       val requestedPartitionColIds =
         Array.fill(readDataSchema.length)(-1) ++ Range(0, 
partitionSchema.length)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
index 072ab26774e5..ca37d22eeb1e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala
@@ -21,6 +21,7 @@ import scala.collection.JavaConverters.mapAsScalaMapConverter
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.fs.Path
 
+import org.apache.spark.memory.MemoryMode
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.expressions.Expression
 import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
@@ -64,11 +65,16 @@ case class OrcScan(
   override def createReaderFactory(): PartitionReaderFactory = {
     val broadcastedConf = sparkSession.sparkContext.broadcast(
       new SerializableConfiguration(hadoopConf))
+    val memoryMode = if 
(sparkSession.sessionState.conf.offHeapColumnVectorEnabled) {
+      MemoryMode.OFF_HEAP
+    } else {
+      MemoryMode.ON_HEAP
+    }
     // The partition values are already truncated in `FileScan.partitions`.
     // We should use `readPartitionSchema` as the partition schema here.
     OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf,
       dataSchema, readDataSchema, readPartitionSchema, pushedFilters, 
pushedAggregate,
-      new OrcOptions(options.asScala.toMap, sparkSession.sessionState.conf))
+      new OrcOptions(options.asScala.toMap, sparkSession.sessionState.conf), 
memoryMode)
   }
 
   override def equals(obj: Any): Boolean = obj match {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala
index a9389c1c21b4..06ea12f83ce7 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReaderSuite.scala
@@ -26,11 +26,12 @@ import 
org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
 import org.apache.orc.TypeDescription
 
 import org.apache.spark.TestUtils
+import org.apache.spark.memory.MemoryMode
 import org.apache.spark.sql.QueryTest
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
-import org.apache.spark.sql.execution.vectorized.ConstantColumnVector
+import org.apache.spark.sql.execution.vectorized.{ConstantColumnVector, 
OffHeapColumnVector}
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -53,7 +54,7 @@ class OrcColumnarBatchReaderSuite extends QueryTest with 
SharedSparkSession {
         requestedDataColIds: Array[Int],
         requestedPartitionColIds: Array[Int],
         resultFields: Array[StructField]): OrcColumnarBatchReader = {
-      val reader = new OrcColumnarBatchReader(4096)
+      val reader = new OrcColumnarBatchReader(4096, MemoryMode.ON_HEAP)
       reader.initBatch(
         orcFileSchema,
         resultFields,
@@ -117,7 +118,7 @@ class OrcColumnarBatchReaderSuite extends QueryTest with 
SharedSparkSession {
         val fileSplit = new FileSplit(new Path(file.getCanonicalPath), 0L, 
file.length, Array.empty)
         val taskConf = sqlContext.sessionState.newHadoopConf()
         val orcFileSchema = TypeDescription.fromString(schema.simpleString)
-        val vectorizedReader = new OrcColumnarBatchReader(4096)
+        val vectorizedReader = new OrcColumnarBatchReader(4096, 
MemoryMode.ON_HEAP)
         val attemptId = new TaskAttemptID(new TaskID(new JobID(), 
TaskType.MAP, 0), 0)
         val taskAttemptContext = new TaskAttemptContextImpl(taskConf, 
attemptId)
 
@@ -148,4 +149,15 @@ class OrcColumnarBatchReaderSuite extends QueryTest with 
SharedSparkSession {
       }
     }
   }
+
+  test("SPARK-46598: off-heap mode") {
+    val reader = new OrcColumnarBatchReader(4096, MemoryMode.OFF_HEAP)
+    reader.initBatch(
+      TypeDescription.fromString("struct<col1:int,col2:int>"),
+      StructType.fromDDL("col1 int, col2 int, col3 int").fields,
+      Array(0, 1, -1),
+      Array(-1, -1, -1),
+      InternalRow.empty)
+    assert(reader.columnarBatch.column(2).isInstanceOf[OffHeapColumnVector])
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to