This is an automated email from the ASF dual-hosted git repository.
gangwu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/orc.git
The following commit(s) were added to refs/heads/main by this push:
new 511c8c194 ORC-1385: [C++] Support schema evolution of numeric types
511c8c194 is described below
commit 511c8c19497cb70499353a59b6484a0e6a82a539
Author: ffacs <[email protected]>
AuthorDate: Fri Apr 14 09:48:51 2023 +0800
ORC-1385: [C++] Support schema evolution of numeric types
### What changes were proposed in this pull request?
support schema evolution converting between types within {boolean, tinyint,
smallint, int, bigint, float, double}
### Why are the changes needed?
To support schema evolution in c++
### How was this patch tested?
UT passed
This closes #1454
---
c++/include/orc/Reader.hh | 20 ++
c++/src/CMakeLists.txt | 1 +
c++/src/ColumnReader.cc | 61 +++--
c++/src/ColumnReader.hh | 11 +-
c++/src/ConvertColumnReader.cc | 446 ++++++++++++++++++++++++++++++++++++
c++/src/ConvertColumnReader.hh | 53 +++++
c++/src/Options.hh | 21 ++
c++/src/Reader.cc | 35 ++-
c++/src/Reader.hh | 8 +
c++/src/SchemaEvolution.cc | 68 ++----
c++/src/StripeStream.cc | 4 +
c++/src/StripeStream.hh | 2 +
c++/src/TypeImpl.cc | 6 +-
c++/test/CMakeLists.txt | 3 +
c++/test/MockStripeStreams.cc | 44 ++++
c++/test/MockStripeStreams.hh | 56 +++++
c++/test/TestColumnReader.cc | 51 +----
c++/test/TestConvertColumnReader.cc | 149 ++++++++++++
c++/test/TestSchemaEvolution.cc | 99 ++++++++
19 files changed, 1010 insertions(+), 128 deletions(-)
diff --git a/c++/include/orc/Reader.hh b/c++/include/orc/Reader.hh
index d8f83f94a..b631c2c6e 100644
--- a/c++/include/orc/Reader.hh
+++ b/c++/include/orc/Reader.hh
@@ -336,6 +336,26 @@ namespace orc {
* @return if not set, the default is false
*/
bool getUseTightNumericVector() const;
+
+ /**
+ * Set read type for schema evolution
+ */
+ RowReaderOptions& setReadType(std::shared_ptr<Type> type);
+
+ /**
+ * Get read type for schema evolution
+ */
+ std::shared_ptr<Type>& getReadType() const;
+
+ /**
+ * Set whether reader throws or returns null when value overflows for
schema evolution.
+ */
+ RowReaderOptions& throwOnSchemaEvolutionOverflow(bool shouldThrow);
+
+ /**
+ * Whether reader throws or returns null when value overflows for schema
evolution.
+ */
+ bool getThrowOnSchemaEvolutionOverflow() const;
};
class RowReader;
diff --git a/c++/src/CMakeLists.txt b/c++/src/CMakeLists.txt
index b9904160b..16b5549b9 100644
--- a/c++/src/CMakeLists.txt
+++ b/c++/src/CMakeLists.txt
@@ -172,6 +172,7 @@ set(SOURCE_FILES
ColumnWriter.cc
Common.cc
Compression.cc
+ ConvertColumnReader.cc
Exceptions.cc
Int128.cc
LzoDecompressor.cc
diff --git a/c++/src/ColumnReader.cc b/c++/src/ColumnReader.cc
index 2a72c80ee..3552acf39 100644
--- a/c++/src/ColumnReader.cc
+++ b/c++/src/ColumnReader.cc
@@ -21,7 +21,9 @@
#include "Adaptor.hh"
#include "ByteRLE.hh"
#include "ColumnReader.hh"
+#include "ConvertColumnReader.hh"
#include "RLE.hh"
+#include "SchemaEvolution.hh"
#include "orc/Exceptions.hh"
#include <math.h>
@@ -828,7 +830,8 @@ namespace orc {
std::vector<std::unique_ptr<ColumnReader>> children;
public:
- StructColumnReader(const Type& type, StripeStreams& stipe, bool
useTightNumericVector = false);
+ StructColumnReader(const Type& type, StripeStreams& stipe, bool
useTightNumericVector = false,
+ bool throwOnSchemaEvolutionOverflow = false);
uint64_t skip(uint64_t numValues) override;
@@ -844,7 +847,8 @@ namespace orc {
};
StructColumnReader::StructColumnReader(const Type& type, StripeStreams&
stripe,
- bool useTightNumericVector)
+ bool useTightNumericVector,
+ bool throwOnSchemaEvolutionOverflow)
: ColumnReader(type, stripe) {
// count the number of selected sub-columns
const std::vector<bool> selectedColumns = stripe.getSelectedColumns();
@@ -853,7 +857,8 @@ namespace orc {
for (unsigned int i = 0; i < type.getSubtypeCount(); ++i) {
const Type& child = *type.getSubtype(i);
if (selectedColumns[static_cast<uint64_t>(child.getColumnId())]) {
- children.push_back(buildReader(child, stripe,
useTightNumericVector));
+ children.push_back(
+ buildReader(child, stripe, useTightNumericVector,
throwOnSchemaEvolutionOverflow));
}
}
break;
@@ -913,7 +918,8 @@ namespace orc {
std::unique_ptr<RleDecoder> rle;
public:
- ListColumnReader(const Type& type, StripeStreams& stipe, bool
useTightNumericVector = false);
+ ListColumnReader(const Type& type, StripeStreams& stipe, bool
useTightNumericVector = false,
+ bool throwOnSchemaEvolutionOverflow = false);
~ListColumnReader() override;
uint64_t skip(uint64_t numValues) override;
@@ -930,7 +936,8 @@ namespace orc {
};
ListColumnReader::ListColumnReader(const Type& type, StripeStreams& stripe,
- bool useTightNumericVector)
+ bool useTightNumericVector,
+ bool throwOnSchemaEvolutionOverflow)
: ColumnReader(type, stripe) {
// count the number of selected sub-columns
const std::vector<bool> selectedColumns = stripe.getSelectedColumns();
@@ -941,7 +948,7 @@ namespace orc {
rle = createRleDecoder(std::move(stream), false, vers, memoryPool,
metrics);
const Type& childType = *type.getSubtype(0);
if (selectedColumns[static_cast<uint64_t>(childType.getColumnId())]) {
- child = buildReader(childType, stripe, useTightNumericVector);
+ child = buildReader(childType, stripe, useTightNumericVector,
throwOnSchemaEvolutionOverflow);
}
}
@@ -1033,7 +1040,8 @@ namespace orc {
std::unique_ptr<RleDecoder> rle;
public:
- MapColumnReader(const Type& type, StripeStreams& stipe, bool
useTightNumericVector = false);
+ MapColumnReader(const Type& type, StripeStreams& stipe, bool
useTightNumericVector = false,
+ bool throwOnSchemaEvolutionOverflow = false);
~MapColumnReader() override;
uint64_t skip(uint64_t numValues) override;
@@ -1050,7 +1058,7 @@ namespace orc {
};
MapColumnReader::MapColumnReader(const Type& type, StripeStreams& stripe,
- bool useTightNumericVector)
+ bool useTightNumericVector, bool
throwOnSchemaEvolutionOverflow)
: ColumnReader(type, stripe) {
// Determine if the key and/or value columns are selected
const std::vector<bool> selectedColumns = stripe.getSelectedColumns();
@@ -1061,11 +1069,13 @@ namespace orc {
rle = createRleDecoder(std::move(stream), false, vers, memoryPool,
metrics);
const Type& keyType = *type.getSubtype(0);
if (selectedColumns[static_cast<uint64_t>(keyType.getColumnId())]) {
- keyReader = buildReader(keyType, stripe, useTightNumericVector);
+ keyReader =
+ buildReader(keyType, stripe, useTightNumericVector,
throwOnSchemaEvolutionOverflow);
}
const Type& elementType = *type.getSubtype(1);
if (selectedColumns[static_cast<uint64_t>(elementType.getColumnId())]) {
- elementReader = buildReader(elementType, stripe, useTightNumericVector);
+ elementReader =
+ buildReader(elementType, stripe, useTightNumericVector,
throwOnSchemaEvolutionOverflow);
}
}
@@ -1175,7 +1185,8 @@ namespace orc {
uint64_t numChildren;
public:
- UnionColumnReader(const Type& type, StripeStreams& stipe, bool
useTightNumericVector = false);
+ UnionColumnReader(const Type& type, StripeStreams& stipe, bool
useTightNumericVector = false,
+ bool throwOnSchemaEvolutionOverflow = false);
uint64_t skip(uint64_t numValues) override;
@@ -1191,7 +1202,8 @@ namespace orc {
};
UnionColumnReader::UnionColumnReader(const Type& type, StripeStreams& stripe,
- bool useTightNumericVector)
+ bool useTightNumericVector,
+ bool throwOnSchemaEvolutionOverflow)
: ColumnReader(type, stripe) {
numChildren = type.getSubtypeCount();
childrenReader.resize(numChildren);
@@ -1206,7 +1218,8 @@ namespace orc {
for (unsigned int i = 0; i < numChildren; ++i) {
const Type& child = *type.getSubtype(i);
if (selectedColumns[static_cast<size_t>(child.getColumnId())]) {
- childrenReader[i] = buildReader(child, stripe, useTightNumericVector);
+ childrenReader[i] =
+ buildReader(child, stripe, useTightNumericVector,
throwOnSchemaEvolutionOverflow);
}
}
}
@@ -1699,7 +1712,15 @@ namespace orc {
* Create a reader for the given stripe.
*/
std::unique_ptr<ColumnReader> buildReader(const Type& type, StripeStreams&
stripe,
- bool useTightNumericVector) {
+ bool useTightNumericVector,
+ bool
throwOnSchemaEvolutionOverflow,
+ bool convertToReadType) {
+ if (convertToReadType && stripe.getSchemaEvolution() &&
+ stripe.getSchemaEvolution()->needConvert(type)) {
+ return buildConvertReader(type, stripe, useTightNumericVector,
+ throwOnSchemaEvolutionOverflow);
+ }
+
switch (static_cast<int64_t>(type.getKind())) {
case SHORT: {
if (useTightNumericVector) {
@@ -1744,16 +1765,20 @@ namespace orc {
return std::make_unique<ByteColumnReader<LongVectorBatch>>(type,
stripe);
case LIST:
- return std::make_unique<ListColumnReader>(type, stripe,
useTightNumericVector);
+ return std::make_unique<ListColumnReader>(type, stripe,
useTightNumericVector,
+
throwOnSchemaEvolutionOverflow);
case MAP:
- return std::make_unique<MapColumnReader>(type, stripe,
useTightNumericVector);
+ return std::make_unique<MapColumnReader>(type, stripe,
useTightNumericVector,
+
throwOnSchemaEvolutionOverflow);
case UNION:
- return std::make_unique<UnionColumnReader>(type, stripe,
useTightNumericVector);
+ return std::make_unique<UnionColumnReader>(type, stripe,
useTightNumericVector,
+
throwOnSchemaEvolutionOverflow);
case STRUCT:
- return std::make_unique<StructColumnReader>(type, stripe,
useTightNumericVector);
+ return std::make_unique<StructColumnReader>(type, stripe,
useTightNumericVector,
+
throwOnSchemaEvolutionOverflow);
case FLOAT: {
if (useTightNumericVector) {
diff --git a/c++/src/ColumnReader.hh b/c++/src/ColumnReader.hh
index 3b765cbe5..f0f3fe1b5 100644
--- a/c++/src/ColumnReader.hh
+++ b/c++/src/ColumnReader.hh
@@ -30,6 +30,8 @@
namespace orc {
+ class SchemaEvolution;
+
class StripeStreams {
public:
virtual ~StripeStreams();
@@ -101,6 +103,11 @@ namespace orc {
* encoded in RLE.
*/
virtual bool isDecimalAsLong() const = 0;
+
+ /**
+ * @return get schema evolution utility object
+ */
+ virtual const SchemaEvolution* getSchemaEvolution() const = 0;
};
/**
@@ -159,7 +166,9 @@ namespace orc {
* Create a reader for the given stripe.
*/
std::unique_ptr<ColumnReader> buildReader(const Type& type, StripeStreams&
stripe,
- bool useTightNumericVector =
false);
+ bool useTightNumericVector = false,
+ bool
throwOnSchemaEvolutionOverflow = false,
+ bool convertToReadType = true);
} // namespace orc
#endif
diff --git a/c++/src/ConvertColumnReader.cc b/c++/src/ConvertColumnReader.cc
new file mode 100644
index 000000000..c929b69f1
--- /dev/null
+++ b/c++/src/ConvertColumnReader.cc
@@ -0,0 +1,446 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "ConvertColumnReader.hh"
+
+namespace orc {
+
+ // Assume that we are using tight numeric vector batch
+ using BooleanVectorBatch = ByteVectorBatch;
+
+ ConvertColumnReader::ConvertColumnReader(const Type& _readType, const Type&
fileType,
+ StripeStreams& stripe, bool
_throwOnOverflow)
+ : ColumnReader(_readType, stripe), readType(_readType),
throwOnOverflow(_throwOnOverflow) {
+ reader = buildReader(fileType, stripe, /*useTightNumericVector=*/true,
+ /*throwOnOverflow=*/false, /*convertToReadType*/
false);
+ data =
+ fileType.createRowBatch(0, memoryPool, /*encoded=*/false,
/*useTightNumericVector=*/true);
+ }
+
+ void ConvertColumnReader::next(ColumnVectorBatch& rowBatch, uint64_t
numValues, char* notNull) {
+ reader->next(*data, numValues, notNull);
+ rowBatch.resize(data->capacity);
+ rowBatch.numElements = data->numElements;
+ rowBatch.hasNulls = data->hasNulls;
+ if (!rowBatch.hasNulls) {
+ memset(rowBatch.notNull.data(), 1, data->notNull.size());
+ } else {
+ memcpy(rowBatch.notNull.data(), data->notNull.data(),
data->notNull.size());
+ }
+ }
+
+ uint64_t ConvertColumnReader::skip(uint64_t numValues) {
+ return reader->skip(numValues);
+ }
+
+ void ConvertColumnReader::seekToRowGroup(
+ std::unordered_map<uint64_t, PositionProvider>& positions) {
+ reader->seekToRowGroup(positions);
+ }
+
+ static inline bool canFitInLong(double value) {
+ constexpr double MIN_LONG_AS_DOUBLE = -0x1p63;
+ constexpr double MAX_LONG_AS_DOUBLE_PLUS_ONE = 0x1p63;
+ return ((MIN_LONG_AS_DOUBLE - value < 1.0) && (value <
MAX_LONG_AS_DOUBLE_PLUS_ONE));
+ }
+
+ template <typename FileType, typename ReadType>
+ static inline void handleOverflow(ColumnVectorBatch& dstBatch, uint64_t idx,
bool shouldThrow) {
+ if (!shouldThrow) {
+ dstBatch.notNull.data()[idx] = 0;
+ dstBatch.hasNulls = true;
+ } else {
+ std::ostringstream ss;
+ ss << "Overflow when convert from " << typeid(FileType).name() << " to "
+ << typeid(ReadType).name();
+ throw SchemaEvolutionError(ss.str());
+ }
+ }
+
+ // return false if overflow
+ template <typename ReadType>
+ static bool downCastToInteger(ReadType& dstValue, int64_t inputLong) {
+ dstValue = static_cast<ReadType>(inputLong);
+ if constexpr (std::is_same<ReadType, int64_t>::value) {
+ return true;
+ }
+ if (static_cast<int64_t>(dstValue) != inputLong) {
+ return false;
+ }
+ return true;
+ }
+
+ template <typename DestBatchPtrType>
+ static inline DestBatchPtrType SafeCastBatchTo(ColumnVectorBatch* batch) {
+ auto result = dynamic_cast<DestBatchPtrType>(batch);
+ if (result == nullptr) {
+ std::ostringstream ss;
+ ss << "Bad cast when convert from ColumnVectorBatch to "
+ << typeid(typename std::remove_const<
+ typename
std::remove_pointer<DestBatchPtrType>::type>::type)
+ .name();
+ throw InvalidArgument(ss.str());
+ }
+ return result;
+ }
+
+ // set null or throw exception if overflow
+ template <typename ReadType, typename FileType>
+ static inline void convertNumericElement(const FileType& srcValue, ReadType&
destValue,
+ ColumnVectorBatch& destBatch,
uint64_t idx,
+ bool shouldThrow) {
+ constexpr bool
isFileTypeFloatingPoint(std::is_floating_point<FileType>::value);
+ constexpr bool
isReadTypeFloatingPoint(std::is_floating_point<ReadType>::value);
+ int64_t longValue = static_cast<int64_t>(srcValue);
+ if (isFileTypeFloatingPoint) {
+ if (isReadTypeFloatingPoint) {
+ destValue = static_cast<ReadType>(srcValue);
+ } else {
+ if (!canFitInLong(static_cast<double>(srcValue)) ||
+ !downCastToInteger(destValue, longValue)) {
+ handleOverflow<FileType, ReadType>(destBatch, idx, shouldThrow);
+ }
+ }
+ } else {
+ if (isReadTypeFloatingPoint) {
+ destValue = static_cast<ReadType>(srcValue);
+ if (destValue != destValue) { // check is NaN
+ handleOverflow<FileType, ReadType>(destBatch, idx, shouldThrow);
+ }
+ } else {
+ if (!downCastToInteger(destValue, static_cast<int64_t>(srcValue))) {
+ handleOverflow<FileType, ReadType>(destBatch, idx, shouldThrow);
+ }
+ }
+ }
+ }
+
+ // { boolean, byte, short, int, long, float, double } ->
+ // { byte, short, int, long, float, double }
+ template <typename FileTypeBatch, typename ReadTypeBatch, typename ReadType>
+ class NumericConvertColumnReader : public ConvertColumnReader {
+ public:
+ NumericConvertColumnReader(const Type& _readType, const Type& fileType,
StripeStreams& stripe,
+ bool _throwOnOverflow)
+ : ConvertColumnReader(_readType, fileType, stripe, _throwOnOverflow) {}
+
+ void next(ColumnVectorBatch& rowBatch, uint64_t numValues, char* notNull)
override {
+ ConvertColumnReader::next(rowBatch, numValues, notNull);
+ const auto& srcBatch = *SafeCastBatchTo<const
FileTypeBatch*>(data.get());
+ auto& dstBatch = *SafeCastBatchTo<ReadTypeBatch*>(&rowBatch);
+ if (rowBatch.hasNulls) {
+ for (uint64_t i = 0; i < rowBatch.numElements; ++i) {
+ if (rowBatch.notNull[i]) {
+ convertNumericElement<ReadType>(srcBatch.data[i],
dstBatch.data[i], rowBatch, i,
+ throwOnOverflow);
+ }
+ }
+ } else {
+ for (uint64_t i = 0; i < rowBatch.numElements; ++i) {
+ convertNumericElement<ReadType>(srcBatch.data[i], dstBatch.data[i],
rowBatch, i,
+ throwOnOverflow);
+ }
+ }
+ }
+ };
+
+ // { boolean, byte, short, int, long, float, double } -> { boolean }
+ template <typename FileTypeBatch>
+ class NumericConvertColumnReader<FileTypeBatch, BooleanVectorBatch, bool>
+ : public ConvertColumnReader {
+ public:
+ NumericConvertColumnReader(const Type& _readType, const Type& fileType,
StripeStreams& stripe,
+ bool _throwOnOverflow)
+ : ConvertColumnReader(_readType, fileType, stripe, _throwOnOverflow) {}
+
+ void next(ColumnVectorBatch& rowBatch, uint64_t numValues, char* notNull)
override {
+ ConvertColumnReader::next(rowBatch, numValues, notNull);
+ const auto& srcBatch = *SafeCastBatchTo<const
FileTypeBatch*>(data.get());
+ auto& dstBatch = *SafeCastBatchTo<BooleanVectorBatch*>(&rowBatch);
+ if (rowBatch.hasNulls) {
+ for (uint64_t i = 0; i < rowBatch.numElements; ++i) {
+ if (rowBatch.notNull[i]) {
+ dstBatch.data[i] = (static_cast<int64_t>(srcBatch.data[i]) == 0 ?
0 : 1);
+ }
+ }
+ } else {
+ for (uint64_t i = 0; i < rowBatch.numElements; ++i) {
+ dstBatch.data[i] = (static_cast<int64_t>(srcBatch.data[i]) == 0 ? 0
: 1);
+ }
+ }
+ }
+ };
+
+#define DEFINE_NUMERIC_CONVERT_READER(FROM, TO, TYPE) \
+ using FROM##To##TO##ColumnReader = \
+ NumericConvertColumnReader<FROM##VectorBatch, TO##VectorBatch, TYPE>;
+
+ DEFINE_NUMERIC_CONVERT_READER(Boolean, Byte, int8_t)
+ DEFINE_NUMERIC_CONVERT_READER(Boolean, Short, int16_t)
+ DEFINE_NUMERIC_CONVERT_READER(Boolean, Int, int32_t)
+ DEFINE_NUMERIC_CONVERT_READER(Boolean, Long, int64_t)
+ DEFINE_NUMERIC_CONVERT_READER(Byte, Short, int16_t)
+ DEFINE_NUMERIC_CONVERT_READER(Byte, Int, int32_t)
+ DEFINE_NUMERIC_CONVERT_READER(Byte, Long, int64_t)
+ DEFINE_NUMERIC_CONVERT_READER(Short, Int, int32_t)
+ DEFINE_NUMERIC_CONVERT_READER(Short, Long, int64_t)
+ DEFINE_NUMERIC_CONVERT_READER(Int, Long, int64_t)
+ DEFINE_NUMERIC_CONVERT_READER(Float, Double, double)
+ DEFINE_NUMERIC_CONVERT_READER(Byte, Boolean, bool)
+ DEFINE_NUMERIC_CONVERT_READER(Short, Boolean, bool)
+ DEFINE_NUMERIC_CONVERT_READER(Short, Byte, int8_t)
+ DEFINE_NUMERIC_CONVERT_READER(Int, Boolean, bool)
+ DEFINE_NUMERIC_CONVERT_READER(Int, Byte, int8_t)
+ DEFINE_NUMERIC_CONVERT_READER(Int, Short, int16_t)
+ DEFINE_NUMERIC_CONVERT_READER(Long, Boolean, bool)
+ DEFINE_NUMERIC_CONVERT_READER(Long, Byte, int8_t)
+ DEFINE_NUMERIC_CONVERT_READER(Long, Short, int16_t)
+ DEFINE_NUMERIC_CONVERT_READER(Long, Int, int32_t)
+ DEFINE_NUMERIC_CONVERT_READER(Double, Float, float)
+ // Floating to integer
+ DEFINE_NUMERIC_CONVERT_READER(Float, Boolean, bool)
+ DEFINE_NUMERIC_CONVERT_READER(Float, Byte, int8_t)
+ DEFINE_NUMERIC_CONVERT_READER(Float, Short, int16_t)
+ DEFINE_NUMERIC_CONVERT_READER(Float, Int, int32_t)
+ DEFINE_NUMERIC_CONVERT_READER(Float, Long, int64_t)
+ DEFINE_NUMERIC_CONVERT_READER(Double, Boolean, bool)
+ DEFINE_NUMERIC_CONVERT_READER(Double, Byte, int8_t)
+ DEFINE_NUMERIC_CONVERT_READER(Double, Short, int16_t)
+ DEFINE_NUMERIC_CONVERT_READER(Double, Int, int32_t)
+ DEFINE_NUMERIC_CONVERT_READER(Double, Long, int64_t)
+ // Integer to Floating
+ DEFINE_NUMERIC_CONVERT_READER(Boolean, Float, float)
+ DEFINE_NUMERIC_CONVERT_READER(Byte, Float, float)
+ DEFINE_NUMERIC_CONVERT_READER(Short, Float, float)
+ DEFINE_NUMERIC_CONVERT_READER(Int, Float, float)
+ DEFINE_NUMERIC_CONVERT_READER(Long, Float, float)
+ DEFINE_NUMERIC_CONVERT_READER(Boolean, Double, double)
+ DEFINE_NUMERIC_CONVERT_READER(Byte, Double, double)
+ DEFINE_NUMERIC_CONVERT_READER(Short, Double, double)
+ DEFINE_NUMERIC_CONVERT_READER(Int, Double, double)
+ DEFINE_NUMERIC_CONVERT_READER(Long, Double, double)
+
+#define CASE_CREATE_READER(TYPE, CONVERT) \
+ case TYPE: \
+ return std::make_unique<CONVERT##ColumnReader>(_readType, fileType,
stripe, throwOnOverflow);
+
+#define CASE_EXCEPTION
\
+ default:
\
+ throw SchemaEvolutionError("Cannot convert from " + fileType.toString() +
" to " + \
+ _readType.toString());
+
+ std::unique_ptr<ColumnReader> buildConvertReader(const Type& fileType,
StripeStreams& stripe,
+ bool useTightNumericVector,
+ bool throwOnOverflow) {
+ if (!useTightNumericVector) {
+ throw SchemaEvolutionError(
+ "SchemaEvolution only support tight vector, please create
ColumnVectorBatch with "
+ "option useTightNumericVector");
+ }
+ const auto& _readType =
*stripe.getSchemaEvolution()->getReadType(fileType);
+
+ switch (fileType.getKind()) {
+ case BOOLEAN: {
+ switch (_readType.getKind()) {
+ CASE_CREATE_READER(BYTE, BooleanToByte);
+ CASE_CREATE_READER(SHORT, BooleanToShort);
+ CASE_CREATE_READER(INT, BooleanToInt);
+ CASE_CREATE_READER(LONG, BooleanToLong);
+ CASE_CREATE_READER(FLOAT, BooleanToFloat);
+ CASE_CREATE_READER(DOUBLE, BooleanToDouble);
+ case BOOLEAN:
+ case STRING:
+ case BINARY:
+ case TIMESTAMP:
+ case LIST:
+ case MAP:
+ case STRUCT:
+ case UNION:
+ case DECIMAL:
+ case DATE:
+ case VARCHAR:
+ case CHAR:
+ case TIMESTAMP_INSTANT:
+ CASE_EXCEPTION
+ }
+ }
+ case BYTE: {
+ switch (_readType.getKind()) {
+ CASE_CREATE_READER(BOOLEAN, ByteToBoolean);
+ CASE_CREATE_READER(SHORT, ByteToShort);
+ CASE_CREATE_READER(INT, ByteToInt);
+ CASE_CREATE_READER(LONG, ByteToLong);
+ CASE_CREATE_READER(FLOAT, ByteToFloat);
+ CASE_CREATE_READER(DOUBLE, ByteToDouble);
+ case BYTE:
+ case STRING:
+ case BINARY:
+ case TIMESTAMP:
+ case LIST:
+ case MAP:
+ case STRUCT:
+ case UNION:
+ case DECIMAL:
+ case DATE:
+ case VARCHAR:
+ case CHAR:
+ case TIMESTAMP_INSTANT:
+ CASE_EXCEPTION
+ }
+ }
+ case SHORT: {
+ switch (_readType.getKind()) {
+ CASE_CREATE_READER(BOOLEAN, ShortToBoolean);
+ CASE_CREATE_READER(BYTE, ShortToByte);
+ CASE_CREATE_READER(INT, ShortToInt);
+ CASE_CREATE_READER(LONG, ShortToLong);
+ CASE_CREATE_READER(FLOAT, ShortToFloat);
+ CASE_CREATE_READER(DOUBLE, ShortToDouble);
+ case SHORT:
+ case STRING:
+ case BINARY:
+ case TIMESTAMP:
+ case LIST:
+ case MAP:
+ case STRUCT:
+ case UNION:
+ case DECIMAL:
+ case DATE:
+ case VARCHAR:
+ case CHAR:
+ case TIMESTAMP_INSTANT:
+ CASE_EXCEPTION
+ }
+ }
+ case INT: {
+ switch (_readType.getKind()) {
+ CASE_CREATE_READER(BOOLEAN, IntToBoolean);
+ CASE_CREATE_READER(BYTE, IntToByte);
+ CASE_CREATE_READER(SHORT, IntToShort);
+ CASE_CREATE_READER(LONG, IntToLong);
+ CASE_CREATE_READER(FLOAT, IntToFloat);
+ CASE_CREATE_READER(DOUBLE, IntToDouble);
+ case INT:
+ case STRING:
+ case BINARY:
+ case TIMESTAMP:
+ case LIST:
+ case MAP:
+ case STRUCT:
+ case UNION:
+ case DECIMAL:
+ case DATE:
+ case VARCHAR:
+ case CHAR:
+ case TIMESTAMP_INSTANT:
+ CASE_EXCEPTION
+ }
+ }
+ case LONG: {
+ switch (_readType.getKind()) {
+ CASE_CREATE_READER(BOOLEAN, LongToBoolean);
+ CASE_CREATE_READER(BYTE, LongToByte);
+ CASE_CREATE_READER(SHORT, LongToShort);
+ CASE_CREATE_READER(INT, LongToInt);
+ CASE_CREATE_READER(FLOAT, LongToFloat);
+ CASE_CREATE_READER(DOUBLE, LongToDouble);
+ case LONG:
+ case STRING:
+ case BINARY:
+ case TIMESTAMP:
+ case LIST:
+ case MAP:
+ case STRUCT:
+ case UNION:
+ case DECIMAL:
+ case DATE:
+ case VARCHAR:
+ case CHAR:
+ case TIMESTAMP_INSTANT:
+ CASE_EXCEPTION
+ }
+ }
+ case FLOAT: {
+ switch (_readType.getKind()) {
+ CASE_CREATE_READER(BOOLEAN, FloatToBoolean);
+ CASE_CREATE_READER(BYTE, FloatToByte);
+ CASE_CREATE_READER(SHORT, FloatToShort);
+ CASE_CREATE_READER(INT, FloatToInt);
+ CASE_CREATE_READER(LONG, FloatToLong);
+ CASE_CREATE_READER(DOUBLE, FloatToDouble);
+ case FLOAT:
+ case STRING:
+ case BINARY:
+ case TIMESTAMP:
+ case LIST:
+ case MAP:
+ case STRUCT:
+ case UNION:
+ case DECIMAL:
+ case DATE:
+ case VARCHAR:
+ case CHAR:
+ case TIMESTAMP_INSTANT:
+ CASE_EXCEPTION
+ }
+ }
+ case DOUBLE: {
+ switch (_readType.getKind()) {
+ CASE_CREATE_READER(BOOLEAN, DoubleToBoolean);
+ CASE_CREATE_READER(BYTE, DoubleToByte);
+ CASE_CREATE_READER(SHORT, DoubleToShort);
+ CASE_CREATE_READER(INT, DoubleToInt);
+ CASE_CREATE_READER(LONG, DoubleToLong);
+ CASE_CREATE_READER(FLOAT, DoubleToFloat);
+ case DOUBLE:
+ case STRING:
+ case BINARY:
+ case TIMESTAMP:
+ case LIST:
+ case MAP:
+ case STRUCT:
+ case UNION:
+ case DECIMAL:
+ case DATE:
+ case VARCHAR:
+ case CHAR:
+ case TIMESTAMP_INSTANT:
+ CASE_EXCEPTION
+ }
+ }
+ case STRING:
+ case BINARY:
+ case TIMESTAMP:
+ case LIST:
+ case MAP:
+ case STRUCT:
+ case UNION:
+ case DECIMAL:
+ case DATE:
+ case VARCHAR:
+ case CHAR:
+ case TIMESTAMP_INSTANT:
+ CASE_EXCEPTION
+ }
+ }
+
+#undef DEFINE_NUMERIC_CONVERT_READER
+#undef CASE_CREATE_READER
+#undef CASE_EXCEPTION
+
+} // namespace orc
diff --git a/c++/src/ConvertColumnReader.hh b/c++/src/ConvertColumnReader.hh
new file mode 100644
index 000000000..6ed4d0170
--- /dev/null
+++ b/c++/src/ConvertColumnReader.hh
@@ -0,0 +1,53 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ORC_CONVERT_COLUMN_READER_HH
+#define ORC_CONVERT_COLUMN_READER_HH
+
+#include "ColumnReader.hh"
+#include "SchemaEvolution.hh"
+
+namespace orc {
+
+ class ConvertColumnReader : public ColumnReader {
+ public:
+ ConvertColumnReader(const Type& readType, const Type& fileType,
StripeStreams& stripe,
+ bool throwOnOverflow);
+
+ // override next() to implement convert logic
+ void next(ColumnVectorBatch& rowBatch, uint64_t numValues, char* notNull)
override;
+
+ uint64_t skip(uint64_t numValues) override;
+
+ void seekToRowGroup(std::unordered_map<uint64_t, PositionProvider>&
positions) override;
+
+ protected:
+ bool useTightNumericVector;
+ const Type& readType;
+ std::unique_ptr<ColumnReader> reader;
+ std::unique_ptr<ColumnVectorBatch> data;
+ const bool throwOnOverflow;
+ };
+
+ std::unique_ptr<ColumnReader> buildConvertReader(const Type& fileType,
StripeStreams& stripe,
+ bool useTightNumericVector,
+ bool throwOnOverflow);
+
+} // namespace orc
+
+#endif // ORC_CONVERT_COLUMN_READER_HH
diff --git a/c++/src/Options.hh b/c++/src/Options.hh
index 151434e8d..51cd8efd6 100644
--- a/c++/src/Options.hh
+++ b/c++/src/Options.hh
@@ -139,6 +139,8 @@ namespace orc {
std::string readerTimezone;
RowReaderOptions::IdReadIntentMap idReadIntentMap;
bool useTightNumericVector;
+ std::shared_ptr<Type> readType;
+ bool throwOnSchemaEvolutionOverflow;
RowReaderOptionsPrivate() {
selection = ColumnSelection_NONE;
@@ -149,6 +151,7 @@ namespace orc {
enableLazyDecoding = false;
readerTimezone = "GMT";
useTightNumericVector = false;
+ throwOnSchemaEvolutionOverflow = false;
}
};
@@ -257,6 +260,15 @@ namespace orc {
return privateBits->throwOnHive11DecimalOverflow;
}
+ RowReaderOptions& RowReaderOptions::throwOnSchemaEvolutionOverflow(bool
shouldThrow) {
+ privateBits->throwOnSchemaEvolutionOverflow = shouldThrow;
+ return *this;
+ }
+
+ bool RowReaderOptions::getThrowOnSchemaEvolutionOverflow() const {
+ return privateBits->throwOnSchemaEvolutionOverflow;
+ }
+
RowReaderOptions& RowReaderOptions::forcedScaleOnHive11Decimal(int32_t
forcedScale) {
privateBits->forcedScaleOnHive11Decimal = forcedScale;
return *this;
@@ -305,6 +317,15 @@ namespace orc {
bool RowReaderOptions::getUseTightNumericVector() const {
return privateBits->useTightNumericVector;
}
+
+ RowReaderOptions& RowReaderOptions::setReadType(std::shared_ptr<Type> type) {
+ privateBits->readType = std::move(type);
+ return *this;
+ }
+
+ std::shared_ptr<Type>& RowReaderOptions::getReadType() const {
+ return privateBits->readType;
+ }
} // namespace orc
#endif
diff --git a/c++/src/Reader.cc b/c++/src/Reader.cc
index 8ccb0ec90..b52675abb 100644
--- a/c++/src/Reader.cc
+++ b/c++/src/Reader.cc
@@ -255,7 +255,8 @@ namespace orc {
footer(contents->footer.get()),
firstRowOfStripe(*contents->pool, 0),
enableEncodedBlock(opts.getEnableLazyDecoding()),
- readerTimezone(getTimezoneByName(opts.getTimezoneName())) {
+ readerTimezone(getTimezoneByName(opts.getTimezoneName())),
+ schemaEvolution(opts.getReadType(), contents->schema.get()) {
uint64_t numberOfStripes;
numberOfStripes = static_cast<uint64_t>(footer->stripes_size());
currentStripe = numberOfStripes;
@@ -264,6 +265,7 @@ namespace orc {
rowsInCurrentStripe = 0;
numRowGroupsInStripeRange = 0;
useTightNumericVector = opts.getUseTightNumericVector();
+ throwOnSchemaEvolutionOverflow = opts.getThrowOnSchemaEvolutionOverflow();
uint64_t rowTotal = 0;
firstRowOfStripe.resize(numberOfStripes);
@@ -1091,7 +1093,8 @@ namespace orc {
StripeStreamsImpl stripeStreams(*this, currentStripe, currentStripeInfo,
currentStripeFooter,
currentStripeInfo.offset(),
*contents->stream, writerTimezone,
readerTimezone);
- reader = buildReader(*contents->schema, stripeStreams,
useTightNumericVector);
+ reader = buildReader(*contents->schema, stripeStreams,
useTightNumericVector,
+ throwOnSchemaEvolutionOverflow,
/*convertToReadType=*/true);
if (sargsApplier) {
// move to the 1st selected row group when PPD is enabled.
@@ -1204,9 +1207,33 @@ namespace orc {
return rowsInCurrentStripe;
}
+ static void getColumnIds(const Type* type, std::set<uint64_t>& columnIds) {
+ columnIds.insert(type->getColumnId());
+ for (uint64_t i = 0; i < type->getSubtypeCount(); ++i) {
+ getColumnIds(type->getSubtype(i), columnIds);
+ }
+ }
+
std::unique_ptr<ColumnVectorBatch> RowReaderImpl::createRowBatch(uint64_t
capacity) const {
- return getSelectedType().createRowBatch(capacity, *contents->pool,
enableEncodedBlock,
- useTightNumericVector);
+ // If the read type is specified, then check that the selected schema
matches the read type
+ // on the first call to createRowBatch.
+ if (schemaEvolution.getReadType() && selectedSchema.get() == nullptr) {
+ auto fileSchema = &getSelectedType();
+ auto readType = schemaEvolution.getReadType();
+ std::set<uint64_t> readColumns, fileColumns;
+ getColumnIds(readType, readColumns);
+ getColumnIds(fileSchema, fileColumns);
+ if (readColumns != fileColumns) {
+ std::ostringstream ss;
+ ss << "The selected schema " << fileSchema->toString() << " doesn't
match read type "
+ << readType->toString();
+ throw SchemaEvolutionError(ss.str());
+ }
+ }
+ const Type& readType =
+ schemaEvolution.getReadType() ? *schemaEvolution.getReadType() :
getSelectedType();
+ return readType.createRowBatch(capacity, *contents->pool,
enableEncodedBlock,
+ useTightNumericVector);
}
void ensureOrcFooter(InputStream* stream, DataBuffer<char>* buffer, uint64_t
postscriptLength) {
diff --git a/c++/src/Reader.hh b/c++/src/Reader.hh
index ea6db3aad..a1367e4bd 100644
--- a/c++/src/Reader.hh
+++ b/c++/src/Reader.hh
@@ -159,6 +159,7 @@ namespace orc {
bool enableEncodedBlock;
bool useTightNumericVector;
+ bool throwOnSchemaEvolutionOverflow;
// internal methods
void startNextStripe();
inline void markEndOfFile();
@@ -172,6 +173,9 @@ namespace orc {
// desired timezone to return data of timestamp types.
const Timezone& readerTimezone;
+ // match read and file types
+ SchemaEvolution schemaEvolution;
+
// load stripe index if not done so
void loadStripeIndex();
@@ -237,6 +241,10 @@ namespace orc {
bool getThrowOnHive11DecimalOverflow() const;
bool getIsDecimalAsLong() const;
int32_t getForcedScaleOnHive11Decimal() const;
+
+ const SchemaEvolution* getSchemaEvolution() const {
+ return &schemaEvolution;
+ }
};
class ReaderImpl : public Reader {
diff --git a/c++/src/SchemaEvolution.cc b/c++/src/SchemaEvolution.cc
index adbedd5fa..d694a49d4 100644
--- a/c++/src/SchemaEvolution.cc
+++ b/c++/src/SchemaEvolution.cc
@@ -49,19 +49,10 @@ namespace orc {
}
};
- // map from file type to read type. it does not contain identity mapping.
- using TypeSet = std::unordered_set<TypeKind, EnumClassHash>;
- using ConvertMap = std::unordered_map<TypeKind, TypeSet, EnumClassHash>;
-
- inline bool supportConversion(const Type& readType, const Type& fileType) {
- static const ConvertMap& SUPPORTED_CONVERSIONS = *new ConvertMap{
- // support nothing now
- };
- auto iter = SUPPORTED_CONVERSIONS.find(fileType.getKind());
- if (iter == SUPPORTED_CONVERSIONS.cend()) {
- return false;
- }
- return iter->second.find(readType.getKind()) != iter->second.cend();
+ bool isNumeric(const Type& type) {
+ auto kind = type.getKind();
+ return kind == BOOLEAN || kind == BYTE || kind == SHORT || kind == INT ||
kind == LONG ||
+ kind == FLOAT || kind == DOUBLE;
}
struct ConversionCheckResult {
@@ -74,10 +65,10 @@ namespace orc {
if (readType.getKind() == fileType.getKind()) {
ret.isValid = true;
if (fileType.getKind() == CHAR || fileType.getKind() == VARCHAR) {
- ret.needConvert = readType.getMaximumLength() <
fileType.getMaximumLength();
+ ret.isValid = readType.getMaximumLength() ==
fileType.getMaximumLength();
} else if (fileType.getKind() == DECIMAL) {
- ret.needConvert = readType.getPrecision() != fileType.getPrecision() ||
- readType.getScale() != fileType.getScale();
+ ret.isValid = readType.getPrecision() == fileType.getPrecision() &&
+ readType.getScale() == fileType.getScale();
}
} else {
switch (fileType.getKind()) {
@@ -87,48 +78,19 @@ namespace orc {
case INT:
case LONG:
case FLOAT:
- case DOUBLE:
- case DECIMAL: {
- ret.isValid = ret.needConvert =
- (readType.getKind() != DATE && readType.getKind() != BINARY);
- break;
- }
- case STRING: {
- ret.isValid = ret.needConvert = true;
+ case DOUBLE: {
+ ret.isValid = ret.needConvert = isNumeric(readType);
break;
}
+ case DECIMAL:
+ case STRING:
case CHAR:
- case VARCHAR: {
- ret.isValid = true;
- if (readType.getKind() == STRING) {
- ret.needConvert = false;
- } else if (readType.getKind() == CHAR || readType.getKind() ==
VARCHAR) {
- ret.needConvert = readType.getMaximumLength() <
fileType.getMaximumLength();
- } else {
- ret.needConvert = true;
- }
- break;
- }
+ case VARCHAR:
case TIMESTAMP:
- case TIMESTAMP_INSTANT: {
- if (readType.getKind() == TIMESTAMP || readType.getKind() ==
TIMESTAMP_INSTANT) {
- ret = {true, false};
- } else {
- ret.isValid = ret.needConvert = (readType.getKind() != BINARY);
- }
- break;
- }
- case DATE: {
- ret.isValid = ret.needConvert =
- readType.getKind() == STRING || readType.getKind() == CHAR ||
- readType.getKind() == VARCHAR || readType.getKind() == TIMESTAMP
||
- readType.getKind() == TIMESTAMP_INSTANT;
- break;
- }
+ case TIMESTAMP_INSTANT:
+ case DATE:
case BINARY: {
- ret.isValid = ret.needConvert = readType.getKind() == STRING ||
- readType.getKind() == CHAR ||
- readType.getKind() == VARCHAR;
+ // Not support
break;
}
case STRUCT:
diff --git a/c++/src/StripeStream.cc b/c++/src/StripeStream.cc
index 1f43da4f2..6b95a4dc4 100644
--- a/c++/src/StripeStream.cc
+++ b/c++/src/StripeStream.cc
@@ -129,6 +129,10 @@ namespace orc {
return reader.getForcedScaleOnHive11Decimal();
}
+ const SchemaEvolution* StripeStreamsImpl::getSchemaEvolution() const {
+ return reader.getSchemaEvolution();
+ }
+
void StripeInformationImpl::ensureStripeFooterLoaded() const {
if (stripeFooter.get() == nullptr) {
std::unique_ptr<SeekableInputStream> pbStream =
diff --git a/c++/src/StripeStream.hh b/c++/src/StripeStream.hh
index 74bebda6f..a3b748c6e 100644
--- a/c++/src/StripeStream.hh
+++ b/c++/src/StripeStream.hh
@@ -77,6 +77,8 @@ namespace orc {
bool isDecimalAsLong() const override;
int32_t getForcedScaleOnHive11Decimal() const override;
+
+ const SchemaEvolution* getSchemaEvolution() const override;
};
/**
diff --git a/c++/src/TypeImpl.cc b/c++/src/TypeImpl.cc
index 0075d0478..7e9af806f 100644
--- a/c++/src/TypeImpl.cc
+++ b/c++/src/TypeImpl.cc
@@ -307,15 +307,17 @@ namespace orc {
}
}
case LONG:
- case DATE:
+ case DATE: {
return std::make_unique<LongVectorBatch>(capacity, memoryPool);
+ }
case FLOAT:
if (useTightNumericVector) {
return std::make_unique<FloatVectorBatch>(capacity, memoryPool);
}
- case DOUBLE:
+ case DOUBLE: {
return std::make_unique<DoubleVectorBatch>(capacity, memoryPool);
+ }
case STRING:
case BINARY:
diff --git a/c++/test/CMakeLists.txt b/c++/test/CMakeLists.txt
index 387ce9dbf..ead2f5e4a 100644
--- a/c++/test/CMakeLists.txt
+++ b/c++/test/CMakeLists.txt
@@ -26,6 +26,7 @@ set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${CXX17_FLAGS}
${WARN_FLAGS}")
add_executable (orc-test
MemoryInputStream.cc
MemoryOutputStream.cc
+ MockStripeStreams.cc
TestAttributes.cc
TestBlockBuffer.cc
TestBufferedOutputStream.cc
@@ -36,6 +37,7 @@ add_executable (orc-test
TestColumnReader.cc
TestColumnStatistics.cc
TestCompression.cc
+ TestConvertColumnReader.cc
TestDecompression.cc
TestDecimal.cc
TestDictionaryEncoding.cc
@@ -50,6 +52,7 @@ add_executable (orc-test
TestRLEV2Util.cc
TestSargsApplier.cc
TestSearchArgument.cc
+ TestSchemaEvolution.cc
TestStripeIndexStatistics.cc
TestTimestampStatistics.cc
TestTimezone.cc
diff --git a/c++/test/MockStripeStreams.cc b/c++/test/MockStripeStreams.cc
new file mode 100644
index 000000000..edd6c0d76
--- /dev/null
+++ b/c++/test/MockStripeStreams.cc
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MockStripeStreams.hh"
+
+namespace orc {
+ MemoryPool& MockStripeStreams::getMemoryPool() const {
+ return *getDefaultPool();
+ }
+
+ ReaderMetrics* MockStripeStreams::getReaderMetrics() const {
+ return getDefaultReaderMetrics();
+ }
+
+ const Timezone& MockStripeStreams::getWriterTimezone() const {
+ return getTimezoneByName("America/Los_Angeles");
+ }
+
+ const Timezone& MockStripeStreams::getReaderTimezone() const {
+ return getTimezoneByName("GMT");
+ }
+
+ std::unique_ptr<SeekableInputStream> MockStripeStreams::getStream(uint64_t
columnId,
+
proto::Stream_Kind kind,
+ bool
stream) const {
+ return std::unique_ptr<SeekableInputStream>(getStreamProxy(columnId, kind,
stream));
+ }
+
+} // namespace orc
diff --git a/c++/test/MockStripeStreams.hh b/c++/test/MockStripeStreams.hh
new file mode 100644
index 000000000..dd32ad599
--- /dev/null
+++ b/c++/test/MockStripeStreams.hh
@@ -0,0 +1,56 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef ORC_MOCKSTRIPESTREAM_HH
+#define ORC_MOCKSTRIPESTREAM_HH
+
+#include "ColumnReader.hh"
+
+#include "wrap/gmock.h"
+#include "wrap/gtest-wrapper.h"
+#include "wrap/orc-proto-wrapper.hh"
+
+namespace orc {
+ class MockStripeStreams : public StripeStreams {
+ public:
+ ~MockStripeStreams() override {}
+
+ std::unique_ptr<SeekableInputStream> getStream(uint64_t columnId,
proto::Stream_Kind kind,
+ bool stream) const override;
+
+ MOCK_CONST_METHOD0(getSelectedColumns, const std::vector<bool>());
+ MOCK_CONST_METHOD1(getEncoding, proto::ColumnEncoding(uint64_t));
+ MOCK_CONST_METHOD3(getStreamProxy, SeekableInputStream*(uint64_t,
proto::Stream_Kind, bool));
+ MOCK_CONST_METHOD0(getErrorStream, std::ostream*());
+ MOCK_CONST_METHOD0(getThrowOnHive11DecimalOverflow, bool());
+ MOCK_CONST_METHOD0(getForcedScaleOnHive11Decimal, int32_t());
+ MOCK_CONST_METHOD0(isDecimalAsLong, bool());
+ MOCK_CONST_METHOD0(getSchemaEvolution, const SchemaEvolution*());
+
+ MemoryPool& getMemoryPool() const override;
+
+ ReaderMetrics* getReaderMetrics() const override;
+
+ const Timezone& getWriterTimezone() const override;
+
+ const Timezone& getReaderTimezone() const override;
+ };
+
+} // namespace orc
+
+#endif
diff --git a/c++/test/TestColumnReader.cc b/c++/test/TestColumnReader.cc
index 6230c2e50..ec02cabe9 100644
--- a/c++/test/TestColumnReader.cc
+++ b/c++/test/TestColumnReader.cc
@@ -18,13 +18,10 @@
#include "Adaptor.hh"
#include "ColumnReader.hh"
+#include "MockStripeStreams.hh"
#include "OrcTest.hh"
#include "orc/Exceptions.hh"
-#include "wrap/gmock.h"
-#include "wrap/gtest-wrapper.h"
-#include "wrap/orc-proto-wrapper.hh"
-
#include <cmath>
#include <iostream>
#include <vector>
@@ -41,52 +38,6 @@ namespace orc {
using ::testing::TestWithParam;
using ::testing::Values;
- class MockStripeStreams : public StripeStreams {
- public:
- ~MockStripeStreams() override;
-
- std::unique_ptr<SeekableInputStream> getStream(uint64_t columnId,
proto::Stream_Kind kind,
- bool stream) const override;
-
- MOCK_CONST_METHOD0(getSelectedColumns,
-
- const std::vector<bool>()
-
- );
- MOCK_CONST_METHOD1(getEncoding, proto::ColumnEncoding(uint64_t));
- MOCK_CONST_METHOD3(getStreamProxy, SeekableInputStream*(uint64_t,
proto::Stream_Kind, bool));
- MOCK_CONST_METHOD0(getErrorStream, std::ostream*());
- MOCK_CONST_METHOD0(getThrowOnHive11DecimalOverflow, bool());
- MOCK_CONST_METHOD0(getForcedScaleOnHive11Decimal, int32_t());
- MOCK_CONST_METHOD0(isDecimalAsLong, bool());
-
- MemoryPool& getMemoryPool() const override {
- return *getDefaultPool();
- }
-
- ReaderMetrics* getReaderMetrics() const override {
- return getDefaultReaderMetrics();
- }
-
- const Timezone& getWriterTimezone() const override {
- return getTimezoneByName("America/Los_Angeles");
- }
-
- const Timezone& getReaderTimezone() const override {
- return getTimezoneByName("GMT");
- }
- };
-
- MockStripeStreams::~MockStripeStreams() {
- // PASS
- }
-
- std::unique_ptr<SeekableInputStream> MockStripeStreams::getStream(uint64_t
columnId,
-
proto::Stream_Kind kind,
- bool
shouldStream) const {
- return std::unique_ptr<SeekableInputStream>(getStreamProxy(columnId, kind,
shouldStream));
- }
-
bool isNotNull(tm* timeptr) {
return timeptr != nullptr;
}
diff --git a/c++/test/TestConvertColumnReader.cc
b/c++/test/TestConvertColumnReader.cc
new file mode 100644
index 000000000..c756845cf
--- /dev/null
+++ b/c++/test/TestConvertColumnReader.cc
@@ -0,0 +1,149 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "orc/Type.hh"
+#include "wrap/gtest-wrapper.h"
+
+#include "MockStripeStreams.hh"
+#include "OrcTest.hh"
+#include "SchemaEvolution.hh"
+
+#include "ConvertColumnReader.hh"
+#include "MemoryInputStream.hh"
+#include "MemoryOutputStream.hh"
+
+namespace orc {
+
+ static std::unique_ptr<Reader> createReader(MemoryPool& memoryPool,
+ std::unique_ptr<InputStream>
stream) {
+ ReaderOptions options;
+ options.setMemoryPool(memoryPool);
+ return createReader(std::move(stream), options);
+ }
+
+ TEST(ConvertColumnReader, betweenNumericWithoutOverflows) {
+ constexpr int DEFAULT_MEM_STREAM_SIZE = 10 * 1024;
+ constexpr int TEST_CASES = 1024;
+ MemoryOutputStream memStream(DEFAULT_MEM_STREAM_SIZE);
+ std::unique_ptr<Type> fileType(
+
Type::buildTypeFromString("struct<t1:boolean,t2:int,t3:double,t4:float>"));
+ std::shared_ptr<Type> readType(
+
Type::buildTypeFromString("struct<t1:int,t2:boolean,t3:bigint,t4:boolean>"));
+ WriterOptions options;
+ options.setUseTightNumericVector(true);
+ auto writer = createWriter(*fileType, &memStream, options);
+ auto batch = writer->createRowBatch(TEST_CASES);
+ auto& structBatch = dynamic_cast<StructVectorBatch&>(*batch);
+ auto& c0 = dynamic_cast<ByteVectorBatch&>(*structBatch.fields[0]);
+ auto& c1 = dynamic_cast<IntVectorBatch&>(*structBatch.fields[1]);
+ auto& c2 = dynamic_cast<DoubleVectorBatch&>(*structBatch.fields[2]);
+ auto& c3 = dynamic_cast<FloatVectorBatch&>(*structBatch.fields[3]);
+
+ structBatch.numElements = c0.numElements = c1.numElements = c2.numElements
= c3.numElements =
+ TEST_CASES;
+
+ for (size_t i = 0; i < TEST_CASES; i++) {
+ c0.data[i] = i % 2 || i % 3 ? true : false;
+ c1.data[i] = static_cast<int>((TEST_CASES / 2 - i) * TEST_CASES);
+ c2.data[i] = static_cast<double>(TEST_CASES - i) / (TEST_CASES / 2);
+ c3.data[i] = static_cast<float>(TEST_CASES - i) / (TEST_CASES / 2);
+ }
+
+ writer->add(*batch);
+ writer->close();
+
+ auto inStream = std::make_unique<MemoryInputStream>(memStream.getData(),
memStream.getLength());
+ auto pool = getDefaultPool();
+ auto reader = createReader(*pool, std::move(inStream));
+ RowReaderOptions rowReaderOpts;
+ rowReaderOpts.setReadType(readType);
+ rowReaderOpts.setUseTightNumericVector(true);
+ auto rowReader = reader->createRowReader(rowReaderOpts);
+ auto readBatch = rowReader->createRowBatch(TEST_CASES);
+ EXPECT_EQ(true, rowReader->next(*readBatch));
+ auto& readStructBatch = dynamic_cast<StructVectorBatch&>(*readBatch);
+ auto& readC0 = dynamic_cast<IntVectorBatch&>(*readStructBatch.fields[0]);
+ auto& readC1 = dynamic_cast<ByteVectorBatch&>(*readStructBatch.fields[1]);
+ auto& readC2 = dynamic_cast<LongVectorBatch&>(*readStructBatch.fields[2]);
+ auto& readC3 = dynamic_cast<ByteVectorBatch&>(*readStructBatch.fields[3]);
+
+ for (size_t i = 0; i < TEST_CASES; i++) {
+ EXPECT_EQ(readC0.data[i], i % 2 || i % 3 ? 1 : 0);
+ EXPECT_TRUE(readC1.data[i] == true || i == TEST_CASES / 2);
+ EXPECT_EQ(readC2.data[i],
+ i > TEST_CASES / 2 ? 0 : static_cast<int64_t>((TEST_CASES - i)
/ (TEST_CASES / 2)));
+ EXPECT_TRUE(readC3.data[i] == true || i > TEST_CASES / 2);
+ }
+
+ rowReaderOpts.setUseTightNumericVector(false);
+ rowReader = reader->createRowReader(rowReaderOpts);
+ readBatch = rowReader->createRowBatch(TEST_CASES);
+ EXPECT_THROW(rowReader->next(*readBatch), SchemaEvolutionError);
+ }
+
+ TEST(ConvertColumnReader, betweenNumricOverflows) {
+ constexpr int DEFAULT_MEM_STREAM_SIZE = 10 * 1024;
+ MemoryOutputStream memStream(DEFAULT_MEM_STREAM_SIZE);
+ std::unique_ptr<Type> fileType(
+ Type::buildTypeFromString("struct<t1:double,t2:bigint,t3:bigint>"));
+ std::shared_ptr<Type>
readType(Type::buildTypeFromString("struct<t1:int,t2:int,t3:float>"));
+ WriterOptions options;
+ auto writer = createWriter(*fileType, &memStream, options);
+ auto batch = writer->createRowBatch(2);
+ auto& structBatch = dynamic_cast<StructVectorBatch&>(*batch);
+ auto& c0 = dynamic_cast<DoubleVectorBatch&>(*structBatch.fields[0]);
+ auto& c1 = dynamic_cast<LongVectorBatch&>(*structBatch.fields[1]);
+ auto& c2 = dynamic_cast<LongVectorBatch&>(*structBatch.fields[2]);
+
+ c0.data[0] = 1e35;
+ c0.data[1] = 1e9 + 7;
+ c1.data[0] = (1LL << 31);
+ c1.data[1] = (1LL << 31) - 1;
+ c2.data[0] = (1LL << 62) + 112312;
+ c2.data[1] = (1LL << 20) + 77553;
+
+ structBatch.numElements = c0.numElements = c1.numElements = c2.numElements
= 2;
+ writer->add(*batch);
+ writer->close();
+
+ auto inStream = std::make_unique<MemoryInputStream>(memStream.getData(),
memStream.getLength());
+ auto pool = getDefaultPool();
+ auto reader = createReader(*pool, std::move(inStream));
+ RowReaderOptions rowReaderOpts;
+ rowReaderOpts.setReadType(readType);
+ rowReaderOpts.setUseTightNumericVector(true);
+ auto rowReader = reader->createRowReader(rowReaderOpts);
+ auto readBatch = rowReader->createRowBatch(2);
+ EXPECT_EQ(true, rowReader->next(*readBatch));
+ auto& readStructBatch = dynamic_cast<StructVectorBatch&>(*readBatch);
+ auto& readC0 = dynamic_cast<IntVectorBatch&>(*readStructBatch.fields[0]);
+ auto& readC1 = dynamic_cast<IntVectorBatch&>(*readStructBatch.fields[1]);
+ auto& readC2 = dynamic_cast<FloatVectorBatch&>(*readStructBatch.fields[2]);
+ EXPECT_EQ(readC0.notNull[0], false);
+ EXPECT_EQ(readC1.notNull[0], false);
+ EXPECT_EQ(readC2.notNull[0], true);
+ EXPECT_TRUE(readC0.notNull[1]);
+ EXPECT_TRUE(readC1.notNull[1]);
+ EXPECT_TRUE(readC2.notNull[1]);
+
+ rowReaderOpts.throwOnSchemaEvolutionOverflow(true);
+ rowReader = reader->createRowReader(rowReaderOpts);
+ readBatch = rowReader->createRowBatch(2);
+ EXPECT_THROW(rowReader->next(*readBatch), SchemaEvolutionError);
+ }
+} // namespace orc
diff --git a/c++/test/TestSchemaEvolution.cc b/c++/test/TestSchemaEvolution.cc
new file mode 100644
index 000000000..9f6f776dc
--- /dev/null
+++ b/c++/test/TestSchemaEvolution.cc
@@ -0,0 +1,99 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MockStripeStreams.hh"
+#include "SchemaEvolution.hh"
+#include "TypeImpl.hh"
+
+#include "wrap/gtest-wrapper.h"
+
+namespace orc {
+
+ bool testConvertReader(const std::string file, const std::string& read, bool
can, bool need) {
+ auto fileType =
std::shared_ptr<Type>(Type::buildTypeFromString(file).release());
+ auto readType =
std::shared_ptr<Type>(Type::buildTypeFromString(read).release());
+
+ if (!can) {
+ EXPECT_THROW(SchemaEvolution(readType, fileType.get()),
SchemaEvolutionError)
+ << "fileType: " << fileType->toString() << "\nreadType: " <<
readType->toString();
+ } else {
+ // if can convert, check that there are no excepted be thrown and
+ // we can create reader successfully
+ SchemaEvolution se(readType, fileType.get());
+ EXPECT_FALSE(se.needConvert(*fileType));
+ EXPECT_EQ(need, se.needConvert(*fileType->getSubtype(0)))
+ << "fileType: " << fileType->toString() << "\nreadType: " <<
readType->toString();
+ MockStripeStreams streams;
+ std::vector<bool> selectedColumns(2);
+ EXPECT_CALL(streams,
getSelectedColumns()).WillRepeatedly(testing::Return(selectedColumns));
+ proto::ColumnEncoding directEncoding;
+ directEncoding.set_kind(proto::ColumnEncoding_Kind_DIRECT);
+ EXPECT_CALL(streams,
getEncoding(testing::_)).WillRepeatedly(testing::Return(directEncoding));
+
+ EXPECT_CALL(streams, getStreamProxy(testing::_, testing::_, testing::_))
+ .WillRepeatedly(testing::Return(nullptr));
+
+ std::string dummyStream("dummy");
+ ON_CALL(streams, getStreamProxy(1, proto::Stream_Kind_SECONDARY,
testing::_))
+ .WillByDefault(testing::Return(
+ new SeekableArrayInputStream(dummyStream.c_str(),
dummyStream.length())));
+
+ EXPECT_CALL(streams,
getSchemaEvolution()).WillRepeatedly(testing::Return(&se));
+
+ EXPECT_TRUE(buildReader(*fileType, streams) != nullptr);
+ }
+ return true;
+ }
+
+ TEST(SchemaEvolution, createConvertReader) {
+ std::map<size_t, std::string> types = {
+ {0, "struct<t1:boolean>"}, {1, "struct<t1:tinyint>"},
+ {2, "struct<t1:smallint>"}, {3, "struct<t1:int>"},
+ {4, "struct<t1:bigint>"}, {5, "struct<t1:float>"},
+ {6, "struct<t1:double>"}, {7, "struct<t1:string>"},
+ {8, "struct<t1:char(5)>"}, {9, "struct<t1:varchar(5)>"},
+ {10, "struct<t1:char(3)>"}, {11, "struct<t1:varchar(3)>"},
+ {12, "struct<t1:decimal(25,2)>"}, {13, "struct<t1:decimal(15,2)>"},
+ {14, "struct<t1:timestamp>"}, {15, "struct<t1:timestamp with local
time zone>"},
+ {16, "struct<t1:date>"}};
+
+ size_t typesSize = types.size();
+ std::vector<std::vector<bool>> needConvert(typesSize,
std::vector<bool>(typesSize, 0));
+ std::vector<std::vector<bool>> canConvert(typesSize,
std::vector<bool>(typesSize, 0));
+
+ // all types can convert to itselfs
+ for (size_t i = 0; i < types.size(); i++) {
+ canConvert[i][i] = true;
+ }
+
+ // conversion from numeric to numeric
+ for (size_t i = 0; i <= 6; i++) {
+ for (size_t j = 0; j <= 6; j++) {
+ canConvert[i][j] = true;
+ needConvert[i][j] = (i != j);
+ }
+ }
+
+ for (size_t i = 0; i < typesSize; i++) {
+ for (size_t j = 0; j < typesSize; j++) {
+ testConvertReader(types[i], types[j], canConvert[i][j],
needConvert[i][j]);
+ }
+ }
+ }
+
+} // namespace orc