This is an automated email from the ASF dual-hosted git repository.
ffacs pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/orc.git
The following commit(s) were added to refs/heads/main by this push:
new 5d8e70280 ORC-1389: [C++] Support schema evolution from string group
to numeric/string group
5d8e70280 is described below
commit 5d8e702809badd9de2f057834d601916b38d3f08
Author: ffacs <[email protected]>
AuthorDate: Thu May 16 22:37:55 2024 +0800
ORC-1389: [C++] Support schema evolution from string group to
numeric/string group
### What changes were proposed in this pull request?
Support conversion from
{string, char, varchar}
to
{boolean, byte, short, int, long, float, double, string, char, varchar}
### Why are the changes needed?
To support schema evolution on c++ side.
### How was this patch tested?
UT passed
### Was this patch authored or co-authored using generative AI tooling?
NO
Closes #1931 from ffacs/ORC-1389.
Authored-by: ffacs <[email protected]>
Signed-off-by: ffacs <[email protected]>
---
c++/src/ColumnWriter.cc | 70 +---------
c++/src/ConvertColumnReader.cc | 259 ++++++++++++++++++++++++++++++++----
c++/src/SchemaEvolution.cc | 7 +-
c++/src/Utils.hh | 70 ++++++++++
c++/test/TestConvertColumnReader.cc | 161 ++++++++++++++++++++++
c++/test/TestSchemaEvolution.cc | 16 +++
6 files changed, 485 insertions(+), 98 deletions(-)
diff --git a/c++/src/ColumnWriter.cc b/c++/src/ColumnWriter.cc
index 86e30ce90..05ffd3a2d 100644
--- a/c++/src/ColumnWriter.cc
+++ b/c++/src/ColumnWriter.cc
@@ -24,6 +24,7 @@
#include "RLE.hh"
#include "Statistics.hh"
#include "Timezone.hh"
+#include "Utils.hh"
namespace orc {
StreamsFactory::~StreamsFactory() {
@@ -1356,75 +1357,6 @@ namespace orc {
deleteDictStreams();
}
- struct Utf8Utils {
- /**
- * Counts how many utf-8 chars of the input data
- */
- static uint64_t charLength(const char* data, uint64_t length) {
- uint64_t chars = 0;
- for (uint64_t i = 0; i < length; i++) {
- if (isUtfStartByte(data[i])) {
- chars++;
- }
- }
- return chars;
- }
-
- /**
- * Return the number of bytes required to read at most maxCharLength
- * characters in full from a utf-8 encoded byte array provided
- * by data. This does not validate utf-8 data, but
- * operates correctly on already valid utf-8 data.
- *
- * @param maxCharLength number of characters required
- * @param data the bytes of UTF-8
- * @param length the length of data to truncate
- */
- static uint64_t truncateBytesTo(uint64_t maxCharLength, const char* data,
uint64_t length) {
- uint64_t chars = 0;
- if (length <= maxCharLength) {
- return length;
- }
- for (uint64_t i = 0; i < length; i++) {
- if (isUtfStartByte(data[i])) {
- chars++;
- }
- if (chars > maxCharLength) {
- return i;
- }
- }
- // everything fits
- return length;
- }
-
- /**
- * Checks if b is the first byte of a UTF-8 character.
- */
- inline static bool isUtfStartByte(char b) {
- return (b & 0xC0) != 0x80;
- }
-
- /**
- * Find the start of the last character that ends in the current string.
- * @param text the bytes of the utf-8
- * @param from the first byte location
- * @param until the last byte location
- * @return the index of the last character
- */
- static uint64_t findLastCharacter(const char* text, uint64_t from,
uint64_t until) {
- uint64_t posn = until;
- /* we don't expect characters more than 5 bytes */
- while (posn >= from) {
- if (isUtfStartByte(text[posn])) {
- return posn;
- }
- posn -= 1;
- }
- /* beginning of a valid char not found */
- throw std::logic_error("Could not truncate string, beginning of a valid
char not found");
- }
- };
-
class CharColumnWriter : public StringColumnWriter {
public:
CharColumnWriter(const Type& type, const StreamsFactory& factory, const
WriterOptions& options)
diff --git a/c++/src/ConvertColumnReader.cc b/c++/src/ConvertColumnReader.cc
index 67ee6d6c4..a24b8cb05 100644
--- a/c++/src/ConvertColumnReader.cc
+++ b/c++/src/ConvertColumnReader.cc
@@ -17,6 +17,7 @@
*/
#include "ConvertColumnReader.hh"
+#include "Utils.hh"
namespace orc {
@@ -694,6 +695,112 @@ namespace orc {
const int32_t scale_;
};
+ template <typename ReadTypeBatch, typename ReadType>
+ class StringVariantToNumericColumnReader : public ConvertColumnReader {
+ public:
+ StringVariantToNumericColumnReader(const Type& readType, const Type&
fileType,
+ StripeStreams& stripe, bool
throwOnOverflow)
+ : ConvertColumnReader(readType, fileType, stripe, throwOnOverflow) {}
+
+ void next(ColumnVectorBatch& rowBatch, uint64_t numValues, char* notNull)
override {
+ ConvertColumnReader::next(rowBatch, numValues, notNull);
+
+ const auto& srcBatch = *SafeCastBatchTo<const
StringVectorBatch*>(data.get());
+ auto& dstBatch = *SafeCastBatchTo<ReadTypeBatch*>(&rowBatch);
+ for (uint64_t i = 0; i < numValues; ++i) {
+ if (!rowBatch.hasNulls || rowBatch.notNull[i]) {
+ if constexpr (std::is_floating_point_v<ReadType>) {
+ convertToDouble(dstBatch, srcBatch, i);
+ } else {
+ convertToInteger(dstBatch, srcBatch, i);
+ }
+ }
+ }
+ }
+
+ private:
+ void convertToInteger(ReadTypeBatch& dstBatch, const StringVectorBatch&
srcBatch,
+ uint64_t idx) {
+ int64_t longValue = 0;
+ try {
+ longValue = std::stoll(std::string(srcBatch.data[idx],
srcBatch.length[idx]));
+ } catch (...) {
+ handleOverflow<std::string, ReadType>(dstBatch, idx, throwOnOverflow);
+ return;
+ }
+ if constexpr (std::is_same_v<ReadType, bool>) {
+ dstBatch.data[idx] = longValue == 0 ? 0 : 1;
+ } else {
+ if (!downCastToInteger(dstBatch.data[idx], longValue)) {
+ handleOverflow<std::string, ReadType>(dstBatch, idx,
throwOnOverflow);
+ }
+ }
+ }
+
+ void convertToDouble(ReadTypeBatch& dstBatch, const StringVectorBatch&
srcBatch, uint64_t idx) {
+ try {
+ if constexpr (std::is_same_v<ReadType, float>) {
+ dstBatch.data[idx] = std::stof(std::string(srcBatch.data[idx],
srcBatch.length[idx]));
+ } else {
+ dstBatch.data[idx] = std::stod(std::string(srcBatch.data[idx],
srcBatch.length[idx]));
+ }
+ } catch (...) {
+ handleOverflow<std::string, ReadType>(dstBatch, idx, throwOnOverflow);
+ }
+ }
+ };
+
+ class StringVariantConvertColumnReader : public
ConvertToStringVariantColumnReader {
+ public:
+ StringVariantConvertColumnReader(const Type& readType, const Type&
fileType,
+ StripeStreams& stripe, bool
throwOnOverflow)
+ : ConvertToStringVariantColumnReader(readType, fileType, stripe,
throwOnOverflow) {}
+
+ uint64_t convertToStrBuffer(ColumnVectorBatch& rowBatch, uint64_t
numValues) override {
+ uint64_t size = 0;
+ strBuffer.resize(numValues);
+ const auto& srcBatch = *SafeCastBatchTo<const
StringVectorBatch*>(data.get());
+ const auto maxLength = readType.getMaximumLength();
+ if (readType.getKind() == STRING) {
+ for (uint64_t i = 0; i < numValues; ++i) {
+ if (!rowBatch.hasNulls || rowBatch.notNull[i]) {
+ strBuffer[i] = std::string(srcBatch.data[i], srcBatch.length[i]);
+ size += strBuffer[i].size();
+ }
+ }
+ } else if (readType.getKind() == VARCHAR) {
+ for (uint64_t i = 0; i < numValues; ++i) {
+ if (!rowBatch.hasNulls || rowBatch.notNull[i]) {
+ const char* charData = srcBatch.data[i];
+ uint64_t originLength = srcBatch.length[i];
+ uint64_t itemLength = Utf8Utils::truncateBytesTo(maxLength,
charData, originLength);
+ strBuffer[i] = std::string(charData, itemLength);
+ size += strBuffer[i].length();
+ }
+ }
+ } else if (readType.getKind() == CHAR) {
+ for (uint64_t i = 0; i < numValues; ++i) {
+ if (!rowBatch.hasNulls || rowBatch.notNull[i]) {
+ const char* charData = srcBatch.data[i];
+ uint64_t originLength = srcBatch.length[i];
+ uint64_t charLength = Utf8Utils::charLength(charData,
originLength);
+ auto itemLength = Utf8Utils::truncateBytesTo(maxLength, charData,
originLength);
+ strBuffer[i] = std::string(srcBatch.data[i], itemLength);
+ // the padding is exactly 1 byte per char
+ if (charLength < maxLength) {
+ strBuffer[i].resize(itemLength + maxLength - charLength, ' ');
+ }
+ size += strBuffer[i].length();
+ }
+ }
+ } else {
+ throw SchemaEvolutionError("Invalid type for numeric to string
conversion: " +
+ readType.toString());
+ }
+ return size;
+ }
+ };
+
#define DEFINE_NUMERIC_CONVERT_READER(FROM, TO, TYPE) \
using FROM##To##TO##ColumnReader = \
NumericConvertColumnReader<FROM##VectorBatch, TO##VectorBatch, TYPE>;
@@ -730,6 +837,12 @@ namespace orc {
using Decimal64To##TO##ColumnReader =
DecimalToStringVariantColumnReader<Decimal64VectorBatch>; \
using Decimal128To##TO##ColumnReader =
DecimalToStringVariantColumnReader<Decimal128VectorBatch>;
+#define DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(FROM, TO, TYPE) \
+ using FROM##To##TO##ColumnReader =
StringVariantToNumericColumnReader<TO##VectorBatch, TYPE>;
+
+#define DEFINE_STRING_VARIANT_CONVERT_READER(FROM, TO) \
+ using FROM##To##TO##ColumnReader = StringVariantConvertColumnReader;
+
DEFINE_NUMERIC_CONVERT_READER(Boolean, Byte, int8_t)
DEFINE_NUMERIC_CONVERT_READER(Boolean, Short, int16_t)
DEFINE_NUMERIC_CONVERT_READER(Boolean, Int, int32_t)
@@ -834,8 +947,41 @@ namespace orc {
DEFINE_DECIMAL_CONVERT_TO_STRING_VARINT_READER(Char)
DEFINE_DECIMAL_CONVERT_TO_STRING_VARINT_READER(Varchar)
+ // String variant to numeric
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Boolean, bool)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Byte, int8_t)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Short, int16_t)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Int, int32_t)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Long, int64_t)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Float, float)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Double, double)
+
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Boolean, bool)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Byte, int8_t)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Short, int16_t)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Int, int32_t)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Long, int64_t)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Float, float)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Double, double)
+
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Boolean, bool)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Byte, int8_t)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Short, int16_t)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Int, int32_t)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Long, int64_t)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Float, float)
+ DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Double, double)
+
+ // String variant to string variant
+ DEFINE_STRING_VARIANT_CONVERT_READER(String, Char)
+ DEFINE_STRING_VARIANT_CONVERT_READER(String, Varchar)
+ DEFINE_STRING_VARIANT_CONVERT_READER(Char, String)
+ DEFINE_STRING_VARIANT_CONVERT_READER(Char, Varchar)
+ DEFINE_STRING_VARIANT_CONVERT_READER(Varchar, String)
+ DEFINE_STRING_VARIANT_CONVERT_READER(Varchar, Char)
+
#define CREATE_READER(NAME) \
- return std::make_unique<NAME>(_readType, fileType, stripe, throwOnOverflow);
+ return std::make_unique<NAME>(readType, fileType, stripe, throwOnOverflow);
#define CASE_CREATE_READER(TYPE, CONVERT) \
case TYPE: \
@@ -858,7 +1004,7 @@ namespace orc {
#define CASE_CREATE_DECIMAL_READER(FROM) \
case DECIMAL: { \
- if (isDecimal64(_readType)) { \
+ if (isDecimal64(readType)) { \
CREATE_READER(FROM##ToDecimal64ColumnReader) \
} else { \
CREATE_READER(FROM##ToDecimal128ColumnReader) \
@@ -868,7 +1014,7 @@ namespace orc {
#define CASE_EXCEPTION
\
default:
\
throw SchemaEvolutionError("Cannot convert from " + fileType.toString() +
" to " + \
- _readType.toString());
+ readType.toString());
std::unique_ptr<ColumnReader> buildConvertReader(const Type& fileType,
StripeStreams& stripe,
bool useTightNumericVector,
@@ -878,11 +1024,11 @@ namespace orc {
"SchemaEvolution only support tight vector, please create
ColumnVectorBatch with "
"option useTightNumericVector");
}
- const auto& _readType =
*stripe.getSchemaEvolution()->getReadType(fileType);
+ const auto& readType = *stripe.getSchemaEvolution()->getReadType(fileType);
switch (fileType.getKind()) {
case BOOLEAN: {
- switch (_readType.getKind()) {
+ switch (readType.getKind()) {
CASE_CREATE_READER(BYTE, BooleanToByte)
CASE_CREATE_READER(SHORT, BooleanToShort)
CASE_CREATE_READER(INT, BooleanToInt)
@@ -906,7 +1052,7 @@ namespace orc {
}
}
case BYTE: {
- switch (_readType.getKind()) {
+ switch (readType.getKind()) {
CASE_CREATE_READER(BOOLEAN, ByteToBoolean)
CASE_CREATE_READER(SHORT, ByteToShort)
CASE_CREATE_READER(INT, ByteToInt)
@@ -930,7 +1076,7 @@ namespace orc {
}
}
case SHORT: {
- switch (_readType.getKind()) {
+ switch (readType.getKind()) {
CASE_CREATE_READER(BOOLEAN, ShortToBoolean)
CASE_CREATE_READER(BYTE, ShortToByte)
CASE_CREATE_READER(INT, ShortToInt)
@@ -954,7 +1100,7 @@ namespace orc {
}
}
case INT: {
- switch (_readType.getKind()) {
+ switch (readType.getKind()) {
CASE_CREATE_READER(BOOLEAN, IntToBoolean)
CASE_CREATE_READER(BYTE, IntToByte)
CASE_CREATE_READER(SHORT, IntToShort)
@@ -978,7 +1124,7 @@ namespace orc {
}
}
case LONG: {
- switch (_readType.getKind()) {
+ switch (readType.getKind()) {
CASE_CREATE_READER(BOOLEAN, LongToBoolean)
CASE_CREATE_READER(BYTE, LongToByte)
CASE_CREATE_READER(SHORT, LongToShort)
@@ -1002,7 +1148,7 @@ namespace orc {
}
}
case FLOAT: {
- switch (_readType.getKind()) {
+ switch (readType.getKind()) {
CASE_CREATE_READER(BOOLEAN, FloatToBoolean)
CASE_CREATE_READER(BYTE, FloatToByte)
CASE_CREATE_READER(SHORT, FloatToShort)
@@ -1026,7 +1172,7 @@ namespace orc {
}
}
case DOUBLE: {
- switch (_readType.getKind()) {
+ switch (readType.getKind()) {
CASE_CREATE_READER(BOOLEAN, DoubleToBoolean)
CASE_CREATE_READER(BYTE, DoubleToByte)
CASE_CREATE_READER(SHORT, DoubleToShort)
@@ -1050,7 +1196,7 @@ namespace orc {
}
}
case DECIMAL: {
- switch (_readType.getKind()) {
+ switch (readType.getKind()) {
CASE_CREATE_FROM_DECIMAL_READER(BOOLEAN, Boolean)
CASE_CREATE_FROM_DECIMAL_READER(BYTE, Byte)
CASE_CREATE_FROM_DECIMAL_READER(SHORT, Short)
@@ -1065,13 +1211,13 @@ namespace orc {
CASE_CREATE_FROM_DECIMAL_READER(TIMESTAMP_INSTANT, Timestamp)
case DECIMAL: {
if (isDecimal64(fileType)) {
- if (isDecimal64(_readType)) {
+ if (isDecimal64(readType)) {
CREATE_READER(Decimal64ToDecimal64ColumnReader)
} else {
CREATE_READER(Decimal64ToDecimal128ColumnReader)
}
} else {
- if (isDecimal64(_readType)) {
+ if (isDecimal64(readType)) {
CREATE_READER(Decimal128ToDecimal64ColumnReader)
} else {
CREATE_READER(Decimal128ToDecimal128ColumnReader)
@@ -1087,7 +1233,78 @@ namespace orc {
CASE_EXCEPTION
}
}
- case STRING:
+ case STRING: {
+ switch (readType.getKind()) {
+ CASE_CREATE_READER(BOOLEAN, StringToBoolean)
+ CASE_CREATE_READER(BYTE, StringToByte)
+ CASE_CREATE_READER(SHORT, StringToShort)
+ CASE_CREATE_READER(INT, StringToInt)
+ CASE_CREATE_READER(LONG, StringToLong)
+ CASE_CREATE_READER(FLOAT, StringToFloat)
+ CASE_CREATE_READER(DOUBLE, StringToDouble)
+ CASE_CREATE_READER(CHAR, StringToChar)
+ CASE_CREATE_READER(VARCHAR, StringToVarchar)
+ case STRING:
+ case BINARY:
+ case TIMESTAMP:
+ case LIST:
+ case MAP:
+ case STRUCT:
+ case UNION:
+ case DATE:
+ case TIMESTAMP_INSTANT:
+ case DECIMAL:
+ CASE_EXCEPTION
+ }
+ }
+ case CHAR: {
+ switch (readType.getKind()) {
+ CASE_CREATE_READER(BOOLEAN, CharToBoolean)
+ CASE_CREATE_READER(BYTE, CharToByte)
+ CASE_CREATE_READER(SHORT, CharToShort)
+ CASE_CREATE_READER(INT, CharToInt)
+ CASE_CREATE_READER(LONG, CharToLong)
+ CASE_CREATE_READER(FLOAT, CharToFloat)
+ CASE_CREATE_READER(DOUBLE, CharToDouble)
+ CASE_CREATE_READER(STRING, CharToString)
+ CASE_CREATE_READER(VARCHAR, CharToVarchar)
+ case CHAR:
+ case BINARY:
+ case TIMESTAMP:
+ case LIST:
+ case MAP:
+ case STRUCT:
+ case UNION:
+ case DATE:
+ case TIMESTAMP_INSTANT:
+ case DECIMAL:
+ CASE_EXCEPTION
+ }
+ }
+ case VARCHAR: {
+ switch (readType.getKind()) {
+ CASE_CREATE_READER(BOOLEAN, VarcharToBoolean)
+ CASE_CREATE_READER(BYTE, VarcharToByte)
+ CASE_CREATE_READER(SHORT, VarcharToShort)
+ CASE_CREATE_READER(INT, VarcharToInt)
+ CASE_CREATE_READER(LONG, VarcharToLong)
+ CASE_CREATE_READER(FLOAT, VarcharToFloat)
+ CASE_CREATE_READER(DOUBLE, VarcharToDouble)
+ CASE_CREATE_READER(STRING, VarcharToString)
+ CASE_CREATE_READER(CHAR, VarcharToChar)
+ case VARCHAR:
+ case BINARY:
+ case TIMESTAMP:
+ case LIST:
+ case MAP:
+ case STRUCT:
+ case UNION:
+ case DATE:
+ case TIMESTAMP_INSTANT:
+ case DECIMAL:
+ CASE_EXCEPTION
+ }
+ }
case BINARY:
case TIMESTAMP:
case LIST:
@@ -1095,21 +1312,9 @@ namespace orc {
case STRUCT:
case UNION:
case DATE:
- case VARCHAR:
- case CHAR:
case TIMESTAMP_INSTANT:
CASE_EXCEPTION
}
}
-#undef DEFINE_NUMERIC_CONVERT_READER
-#undef DEFINE_NUMERIC_CONVERT_TO_STRING_VARINT_READER
-#undef DEFINE_NUMERIC_CONVERT_TO_DECIMAL_READER
-#undef DEFINE_NUMERIC_CONVERT_TO_TIMESTAMP_READER
-#undef DEFINE_DECIMAL_CONVERT_TO_NUMERIC_READER
-#undef DEFINE_DECIMAL_CONVERT_TO_DECIMAL_READER
-#undef CASE_CREATE_FROM_DECIMAL_READER
-#undef CASE_CREATE_READER
-#undef CASE_EXCEPTION
-
} // namespace orc
diff --git a/c++/src/SchemaEvolution.cc b/c++/src/SchemaEvolution.cc
index 4099818ff..ab4007309 100644
--- a/c++/src/SchemaEvolution.cc
+++ b/c++/src/SchemaEvolution.cc
@@ -80,7 +80,7 @@ namespace orc {
if (readType.getKind() == fileType.getKind()) {
ret.isValid = true;
if (fileType.getKind() == CHAR || fileType.getKind() == VARCHAR) {
- ret.isValid = readType.getMaximumLength() ==
fileType.getMaximumLength();
+ ret.needConvert = readType.getMaximumLength() !=
fileType.getMaximumLength();
} else if (fileType.getKind() == DECIMAL) {
ret.needConvert = readType.getPrecision() != fileType.getPrecision() ||
readType.getScale() != fileType.getScale();
@@ -105,7 +105,10 @@ namespace orc {
}
case STRING:
case CHAR:
- case VARCHAR:
+ case VARCHAR: {
+ ret.isValid = ret.needConvert = isStringVariant(readType) ||
isNumeric(readType);
+ break;
+ }
case TIMESTAMP:
case TIMESTAMP_INSTANT:
case DATE:
diff --git a/c++/src/Utils.hh b/c++/src/Utils.hh
index 4a609788f..851d0af15 100644
--- a/c++/src/Utils.hh
+++ b/c++/src/Utils.hh
@@ -21,6 +21,7 @@
#include <atomic>
#include <chrono>
+#include <stdexcept>
namespace orc {
@@ -70,6 +71,75 @@ namespace orc {
#define SCOPED_MINUS_STOPWATCH(METRICS_PTR, LATENCY_VAR)
#endif
+ struct Utf8Utils {
+ /**
+ * Counts how many utf-8 chars of the input data
+ */
+ static uint64_t charLength(const char* data, uint64_t length) {
+ uint64_t chars = 0;
+ for (uint64_t i = 0; i < length; i++) {
+ if (isUtfStartByte(data[i])) {
+ chars++;
+ }
+ }
+ return chars;
+ }
+
+ /**
+ * Return the number of bytes required to read at most maxCharLength
+ * characters in full from a utf-8 encoded byte array provided
+ * by data. This does not validate utf-8 data, but
+ * operates correctly on already valid utf-8 data.
+ *
+ * @param maxCharLength number of characters required
+ * @param data the bytes of UTF-8
+ * @param length the length of data to truncate
+ */
+ static uint64_t truncateBytesTo(uint64_t maxCharLength, const char* data,
uint64_t length) {
+ uint64_t chars = 0;
+ if (length <= maxCharLength) {
+ return length;
+ }
+ for (uint64_t i = 0; i < length; i++) {
+ if (isUtfStartByte(data[i])) {
+ chars++;
+ }
+ if (chars > maxCharLength) {
+ return i;
+ }
+ }
+ // everything fits
+ return length;
+ }
+
+ /**
+ * Checks if b is the first byte of a UTF-8 character.
+ */
+ inline static bool isUtfStartByte(char b) {
+ return (b & 0xC0) != 0x80;
+ }
+
+ /**
+ * Find the start of the last character that ends in the current string.
+ * @param text the bytes of the utf-8
+ * @param from the first byte location
+ * @param until the last byte location
+ * @return the index of the last character
+ */
+ static uint64_t findLastCharacter(const char* text, uint64_t from,
uint64_t until) {
+ uint64_t posn = until;
+ /* we don't expect characters more than 5 bytes */
+ while (posn >= from) {
+ if (isUtfStartByte(text[posn])) {
+ return posn;
+ }
+ posn -= 1;
+ }
+ /* beginning of a valid char not found */
+ throw std::logic_error("Could not truncate string, beginning of a valid
char not found");
+ }
+ };
+
} // namespace orc
#endif
diff --git a/c++/test/TestConvertColumnReader.cc
b/c++/test/TestConvertColumnReader.cc
index 83798289d..f9f7ac61d 100644
--- a/c++/test/TestConvertColumnReader.cc
+++ b/c++/test/TestConvertColumnReader.cc
@@ -815,4 +815,165 @@ namespace orc {
}
}
+ TEST(ConvertColumnReader, TestConvertStringVariantToNumeric) {
+ constexpr int DEFAULT_MEM_STREAM_SIZE = 10 * 1024 * 1024;
+ constexpr int TEST_CASES = 6;
+ MemoryOutputStream memStream(DEFAULT_MEM_STREAM_SIZE);
+ std::unique_ptr<Type> fileType(
+
Type::buildTypeFromString("struct<c1:char(25),c2:varchar(25),c3:string>"));
+ std::shared_ptr<Type>
readType(Type::buildTypeFromString("struct<c1:boolean,c2:int,c3:float>"));
+ WriterOptions options;
+ auto writer = createWriter(*fileType, &memStream, options);
+ auto batch = writer->createRowBatch(TEST_CASES);
+ auto structBatch = dynamic_cast<StructVectorBatch*>(batch.get());
+ auto& c1 = dynamic_cast<StringVectorBatch&>(*structBatch->fields[0]);
+ auto& c2 = dynamic_cast<StringVectorBatch&>(*structBatch->fields[1]);
+ auto& c3 = dynamic_cast<StringVectorBatch&>(*structBatch->fields[2]);
+ std::vector<std::string> raw1{"", "123456", "0", "-1234567890",
"999999999999999999999999",
+ "error"};
+ std::vector<std::string> raw2{"", "123456", "0", "-1234567890",
"999999999999999999999999",
+ "error"};
+ std::vector<std::string> raw3{
+ "", "123456", "-0.0", "-123456789.0123",
"1000000000000000000000000000000000000000",
+ "error"};
+
+ c1.notNull[0] = c2.notNull[0] = c3.notNull[0] = false;
+ for (int i = 1; i < TEST_CASES; i++) {
+ c1.data[i] = raw1[i].data();
+ c1.length[i] = raw1[i].length();
+ c1.notNull[i] = true;
+
+ c2.data[i] = raw2[i].data();
+ c2.length[i] = raw2[i].length();
+ c2.notNull[i] = true;
+
+ c3.data[i] = raw3[i].data();
+ c3.length[i] = raw3[i].length();
+ c3.notNull[i] = true;
+ }
+
+ structBatch->numElements = c1.numElements = c2.numElements =
c3.numElements = TEST_CASES;
+ structBatch->hasNulls = c1.hasNulls = c2.hasNulls = c3.hasNulls = true;
+ writer->add(*batch);
+ writer->close();
+ auto inStream = std::make_unique<MemoryInputStream>(memStream.getData(),
memStream.getLength());
+ auto pool = getDefaultPool();
+ auto reader = createReader(*pool, std::move(inStream));
+ RowReaderOptions rowReaderOptions;
+ rowReaderOptions.setUseTightNumericVector(true);
+ rowReaderOptions.setReadType(readType);
+ auto rowReader = reader->createRowReader(rowReaderOptions);
+ auto readBatch = rowReader->createRowBatch(TEST_CASES);
+ EXPECT_EQ(true, rowReader->next(*readBatch));
+
+ auto& readSturctBatch = dynamic_cast<StructVectorBatch&>(*readBatch);
+ auto& readC1 =
dynamic_cast<BooleanVectorBatch&>(*readSturctBatch.fields[0]);
+ auto& readC2 = dynamic_cast<IntVectorBatch&>(*readSturctBatch.fields[1]);
+ auto& readC3 = dynamic_cast<FloatVectorBatch&>(*readSturctBatch.fields[2]);
+
+ EXPECT_FALSE(readC1.notNull[0]);
+ EXPECT_FALSE(readC2.notNull[0]);
+ EXPECT_FALSE(readC3.notNull[0]);
+
+ for (int i = 1; i < 4; i++) {
+ EXPECT_TRUE(readC1.notNull[i]);
+ EXPECT_TRUE(readC2.notNull[i]);
+ EXPECT_TRUE(readC3.notNull[i]);
+ }
+
+ for (int i = 4; i <= 5; i++) {
+ EXPECT_FALSE(readC1.notNull[i]) << i;
+ EXPECT_FALSE(readC2.notNull[i]) << i;
+ EXPECT_FALSE(readC3.notNull[i]) << i;
+ }
+
+ EXPECT_EQ(readC1.data[1], 1);
+ EXPECT_EQ(readC2.data[1], 123456);
+ EXPECT_FLOAT_EQ(readC3.data[1], 123456);
+
+ EXPECT_EQ(readC1.data[2], 0);
+ EXPECT_EQ(readC2.data[2], 0);
+ EXPECT_FLOAT_EQ(readC3.data[2], -0.0);
+
+ EXPECT_EQ(readC1.data[3], 1);
+ EXPECT_EQ(readC2.data[3], -1234567890);
+ EXPECT_FLOAT_EQ(readC3.data[3], -123456789.0123);
+ }
+
+ TEST(ConvertColumnReader, TestConvertStringVariant) {
+ constexpr int DEFAULT_MEM_STREAM_SIZE = 10 * 1024 * 1024;
+ constexpr int TEST_CASES = 4;
+ MemoryOutputStream memStream(DEFAULT_MEM_STREAM_SIZE);
+ std::unique_ptr<Type> fileType(
+
Type::buildTypeFromString("struct<c1:char(5),c2:varchar(5),c3:string>"));
+ std::shared_ptr<Type> readType(
+
Type::buildTypeFromString("struct<c1:string,c2:char(4),c3:varchar(4)>"));
+ WriterOptions options;
+ auto writer = createWriter(*fileType, &memStream, options);
+ auto batch = writer->createRowBatch(TEST_CASES);
+ auto structBatch = dynamic_cast<StructVectorBatch*>(batch.get());
+ auto& c1 = dynamic_cast<StringVectorBatch&>(*structBatch->fields[0]);
+ auto& c2 = dynamic_cast<StringVectorBatch&>(*structBatch->fields[1]);
+ auto& c3 = dynamic_cast<StringVectorBatch&>(*structBatch->fields[2]);
+
+ std::vector<std::string> raw1{"", "12345", "1", "1234"};
+ std::vector<std::string> raw2{"", "12345", "1", "1234"};
+ std::vector<std::string> raw3{"", "12345", "1", "1234"};
+
+ c1.notNull[0] = c2.notNull[0] = c3.notNull[0] = false;
+ for (int i = 1; i < TEST_CASES; i++) {
+ c1.data[i] = raw1[i].data();
+ c1.length[i] = raw1[i].length();
+ c1.notNull[i] = true;
+
+ c2.data[i] = raw2[i].data();
+ c2.length[i] = raw2[i].length();
+ c2.notNull[i] = true;
+
+ c3.data[i] = raw3[i].data();
+ c3.length[i] = raw3[i].length();
+ c3.notNull[i] = true;
+ }
+ structBatch->numElements = c1.numElements = c2.numElements =
c3.numElements = TEST_CASES;
+ structBatch->hasNulls = c1.hasNulls = c2.hasNulls = c3.hasNulls = true;
+ writer->add(*batch);
+ writer->close();
+ auto inStream = std::make_unique<MemoryInputStream>(memStream.getData(),
memStream.getLength());
+ auto pool = getDefaultPool();
+ auto reader = createReader(*pool, std::move(inStream));
+ RowReaderOptions rowReaderOptions;
+ rowReaderOptions.setUseTightNumericVector(true);
+ rowReaderOptions.setReadType(readType);
+ auto rowReader = reader->createRowReader(rowReaderOptions);
+ auto readBatch = rowReader->createRowBatch(TEST_CASES);
+ EXPECT_EQ(true, rowReader->next(*readBatch));
+
+ auto& readSturctBatch = dynamic_cast<StructVectorBatch&>(*readBatch);
+ auto& readC1 =
dynamic_cast<StringVectorBatch&>(*readSturctBatch.fields[0]);
+ auto& readC2 =
dynamic_cast<StringVectorBatch&>(*readSturctBatch.fields[1]);
+ auto& readC3 =
dynamic_cast<StringVectorBatch&>(*readSturctBatch.fields[2]);
+
+ EXPECT_FALSE(readC1.notNull[0]);
+ EXPECT_FALSE(readC2.notNull[0]);
+ EXPECT_FALSE(readC3.notNull[0]);
+
+ for (int i = 1; i < TEST_CASES; i++) {
+ EXPECT_TRUE(readC1.notNull[i]);
+ EXPECT_TRUE(readC2.notNull[i]);
+ EXPECT_TRUE(readC3.notNull[i]);
+ }
+
+ EXPECT_EQ(std::string(readC1.data[1], readC1.length[1]), "12345");
+ EXPECT_EQ(std::string(readC2.data[1], readC2.length[1]), "1234");
+ EXPECT_EQ(std::string(readC3.data[1], readC3.length[1]), "1234");
+
+ EXPECT_EQ(std::string(readC1.data[2], readC1.length[2]), "1 ");
+ EXPECT_EQ(std::string(readC2.data[2], readC2.length[2]), "1 ");
+ EXPECT_EQ(std::string(readC3.data[2], readC3.length[2]), "1");
+
+ EXPECT_EQ(std::string(readC1.data[3], readC1.length[3]), "1234 ");
+ EXPECT_EQ(std::string(readC2.data[3], readC2.length[3]), "1234");
+ EXPECT_EQ(std::string(readC3.data[3], readC3.length[3]), "1234");
+ }
+
} // namespace orc
diff --git a/c++/test/TestSchemaEvolution.cc b/c++/test/TestSchemaEvolution.cc
index c52ba009f..12001fca6 100644
--- a/c++/test/TestSchemaEvolution.cc
+++ b/c++/test/TestSchemaEvolution.cc
@@ -148,6 +148,22 @@ namespace orc {
}
}
+ // conversion from string variant to numeric
+ for (size_t i = 7; i <= 11; i++) {
+ for (size_t j = 0; j <= 6; j++) {
+ canConvert[i][j] = true;
+ needConvert[i][j] = true;
+ }
+ }
+
+ // conversion from string variant to string variant
+ for (size_t i = 7; i <= 11; i++) {
+ for (size_t j = 7; j <= 11; j++) {
+ canConvert[i][j] = true;
+ needConvert[i][j] = (i != j);
+ }
+ }
+
for (size_t i = 0; i < typesSize; i++) {
for (size_t j = 0; j < typesSize; j++) {
testConvertReader(types[i], types[j], canConvert[i][j],
needConvert[i][j]);