Github user mmccline commented on a diff in the pull request:
https://github.com/apache/spark/pull/19943#discussion_r156475822
--- Diff:
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.scala
---
@@ -0,0 +1,442 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.orc
+
+import org.apache.hadoop.mapreduce.{InputSplit, RecordReader,
TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.orc._
+import org.apache.orc.mapred.OrcInputFormat
+import org.apache.orc.storage.ql.exec.vector._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.memory.MemoryMode
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.vectorized._
+import org.apache.spark.sql.types._
+
+
+/**
+ * To support vectorization in WholeStageCodeGen, this reader returns
ColumnarBatch.
+ */
+private[orc] class OrcColumnarBatchReader extends RecordReader[Void,
ColumnarBatch] with Logging {
+ /**
+ * ORC File Reader.
+ */
+ private var reader: Reader = _
+
+ /**
+ * Vectorized Row Batch.
+ */
+ private var batch: VectorizedRowBatch = _
+
+ /**
+ * Requested Column IDs.
+ */
+ private var requestedColIds: Array[Int] = _
+
+ /**
+ * Record reader from row batch.
+ */
+ private var rows: org.apache.orc.RecordReader = _
+
+ /**
+ * Required Schema.
+ */
+ private var requiredSchema: StructType = _
+
+ /**
+ * ColumnarBatch for vectorized execution by whole-stage codegen.
+ */
+ private var columnarBatch: ColumnarBatch = _
+
+ /**
+ * Writable columnVectors of ColumnarBatch.
+ */
+ private var columnVectors: Seq[WritableColumnVector] = _
+
+ /**
+ * The number of rows read and considered to be returned.
+ */
+ private var rowsReturned: Long = 0L
+
+ /**
+ * Total number of rows.
+ */
+ private var totalRowCount: Long = 0L
+
+ override def getCurrentKey: Void = null
+
+ override def getCurrentValue: ColumnarBatch = columnarBatch
+
+ override def getProgress: Float = rowsReturned.toFloat / totalRowCount
+
+ override def nextKeyValue(): Boolean = nextBatch()
+
+ override def close(): Unit = {
+ if (columnarBatch != null) {
+ columnarBatch.close()
+ columnarBatch = null
+ }
+ if (rows != null) {
+ rows.close()
+ rows = null
+ }
+ }
+
+ /**
+ * Initialize ORC file reader and batch record reader.
+ * Please note that `setRequiredSchema` is needed to be called after
this.
+ */
+ override def initialize(inputSplit: InputSplit, taskAttemptContext:
TaskAttemptContext): Unit = {
+ val fileSplit = inputSplit.asInstanceOf[FileSplit]
+ val conf = taskAttemptContext.getConfiguration
+ reader = OrcFile.createReader(
+ fileSplit.getPath,
+ OrcFile.readerOptions(conf)
+ .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf))
+ .filesystem(fileSplit.getPath.getFileSystem(conf)))
+
+ val options = OrcInputFormat.buildOptions(conf, reader,
fileSplit.getStart, fileSplit.getLength)
+ rows = reader.rows(options)
+ }
+
+ /**
+ * Set required schema and partition information.
+ * With this information, this creates ColumnarBatch with the full
schema.
+ */
+ def setRequiredSchema(
+ orcSchema: TypeDescription,
+ requestedColIds: Array[Int],
+ resultSchema: StructType,
+ requiredSchema: StructType,
+ partitionValues: InternalRow): Unit = {
+ batch = orcSchema.createRowBatch(OrcColumnarBatchReader.DEFAULT_SIZE)
+ totalRowCount = reader.getNumberOfRows
+ logDebug(s"totalRowCount = $totalRowCount")
+
+ this.requiredSchema = requiredSchema
+ this.requestedColIds = requestedColIds
+
+ val memMode = OrcColumnarBatchReader.DEFAULT_MEMORY_MODE
+ val capacity = ColumnarBatch.DEFAULT_BATCH_SIZE
+ if (memMode == MemoryMode.OFF_HEAP) {
+ columnVectors = OffHeapColumnVector.allocateColumns(capacity,
resultSchema)
+ } else {
+ columnVectors = OnHeapColumnVector.allocateColumns(capacity,
resultSchema)
+ }
+ columnarBatch = new ColumnarBatch(resultSchema, columnVectors.toArray,
capacity)
+
+ if (partitionValues.numFields > 0) {
+ val partitionIdx = requiredSchema.fields.length
+ for (i <- 0 until partitionValues.numFields) {
+ ColumnVectorUtils.populate(columnVectors(i + partitionIdx),
partitionValues, i)
+ columnVectors(i + partitionIdx).setIsConstant()
+ }
+ }
+ }
+
+ /**
+ * Return true if there exists more data in the next batch. If exists,
prepare the next batch
+ * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch
columns.
+ */
+ private def nextBatch(): Boolean = {
+ if (rowsReturned >= totalRowCount) {
+ return false
+ }
+
+ rows.nextBatch(batch)
+ val batchSize = batch.size
+ if (batchSize == 0) {
+ return false
+ }
+ rowsReturned += batchSize
+ columnarBatch.reset()
+ columnarBatch.setNumRows(batchSize)
+
+ var i = 0
+ while (i < requiredSchema.length) {
+ val field = requiredSchema(i)
+ val toColumn = columnVectors(i)
+
+ if (requestedColIds(i) < 0) {
+ toColumn.appendNulls(batchSize)
+ } else {
+ val fromColumn = batch.cols(requestedColIds(i))
+
+ if (fromColumn.isRepeating) {
+ if (fromColumn.isNull(0)) {
+ toColumn.appendNulls(batchSize)
+ } else {
+ field.dataType match {
+ case BooleanType =>
+ val data =
fromColumn.asInstanceOf[LongColumnVector].vector(0) == 1
+ toColumn.appendBooleans(batchSize, data)
+
+ case ByteType =>
+ val data =
fromColumn.asInstanceOf[LongColumnVector].vector(0).toByte
+ toColumn.appendBytes(batchSize, data)
+ case ShortType =>
+ val data =
fromColumn.asInstanceOf[LongColumnVector].vector(0).toShort
+ toColumn.appendShorts(batchSize, data)
+ case IntegerType =>
+ val data =
fromColumn.asInstanceOf[LongColumnVector].vector(0).toInt
+ toColumn.appendInts(batchSize, data)
+ case LongType =>
+ val data =
fromColumn.asInstanceOf[LongColumnVector].vector(0)
+ toColumn.appendLongs(batchSize, data)
+
+ case DateType =>
+ val data =
fromColumn.asInstanceOf[LongColumnVector].vector(0).toInt
+ toColumn.appendInts(batchSize, data)
+
+ case TimestampType =>
+ val data = fromColumn.asInstanceOf[TimestampColumnVector]
+ toColumn.appendLongs(batchSize, data.time(0) * 1000L +
data.nanos(0) / 1000L)
+
+ case FloatType =>
+ val data =
fromColumn.asInstanceOf[DoubleColumnVector].vector(0).toFloat
+ toColumn.appendFloats(batchSize, data)
+ case DoubleType =>
+ val data =
fromColumn.asInstanceOf[DoubleColumnVector].vector(0)
+ toColumn.appendDoubles(batchSize, data)
+
+ case StringType =>
+ val data = fromColumn.asInstanceOf[BytesColumnVector]
+ var index = 0
+ while (index < batchSize) {
+ toColumn.appendByteArray(data.vector(0), data.start(0),
data.length(0))
+ index += 1
+ }
+ case BinaryType =>
+ val data = fromColumn.asInstanceOf[BytesColumnVector]
+ var index = 0
+ while (index < batchSize) {
+ toColumn.appendByteArray(data.vector(0), data.start(0),
data.length(0))
+ index += 1
+ }
+
+ case DecimalType.Fixed(precision, scale) =>
+ val d =
fromColumn.asInstanceOf[DecimalColumnVector].vector(0)
+ val value = Decimal(d.getHiveDecimal.bigDecimalValue,
d.precision(), d.scale)
+ value.changePrecision(precision, scale)
+ if (precision <= Decimal.MAX_INT_DIGITS) {
+ toColumn.appendInts(batchSize,
value.toUnscaledLong.toInt)
+ } else if (precision <= Decimal.MAX_LONG_DIGITS) {
+ toColumn.appendLongs(batchSize, value.toUnscaledLong)
+ } else {
+ val bytes =
value.toJavaBigDecimal.unscaledValue.toByteArray
+ var index = 0
+ while (index < batchSize) {
+ toColumn.appendByteArray(bytes, 0, bytes.length)
+ index += 1
+ }
+ }
+
+ case dt =>
+ throw new UnsupportedOperationException(s"Unsupported Data
Type: $dt")
+ }
+ }
+ } else if (!field.nullable || fromColumn.noNulls) {
--- End diff --
Throwing data error exceptions can be nasty -- perhaps setting the value to
a known value and a warning might be better. I noticed leveraging of not
nullable in the Hive optimizer recently but we will see if there are any
practical issues that arise...
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]