Github user dongjoon-hyun commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19943#discussion_r160311799
  
    --- Diff: 
sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
 ---
    @@ -0,0 +1,528 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.datasources.orc;
    +
    +import java.io.IOException;
    +import java.util.stream.IntStream;
    +
    +import org.apache.hadoop.conf.Configuration;
    +import org.apache.hadoop.mapreduce.InputSplit;
    +import org.apache.hadoop.mapreduce.RecordReader;
    +import org.apache.hadoop.mapreduce.TaskAttemptContext;
    +import org.apache.hadoop.mapreduce.lib.input.FileSplit;
    +import org.apache.orc.OrcConf;
    +import org.apache.orc.OrcFile;
    +import org.apache.orc.Reader;
    +import org.apache.orc.TypeDescription;
    +import org.apache.orc.mapred.OrcInputFormat;
    +import org.apache.orc.storage.common.type.HiveDecimal;
    +import org.apache.orc.storage.ql.exec.vector.*;
    +import org.apache.orc.storage.serde2.io.HiveDecimalWritable;
    +
    +import org.apache.spark.memory.MemoryMode;
    +import org.apache.spark.sql.catalyst.InternalRow;
    +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
    +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
    +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
    +import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
    +import org.apache.spark.sql.types.*;
    +import org.apache.spark.sql.vectorized.ColumnarBatch;
    +
    +
    +/**
    + * To support vectorization in WholeStageCodeGen, this reader returns 
ColumnarBatch.
    + * After creating, `initialize` and `initBatch` should be called 
sequentially.
    + */
    +public class OrcColumnarBatchReader extends RecordReader<Void, 
ColumnarBatch> {
    +
    +  // ORC File Reader
    +  private Reader reader;
    +
    +  // Vectorized ORC Row Batch
    +  private VectorizedRowBatch batch;
    +
    +  /**
    +   * The column IDs of the physical ORC file schema which are required by 
this reader.
    +   * -1 means this required column doesn't exist in the ORC file.
    +   */
    +  private int[] requestedColIds;
    +
    +  // Record reader from ORC row batch.
    +  private org.apache.orc.RecordReader recordReader;
    +
    +  private StructField[] requiredFields;
    +
    +  // The result columnar batch for vectorized execution by whole-stage 
codegen.
    +  private ColumnarBatch columnarBatch;
    +
    +  // Writable column vectors of the result columnar batch.
    +  private WritableColumnVector[] columnVectors;
    +
    +  // The number of rows read and considered to be returned.
    +  private long rowsReturned = 0L;
    +
    +  private long totalRowCount = 0L;
    +
    +  /**
    +   * The memory mode of the columnarBatch
    +   */
    +  private final MemoryMode MEMORY_MODE;
    +
    +  public OrcColumnarBatchReader(boolean useOffHeap) {
    +    MEMORY_MODE = useOffHeap ? MemoryMode.OFF_HEAP : MemoryMode.ON_HEAP;
    +  }
    +
    +
    +  @Override
    +  public Void getCurrentKey() throws IOException, InterruptedException {
    +    return null;
    +  }
    +
    +  @Override
    +  public ColumnarBatch getCurrentValue() throws IOException, 
InterruptedException {
    +    return columnarBatch;
    +  }
    +
    +  @Override
    +  public float getProgress() throws IOException, InterruptedException {
    +    return (float) rowsReturned / totalRowCount;
    +  }
    +
    +  @Override
    +  public boolean nextKeyValue() throws IOException, InterruptedException {
    +    return nextBatch();
    +  }
    +
    +  @Override
    +  public void close() throws IOException {
    +    if (columnarBatch != null) {
    +      columnarBatch.close();
    +      columnarBatch = null;
    +    }
    +    if (recordReader != null) {
    +      recordReader.close();
    +      recordReader = null;
    +    }
    +  }
    +
    +  /**
    +   * Initialize ORC file reader and batch record reader.
    +   * Please note that `initBatch` is needed to be called after this.
    +   */
    +  @Override
    +  public void initialize(InputSplit inputSplit, TaskAttemptContext 
taskAttemptContext)
    +      throws IOException, InterruptedException {
    +    FileSplit fileSplit = (FileSplit)inputSplit;
    +    Configuration conf = taskAttemptContext.getConfiguration();
    +    reader = OrcFile.createReader(
    +      fileSplit.getPath(),
    +      OrcFile.readerOptions(conf)
    +        .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf))
    +        .filesystem(fileSplit.getPath().getFileSystem(conf)));
    +
    +    Reader.Options options =
    +      OrcInputFormat.buildOptions(conf, reader, fileSplit.getStart(), 
fileSplit.getLength());
    +    recordReader = reader.rows(options);
    +    totalRowCount = reader.getNumberOfRows();
    +  }
    +
    +  /**
    +   * Initialize columnar batch by setting required schema and partition 
information.
    +   * With this information, this creates ColumnarBatch with the full 
schema.
    +   */
    +  public void initBatch(
    +      TypeDescription orcSchema,
    +      int[] requestedColIds,
    +      StructField[] requiredFields,
    +      StructType partitionSchema,
    +      InternalRow partitionValues) {
    +    batch = orcSchema.createRowBatch(DEFAULT_SIZE);
    +    assert(!batch.selectedInUse); // `selectedInUse` should be initialized 
with `false`.
    +
    +    this.requiredFields = requiredFields;
    +    this.requestedColIds = requestedColIds;
    +    assert(requiredFields.length == requestedColIds.length);
    +
    +    StructType resultSchema = new StructType(requiredFields);
    +    for (StructField f : partitionSchema.fields()) {
    +      resultSchema = resultSchema.add(f);
    +    }
    +
    +    int capacity = DEFAULT_SIZE;
    +    if (MEMORY_MODE == MemoryMode.OFF_HEAP) {
    +      columnVectors = OffHeapColumnVector.allocateColumns(capacity, 
resultSchema);
    +    } else {
    +      columnVectors = OnHeapColumnVector.allocateColumns(capacity, 
resultSchema);
    +    }
    +    columnarBatch = new ColumnarBatch(resultSchema, columnVectors, 
capacity);
    +
    +    if (partitionValues.numFields() > 0) {
    +      int partitionIdx = requiredFields.length;
    +      for (int i = 0; i < partitionValues.numFields(); i++) {
    +        ColumnVectorUtils.populate(columnVectors[i + partitionIdx], 
partitionValues, i);
    +        columnVectors[i + partitionIdx].setIsConstant();
    +      }
    +    }
    +
    +    // Initialize the missing columns once.
    +    for (int i = 0; i < requiredFields.length; i++) {
    +      if (requestedColIds[i] == -1) {
    +        columnVectors[i].putNulls(0, columnarBatch.capacity());
    +        columnVectors[i].setIsConstant();
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Return true if there exists more data in the next batch. If exists, 
prepare the next batch
    +   * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch 
columns.
    +   */
    +  private boolean nextBatch() throws IOException {
    +    if (rowsReturned >= totalRowCount) {
    +      return false;
    +    }
    +
    +    recordReader.nextBatch(batch);
    +    int batchSize = batch.size;
    +    if (batchSize == 0) {
    +      return false;
    +    }
    +    rowsReturned += batchSize;
    +    for (WritableColumnVector vector : columnVectors) {
    +      vector.reset();
    +    }
    +    columnarBatch.setNumRows(batchSize);
    +    for (int i = 0; i < requiredFields.length; i++) {
    +      StructField field = requiredFields[i];
    +      WritableColumnVector toColumn = columnVectors[i];
    +
    +      if (requestedColIds[i] >= 0) {
    +        ColumnVector fromColumn = batch.cols[requestedColIds[i]];
    +
    +        if (fromColumn.isRepeating) {
    +          putRepeatingValues(batchSize, field, fromColumn, toColumn);
    +        } else if (fromColumn.noNulls) {
    +          putNonNullValues(batchSize, field, fromColumn, toColumn);
    +        } else {
    +          putValues(batchSize, field, fromColumn, toColumn);
    +        }
    +      }
    +    }
    +    return true;
    +  }
    +
    +  private void putRepeatingValues(
    +      int batchSize,
    +      StructField field,
    +      ColumnVector fromColumn,
    +      WritableColumnVector toColumn) {
    +    if (fromColumn.isNull[0]) {
    +      toColumn.putNulls(0, batchSize);
    +    } else {
    +      DataType type = field.dataType();
    +      if (type instanceof BooleanType) {
    +        toColumn.putBooleans(0, batchSize, 
((LongColumnVector)fromColumn).vector[0] == 1);
    +      } else if (type instanceof ByteType) {
    +        toColumn.putBytes(0, batchSize, 
(byte)((LongColumnVector)fromColumn).vector[0]);
    +      } else if (type instanceof ShortType) {
    +        toColumn.putShorts(0, batchSize, 
(short)((LongColumnVector)fromColumn).vector[0]);
    +      } else if (type instanceof IntegerType || type instanceof DateType) {
    +        toColumn.putInts(0, batchSize, 
(int)((LongColumnVector)fromColumn).vector[0]);
    +      } else if (type instanceof LongType) {
    +        toColumn.putLongs(0, batchSize, 
((LongColumnVector)fromColumn).vector[0]);
    +      } else if (type instanceof TimestampType) {
    +        toColumn.putLongs(0, batchSize,
    +          fromTimestampColumnVector((TimestampColumnVector)fromColumn, 0));
    +      } else if (type instanceof FloatType) {
    +        toColumn.putFloats(0, batchSize, 
(float)((DoubleColumnVector)fromColumn).vector[0]);
    +      } else if (type instanceof DoubleType) {
    +        toColumn.putDoubles(0, batchSize, 
((DoubleColumnVector)fromColumn).vector[0]);
    +      } else if (type instanceof StringType || type instanceof BinaryType) 
{
    +        BytesColumnVector data = (BytesColumnVector)fromColumn;
    +        WritableColumnVector arrayData = toColumn.getChildColumn(0);
    +        int size = data.vector[0].length;
    +        arrayData.reserve(size);
    +        arrayData.putBytes(0, size, data.vector[0], 0);
    +        for (int index = 0; index < batchSize; index++) {
    +          toColumn.putArray(index, 0, size);
    +        }
    +      } else if (type instanceof DecimalType) {
    +        DecimalType decimalType = (DecimalType)type;
    +        putDecimalWritables(
    +          toColumn,
    +          batchSize,
    +          decimalType.precision(),
    +          decimalType.scale(),
    +          ((DecimalColumnVector)fromColumn).vector[0]);
    +      } else {
    +        throw new UnsupportedOperationException("Unsupported Data Type: " 
+ type);
    +      }
    +    }
    +  }
    +
    +  private void putNonNullValues(
    +      int batchSize,
    +      StructField field,
    +      ColumnVector fromColumn,
    +      WritableColumnVector toColumn) {
    +    DataType type = field.dataType();
    +    if (type instanceof BooleanType) {
    +      long[] data = ((LongColumnVector)fromColumn).vector;
    +      for (int index = 0; index < batchSize; index++) {
    +        toColumn.putBoolean(index, data[index] == 1);
    +      }
    +    } else if (type instanceof ByteType) {
    +      long[] data = ((LongColumnVector)fromColumn).vector;
    +      for (int index = 0; index < batchSize; index++) {
    +        toColumn.putByte(index, (byte)data[index]);
    +      }
    +    } else if (type instanceof ShortType) {
    +      long[] data = ((LongColumnVector)fromColumn).vector;
    +      for (int index = 0; index < batchSize; index++) {
    +        toColumn.putShort(index, (short)data[index]);
    +      }
    +    } else if (type instanceof IntegerType || type instanceof DateType) {
    +      long[] data = ((LongColumnVector)fromColumn).vector;
    +      for (int index = 0; index < batchSize; index++) {
    +        toColumn.putInt(index, (int)data[index]);
    +      }
    +    } else if (type instanceof LongType) {
    +      toColumn.putLongs(0, batchSize, 
((LongColumnVector)fromColumn).vector, 0);
    +    } else if (type instanceof TimestampType) {
    +      TimestampColumnVector data = ((TimestampColumnVector)fromColumn);
    +      for (int index = 0; index < batchSize; index++) {
    +        toColumn.putLong(index, fromTimestampColumnVector(data, index));
    +      }
    +    } else if (type instanceof FloatType) {
    +      double[] data = ((DoubleColumnVector)fromColumn).vector;
    +      for (int index = 0; index < batchSize; index++) {
    +        toColumn.putFloat(index, (float)data[index]);
    +      }
    +    } else if (type instanceof DoubleType) {
    +      toColumn.putDoubles(0, batchSize, 
((DoubleColumnVector)fromColumn).vector, 0);
    +    } else if (type instanceof StringType || type instanceof BinaryType) {
    +      BytesColumnVector data = ((BytesColumnVector)fromColumn);
    +      WritableColumnVector arrayData = toColumn.getChildColumn(0);
    +      int totalNumBytes = IntStream.of(data.length).sum();
    +      arrayData.reserve(totalNumBytes);
    +      for (int index = 0, pos = 0; index < batchSize; pos += 
data.length[index], index++) {
    +        arrayData.putBytes(pos, data.length[index], data.vector[index], 
data.start[index]);
    +        toColumn.putArray(index, pos, data.length[index]);
    +      }
    +    } else if (type instanceof DecimalType) {
    +      DecimalType decimalType = (DecimalType)type;
    +      DecimalColumnVector data = ((DecimalColumnVector)fromColumn);
    +      if (decimalType.precision() > Decimal.MAX_INT_DIGITS()) {
    +        WritableColumnVector arrayData = toColumn.getChildColumn(0);
    +        arrayData.reserve(batchSize * 16);
    +      }
    +      for (int index = 0; index < batchSize; index++) {
    +        putDecimalWritable(
    +          toColumn,
    +          index,
    +          decimalType.precision(),
    +          decimalType.scale(),
    +          data.vector[index]);
    +      }
    +    } else {
    +      throw new UnsupportedOperationException("Unsupported Data Type: " + 
type);
    +    }
    +  }
    +
    +  private void putValues(
    +      int batchSize,
    +      StructField field,
    +      ColumnVector fromColumn,
    +      WritableColumnVector toColumn) {
    +    DataType type = field.dataType();
    +    if (type instanceof BooleanType) {
    +      long[] vector = ((LongColumnVector)fromColumn).vector;
    +      for (int index = 0; index < batchSize; index++) {
    +        if (fromColumn.isNull[index]) {
    +          toColumn.putNull(index);
    +        } else {
    +          toColumn.putBoolean(index, vector[index] == 1);
    +        }
    +      }
    +    } else if (type instanceof ByteType) {
    +      long[] vector = ((LongColumnVector)fromColumn).vector;
    +      for (int index = 0; index < batchSize; index++) {
    +        if (fromColumn.isNull[index]) {
    +          toColumn.putNull(index);
    +        } else {
    +          toColumn.putByte(index, (byte)vector[index]);
    +        }
    +      }
    +    } else if (type instanceof ShortType) {
    +      long[] vector = ((LongColumnVector)fromColumn).vector;
    +      for (int index = 0; index < batchSize; index++) {
    +        if (fromColumn.isNull[index]) {
    +          toColumn.putNull(index);
    +        } else {
    +          toColumn.putShort(index, (short)vector[index]);
    +        }
    +      }
    +    } else if (type instanceof IntegerType || type instanceof DateType) {
    +      long[] vector = ((LongColumnVector)fromColumn).vector;
    +      for (int index = 0; index < batchSize; index++) {
    +        if (fromColumn.isNull[index]) {
    +          toColumn.putNull(index);
    +        } else {
    +          toColumn.putInt(index, (int)vector[index]);
    +        }
    +      }
    +    } else if (type instanceof LongType) {
    +      long[] vector = ((LongColumnVector)fromColumn).vector;
    +      for (int index = 0; index < batchSize; index++) {
    +        if (fromColumn.isNull[index]) {
    +          toColumn.putNull(index);
    +        } else {
    +          toColumn.putLong(index, vector[index]);
    +        }
    +      }
    +    } else if (type instanceof TimestampType) {
    +      TimestampColumnVector vector = ((TimestampColumnVector)fromColumn);
    +      for (int index = 0; index < batchSize; index++) {
    +        if (fromColumn.isNull[index]) {
    +          toColumn.putNull(index);
    +        } else {
    +          toColumn.putLong(index, fromTimestampColumnVector(vector, 
index));
    +        }
    +      }
    +    } else if (type instanceof FloatType) {
    +      double[] vector = ((DoubleColumnVector)fromColumn).vector;
    +      for (int index = 0; index < batchSize; index++) {
    +        if (fromColumn.isNull[index]) {
    +          toColumn.putNull(index);
    +        } else {
    +          toColumn.putFloat(index, (float)vector[index]);
    +        }
    +      }
    +    } else if (type instanceof DoubleType) {
    +      double[] vector = ((DoubleColumnVector)fromColumn).vector;
    +      for (int index = 0; index < batchSize; index++) {
    +        if (fromColumn.isNull[index]) {
    +          toColumn.putNull(index);
    +        } else {
    +          toColumn.putDouble(index, vector[index]);
    +        }
    +      }
    +    } else if (type instanceof StringType || type instanceof BinaryType) {
    +      BytesColumnVector vector = (BytesColumnVector)fromColumn;
    +      WritableColumnVector arrayData = toColumn.getChildColumn(0);
    +      int totalNumBytes = IntStream.of(vector.length).sum();
    +      arrayData.reserve(totalNumBytes);
    +      for (int index = 0, pos = 0; index < batchSize; pos += 
vector.length[index], index++) {
    +        if (fromColumn.isNull[index]) {
    +          toColumn.putNull(index);
    +        } else {
    +          arrayData.putBytes(pos, vector.length[index], 
vector.vector[index], vector.start[index]);
    +          toColumn.putArray(index, pos, vector.length[index]);
    +        }
    +      }
    +    } else if (type instanceof DecimalType) {
    +      DecimalType decimalType = (DecimalType)type;
    +      HiveDecimalWritable[] vector = 
((DecimalColumnVector)fromColumn).vector;
    +      for (int index = 0; index < batchSize; index++) {
    --- End diff --
    
    Oh, I see. It's the following. Thanks!
    ```
    if (type.precision() > Decimal.MAX_LONG_DIGITS()) {
    ...
    ```



---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to