Github user henrify commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19943#discussion_r160318348
  
    --- Diff: 
sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
 ---
    @@ -0,0 +1,605 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.datasources.orc;
    +
    +import java.io.IOException;
    +import java.util.stream.IntStream;
    +
    +import org.apache.hadoop.conf.Configuration;
    +import org.apache.hadoop.mapreduce.InputSplit;
    +import org.apache.hadoop.mapreduce.RecordReader;
    +import org.apache.hadoop.mapreduce.TaskAttemptContext;
    +import org.apache.hadoop.mapreduce.lib.input.FileSplit;
    +import org.apache.orc.OrcConf;
    +import org.apache.orc.OrcFile;
    +import org.apache.orc.Reader;
    +import org.apache.orc.TypeDescription;
    +import org.apache.orc.mapred.OrcInputFormat;
    +import org.apache.orc.storage.common.type.HiveDecimal;
    +import org.apache.orc.storage.ql.exec.vector.*;
    +import org.apache.orc.storage.serde2.io.HiveDecimalWritable;
    +
    +import org.apache.spark.memory.MemoryMode;
    +import org.apache.spark.sql.catalyst.InternalRow;
    +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
    +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
    +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
    +import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
    +import org.apache.spark.sql.types.*;
    +import org.apache.spark.sql.vectorized.ColumnarBatch;
    +
    +
    +/**
    + * To support vectorization in WholeStageCodeGen, this reader returns 
ColumnarBatch.
    + * After creating, `initialize` and `initBatch` should be called 
sequentially.
    + */
    +public class OrcColumnarBatchReader extends RecordReader<Void, 
ColumnarBatch> {
    +  /**
    +   * The default size of batch. We use this value for both ORC and Spark 
consistently
    +   * because they have different default values like the following.
    +   *
    +   * - ORC's VectorizedRowBatch.DEFAULT_SIZE = 1024
    +   * - Spark's ColumnarBatch.DEFAULT_BATCH_SIZE = 4 * 1024
    +   */
    +  public static final int DEFAULT_SIZE = 4 * 1024;
    +
    +  /**
    +   * Returns the number of micros since epoch from an element of 
TimestampColumnVector.
    +   */
    +  private static long fromTimestampColumnVector(TimestampColumnVector 
vector, int index) {
    +    return vector.time[index] * 1000L + vector.nanos[index] / 1000L;
    +  }
    +
    +  // ORC File Reader
    +  private Reader reader;
    +
    +  // Vectorized ORC Row Batch
    +  private VectorizedRowBatch batch;
    +
    +  /**
    +   * The column IDs of the physical ORC file schema which are required by 
this reader.
    +   * -1 means this required column doesn't exist in the ORC file.
    +   */
    +  private int[] requestedColIds;
    +
    +  // Record reader from ORC row batch.
    +  private org.apache.orc.RecordReader recordReader;
    +
    +  private StructField[] requiredFields;
    +
    +  // The result columnar batch for vectorized execution by whole-stage 
codegen.
    +  private ColumnarBatch columnarBatch;
    +
    +  // Writable column vectors of the result columnar batch.
    +  private WritableColumnVector[] columnVectors;
    +
    +  // The number of rows read and considered to be returned.
    +  private long rowsReturned = 0L;
    +
    +  private long totalRowCount = 0L;
    +
    +  /**
    +   * The memory mode of the columnarBatch
    +   */
    +  private final MemoryMode MEMORY_MODE;
    +
    +  public OrcColumnarBatchReader(boolean useOffHeap) {
    +    MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP;
    +  }
    +
    +
    +  @Override
    +  public Void getCurrentKey() throws IOException, InterruptedException {
    +    return null;
    +  }
    +
    +  @Override
    +  public ColumnarBatch getCurrentValue() throws IOException, 
InterruptedException {
    +    return columnarBatch;
    +  }
    +
    +  @Override
    +  public float getProgress() throws IOException, InterruptedException {
    +    return recordReader.getProgress();
    +  }
    +
    +  @Override
    +  public boolean nextKeyValue() throws IOException, InterruptedException {
    +    return nextBatch();
    +  }
    +
    +  @Override
    +  public void close() throws IOException {
    +    if (columnarBatch != null) {
    +      columnarBatch.close();
    +      columnarBatch = null;
    +    }
    +    if (recordReader != null) {
    +      recordReader.close();
    +      recordReader = null;
    +    }
    +  }
    +
    +  /**
    +   * Initialize ORC file reader and batch record reader.
    +   * Please note that `initBatch` is needed to be called after this.
    +   */
    +  @Override
    +  public void initialize(InputSplit inputSplit, TaskAttemptContext 
taskAttemptContext)
    +      throws IOException, InterruptedException {
    +    FileSplit fileSplit = (FileSplit)inputSplit;
    +    Configuration conf = taskAttemptContext.getConfiguration();
    +    reader = OrcFile.createReader(
    +      fileSplit.getPath(),
    +      OrcFile.readerOptions(conf)
    +        .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf))
    +        .filesystem(fileSplit.getPath().getFileSystem(conf)));
    +
    +    Reader.Options options =
    +      OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), 
fileSplit.getLength());
    +    recordReader = reader.rows(options);
    +    totalRowCount = reader.getNumberOfRows();
    +  }
    +
    +  /**
    +   * Initialize columnar batch by setting required schema and partition 
information.
    +   * With this information, this creates ColumnarBatch with the full 
schema.
    +   */
    +  public void initBatch(
    +      TypeDescription orcSchema,
    +      int[] requestedColIds,
    +      StructField[] requiredFields,
    +      StructType partitionSchema,
    +      InternalRow partitionValues) {
    +    batch = orcSchema.createRowBatch(DEFAULT_SIZE);
    +    assert(!batch.selectedInUse); // `selectedInUse` should be initialized 
with `false`.
    +
    +    this.requiredFields = requiredFields;
    +    this.requestedColIds = requestedColIds;
    +    assert(requiredFields.length == requestedColIds.length);
    +
    +    StructType resultSchema = new StructType(requiredFields);
    +    for (StructField f : partitionSchema.fields()) {
    +      resultSchema = resultSchema.add(f);
    +    }
    +
    +    int capacity = DEFAULT_SIZE;
    +    if (MEMORY_MODE == MemoryMode.OFF_HEAP) {
    +      columnVectors = OffHeapColumnVector.allocateColumns(capacity, 
resultSchema);
    +    } else {
    +      columnVectors = OnHeapColumnVector.allocateColumns(capacity, 
resultSchema);
    +    }
    +    columnarBatch = new ColumnarBatch(resultSchema, columnVectors, 
capacity);
    +
    +    if (partitionValues.numFields() > 0) {
    +      int partitionIdx = requiredFields.length;
    +      for (int i = 0; i < partitionValues.numFields(); i++) {
    +        ColumnVectorUtils.populate(columnVectors[i + partitionIdx], 
partitionValues, i);
    +        columnVectors[i + partitionIdx].setIsConstant();
    +      }
    +    }
    +
    +    // Initialize the missing columns once.
    +    for (int i = 0; i < requiredFields.length; i++) {
    +      if (requestedColIds[i] == -1) {
    +        columnVectors[i].putNulls(0, columnarBatch.capacity());
    +        columnVectors[i].setIsConstant();
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Return true if there exists more data in the next batch. If exists, 
prepare the next batch
    +   * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch 
columns.
    +   */
    +  private boolean nextBatch() throws IOException {
    +    if (rowsReturned >= totalRowCount) {
    +      return false;
    +    }
    +
    +    recordReader.nextBatch(batch);
    +    int batchSize = batch.size;
    +    if (batchSize == 0) {
    +      return false;
    +    }
    +    rowsReturned += batchSize;
    +    for (WritableColumnVector vector : columnVectors) {
    +      vector.reset();
    +    }
    +    columnarBatch.setNumRows(batchSize);
    +    for (int i = 0; i < requiredFields.length; i++) {
    +      StructField field = requiredFields[i];
    +      WritableColumnVector toColumn = columnVectors[i];
    +
    +      if (requestedColIds[i] >= 0) {
    +        ColumnVector fromColumn = batch.cols[requestedColIds[i]];
    +
    +        if (fromColumn.isRepeating) {
    +          putRepeatingValues(batchSize, field, fromColumn, toColumn);
    +        } else if (fromColumn.noNulls) {
    +          putNonNullValues(batchSize, field, fromColumn, toColumn);
    +        } else {
    +          putValues(batchSize, field, fromColumn, toColumn);
    +        }
    +      }
    +    }
    +    return true;
    +  }
    +
    +  private void putRepeatingValues(
    +      int batchSize,
    +      StructField field,
    +      ColumnVector fromColumn,
    +      WritableColumnVector toColumn) {
    +    if (fromColumn.isNull[0]) {
    +      toColumn.putNulls(0, batchSize);
    +    } else {
    +      DataType type = field.dataType();
    +      if (type instanceof BooleanType) {
    +        toColumn.putBooleans(0, batchSize, 
((LongColumnVector)fromColumn).vector[0] == 1);
    +      } else if (type instanceof ByteType) {
    +        toColumn.putBytes(0, batchSize, 
(byte)((LongColumnVector)fromColumn).vector[0]);
    +      } else if (type instanceof ShortType) {
    +        toColumn.putShorts(0, batchSize, 
(short)((LongColumnVector)fromColumn).vector[0]);
    +      } else if (type instanceof IntegerType || type instanceof DateType) {
    +        toColumn.putInts(0, batchSize, 
(int)((LongColumnVector)fromColumn).vector[0]);
    +      } else if (type instanceof LongType) {
    +        toColumn.putLongs(0, batchSize, 
((LongColumnVector)fromColumn).vector[0]);
    +      } else if (type instanceof TimestampType) {
    +        toColumn.putLongs(0, batchSize,
    +          fromTimestampColumnVector((TimestampColumnVector)fromColumn, 0));
    +      } else if (type instanceof FloatType) {
    +        toColumn.putFloats(0, batchSize, 
(float)((DoubleColumnVector)fromColumn).vector[0]);
    +      } else if (type instanceof DoubleType) {
    +        toColumn.putDoubles(0, batchSize, 
((DoubleColumnVector)fromColumn).vector[0]);
    +      } else if (type instanceof StringType || type instanceof BinaryType) 
{
    +        putByteArrays(batchSize, toColumn, 
((BytesColumnVector)fromColumn).vector[0]);
    +      } else if (type instanceof DecimalType) {
    +        DecimalType decimalType = (DecimalType)type;
    +        putDecimalWritables(
    +          toColumn,
    +          batchSize,
    +          decimalType.precision(),
    +          decimalType.scale(),
    +          ((DecimalColumnVector)fromColumn).vector[0]);
    +      } else {
    +        throw new UnsupportedOperationException("Unsupported Data Type: " 
+ type);
    +      }
    +    }
    +  }
    +
    +  private void putNonNullValues(
    +      int batchSize,
    +      StructField field,
    +      ColumnVector fromColumn,
    +      WritableColumnVector toColumn) {
    +    DataType type = field.dataType();
    +    if (type instanceof BooleanType) {
    +      putNonNullBooleans(batchSize, (LongColumnVector) fromColumn, 
toColumn);
    +    } else if (type instanceof ByteType) {
    +      putNonNullBytes(batchSize, (LongColumnVector) fromColumn, toColumn);
    +    } else if (type instanceof ShortType) {
    +      putNonNullShorts(batchSize, (LongColumnVector) fromColumn, toColumn);
    +    } else if (type instanceof IntegerType || type instanceof DateType) {
    +      putNonNullInts(batchSize, (LongColumnVector) fromColumn, toColumn);
    +    } else if (type instanceof LongType) {
    +      toColumn.putLongs(0, batchSize, 
((LongColumnVector)fromColumn).vector, 0);
    +    } else if (type instanceof TimestampType) {
    +      putNonNullTimestamps(batchSize, (TimestampColumnVector) fromColumn, 
toColumn);
    +    } else if (type instanceof FloatType) {
    +      putNonNullFloats(batchSize, (DoubleColumnVector) fromColumn, 
toColumn);
    +    } else if (type instanceof DoubleType) {
    +      toColumn.putDoubles(0, batchSize, 
((DoubleColumnVector)fromColumn).vector, 0);
    +    } else if (type instanceof StringType || type instanceof BinaryType) {
    +      putNonNullByteArray(batchSize, (BytesColumnVector) fromColumn, 
toColumn);
    +    } else if (type instanceof DecimalType) {
    +      putNonNullDecimals(batchSize, (DecimalColumnVector) fromColumn, 
toColumn, (DecimalType) type);
    +    } else {
    +      throw new UnsupportedOperationException("Unsupported Data Type: " + 
type);
    +    }
    +  }
    +
    +
    +  private void putValues(
    +      int batchSize,
    +      StructField field,
    +      ColumnVector fromColumn,
    +      WritableColumnVector toColumn) {
    +    DataType type = field.dataType();
    +    if (type instanceof BooleanType) {
    +      putBooleans(batchSize, (LongColumnVector)fromColumn, toColumn);
    +    } else if (type instanceof ByteType) {
    +      putBytes(batchSize, (LongColumnVector)fromColumn, toColumn);
    +    } else if (type instanceof ShortType) {
    +      putShorts(batchSize, (LongColumnVector)fromColumn, toColumn);
    +    } else if (type instanceof IntegerType || type instanceof DateType) {
    +      putInts(batchSize, (LongColumnVector)fromColumn, toColumn);
    +    } else if (type instanceof LongType) {
    +      putLongs(batchSize, (LongColumnVector)fromColumn, toColumn);
    +    } else if (type instanceof TimestampType) {
    +      putTimestamps(batchSize, (TimestampColumnVector)fromColumn, 
toColumn);
    +    } else if (type instanceof FloatType) {
    +      putFloats(batchSize, (DoubleColumnVector)fromColumn, toColumn);
    +    } else if (type instanceof DoubleType) {
    +      putDoubles(batchSize, (DoubleColumnVector)fromColumn, toColumn);
    +    } else if (type instanceof StringType || type instanceof BinaryType) {
    +      putByteArrays(batchSize, (BytesColumnVector)fromColumn, toColumn);
    +    } else if (type instanceof DecimalType) {
    +      putDecimals(batchSize, (DecimalColumnVector)fromColumn, toColumn, 
(DecimalType) type);
    +    } else {
    +      throw new UnsupportedOperationException("Unsupported Data Type: " + 
type);
    +    }
    +  }
    +
    +  // 
--------------------------------------------------------------------------
    +  // Put a value
    +  // 
--------------------------------------------------------------------------
    +
    +  private static void putDecimalWritable(
    +      WritableColumnVector toColumn,
    +      int index,
    +      int precision,
    +      int scale,
    +      HiveDecimalWritable decimalWritable) {
    +    HiveDecimal decimal = decimalWritable.getHiveDecimal();
    +    Decimal value =
    +      Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), 
decimal.scale());
    +    value.changePrecision(precision, scale);
    +
    +    if (precision <= Decimal.MAX_INT_DIGITS()) {
    +      toColumn.putInt(index, (int) value.toUnscaledLong());
    +    } else if (precision <= Decimal.MAX_LONG_DIGITS()) {
    +      toColumn.putLong(index, value.toUnscaledLong());
    +    } else {
    +      byte[] bytes = 
value.toJavaBigDecimal().unscaledValue().toByteArray();
    +      WritableColumnVector arrayData = toColumn.getChildColumn(0);
    +      arrayData.putBytes(index * 16, bytes.length, bytes, 0);
    +      toColumn.putArray(index, index * 16, bytes.length);
    +    }
    +  }
    +
    +  // 
--------------------------------------------------------------------------
    +  // Put repeating values
    +  // 
--------------------------------------------------------------------------
    +
    +  private void putByteArrays(int count, WritableColumnVector toColumn, 
byte[] bytes) {
    +    WritableColumnVector arrayData = toColumn.getChildColumn(0);
    +    int size = bytes.length;
    +    arrayData.reserve(size);
    +    arrayData.putBytes(0, size, bytes, 0);
    +    for (int index = 0; index < count; index++) {
    +      toColumn.putArray(index, 0, size);
    +    }
    +  }
    +
    +  private static void putDecimalWritables(
    +      WritableColumnVector toColumn,
    +      int size,
    +      int precision,
    +      int scale,
    +      HiveDecimalWritable decimalWritable) {
    +    HiveDecimal decimal = decimalWritable.getHiveDecimal();
    +    Decimal value =
    +      Decimal.apply(decimal.bigDecimalValue(), decimal.precision(), 
decimal.scale());
    +    value.changePrecision(precision, scale);
    +
    +    if (precision <= Decimal.MAX_INT_DIGITS()) {
    +      toColumn.putInts(0, size, (int) value.toUnscaledLong());
    +    } else if (precision <= Decimal.MAX_LONG_DIGITS()) {
    +      toColumn.putLongs(0, size, value.toUnscaledLong());
    +    } else {
    +      byte[] bytes = 
value.toJavaBigDecimal().unscaledValue().toByteArray();
    +      WritableColumnVector arrayData = toColumn.getChildColumn(0);
    +      arrayData.reserve(bytes.length);
    +      arrayData.putBytes(0, bytes.length, bytes, 0);
    +      for (int index = 0; index < size; index++) {
    +        toColumn.putArray(index, 0, bytes.length);
    +      }
    +    }
    +  }
    +
    +  // 
--------------------------------------------------------------------------
    +  // Put non-null values
    +  // 
--------------------------------------------------------------------------
    +
    +  private void putNonNullBooleans(int count, LongColumnVector fromColumn, 
WritableColumnVector toColumn) {
    +    long[] data = fromColumn.vector;
    +    for (int index = 0; index < count; index++) {
    +      toColumn.putBoolean(index, data[index] == 1);
    +    }
    +  }
    +
    +  private void putNonNullByteArray(int count, BytesColumnVector 
fromColumn, WritableColumnVector toColumn) {
    +    BytesColumnVector data = fromColumn;
    +    WritableColumnVector arrayData = toColumn.getChildColumn(0);
    +    int totalNumBytes = IntStream.of(data.length).sum();
    +    arrayData.reserve(totalNumBytes);
    +    for (int index = 0, pos = 0; index < count; pos += data.length[index], 
index++) {
    +      arrayData.putBytes(pos, data.length[index], data.vector[index], 
data.start[index]);
    +      toColumn.putArray(index, pos, data.length[index]);
    +    }
    +  }
    +
    +  private void putNonNullBytes(int count, LongColumnVector fromColumn, 
WritableColumnVector toColumn) {
    --- End diff --
    
    It almost seemed that they hurt the performance. The MR and Hive tests were 
up and down randomly as expected, but Vectorized tests were down in almost 
every benchmark.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to