This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a5c8866 [SPARK-34859][SQL] Handle column index when using vectorized
Parquet reader
a5c8866 is described below
commit a5c886619dd1573e96bbba058db099b47f0c147c
Author: Chao Sun <[email protected]>
AuthorDate: Wed Jun 30 14:21:18 2021 -0700
[SPARK-34859][SQL] Handle column index when using vectorized Parquet reader
### What changes were proposed in this pull request?
Make the current vectorized Parquet reader to work with column index
introduced in Parquet 1.11. In particular, this PR makes the following changes:
1. in `ParquetReadState`, track row ranges returned via
`PageReadStore.getRowIndexes` as well as the first row index for each page via
`DataPage.getFirstRowIndex`.
1. introduced a new API `ParquetVectorUpdater.skipValues` which skips a
batch of values from a Parquet value reader. As part of the process also
renamed existing `updateBatch` to `readValues`, and `update` to `readValue` to
keep the method names consistent.
1. in correspondence as above, also introduced new API
`VectorizedValuesReader.skipXXX` for different data types, as well as the
implementations. These are useful when the reader knows that the given batch of
values can be skipped, for instance, due to the batch is not covered in the row
ranges generated by column index filtering.
2. changed `VectorizedRleValuesReader` to handle column index filtering.
This is done by comparing the range that is going to be read next within the
current RLE/PACKED block (let's call this block range), against the current row
range. There are three cases:
* if the block range is before the current row range, skip all the
values in the block range
* if the block range is after the current row range, advance the row
range and repeat the steps
* if the block range overlaps with the current row range, only read the
values within the overlapping area and skip the rest.
### Why are the changes needed?
[Parquet Column
Index](https://github.com/apache/parquet-format/blob/master/PageIndex.md) is a
new feature in Parquet 1.11 which allows very efficient filtering on page level
(some benchmark numbers can be found
[here](https://blog.cloudera.com/speeding-up-select-queries-with-parquet-page-indexes/)),
especially when data is sorted. The feature is largely implemented in
parquet-mr (via classes such as `ColumnIndex` and `ColumnIndexFilter`). In
Spark, the non-vectorized Parquet reader c [...]
Previously,
[SPARK-26345](https://issues.apache.org/jira/browse/SPARK-26345) / (#31393)
updated Spark to only scan pages filtered by column index from parquet-mr side.
This is done by calling `ParquetFileReader.readNextFilteredRowGroup` and
`ParquetFileReader.getFilteredRecordCount` API. The implementation, however,
only work for a few limited cases: in the scenario where there are multiple
columns and their type width are different (e.g., `int` and `bigint`), it could
return incorrec [...]
In order to fix the above, Spark needs to leverage the API
`PageReadStore.getRowIndexes` and `DataPage.getFirstRowIndex`. The former
returns the indexes of all rows (note the difference between rows and values:
for flat schema there is no difference between the two, but for nested schema
they're different) after filtering within a Parquet row group. The latter
returns the first row index within a single data page. With the combination of
the two, one is able to know which rows/values [...]
### Does this PR introduce _any_ user-facing change?
Yes. Now the vectorized Parquet reader should work correctly with column
index.
### How was this patch tested?
Borrowed tests from #31998 and added a few more tests.
Closes #32753 from sunchao/SPARK-34859.
Lead-authored-by: Chao Sun <[email protected]>
Co-authored-by: Li Xian <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../datasources/parquet/ParquetReadState.java | 120 ++++++++++-
.../datasources/parquet/ParquetVectorUpdater.java | 12 +-
.../parquet/ParquetVectorUpdaterFactory.java | 216 ++++++++++++++-----
.../parquet/VectorizedColumnReader.java | 28 ++-
.../parquet/VectorizedParquetRecordReader.java | 1 +
.../parquet/VectorizedPlainValuesReader.java | 51 +++++
.../parquet/VectorizedRleValuesReader.java | 239 ++++++++++++++++-----
.../parquet/VectorizedValuesReader.java | 13 ++
.../parquet/ParquetColumnIndexSuite.scala | 126 +++++++++++
.../datasources/parquet/ParquetIOSuite.scala | 72 ++++++-
10 files changed, 746 insertions(+), 132 deletions(-)
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java
index 28dcc44..b260887 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java
@@ -17,13 +17,38 @@
package org.apache.spark.sql.execution.datasources.parquet;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.PrimitiveIterator;
+
/**
* Helper class to store intermediate state while reading a Parquet column
chunk.
*/
final class ParquetReadState {
- /** Maximum definition level */
+ /** A special row range used when there is no row indexes (hence all rows
must be included) */
+ private static final RowRange MAX_ROW_RANGE = new RowRange(Long.MIN_VALUE,
Long.MAX_VALUE);
+
+ /**
+ * A special row range used when the row indexes are present AND all the row
ranges have been
+ * processed. This serves as a sentinel at the end indicating that all rows
come after the last
+ * row range should be skipped.
+ */
+ private static final RowRange END_ROW_RANGE = new RowRange(Long.MAX_VALUE,
Long.MIN_VALUE);
+
+ /** Iterator over all row ranges, only not-null if column index is present */
+ private final Iterator<RowRange> rowRanges;
+
+ /** The current row range */
+ private RowRange currentRange;
+
+ /** Maximum definition level for the Parquet column */
final int maxDefinitionLevel;
+ /** The current index over all rows within the column chunk. This is used to
check if the
+ * current row should be skipped by comparing against the row ranges. */
+ long rowId;
+
/** The offset in the current batch to put the next value */
int offset;
@@ -33,31 +58,108 @@ final class ParquetReadState {
/** The remaining number of values to read in the current batch */
int valuesToReadInBatch;
- ParquetReadState(int maxDefinitionLevel) {
+ ParquetReadState(int maxDefinitionLevel, PrimitiveIterator.OfLong
rowIndexes) {
this.maxDefinitionLevel = maxDefinitionLevel;
+ this.rowRanges = constructRanges(rowIndexes);
+ nextRange();
}
/**
- * Called at the beginning of reading a new batch.
+ * Construct a list of row ranges from the given `rowIndexes`. For example,
suppose the
+ * `rowIndexes` are `[0, 1, 2, 4, 5, 7, 8, 9]`, it will be converted into 3
row ranges:
+ * `[0-2], [4-5], [7-9]`.
*/
- void resetForBatch(int batchSize) {
+ private Iterator<RowRange> constructRanges(PrimitiveIterator.OfLong
rowIndexes) {
+ if (rowIndexes == null) {
+ return null;
+ }
+
+ List<RowRange> rowRanges = new ArrayList<>();
+ long currentStart = Long.MIN_VALUE;
+ long previous = Long.MIN_VALUE;
+
+ while (rowIndexes.hasNext()) {
+ long idx = rowIndexes.nextLong();
+ if (currentStart == Long.MIN_VALUE) {
+ currentStart = idx;
+ } else if (previous + 1 != idx) {
+ RowRange range = new RowRange(currentStart, previous);
+ rowRanges.add(range);
+ currentStart = idx;
+ }
+ previous = idx;
+ }
+
+ if (previous != Long.MIN_VALUE) {
+ rowRanges.add(new RowRange(currentStart, previous));
+ }
+
+ return rowRanges.iterator();
+ }
+
+ /**
+ * Must be called at the beginning of reading a new batch.
+ */
+ void resetForNewBatch(int batchSize) {
this.offset = 0;
this.valuesToReadInBatch = batchSize;
}
/**
- * Called at the beginning of reading a new page.
+ * Must be called at the beginning of reading a new page.
*/
- void resetForPage(int totalValuesInPage) {
+ void resetForNewPage(int totalValuesInPage, long pageFirstRowIndex) {
this.valuesToReadInPage = totalValuesInPage;
+ this.rowId = pageFirstRowIndex;
}
/**
- * Advance the current offset to the new values.
+ * Returns the start index of the current row range.
*/
- void advanceOffset(int newOffset) {
+ long currentRangeStart() {
+ return currentRange.start;
+ }
+
+ /**
+ * Returns the end index of the current row range.
+ */
+ long currentRangeEnd() {
+ return currentRange.end;
+ }
+
+ /**
+ * Advance the current offset and rowId to the new values.
+ */
+ void advanceOffsetAndRowId(int newOffset, long newRowId) {
valuesToReadInBatch -= (newOffset - offset);
- valuesToReadInPage -= (newOffset - offset);
+ valuesToReadInPage -= (newRowId - rowId);
offset = newOffset;
+ rowId = newRowId;
+ }
+
+ /**
+ * Advance to the next range.
+ */
+ void nextRange() {
+ if (rowRanges == null) {
+ currentRange = MAX_ROW_RANGE;
+ } else if (!rowRanges.hasNext()) {
+ currentRange = END_ROW_RANGE;
+ } else {
+ currentRange = rowRanges.next();
+ }
+ }
+
+ /**
+ * Helper struct to represent a range of row indexes `[start, end]`.
+ */
+ private static class RowRange {
+ final long start;
+ final long end;
+
+ RowRange(long start, long end) {
+ this.start = start;
+ this.end = end;
+ }
}
}
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdater.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdater.java
index b91d507..9bb8529 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdater.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdater.java
@@ -30,20 +30,28 @@ public interface ParquetVectorUpdater {
* @param values destination values vector
* @param valuesReader reader to read values from
*/
- void updateBatch(
+ void readValues(
int total,
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader);
/**
+ * Skip a batch of `total` values from `valuesReader`.
+ *
+ * @param total total number of values to skip
+ * @param valuesReader reader to skip values from
+ */
+ void skipValues(int total, VectorizedValuesReader valuesReader);
+
+ /**
* Read a single value from `valuesReader` into `values`, at `offset`.
*
* @param offset offset in `values` to put the new value
* @param values destination value vector
* @param valuesReader reader to read values from
*/
- void update(int offset, WritableColumnVector values, VectorizedValuesReader
valuesReader);
+ void readValue(int offset, WritableColumnVector values,
VectorizedValuesReader valuesReader);
/**
* Process a batch of `total` values starting from `offset` in `values`,
whose null slots
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
index 62e34fe..2282dc7 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java
@@ -185,7 +185,7 @@ public class ParquetVectorUpdaterFactory {
private static class BooleanUpdater implements ParquetVectorUpdater {
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
@@ -194,7 +194,12 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipBooleans(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -213,7 +218,7 @@ public class ParquetVectorUpdaterFactory {
private static class IntegerUpdater implements ParquetVectorUpdater {
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
@@ -222,7 +227,12 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipIntegers(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -241,7 +251,7 @@ public class ParquetVectorUpdaterFactory {
private static class UnsignedIntegerUpdater implements ParquetVectorUpdater {
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
@@ -250,7 +260,12 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipIntegers(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -270,7 +285,7 @@ public class ParquetVectorUpdaterFactory {
private static class ByteUpdater implements ParquetVectorUpdater {
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
@@ -279,7 +294,12 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipBytes(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -298,7 +318,7 @@ public class ParquetVectorUpdaterFactory {
private static class ShortUpdater implements ParquetVectorUpdater {
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
@@ -307,7 +327,12 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipShorts(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -332,7 +357,7 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
@@ -341,7 +366,12 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipIntegers(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -362,7 +392,7 @@ public class ParquetVectorUpdaterFactory {
private static class LongUpdater implements ParquetVectorUpdater {
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
@@ -371,7 +401,12 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipLongs(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -390,7 +425,7 @@ public class ParquetVectorUpdaterFactory {
private static class DowncastLongUpdater implements ParquetVectorUpdater {
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
@@ -401,7 +436,12 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipLongs(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -420,7 +460,7 @@ public class ParquetVectorUpdaterFactory {
private static class UnsignedLongUpdater implements ParquetVectorUpdater {
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
@@ -429,7 +469,12 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipLongs(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -457,7 +502,7 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
@@ -466,7 +511,12 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipLongs(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -487,18 +537,23 @@ public class ParquetVectorUpdaterFactory {
private static class LongAsMicrosUpdater implements ParquetVectorUpdater {
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
for (int i = 0; i < total; ++i) {
- update(offset + i, values, valuesReader);
+ readValue(offset + i, values, valuesReader);
}
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipLongs(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -524,18 +579,23 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
for (int i = 0; i < total; ++i) {
- update(offset + i, values, valuesReader);
+ readValue(offset + i, values, valuesReader);
}
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipLongs(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -557,7 +617,7 @@ public class ParquetVectorUpdaterFactory {
private static class FloatUpdater implements ParquetVectorUpdater {
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
@@ -566,7 +626,12 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipFloats(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -585,7 +650,7 @@ public class ParquetVectorUpdaterFactory {
private static class DoubleUpdater implements ParquetVectorUpdater {
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
@@ -594,7 +659,12 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipDoubles(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -613,7 +683,7 @@ public class ParquetVectorUpdaterFactory {
private static class BinaryUpdater implements ParquetVectorUpdater {
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
@@ -622,7 +692,12 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipBinary(total);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -642,18 +717,23 @@ public class ParquetVectorUpdaterFactory {
private static class BinaryToSQLTimestampUpdater implements
ParquetVectorUpdater {
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
for (int i = 0; i < total; i++) {
- update(offset + i, values, valuesReader);
+ readValue(offset + i, values, valuesReader);
}
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipFixedLenByteArray(total, 12);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -681,18 +761,23 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
for (int i = 0; i < total; i++) {
- update(offset + i, values, valuesReader);
+ readValue(offset + i, values, valuesReader);
}
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipFixedLenByteArray(total, 12);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -723,18 +808,23 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
for (int i = 0; i < total; i++) {
- update(offset + i, values, valuesReader);
+ readValue(offset + i, values, valuesReader);
}
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipFixedLenByteArray(total, 12);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -767,18 +857,23 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
for (int i = 0; i < total; i++) {
- update(offset + i, values, valuesReader);
+ readValue(offset + i, values, valuesReader);
}
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipFixedLenByteArray(total, 12);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -811,18 +906,23 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
for (int i = 0; i < total; i++) {
- update(offset + i, values, valuesReader);
+ readValue(offset + i, values, valuesReader);
}
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipFixedLenByteArray(total, arrayLen);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -848,18 +948,23 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
for (int i = 0; i < total; i++) {
- update(offset + i, values, valuesReader);
+ readValue(offset + i, values, valuesReader);
}
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipFixedLenByteArray(total, arrayLen);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
@@ -886,18 +991,23 @@ public class ParquetVectorUpdaterFactory {
}
@Override
- public void updateBatch(
+ public void readValues(
int total,
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
for (int i = 0; i < total; i++) {
- update(offset + i, values, valuesReader);
+ readValue(offset + i, values, valuesReader);
}
}
@Override
- public void update(
+ public void skipValues(int total, VectorizedValuesReader valuesReader) {
+ valuesReader.skipFixedLenByteArray(total, arrayLen);
+ }
+
+ @Override
+ public void readValue(
int offset,
WritableColumnVector values,
VectorizedValuesReader valuesReader) {
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
index c61ee46..92dea08 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet;
import java.io.IOException;
import java.time.ZoneId;
+import java.util.PrimitiveIterator;
import org.apache.parquet.bytes.ByteBufferInputStream;
import org.apache.parquet.bytes.BytesInput;
@@ -74,6 +75,12 @@ public class VectorizedColumnReader {
*/
private final ParquetReadState readState;
+ /**
+ * The index for the first row in the current page, among all rows across
all pages in the
+ * column chunk for this reader. If there is no column index, the value is 0.
+ */
+ private long pageFirstRowIndex;
+
private final PageReader pageReader;
private final ColumnDescriptor descriptor;
private final LogicalTypeAnnotation logicalTypeAnnotation;
@@ -83,12 +90,13 @@ public class VectorizedColumnReader {
ColumnDescriptor descriptor,
LogicalTypeAnnotation logicalTypeAnnotation,
PageReader pageReader,
+ PrimitiveIterator.OfLong rowIndexes,
ZoneId convertTz,
String datetimeRebaseMode,
String int96RebaseMode) throws IOException {
this.descriptor = descriptor;
this.pageReader = pageReader;
- this.readState = new ParquetReadState(descriptor.getMaxDefinitionLevel());
+ this.readState = new ParquetReadState(descriptor.getMaxDefinitionLevel(),
rowIndexes);
this.logicalTypeAnnotation = logicalTypeAnnotation;
this.updaterFactory = new ParquetVectorUpdaterFactory(
logicalTypeAnnotation, convertTz, datetimeRebaseMode, int96RebaseMode);
@@ -151,18 +159,19 @@ public class VectorizedColumnReader {
// page.
dictionaryIds = column.reserveDictionaryIds(total);
}
- readState.resetForBatch(total);
+ readState.resetForNewBatch(total);
while (readState.valuesToReadInBatch > 0) {
- // Compute the number of values we want to read in this page.
if (readState.valuesToReadInPage == 0) {
int pageValueCount = readPage();
- readState.resetForPage(pageValueCount);
+ readState.resetForNewPage(pageValueCount, pageFirstRowIndex);
}
PrimitiveType.PrimitiveTypeName typeName =
descriptor.getPrimitiveType().getPrimitiveTypeName();
if (isCurrentPageDictionaryEncoded) {
// Save starting offset in case we need to decode dictionary IDs.
int startOffset = readState.offset;
+ // Save starting row index so we can check if we need to eagerly
decode dict ids later
+ long startRowId = readState.rowId;
// Read and decode dictionary ids.
defColumn.readIntegers(readState, dictionaryIds, column,
@@ -170,10 +179,12 @@ public class VectorizedColumnReader {
// TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we
need to post process
// the values to add microseconds precision.
- if (column.hasDictionary() || (startOffset == 0 &&
isLazyDecodingSupported(typeName))) {
+ if (column.hasDictionary() || (startRowId == pageFirstRowIndex &&
+ isLazyDecodingSupported(typeName))) {
// Column vector supports lazy decoding of dictionary values so just
set the dictionary.
- // We can't do this if rowId != 0 AND the column doesn't have a
dictionary (i.e. some
- // non-dictionary encoded values have already been added).
+ // We can't do this if startRowId is not the first row index in the
page AND the column
+ // doesn't have a dictionary (i.e. some non-dictionary encoded
values have already been
+ // added).
PrimitiveType primitiveType = descriptor.getPrimitiveType();
// We need to make sure that we initialize the right type for the
dictionary otherwise
@@ -213,6 +224,8 @@ public class VectorizedColumnReader {
private int readPage() {
DataPage page = pageReader.readPage();
+ this.pageFirstRowIndex = page.getFirstRowIndex().orElse(0L);
+
return page.accept(new DataPage.Visitor<Integer>() {
@Override
public Integer visit(DataPageV1 dataPageV1) {
@@ -268,7 +281,6 @@ public class VectorizedColumnReader {
}
private int readPageV1(DataPageV1 page) throws IOException {
- // Initialize the decoders.
if (page.getDlEncoding() != Encoding.RLE &&
descriptor.getMaxDefinitionLevel() != 0) {
throw new UnsupportedOperationException("Unsupported encoding: " +
page.getDlEncoding());
}
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
index 3245527..9f7836a 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
@@ -334,6 +334,7 @@ public class VectorizedParquetRecordReader extends
SpecificParquetRecordReaderBa
columns.get(i),
types.get(i).getLogicalTypeAnnotation(),
pages.getPageReader(columns.get(i)),
+ pages.getRowIndexes().orElse(null),
convertTz,
datetimeRebaseMode,
int96RebaseMode);
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
index 6a0038d..39591be 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java
@@ -61,6 +61,14 @@ public class VectorizedPlainValuesReader extends
ValuesReader implements Vectori
}
}
+ @Override
+ public final void skipBooleans(int total) {
+ // TODO: properly vectorize this
+ for (int i = 0; i < total; i++) {
+ readBoolean();
+ }
+ }
+
private ByteBuffer getBuffer(int length) {
try {
return in.slice(length).order(ByteOrder.LITTLE_ENDIAN);
@@ -85,6 +93,11 @@ public class VectorizedPlainValuesReader extends
ValuesReader implements Vectori
}
@Override
+ public void skipIntegers(int total) {
+ in.skip(total * 4L);
+ }
+
+ @Override
public final void readUnsignedIntegers(int total, WritableColumnVector c,
int rowId) {
int requiredBytes = total * 4;
ByteBuffer buffer = getBuffer(requiredBytes);
@@ -141,6 +154,11 @@ public class VectorizedPlainValuesReader extends
ValuesReader implements Vectori
}
@Override
+ public void skipLongs(int total) {
+ in.skip(total * 8L);
+ }
+
+ @Override
public final void readUnsignedLongs(int total, WritableColumnVector c, int
rowId) {
int requiredBytes = total * 8;
ByteBuffer buffer = getBuffer(requiredBytes);
@@ -198,6 +216,11 @@ public class VectorizedPlainValuesReader extends
ValuesReader implements Vectori
}
@Override
+ public void skipFloats(int total) {
+ in.skip(total * 4L);
+ }
+
+ @Override
public final void readDoubles(int total, WritableColumnVector c, int rowId) {
int requiredBytes = total * 8;
ByteBuffer buffer = getBuffer(requiredBytes);
@@ -213,6 +236,11 @@ public class VectorizedPlainValuesReader extends
ValuesReader implements Vectori
}
@Override
+ public void skipDoubles(int total) {
+ in.skip(total * 8L);
+ }
+
+ @Override
public final void readBytes(int total, WritableColumnVector c, int rowId) {
// Bytes are stored as a 4-byte little endian int. Just read the first
byte.
// TODO: consider pushing this in ColumnVector by adding a readBytes with
a stride.
@@ -227,6 +255,11 @@ public class VectorizedPlainValuesReader extends
ValuesReader implements Vectori
}
@Override
+ public final void skipBytes(int total) {
+ in.skip(total * 4L);
+ }
+
+ @Override
public final void readShorts(int total, WritableColumnVector c, int rowId) {
int requiredBytes = total * 4;
ByteBuffer buffer = getBuffer(requiredBytes);
@@ -237,6 +270,11 @@ public class VectorizedPlainValuesReader extends
ValuesReader implements Vectori
}
@Override
+ public void skipShorts(int total) {
+ in.skip(total * 4L);
+ }
+
+ @Override
public final boolean readBoolean() {
// TODO: vectorize decoding and keep boolean[] instead of currentByte
if (bitOffset == 0) {
@@ -301,6 +339,14 @@ public class VectorizedPlainValuesReader extends
ValuesReader implements Vectori
}
@Override
+ public void skipBinary(int total) {
+ for (int i = 0; i < total; i++) {
+ int len = readInteger();
+ in.skip(len);
+ }
+ }
+
+ @Override
public final Binary readBinary(int len) {
ByteBuffer buffer = getBuffer(len);
if (buffer.hasArray()) {
@@ -312,4 +358,9 @@ public class VectorizedPlainValuesReader extends
ValuesReader implements Vectori
return Binary.fromConstantByteArray(bytes);
}
}
+
+ @Override
+ public void skipFixedLenByteArray(int total, int len) {
+ in.skip(total * (long) len);
+ }
}
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
index 538b698..03bda0f 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
@@ -156,18 +156,12 @@ public final class VectorizedRleValuesReader extends
ValuesReader
}
/**
- * Reads `total` ints into `c` filling them in starting at `c[rowId]`. This
reader
- * reads the definition levels and then will read from `data` for the
non-null values.
- * If the value is null, c will be populated with `nullValue`. Note that
`nullValue` is only
- * necessary for readIntegers because we also use it to decode dictionaryIds
and want to make
- * sure it always has a value in range.
- *
- * This is a batched version of this logic:
- * if (this.readInt() == level) {
- * c[rowId] = data.readInteger();
- * } else {
- * c[rowId] = null;
- * }
+ * Reads a batch of values into vector `values`, using `valueReader`. The
related states such
+ * as row index, offset, number of values left in the batch and page, etc,
are tracked by
+ * `state`. The type-specific `updater` is used to update or skip values.
+ * <p>
+ * This reader reads the definition levels and then will read from
`valueReader` for the
+ * non-null values. If the value is null, `values` will be populated with
null value.
*/
public void readBatch(
ParquetReadState state,
@@ -175,36 +169,68 @@ public final class VectorizedRleValuesReader extends
ValuesReader
VectorizedValuesReader valueReader,
ParquetVectorUpdater updater) throws IOException {
int offset = state.offset;
- int left = Math.min(state.valuesToReadInBatch, state.valuesToReadInPage);
+ long rowId = state.rowId;
+ int leftInBatch = state.valuesToReadInBatch;
+ int leftInPage = state.valuesToReadInPage;
- while (left > 0) {
+ while (leftInBatch > 0 && leftInPage > 0) {
if (this.currentCount == 0) this.readNextGroup();
- int n = Math.min(left, this.currentCount);
-
- switch (mode) {
- case RLE:
- if (currentValue == state.maxDefinitionLevel) {
- updater.updateBatch(n, offset, values, valueReader);
- } else {
- values.putNulls(offset, n);
- }
- break;
- case PACKED:
- for (int i = 0; i < n; ++i) {
- if (currentBuffer[currentBufferIdx++] == state.maxDefinitionLevel)
{
- updater.update(offset + i, values, valueReader);
+ int n = Math.min(leftInBatch, Math.min(leftInPage, this.currentCount));
+
+ long rangeStart = state.currentRangeStart();
+ long rangeEnd = state.currentRangeEnd();
+
+ if (rowId + n < rangeStart) {
+ updater.skipValues(n, valueReader);
+ advance(n);
+ rowId += n;
+ leftInPage -= n;
+ } else if (rowId > rangeEnd) {
+ state.nextRange();
+ } else {
+ // the range [rowId, rowId + n) overlaps with the current row range in
state
+ long start = Math.max(rangeStart, rowId);
+ long end = Math.min(rangeEnd, rowId + n - 1);
+
+ // skip the part [rowId, start)
+ int toSkip = (int) (start - rowId);
+ if (toSkip > 0) {
+ updater.skipValues(toSkip, valueReader);
+ advance(toSkip);
+ rowId += toSkip;
+ leftInPage -= toSkip;
+ }
+
+ // read the part [start, end]
+ n = (int) (end - start + 1);
+
+ switch (mode) {
+ case RLE:
+ if (currentValue == state.maxDefinitionLevel) {
+ updater.readValues(n, offset, values, valueReader);
} else {
- values.putNull(offset + i);
+ values.putNulls(offset, n);
}
- }
- break;
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] ==
state.maxDefinitionLevel) {
+ updater.readValue(offset + i, values, valueReader);
+ } else {
+ values.putNull(offset + i);
+ }
+ }
+ break;
+ }
+ offset += n;
+ leftInBatch -= n;
+ rowId += n;
+ leftInPage -= n;
+ currentCount -= n;
}
- offset += n;
- left -= n;
- currentCount -= n;
}
- state.advanceOffset(offset);
+ state.advanceOffsetAndRowId(offset, rowId);
}
/**
@@ -217,36 +243,68 @@ public final class VectorizedRleValuesReader extends
ValuesReader
WritableColumnVector nulls,
VectorizedValuesReader data) throws IOException {
int offset = state.offset;
- int left = Math.min(state.valuesToReadInBatch, state.valuesToReadInPage);
+ long rowId = state.rowId;
+ int leftInBatch = state.valuesToReadInBatch;
+ int leftInPage = state.valuesToReadInPage;
- while (left > 0) {
+ while (leftInBatch > 0 && leftInPage > 0) {
if (this.currentCount == 0) this.readNextGroup();
- int n = Math.min(left, this.currentCount);
-
- switch (mode) {
- case RLE:
- if (currentValue == state.maxDefinitionLevel) {
- data.readIntegers(n, values, offset);
- } else {
- nulls.putNulls(offset, n);
- }
- break;
- case PACKED:
- for (int i = 0; i < n; ++i) {
- if (currentBuffer[currentBufferIdx++] == state.maxDefinitionLevel)
{
- values.putInt(offset + i, data.readInteger());
+ int n = Math.min(leftInBatch, Math.min(leftInPage, this.currentCount));
+
+ long rangeStart = state.currentRangeStart();
+ long rangeEnd = state.currentRangeEnd();
+
+ if (rowId + n < rangeStart) {
+ data.skipIntegers(n);
+ advance(n);
+ rowId += n;
+ leftInPage -= n;
+ } else if (rowId > rangeEnd) {
+ state.nextRange();
+ } else {
+ // the range [rowId, rowId + n) overlaps with the current row range in
state
+ long start = Math.max(rangeStart, rowId);
+ long end = Math.min(rangeEnd, rowId + n - 1);
+
+ // skip the part [rowId, start)
+ int toSkip = (int) (start - rowId);
+ if (toSkip > 0) {
+ data.skipIntegers(toSkip);
+ advance(toSkip);
+ rowId += toSkip;
+ leftInPage -= toSkip;
+ }
+
+ // read the part [start, end]
+ n = (int) (end - start + 1);
+
+ switch (mode) {
+ case RLE:
+ if (currentValue == state.maxDefinitionLevel) {
+ data.readIntegers(n, values, offset);
} else {
- nulls.putNull(offset + i);
+ nulls.putNulls(offset, n);
}
- }
- break;
+ break;
+ case PACKED:
+ for (int i = 0; i < n; ++i) {
+ if (currentBuffer[currentBufferIdx++] ==
state.maxDefinitionLevel) {
+ values.putInt(offset + i, data.readInteger());
+ } else {
+ nulls.putNull(offset + i);
+ }
+ }
+ break;
+ }
+ rowId += n;
+ leftInPage -= n;
+ offset += n;
+ leftInBatch -= n;
+ currentCount -= n;
}
- offset += n;
- left -= n;
- currentCount -= n;
}
- state.advanceOffset(offset);
+ state.advanceOffsetAndRowId(offset, rowId);
}
@@ -346,6 +404,71 @@ public final class VectorizedRleValuesReader extends
ValuesReader
throw new UnsupportedOperationException("only readInts is valid.");
}
+ @Override
+ public void skipIntegers(int total) {
+ int left = total;
+ while (left > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int n = Math.min(left, this.currentCount);
+ advance(n);
+ left -= n;
+ }
+ }
+
+ @Override
+ public void skipBooleans(int total) {
+ throw new UnsupportedOperationException("only skipIntegers is valid");
+ }
+
+ @Override
+ public void skipBytes(int total) {
+ throw new UnsupportedOperationException("only skipIntegers is valid");
+ }
+
+ @Override
+ public void skipShorts(int total) {
+ throw new UnsupportedOperationException("only skipIntegers is valid");
+ }
+
+ @Override
+ public void skipLongs(int total) {
+ throw new UnsupportedOperationException("only skipIntegers is valid");
+ }
+
+ @Override
+ public void skipFloats(int total) {
+ throw new UnsupportedOperationException("only skipIntegers is valid");
+ }
+
+ @Override
+ public void skipDoubles(int total) {
+ throw new UnsupportedOperationException("only skipIntegers is valid");
+ }
+
+ @Override
+ public void skipBinary(int total) {
+ throw new UnsupportedOperationException("only skipIntegers is valid");
+ }
+
+ @Override
+ public void skipFixedLenByteArray(int total, int len) {
+ throw new UnsupportedOperationException("only skipIntegers is valid");
+ }
+
+ /**
+ * Advance and skip the next `n` values in the current block. `n` MUST be <=
`currentCount`.
+ */
+ private void advance(int n) {
+ switch (mode) {
+ case RLE:
+ break;
+ case PACKED:
+ currentBufferIdx += n;
+ break;
+ }
+ currentCount -= n;
+ }
+
/**
* Reads the next varint encoded int.
*/
diff --git
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
index a2d663f..fc4eac9 100644
---
a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
+++
b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java
@@ -50,4 +50,17 @@ public interface VectorizedValuesReader {
void readFloats(int total, WritableColumnVector c, int rowId);
void readDoubles(int total, WritableColumnVector c, int rowId);
void readBinary(int total, WritableColumnVector c, int rowId);
+
+ /*
+ * Skips `total` values
+ */
+ void skipBooleans(int total);
+ void skipBytes(int total);
+ void skipShorts(int total);
+ void skipIntegers(int total);
+ void skipLongs(int total);
+ void skipFloats(int total);
+ void skipDoubles(int total);
+ void skipBinary(int total);
+ void skipFixedLenByteArray(int total, int len);
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala
new file mode 100644
index 0000000..f10b701
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.test.SharedSparkSession
+
+class ParquetColumnIndexSuite extends QueryTest with ParquetTest with
SharedSparkSession {
+ import testImplicits._
+
+ /**
+ * create parquet file with two columns and unaligned pages
+ * pages will be of the following layout
+ * col_1 500 500 500 500
+ * |---------|---------|---------|---------|
+ * |-------|-----|-----|---|---|---|---|---|
+ * col_2 400 300 200 200 200 200 200 200
+ */
+ def checkUnalignedPages(actions: (DataFrame => DataFrame)*): Unit = {
+ withTempPath(file => {
+ val ds = spark.range(0, 2000).map(i => (i, i + ":" + "o" * (i /
100).toInt))
+ ds.coalesce(1)
+ .write
+ .option("parquet.page.size", "4096")
+ .parquet(file.getCanonicalPath)
+
+ val parquetDf = spark.read.parquet(file.getCanonicalPath)
+
+ actions.foreach { action =>
+ checkAnswer(action(parquetDf), action(ds.toDF()))
+ }
+ })
+ }
+
+ test("reading from unaligned pages - test filters") {
+ checkUnalignedPages(
+ // single value filter
+ df => df.filter("_1 = 500"),
+ df => df.filter("_1 = 500 or _1 = 1500"),
+ df => df.filter("_1 = 500 or _1 = 501 or _1 = 1500"),
+ df => df.filter("_1 = 500 or _1 = 501 or _1 = 1000 or _1 = 1500"),
+ // range filter
+ df => df.filter("_1 >= 500 and _1 < 1000"),
+ df => df.filter("(_1 >= 500 and _1 < 1000) or (_1 >= 1500 and _1 <
1600)")
+ )
+ }
+
+ test("test reading unaligned pages - test all types") {
+ withTempPath(file => {
+ val df = spark.range(0, 2000).selectExpr(
+ "id as _1",
+ "cast(id as short) as _3",
+ "cast(id as int) as _4",
+ "cast(id as float) as _5",
+ "cast(id as double) as _6",
+ "cast(id as decimal(20,0)) as _7",
+ "cast(cast(1618161925000 + id * 1000 * 60 * 60 * 24 as timestamp) as
date) as _9",
+ "cast(1618161925000 + id as timestamp) as _10"
+ )
+ df.coalesce(1)
+ .write
+ .option("parquet.page.size", "4096")
+ .parquet(file.getCanonicalPath)
+
+ val parquetDf = spark.read.parquet(file.getCanonicalPath)
+ val singleValueFilterExpr = "_1 = 500 or _1 = 1500"
+ checkAnswer(
+ parquetDf.filter(singleValueFilterExpr),
+ df.filter(singleValueFilterExpr)
+ )
+ val rangeFilterExpr = "_1 > 500 "
+ checkAnswer(
+ parquetDf.filter(rangeFilterExpr),
+ df.filter(rangeFilterExpr)
+ )
+ })
+ }
+
+ test("test reading unaligned pages - test all types (dict encode)") {
+ withTempPath(file => {
+ val df = spark.range(0, 2000).selectExpr(
+ "id as _1",
+ "cast(id % 10 as byte) as _2",
+ "cast(id % 10 as short) as _3",
+ "cast(id % 10 as int) as _4",
+ "cast(id % 10 as float) as _5",
+ "cast(id % 10 as double) as _6",
+ "cast(id % 10 as decimal(20,0)) as _7",
+ "cast(id % 2 as boolean) as _8",
+ "cast(cast(1618161925000 + (id % 10) * 1000 * 60 * 60 * 24 as
timestamp) as date) as _9",
+ "cast(1618161925000 + (id % 10) as timestamp) as _10"
+ )
+ df.coalesce(1)
+ .write
+ .option("parquet.page.size", "4096")
+ .parquet(file.getCanonicalPath)
+
+ val parquetDf = spark.read.parquet(file.getCanonicalPath)
+ val singleValueFilterExpr = "_1 = 500 or _1 = 1500"
+ checkAnswer(
+ parquetDf.filter(singleValueFilterExpr),
+ df.filter(singleValueFilterExpr)
+ )
+ val rangeFilterExpr = "_1 > 500"
+ checkAnswer(
+ parquetDf.filter(rangeFilterExpr),
+ df.filter(rangeFilterExpr)
+ )
+ })
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index bc4234f..a330b82 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -368,7 +368,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest
with SharedSparkSession
private def createParquetWriter(
schema: MessageType,
path: Path,
- dictionaryEnabled: Boolean = false): ParquetWriter[Group] = {
+ dictionaryEnabled: Boolean = false,
+ pageSize: Int = 1024,
+ dictionaryPageSize: Int = 1024): ParquetWriter[Group] = {
val hadoopConf = spark.sessionState.newHadoopConf()
ExampleParquetWriter
@@ -378,11 +380,77 @@ class ParquetIOSuite extends QueryTest with ParquetTest
with SharedSparkSession
.withWriterVersion(PARQUET_1_0)
.withCompressionCodec(GZIP)
.withRowGroupSize(1024 * 1024)
- .withPageSize(1024)
+ .withPageSize(pageSize)
+ .withDictionaryPageSize(dictionaryPageSize)
.withConf(hadoopConf)
.build()
}
+ test("SPARK-34859: test multiple pages with different sizes and nulls") {
+ def makeRawParquetFile(
+ path: Path,
+ dictionaryEnabled: Boolean,
+ n: Int,
+ pageSize: Int): Seq[Option[Int]] = {
+ val schemaStr =
+ """
+ |message root {
+ | optional boolean _1;
+ | optional int32 _2;
+ | optional int64 _3;
+ | optional float _4;
+ | optional double _5;
+ |}
+ """.stripMargin
+
+ val schema = MessageTypeParser.parseMessageType(schemaStr)
+ val writer = createParquetWriter(schema, path,
+ dictionaryEnabled = dictionaryEnabled, pageSize = pageSize,
dictionaryPageSize = pageSize)
+
+ val rand = scala.util.Random
+ val expected = (0 until n).map { i =>
+ if (rand.nextBoolean()) {
+ None
+ } else {
+ Some(i)
+ }
+ }
+ expected.foreach { opt =>
+ val record = new SimpleGroup(schema)
+ opt match {
+ case Some(i) =>
+ record.add(0, i % 2 == 0)
+ record.add(1, i)
+ record.add(2, i.toLong)
+ record.add(3, i.toFloat)
+ record.add(4, i.toDouble)
+ case _ =>
+ }
+ writer.write(record)
+ }
+
+ writer.close()
+ expected
+ }
+
+ Seq(true, false).foreach { dictionaryEnabled =>
+ Seq(64, 128, 89).foreach { pageSize =>
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "part-r-0.parquet")
+ val expected = makeRawParquetFile(path, dictionaryEnabled, 1000,
pageSize)
+ readParquetFile(path.toString) { df =>
+ checkAnswer(df, expected.map {
+ case None =>
+ Row(null, null, null, null, null)
+ case Some(i) =>
+ Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
+ })
+ }
+ }
+ }
+ }
+ }
+
test("read raw Parquet file") {
def makeRawParquetFile(path: Path): Unit = {
val schemaStr =
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]