Github user dongjoon-hyun commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19943#discussion_r160078819
  
    --- Diff: 
sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/JavaOrcColumnarBatchReader.java
 ---
    @@ -0,0 +1,503 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.datasources.orc;
    +
    +import java.io.IOException;
    +
    +import org.apache.hadoop.conf.Configuration;
    +import org.apache.hadoop.mapreduce.InputSplit;
    +import org.apache.hadoop.mapreduce.RecordReader;
    +import org.apache.hadoop.mapreduce.TaskAttemptContext;
    +import org.apache.hadoop.mapreduce.lib.input.FileSplit;
    +import org.apache.orc.OrcConf;
    +import org.apache.orc.OrcFile;
    +import org.apache.orc.Reader;
    +import org.apache.orc.TypeDescription;
    +import org.apache.orc.mapred.OrcInputFormat;
    +import org.apache.orc.storage.common.type.HiveDecimal;
    +import org.apache.orc.storage.ql.exec.vector.*;
    +import org.apache.orc.storage.serde2.io.HiveDecimalWritable;
    +
    +import org.apache.spark.memory.MemoryMode;
    +import org.apache.spark.sql.catalyst.InternalRow;
    +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
    +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
    +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
    +import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
    +import org.apache.spark.sql.types.*;
    +import org.apache.spark.sql.vectorized.ColumnarBatch;
    +
    +
    +/**
    + * To support vectorization in WholeStageCodeGen, this reader returns 
ColumnarBatch.
    + * After creating, `initialize` and `setRequiredSchema` should be called 
sequentially.
    + */
    +public class JavaOrcColumnarBatchReader extends RecordReader<Void, 
ColumnarBatch> {
    +
    +  /**
    +   * ORC File Reader.
    +   */
    +  private Reader reader;
    +
    +  /**
    +   * Vectorized Row Batch.
    +   */
    +  private VectorizedRowBatch batch;
    +
    +  /**
    +   * Requested Column IDs.
    +   */
    +  private int[] requestedColIds;
    +
    +  /**
    +   * Record reader from row batch.
    +   */
    +  private org.apache.orc.RecordReader recordReader;
    +
    +  /**
    +   * Required Schema.
    +   */
    +  private StructType requiredSchema;
    +
    +  /**
    +   * ColumnarBatch for vectorized execution by whole-stage codegen.
    +   */
    +  private ColumnarBatch columnarBatch;
    +
    +  /**
    +   * Writable column vectors of ColumnarBatch.
    +   */
    +  private WritableColumnVector[] columnVectors;
    +
    +  /**
    +   * The number of rows read and considered to be returned.
    +   */
    +  private long rowsReturned = 0L;
    +
    +  /**
    +   * Total number of rows.
    +   */
    +  private long totalRowCount = 0L;
    +
    +  @Override
    +  public Void getCurrentKey() throws IOException, InterruptedException {
    +    return null;
    +  }
    +
    +  @Override
    +  public ColumnarBatch getCurrentValue() throws IOException, 
InterruptedException {
    +    return columnarBatch;
    +  }
    +
    +  @Override
    +  public float getProgress() throws IOException, InterruptedException {
    +    return (float) rowsReturned / totalRowCount;
    +  }
    +
    +  @Override
    +  public boolean nextKeyValue() throws IOException, InterruptedException {
    +    return nextBatch();
    +  }
    +
    +  @Override
    +  public void close() throws IOException {
    +    if (columnarBatch != null) {
    +      columnarBatch.close();
    +      columnarBatch = null;
    +    }
    +    if (recordReader != null) {
    +      recordReader.close();
    +      recordReader = null;
    +    }
    +  }
    +
    +  /**
    +   * Initialize ORC file reader and batch record reader.
    +   * Please note that `setRequiredSchema` is needed to be called after 
this.
    +   */
    +  @Override
    +  public void initialize(InputSplit inputSplit, TaskAttemptContext 
taskAttemptContext)
    +      throws IOException, InterruptedException {
    +    FileSplit fileSplit = (FileSplit)inputSplit;
    +    Configuration conf = taskAttemptContext.getConfiguration();
    +    reader = OrcFile.createReader(
    +      fileSplit.getPath(),
    +      OrcFile.readerOptions(conf)
    +        .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf))
    +        .filesystem(fileSplit.getPath().getFileSystem(conf)));
    +
    +    Reader.Options options =
    +            OrcInputFormat.buildOptions(conf, reader, 
fileSplit.getStart(), fileSplit.getLength());
    +    recordReader = reader.rows(options);
    +    totalRowCount = reader.getNumberOfRows();
    +  }
    +
    +  /**
    +   * Set required schema and partition information.
    +   * With this information, this creates ColumnarBatch with the full 
schema.
    +   */
    +  public void setRequiredSchema(
    +    TypeDescription orcSchema,
    +    int[] requestedColIds,
    +    StructType requiredSchema,
    +    StructType partitionSchema,
    +    InternalRow partitionValues) {
    +    batch = orcSchema.createRowBatch(DEFAULT_SIZE);
    +    assert(!batch.selectedInUse); // `selectedInUse` should be initialized 
with `false`.
    +
    +    StructType resultSchema = new StructType(requiredSchema.fields());
    +    for (StructField f : partitionSchema.fields())
    +      resultSchema = resultSchema.add(f);
    +    this.requiredSchema = requiredSchema;
    +    this.requestedColIds = requestedColIds;
    +
    +    int capacity = DEFAULT_SIZE;
    +    if (DEFAULT_MEMORY_MODE == MemoryMode.OFF_HEAP) {
    +      columnVectors = OffHeapColumnVector.allocateColumns(capacity, 
resultSchema);
    +    } else {
    +      columnVectors = OnHeapColumnVector.allocateColumns(capacity, 
resultSchema);
    +    }
    +    columnarBatch = new ColumnarBatch(resultSchema, columnVectors, 
capacity);
    +
    +    if (partitionValues.numFields() > 0) {
    +      int partitionIdx = requiredSchema.fields().length;
    +      for (int i = 0; i < partitionValues.numFields(); i++) {
    +        ColumnVectorUtils.populate(columnVectors[i + partitionIdx], 
partitionValues, i);
    +        columnVectors[i + partitionIdx].setIsConstant();
    +      }
    +    }
    +
    +    // Initialize the missing columns once.
    +    for (int i = 0; i < requiredSchema.length(); i++) {
    +      if (requestedColIds[i] < 0) {
    +        columnVectors[i].putNulls(0, columnarBatch.capacity());
    +        columnVectors[i].setIsConstant();
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Return true if there exists more data in the next batch. If exists, 
prepare the next batch
    +   * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch 
columns.
    +   */
    +  private boolean nextBatch() throws IOException {
    +    if (rowsReturned >= totalRowCount) {
    +      return false;
    +    }
    +
    +    recordReader.nextBatch(batch);
    +    int batchSize = batch.size;
    +    if (batchSize == 0) {
    +      return false;
    +    }
    +    rowsReturned += batchSize;
    +    for (WritableColumnVector vector : columnVectors) {
    +      vector.reset();
    +    }
    +    columnarBatch.setNumRows(batchSize);
    +    int i = 0;
    +    while (i < requiredSchema.length()) {
    +      StructField field = requiredSchema.fields()[i];
    +      WritableColumnVector toColumn = columnVectors[i];
    +
    +      if (requestedColIds[i] < 0) {
    +        toColumn.appendNulls(batchSize);
    +      } else {
    +        ColumnVector fromColumn = batch.cols[requestedColIds[i]];
    +
    +        if (fromColumn.isRepeating) {
    +          if (fromColumn.isNull[0]) {
    +            toColumn.appendNulls(batchSize);
    +          } else {
    +            DataType type = field.dataType();
    +            if (type instanceof BooleanType) {
    +              toColumn.appendBooleans(batchSize, 
((LongColumnVector)fromColumn).vector[0] == 1);
    +            } else if (type instanceof ByteType) {
    +              toColumn.appendBytes(batchSize, 
(byte)((LongColumnVector)fromColumn).vector[0]);
    +            } else if (type instanceof ShortType) {
    +              toColumn.appendShorts(batchSize, 
(short)((LongColumnVector)fromColumn).vector[0]);
    +            } else if (type instanceof IntegerType || type instanceof 
DateType) {
    +              toColumn.appendInts(batchSize, 
(int)((LongColumnVector)fromColumn).vector[0]);
    +            } else if (type instanceof LongType) {
    +              toColumn.appendLongs(batchSize, 
((LongColumnVector)fromColumn).vector[0]);
    +            } else if (type instanceof TimestampType) {
    +              toColumn.appendLongs(batchSize, 
fromTimestampColumnVector((TimestampColumnVector)fromColumn, 0));
    +            } else if (type instanceof FloatType) {
    +              toColumn.appendFloats(batchSize, 
(float)((DoubleColumnVector)fromColumn).vector[0]);
    +            } else if (type instanceof DoubleType) {
    +              toColumn.appendDoubles(batchSize, 
((DoubleColumnVector)fromColumn).vector[0]);
    +            } else if (type instanceof StringType || type instanceof 
BinaryType) {
    +              BytesColumnVector data = (BytesColumnVector)fromColumn;
    +              int index = 0;
    +              while (index < batchSize) {
    --- End diff --
    
    Yep. It's done.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to