This is an automated email from the ASF dual-hosted git repository.

ffacs pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/orc.git


The following commit(s) were added to refs/heads/main by this push:
     new 5d8e70280 ORC-1389: [C++] Support schema evolution from string group 
to numeric/string group
5d8e70280 is described below

commit 5d8e702809badd9de2f057834d601916b38d3f08
Author: ffacs <[email protected]>
AuthorDate: Thu May 16 22:37:55 2024 +0800

    ORC-1389: [C++] Support schema evolution from string group to 
numeric/string group
    
    ### What changes were proposed in this pull request?
    Support conversion from
    
    {string, char, varchar}
    to
    {boolean, byte, short, int, long, float, double, string, char, varchar}
    
    ### Why are the changes needed?
    To support schema evolution on c++ side.
    
    ### How was this patch tested?
    UT passed
    
    ### Was this patch authored or co-authored using generative AI tooling?
    NO
    
    Closes #1931 from ffacs/ORC-1389.
    
    Authored-by: ffacs <[email protected]>
    Signed-off-by: ffacs <[email protected]>
---
 c++/src/ColumnWriter.cc             |  70 +---------
 c++/src/ConvertColumnReader.cc      | 259 ++++++++++++++++++++++++++++++++----
 c++/src/SchemaEvolution.cc          |   7 +-
 c++/src/Utils.hh                    |  70 ++++++++++
 c++/test/TestConvertColumnReader.cc | 161 ++++++++++++++++++++++
 c++/test/TestSchemaEvolution.cc     |  16 +++
 6 files changed, 485 insertions(+), 98 deletions(-)

diff --git a/c++/src/ColumnWriter.cc b/c++/src/ColumnWriter.cc
index 86e30ce90..05ffd3a2d 100644
--- a/c++/src/ColumnWriter.cc
+++ b/c++/src/ColumnWriter.cc
@@ -24,6 +24,7 @@
 #include "RLE.hh"
 #include "Statistics.hh"
 #include "Timezone.hh"
+#include "Utils.hh"
 
 namespace orc {
   StreamsFactory::~StreamsFactory() {
@@ -1356,75 +1357,6 @@ namespace orc {
     deleteDictStreams();
   }
 
-  struct Utf8Utils {
-    /**
-     * Counts how many utf-8 chars of the input data
-     */
-    static uint64_t charLength(const char* data, uint64_t length) {
-      uint64_t chars = 0;
-      for (uint64_t i = 0; i < length; i++) {
-        if (isUtfStartByte(data[i])) {
-          chars++;
-        }
-      }
-      return chars;
-    }
-
-    /**
-     * Return the number of bytes required to read at most maxCharLength
-     * characters in full from a utf-8 encoded byte array provided
-     * by data. This does not validate utf-8 data, but
-     * operates correctly on already valid utf-8 data.
-     *
-     * @param maxCharLength number of characters required
-     * @param data the bytes of UTF-8
-     * @param length the length of data to truncate
-     */
-    static uint64_t truncateBytesTo(uint64_t maxCharLength, const char* data, 
uint64_t length) {
-      uint64_t chars = 0;
-      if (length <= maxCharLength) {
-        return length;
-      }
-      for (uint64_t i = 0; i < length; i++) {
-        if (isUtfStartByte(data[i])) {
-          chars++;
-        }
-        if (chars > maxCharLength) {
-          return i;
-        }
-      }
-      // everything fits
-      return length;
-    }
-
-    /**
-     * Checks if b is the first byte of a UTF-8 character.
-     */
-    inline static bool isUtfStartByte(char b) {
-      return (b & 0xC0) != 0x80;
-    }
-
-    /**
-     * Find the start of the last character that ends in the current string.
-     * @param text the bytes of the utf-8
-     * @param from the first byte location
-     * @param until the last byte location
-     * @return the index of the last character
-     */
-    static uint64_t findLastCharacter(const char* text, uint64_t from, 
uint64_t until) {
-      uint64_t posn = until;
-      /* we don't expect characters more than 5 bytes */
-      while (posn >= from) {
-        if (isUtfStartByte(text[posn])) {
-          return posn;
-        }
-        posn -= 1;
-      }
-      /* beginning of a valid char not found */
-      throw std::logic_error("Could not truncate string, beginning of a valid 
char not found");
-    }
-  };
-
   class CharColumnWriter : public StringColumnWriter {
    public:
     CharColumnWriter(const Type& type, const StreamsFactory& factory, const 
WriterOptions& options)
diff --git a/c++/src/ConvertColumnReader.cc b/c++/src/ConvertColumnReader.cc
index 67ee6d6c4..a24b8cb05 100644
--- a/c++/src/ConvertColumnReader.cc
+++ b/c++/src/ConvertColumnReader.cc
@@ -17,6 +17,7 @@
  */
 
 #include "ConvertColumnReader.hh"
+#include "Utils.hh"
 
 namespace orc {
 
@@ -694,6 +695,112 @@ namespace orc {
     const int32_t scale_;
   };
 
+  template <typename ReadTypeBatch, typename ReadType>
+  class StringVariantToNumericColumnReader : public ConvertColumnReader {
+   public:
+    StringVariantToNumericColumnReader(const Type& readType, const Type& 
fileType,
+                                       StripeStreams& stripe, bool 
throwOnOverflow)
+        : ConvertColumnReader(readType, fileType, stripe, throwOnOverflow) {}
+
+    void next(ColumnVectorBatch& rowBatch, uint64_t numValues, char* notNull) 
override {
+      ConvertColumnReader::next(rowBatch, numValues, notNull);
+
+      const auto& srcBatch = *SafeCastBatchTo<const 
StringVectorBatch*>(data.get());
+      auto& dstBatch = *SafeCastBatchTo<ReadTypeBatch*>(&rowBatch);
+      for (uint64_t i = 0; i < numValues; ++i) {
+        if (!rowBatch.hasNulls || rowBatch.notNull[i]) {
+          if constexpr (std::is_floating_point_v<ReadType>) {
+            convertToDouble(dstBatch, srcBatch, i);
+          } else {
+            convertToInteger(dstBatch, srcBatch, i);
+          }
+        }
+      }
+    }
+
+   private:
+    void convertToInteger(ReadTypeBatch& dstBatch, const StringVectorBatch& 
srcBatch,
+                          uint64_t idx) {
+      int64_t longValue = 0;
+      try {
+        longValue = std::stoll(std::string(srcBatch.data[idx], 
srcBatch.length[idx]));
+      } catch (...) {
+        handleOverflow<std::string, ReadType>(dstBatch, idx, throwOnOverflow);
+        return;
+      }
+      if constexpr (std::is_same_v<ReadType, bool>) {
+        dstBatch.data[idx] = longValue == 0 ? 0 : 1;
+      } else {
+        if (!downCastToInteger(dstBatch.data[idx], longValue)) {
+          handleOverflow<std::string, ReadType>(dstBatch, idx, 
throwOnOverflow);
+        }
+      }
+    }
+
+    void convertToDouble(ReadTypeBatch& dstBatch, const StringVectorBatch& 
srcBatch, uint64_t idx) {
+      try {
+        if constexpr (std::is_same_v<ReadType, float>) {
+          dstBatch.data[idx] = std::stof(std::string(srcBatch.data[idx], 
srcBatch.length[idx]));
+        } else {
+          dstBatch.data[idx] = std::stod(std::string(srcBatch.data[idx], 
srcBatch.length[idx]));
+        }
+      } catch (...) {
+        handleOverflow<std::string, ReadType>(dstBatch, idx, throwOnOverflow);
+      }
+    }
+  };
+
+  class StringVariantConvertColumnReader : public 
ConvertToStringVariantColumnReader {
+   public:
+    StringVariantConvertColumnReader(const Type& readType, const Type& 
fileType,
+                                     StripeStreams& stripe, bool 
throwOnOverflow)
+        : ConvertToStringVariantColumnReader(readType, fileType, stripe, 
throwOnOverflow) {}
+
+    uint64_t convertToStrBuffer(ColumnVectorBatch& rowBatch, uint64_t 
numValues) override {
+      uint64_t size = 0;
+      strBuffer.resize(numValues);
+      const auto& srcBatch = *SafeCastBatchTo<const 
StringVectorBatch*>(data.get());
+      const auto maxLength = readType.getMaximumLength();
+      if (readType.getKind() == STRING) {
+        for (uint64_t i = 0; i < numValues; ++i) {
+          if (!rowBatch.hasNulls || rowBatch.notNull[i]) {
+            strBuffer[i] = std::string(srcBatch.data[i], srcBatch.length[i]);
+            size += strBuffer[i].size();
+          }
+        }
+      } else if (readType.getKind() == VARCHAR) {
+        for (uint64_t i = 0; i < numValues; ++i) {
+          if (!rowBatch.hasNulls || rowBatch.notNull[i]) {
+            const char* charData = srcBatch.data[i];
+            uint64_t originLength = srcBatch.length[i];
+            uint64_t itemLength = Utf8Utils::truncateBytesTo(maxLength, 
charData, originLength);
+            strBuffer[i] = std::string(charData, itemLength);
+            size += strBuffer[i].length();
+          }
+        }
+      } else if (readType.getKind() == CHAR) {
+        for (uint64_t i = 0; i < numValues; ++i) {
+          if (!rowBatch.hasNulls || rowBatch.notNull[i]) {
+            const char* charData = srcBatch.data[i];
+            uint64_t originLength = srcBatch.length[i];
+            uint64_t charLength = Utf8Utils::charLength(charData, 
originLength);
+            auto itemLength = Utf8Utils::truncateBytesTo(maxLength, charData, 
originLength);
+            strBuffer[i] = std::string(srcBatch.data[i], itemLength);
+            // the padding is exactly 1 byte per char
+            if (charLength < maxLength) {
+              strBuffer[i].resize(itemLength + maxLength - charLength, ' ');
+            }
+            size += strBuffer[i].length();
+          }
+        }
+      } else {
+        throw SchemaEvolutionError("Invalid type for numeric to string 
conversion: " +
+                                   readType.toString());
+      }
+      return size;
+    }
+  };
+
 #define DEFINE_NUMERIC_CONVERT_READER(FROM, TO, TYPE) \
   using FROM##To##TO##ColumnReader =                  \
       NumericConvertColumnReader<FROM##VectorBatch, TO##VectorBatch, TYPE>;
@@ -730,6 +837,12 @@ namespace orc {
   using Decimal64To##TO##ColumnReader = 
DecimalToStringVariantColumnReader<Decimal64VectorBatch>; \
   using Decimal128To##TO##ColumnReader = 
DecimalToStringVariantColumnReader<Decimal128VectorBatch>;
 
+#define DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(FROM, TO, TYPE) \
+  using FROM##To##TO##ColumnReader = 
StringVariantToNumericColumnReader<TO##VectorBatch, TYPE>;
+
+#define DEFINE_STRING_VARIANT_CONVERT_READER(FROM, TO) \
+  using FROM##To##TO##ColumnReader = StringVariantConvertColumnReader;
+
   DEFINE_NUMERIC_CONVERT_READER(Boolean, Byte, int8_t)
   DEFINE_NUMERIC_CONVERT_READER(Boolean, Short, int16_t)
   DEFINE_NUMERIC_CONVERT_READER(Boolean, Int, int32_t)
@@ -834,8 +947,41 @@ namespace orc {
   DEFINE_DECIMAL_CONVERT_TO_STRING_VARINT_READER(Char)
   DEFINE_DECIMAL_CONVERT_TO_STRING_VARINT_READER(Varchar)
 
+  // String variant to numeric
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Boolean, bool)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Byte, int8_t)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Short, int16_t)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Int, int32_t)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Long, int64_t)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Float, float)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(String, Double, double)
+
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Boolean, bool)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Byte, int8_t)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Short, int16_t)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Int, int32_t)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Long, int64_t)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Float, float)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Char, Double, double)
+
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Boolean, bool)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Byte, int8_t)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Short, int16_t)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Int, int32_t)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Long, int64_t)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Float, float)
+  DEFINE_STRING_VARIANT_CONVERT_TO_NUMERIC_READER(Varchar, Double, double)
+
+  // String variant to string variant
+  DEFINE_STRING_VARIANT_CONVERT_READER(String, Char)
+  DEFINE_STRING_VARIANT_CONVERT_READER(String, Varchar)
+  DEFINE_STRING_VARIANT_CONVERT_READER(Char, String)
+  DEFINE_STRING_VARIANT_CONVERT_READER(Char, Varchar)
+  DEFINE_STRING_VARIANT_CONVERT_READER(Varchar, String)
+  DEFINE_STRING_VARIANT_CONVERT_READER(Varchar, Char)
+
 #define CREATE_READER(NAME) \
-  return std::make_unique<NAME>(_readType, fileType, stripe, throwOnOverflow);
+  return std::make_unique<NAME>(readType, fileType, stripe, throwOnOverflow);
 
 #define CASE_CREATE_READER(TYPE, CONVERT) \
   case TYPE:                              \
@@ -858,7 +1004,7 @@ namespace orc {
 
 #define CASE_CREATE_DECIMAL_READER(FROM)            \
   case DECIMAL: {                                   \
-    if (isDecimal64(_readType)) {                   \
+    if (isDecimal64(readType)) {                    \
       CREATE_READER(FROM##ToDecimal64ColumnReader)  \
     } else {                                        \
       CREATE_READER(FROM##ToDecimal128ColumnReader) \
@@ -868,7 +1014,7 @@ namespace orc {
 #define CASE_EXCEPTION                                                         
        \
   default:                                                                     
        \
     throw SchemaEvolutionError("Cannot convert from " + fileType.toString() + 
" to " + \
-                               _readType.toString());
+                               readType.toString());
 
   std::unique_ptr<ColumnReader> buildConvertReader(const Type& fileType, 
StripeStreams& stripe,
                                                    bool useTightNumericVector,
@@ -878,11 +1024,11 @@ namespace orc {
           "SchemaEvolution only support tight vector, please create 
ColumnVectorBatch with "
           "option useTightNumericVector");
     }
-    const auto& _readType = 
*stripe.getSchemaEvolution()->getReadType(fileType);
+    const auto& readType = *stripe.getSchemaEvolution()->getReadType(fileType);
 
     switch (fileType.getKind()) {
       case BOOLEAN: {
-        switch (_readType.getKind()) {
+        switch (readType.getKind()) {
           CASE_CREATE_READER(BYTE, BooleanToByte)
           CASE_CREATE_READER(SHORT, BooleanToShort)
           CASE_CREATE_READER(INT, BooleanToInt)
@@ -906,7 +1052,7 @@ namespace orc {
         }
       }
       case BYTE: {
-        switch (_readType.getKind()) {
+        switch (readType.getKind()) {
           CASE_CREATE_READER(BOOLEAN, ByteToBoolean)
           CASE_CREATE_READER(SHORT, ByteToShort)
           CASE_CREATE_READER(INT, ByteToInt)
@@ -930,7 +1076,7 @@ namespace orc {
         }
       }
       case SHORT: {
-        switch (_readType.getKind()) {
+        switch (readType.getKind()) {
           CASE_CREATE_READER(BOOLEAN, ShortToBoolean)
           CASE_CREATE_READER(BYTE, ShortToByte)
           CASE_CREATE_READER(INT, ShortToInt)
@@ -954,7 +1100,7 @@ namespace orc {
         }
       }
       case INT: {
-        switch (_readType.getKind()) {
+        switch (readType.getKind()) {
           CASE_CREATE_READER(BOOLEAN, IntToBoolean)
           CASE_CREATE_READER(BYTE, IntToByte)
           CASE_CREATE_READER(SHORT, IntToShort)
@@ -978,7 +1124,7 @@ namespace orc {
         }
       }
       case LONG: {
-        switch (_readType.getKind()) {
+        switch (readType.getKind()) {
           CASE_CREATE_READER(BOOLEAN, LongToBoolean)
           CASE_CREATE_READER(BYTE, LongToByte)
           CASE_CREATE_READER(SHORT, LongToShort)
@@ -1002,7 +1148,7 @@ namespace orc {
         }
       }
       case FLOAT: {
-        switch (_readType.getKind()) {
+        switch (readType.getKind()) {
           CASE_CREATE_READER(BOOLEAN, FloatToBoolean)
           CASE_CREATE_READER(BYTE, FloatToByte)
           CASE_CREATE_READER(SHORT, FloatToShort)
@@ -1026,7 +1172,7 @@ namespace orc {
         }
       }
       case DOUBLE: {
-        switch (_readType.getKind()) {
+        switch (readType.getKind()) {
           CASE_CREATE_READER(BOOLEAN, DoubleToBoolean)
           CASE_CREATE_READER(BYTE, DoubleToByte)
           CASE_CREATE_READER(SHORT, DoubleToShort)
@@ -1050,7 +1196,7 @@ namespace orc {
         }
       }
       case DECIMAL: {
-        switch (_readType.getKind()) {
+        switch (readType.getKind()) {
           CASE_CREATE_FROM_DECIMAL_READER(BOOLEAN, Boolean)
           CASE_CREATE_FROM_DECIMAL_READER(BYTE, Byte)
           CASE_CREATE_FROM_DECIMAL_READER(SHORT, Short)
@@ -1065,13 +1211,13 @@ namespace orc {
           CASE_CREATE_FROM_DECIMAL_READER(TIMESTAMP_INSTANT, Timestamp)
           case DECIMAL: {
             if (isDecimal64(fileType)) {
-              if (isDecimal64(_readType)) {
+              if (isDecimal64(readType)) {
                 CREATE_READER(Decimal64ToDecimal64ColumnReader)
               } else {
                 CREATE_READER(Decimal64ToDecimal128ColumnReader)
               }
             } else {
-              if (isDecimal64(_readType)) {
+              if (isDecimal64(readType)) {
                 CREATE_READER(Decimal128ToDecimal64ColumnReader)
               } else {
                 CREATE_READER(Decimal128ToDecimal128ColumnReader)
@@ -1087,7 +1233,78 @@ namespace orc {
             CASE_EXCEPTION
         }
       }
-      case STRING:
+      case STRING: {
+        switch (readType.getKind()) {
+          CASE_CREATE_READER(BOOLEAN, StringToBoolean)
+          CASE_CREATE_READER(BYTE, StringToByte)
+          CASE_CREATE_READER(SHORT, StringToShort)
+          CASE_CREATE_READER(INT, StringToInt)
+          CASE_CREATE_READER(LONG, StringToLong)
+          CASE_CREATE_READER(FLOAT, StringToFloat)
+          CASE_CREATE_READER(DOUBLE, StringToDouble)
+          CASE_CREATE_READER(CHAR, StringToChar)
+          CASE_CREATE_READER(VARCHAR, StringToVarchar)
+          case STRING:
+          case BINARY:
+          case TIMESTAMP:
+          case LIST:
+          case MAP:
+          case STRUCT:
+          case UNION:
+          case DATE:
+          case TIMESTAMP_INSTANT:
+          case DECIMAL:
+            CASE_EXCEPTION
+        }
+      }
+      case CHAR: {
+        switch (readType.getKind()) {
+          CASE_CREATE_READER(BOOLEAN, CharToBoolean)
+          CASE_CREATE_READER(BYTE, CharToByte)
+          CASE_CREATE_READER(SHORT, CharToShort)
+          CASE_CREATE_READER(INT, CharToInt)
+          CASE_CREATE_READER(LONG, CharToLong)
+          CASE_CREATE_READER(FLOAT, CharToFloat)
+          CASE_CREATE_READER(DOUBLE, CharToDouble)
+          CASE_CREATE_READER(STRING, CharToString)
+          CASE_CREATE_READER(VARCHAR, CharToVarchar)
+          case CHAR:
+          case BINARY:
+          case TIMESTAMP:
+          case LIST:
+          case MAP:
+          case STRUCT:
+          case UNION:
+          case DATE:
+          case TIMESTAMP_INSTANT:
+          case DECIMAL:
+            CASE_EXCEPTION
+        }
+      }
+      case VARCHAR: {
+        switch (readType.getKind()) {
+          CASE_CREATE_READER(BOOLEAN, VarcharToBoolean)
+          CASE_CREATE_READER(BYTE, VarcharToByte)
+          CASE_CREATE_READER(SHORT, VarcharToShort)
+          CASE_CREATE_READER(INT, VarcharToInt)
+          CASE_CREATE_READER(LONG, VarcharToLong)
+          CASE_CREATE_READER(FLOAT, VarcharToFloat)
+          CASE_CREATE_READER(DOUBLE, VarcharToDouble)
+          CASE_CREATE_READER(STRING, VarcharToString)
+          CASE_CREATE_READER(CHAR, VarcharToChar)
+          case VARCHAR:
+          case BINARY:
+          case TIMESTAMP:
+          case LIST:
+          case MAP:
+          case STRUCT:
+          case UNION:
+          case DATE:
+          case TIMESTAMP_INSTANT:
+          case DECIMAL:
+            CASE_EXCEPTION
+        }
+      }
       case BINARY:
       case TIMESTAMP:
       case LIST:
@@ -1095,21 +1312,9 @@ namespace orc {
       case STRUCT:
       case UNION:
       case DATE:
-      case VARCHAR:
-      case CHAR:
       case TIMESTAMP_INSTANT:
         CASE_EXCEPTION
     }
   }
 
-#undef DEFINE_NUMERIC_CONVERT_READER
-#undef DEFINE_NUMERIC_CONVERT_TO_STRING_VARINT_READER
-#undef DEFINE_NUMERIC_CONVERT_TO_DECIMAL_READER
-#undef DEFINE_NUMERIC_CONVERT_TO_TIMESTAMP_READER
-#undef DEFINE_DECIMAL_CONVERT_TO_NUMERIC_READER
-#undef DEFINE_DECIMAL_CONVERT_TO_DECIMAL_READER
-#undef CASE_CREATE_FROM_DECIMAL_READER
-#undef CASE_CREATE_READER
-#undef CASE_EXCEPTION
-
 }  // namespace orc
diff --git a/c++/src/SchemaEvolution.cc b/c++/src/SchemaEvolution.cc
index 4099818ff..ab4007309 100644
--- a/c++/src/SchemaEvolution.cc
+++ b/c++/src/SchemaEvolution.cc
@@ -80,7 +80,7 @@ namespace orc {
     if (readType.getKind() == fileType.getKind()) {
       ret.isValid = true;
       if (fileType.getKind() == CHAR || fileType.getKind() == VARCHAR) {
-        ret.isValid = readType.getMaximumLength() == 
fileType.getMaximumLength();
+        ret.needConvert = readType.getMaximumLength() != 
fileType.getMaximumLength();
       } else if (fileType.getKind() == DECIMAL) {
         ret.needConvert = readType.getPrecision() != fileType.getPrecision() ||
                           readType.getScale() != fileType.getScale();
@@ -105,7 +105,10 @@ namespace orc {
         }
         case STRING:
         case CHAR:
-        case VARCHAR:
+        case VARCHAR: {
+          ret.isValid = ret.needConvert = isStringVariant(readType) || 
isNumeric(readType);
+          break;
+        }
         case TIMESTAMP:
         case TIMESTAMP_INSTANT:
         case DATE:
diff --git a/c++/src/Utils.hh b/c++/src/Utils.hh
index 4a609788f..851d0af15 100644
--- a/c++/src/Utils.hh
+++ b/c++/src/Utils.hh
@@ -21,6 +21,7 @@
 
 #include <atomic>
 #include <chrono>
+#include <stdexcept>
 
 namespace orc {
 
@@ -70,6 +71,75 @@ namespace orc {
 #define SCOPED_MINUS_STOPWATCH(METRICS_PTR, LATENCY_VAR)
 #endif
 
+  struct Utf8Utils {
+    /**
+     * Counts how many utf-8 chars of the input data
+     */
+    static uint64_t charLength(const char* data, uint64_t length) {
+      uint64_t chars = 0;
+      for (uint64_t i = 0; i < length; i++) {
+        if (isUtfStartByte(data[i])) {
+          chars++;
+        }
+      }
+      return chars;
+    }
+
+    /**
+     * Return the number of bytes required to read at most maxCharLength
+     * characters in full from a utf-8 encoded byte array provided
+     * by data. This does not validate utf-8 data, but
+     * operates correctly on already valid utf-8 data.
+     *
+     * @param maxCharLength number of characters required
+     * @param data the bytes of UTF-8
+     * @param length the length of data to truncate
+     */
+    static uint64_t truncateBytesTo(uint64_t maxCharLength, const char* data, 
uint64_t length) {
+      uint64_t chars = 0;
+      if (length <= maxCharLength) {
+        return length;
+      }
+      for (uint64_t i = 0; i < length; i++) {
+        if (isUtfStartByte(data[i])) {
+          chars++;
+        }
+        if (chars > maxCharLength) {
+          return i;
+        }
+      }
+      // everything fits
+      return length;
+    }
+
+    /**
+     * Checks if b is the first byte of a UTF-8 character.
+     */
+    inline static bool isUtfStartByte(char b) {
+      return (b & 0xC0) != 0x80;
+    }
+
+    /**
+     * Find the start of the last character that ends in the current string.
+     * @param text the bytes of the utf-8
+     * @param from the first byte location
+     * @param until the last byte location
+     * @return the index of the last character
+     */
+    static uint64_t findLastCharacter(const char* text, uint64_t from, 
uint64_t until) {
+      uint64_t posn = until;
+      /* we don't expect characters more than 5 bytes */
+      while (posn >= from) {
+        if (isUtfStartByte(text[posn])) {
+          return posn;
+        }
+        posn -= 1;
+      }
+      /* beginning of a valid char not found */
+      throw std::logic_error("Could not truncate string, beginning of a valid 
char not found");
+    }
+  };
+
 }  // namespace orc
 
 #endif
diff --git a/c++/test/TestConvertColumnReader.cc 
b/c++/test/TestConvertColumnReader.cc
index 83798289d..f9f7ac61d 100644
--- a/c++/test/TestConvertColumnReader.cc
+++ b/c++/test/TestConvertColumnReader.cc
@@ -815,4 +815,165 @@ namespace orc {
     }
   }
 
+  TEST(ConvertColumnReader, TestConvertStringVariantToNumeric) {
+    constexpr int DEFAULT_MEM_STREAM_SIZE = 10 * 1024 * 1024;
+    constexpr int TEST_CASES = 6;
+    MemoryOutputStream memStream(DEFAULT_MEM_STREAM_SIZE);
+    std::unique_ptr<Type> fileType(
+        
Type::buildTypeFromString("struct<c1:char(25),c2:varchar(25),c3:string>"));
+    std::shared_ptr<Type> 
readType(Type::buildTypeFromString("struct<c1:boolean,c2:int,c3:float>"));
+    WriterOptions options;
+    auto writer = createWriter(*fileType, &memStream, options);
+    auto batch = writer->createRowBatch(TEST_CASES);
+    auto structBatch = dynamic_cast<StructVectorBatch*>(batch.get());
+    auto& c1 = dynamic_cast<StringVectorBatch&>(*structBatch->fields[0]);
+    auto& c2 = dynamic_cast<StringVectorBatch&>(*structBatch->fields[1]);
+    auto& c3 = dynamic_cast<StringVectorBatch&>(*structBatch->fields[2]);
+    std::vector<std::string> raw1{"",     "123456", "0", "-1234567890", 
"999999999999999999999999",
+                                  "error"};
+    std::vector<std::string> raw2{"",     "123456", "0", "-1234567890", 
"999999999999999999999999",
+                                  "error"};
+    std::vector<std::string> raw3{
+        "",     "123456", "-0.0", "-123456789.0123", 
"1000000000000000000000000000000000000000",
+        "error"};
+
+    c1.notNull[0] = c2.notNull[0] = c3.notNull[0] = false;
+    for (int i = 1; i < TEST_CASES; i++) {
+      c1.data[i] = raw1[i].data();
+      c1.length[i] = raw1[i].length();
+      c1.notNull[i] = true;
+
+      c2.data[i] = raw2[i].data();
+      c2.length[i] = raw2[i].length();
+      c2.notNull[i] = true;
+
+      c3.data[i] = raw3[i].data();
+      c3.length[i] = raw3[i].length();
+      c3.notNull[i] = true;
+    }
+
+    structBatch->numElements = c1.numElements = c2.numElements = 
c3.numElements = TEST_CASES;
+    structBatch->hasNulls = c1.hasNulls = c2.hasNulls = c3.hasNulls = true;
+    writer->add(*batch);
+    writer->close();
+    auto inStream = std::make_unique<MemoryInputStream>(memStream.getData(), 
memStream.getLength());
+    auto pool = getDefaultPool();
+    auto reader = createReader(*pool, std::move(inStream));
+    RowReaderOptions rowReaderOptions;
+    rowReaderOptions.setUseTightNumericVector(true);
+    rowReaderOptions.setReadType(readType);
+    auto rowReader = reader->createRowReader(rowReaderOptions);
+    auto readBatch = rowReader->createRowBatch(TEST_CASES);
+    EXPECT_EQ(true, rowReader->next(*readBatch));
+
+    auto& readSturctBatch = dynamic_cast<StructVectorBatch&>(*readBatch);
+    auto& readC1 = 
dynamic_cast<BooleanVectorBatch&>(*readSturctBatch.fields[0]);
+    auto& readC2 = dynamic_cast<IntVectorBatch&>(*readSturctBatch.fields[1]);
+    auto& readC3 = dynamic_cast<FloatVectorBatch&>(*readSturctBatch.fields[2]);
+
+    EXPECT_FALSE(readC1.notNull[0]);
+    EXPECT_FALSE(readC2.notNull[0]);
+    EXPECT_FALSE(readC3.notNull[0]);
+
+    for (int i = 1; i < 4; i++) {
+      EXPECT_TRUE(readC1.notNull[i]);
+      EXPECT_TRUE(readC2.notNull[i]);
+      EXPECT_TRUE(readC3.notNull[i]);
+    }
+
+    for (int i = 4; i <= 5; i++) {
+      EXPECT_FALSE(readC1.notNull[i]) << i;
+      EXPECT_FALSE(readC2.notNull[i]) << i;
+      EXPECT_FALSE(readC3.notNull[i]) << i;
+    }
+
+    EXPECT_EQ(readC1.data[1], 1);
+    EXPECT_EQ(readC2.data[1], 123456);
+    EXPECT_FLOAT_EQ(readC3.data[1], 123456);
+
+    EXPECT_EQ(readC1.data[2], 0);
+    EXPECT_EQ(readC2.data[2], 0);
+    EXPECT_FLOAT_EQ(readC3.data[2], -0.0);
+
+    EXPECT_EQ(readC1.data[3], 1);
+    EXPECT_EQ(readC2.data[3], -1234567890);
+    EXPECT_FLOAT_EQ(readC3.data[3], -123456789.0123);
+  }
+
+  TEST(ConvertColumnReader, TestConvertStringVariant) {
+    constexpr int DEFAULT_MEM_STREAM_SIZE = 10 * 1024 * 1024;
+    constexpr int TEST_CASES = 4;
+    MemoryOutputStream memStream(DEFAULT_MEM_STREAM_SIZE);
+    std::unique_ptr<Type> fileType(
+        
Type::buildTypeFromString("struct<c1:char(5),c2:varchar(5),c3:string>"));
+    std::shared_ptr<Type> readType(
+        
Type::buildTypeFromString("struct<c1:string,c2:char(4),c3:varchar(4)>"));
+    WriterOptions options;
+    auto writer = createWriter(*fileType, &memStream, options);
+    auto batch = writer->createRowBatch(TEST_CASES);
+    auto structBatch = dynamic_cast<StructVectorBatch*>(batch.get());
+    auto& c1 = dynamic_cast<StringVectorBatch&>(*structBatch->fields[0]);
+    auto& c2 = dynamic_cast<StringVectorBatch&>(*structBatch->fields[1]);
+    auto& c3 = dynamic_cast<StringVectorBatch&>(*structBatch->fields[2]);
+
+    std::vector<std::string> raw1{"", "12345", "1", "1234"};
+    std::vector<std::string> raw2{"", "12345", "1", "1234"};
+    std::vector<std::string> raw3{"", "12345", "1", "1234"};
+
+    c1.notNull[0] = c2.notNull[0] = c3.notNull[0] = false;
+    for (int i = 1; i < TEST_CASES; i++) {
+      c1.data[i] = raw1[i].data();
+      c1.length[i] = raw1[i].length();
+      c1.notNull[i] = true;
+
+      c2.data[i] = raw2[i].data();
+      c2.length[i] = raw2[i].length();
+      c2.notNull[i] = true;
+
+      c3.data[i] = raw3[i].data();
+      c3.length[i] = raw3[i].length();
+      c3.notNull[i] = true;
+    }
+    structBatch->numElements = c1.numElements = c2.numElements = 
c3.numElements = TEST_CASES;
+    structBatch->hasNulls = c1.hasNulls = c2.hasNulls = c3.hasNulls = true;
+    writer->add(*batch);
+    writer->close();
+    auto inStream = std::make_unique<MemoryInputStream>(memStream.getData(), 
memStream.getLength());
+    auto pool = getDefaultPool();
+    auto reader = createReader(*pool, std::move(inStream));
+    RowReaderOptions rowReaderOptions;
+    rowReaderOptions.setUseTightNumericVector(true);
+    rowReaderOptions.setReadType(readType);
+    auto rowReader = reader->createRowReader(rowReaderOptions);
+    auto readBatch = rowReader->createRowBatch(TEST_CASES);
+    EXPECT_EQ(true, rowReader->next(*readBatch));
+
+    auto& readSturctBatch = dynamic_cast<StructVectorBatch&>(*readBatch);
+    auto& readC1 = 
dynamic_cast<StringVectorBatch&>(*readSturctBatch.fields[0]);
+    auto& readC2 = 
dynamic_cast<StringVectorBatch&>(*readSturctBatch.fields[1]);
+    auto& readC3 = 
dynamic_cast<StringVectorBatch&>(*readSturctBatch.fields[2]);
+
+    EXPECT_FALSE(readC1.notNull[0]);
+    EXPECT_FALSE(readC2.notNull[0]);
+    EXPECT_FALSE(readC3.notNull[0]);
+
+    for (int i = 1; i < TEST_CASES; i++) {
+      EXPECT_TRUE(readC1.notNull[i]);
+      EXPECT_TRUE(readC2.notNull[i]);
+      EXPECT_TRUE(readC3.notNull[i]);
+    }
+
+    EXPECT_EQ(std::string(readC1.data[1], readC1.length[1]), "12345");
+    EXPECT_EQ(std::string(readC2.data[1], readC2.length[1]), "1234");
+    EXPECT_EQ(std::string(readC3.data[1], readC3.length[1]), "1234");
+
+    EXPECT_EQ(std::string(readC1.data[2], readC1.length[2]), "1    ");
+    EXPECT_EQ(std::string(readC2.data[2], readC2.length[2]), "1   ");
+    EXPECT_EQ(std::string(readC3.data[2], readC3.length[2]), "1");
+
+    EXPECT_EQ(std::string(readC1.data[3], readC1.length[3]), "1234 ");
+    EXPECT_EQ(std::string(readC2.data[3], readC2.length[3]), "1234");
+    EXPECT_EQ(std::string(readC3.data[3], readC3.length[3]), "1234");
+  }
+
 }  // namespace orc
diff --git a/c++/test/TestSchemaEvolution.cc b/c++/test/TestSchemaEvolution.cc
index c52ba009f..12001fca6 100644
--- a/c++/test/TestSchemaEvolution.cc
+++ b/c++/test/TestSchemaEvolution.cc
@@ -148,6 +148,22 @@ namespace orc {
       }
     }
 
+    // conversion from string variant to numeric
+    for (size_t i = 7; i <= 11; i++) {
+      for (size_t j = 0; j <= 6; j++) {
+        canConvert[i][j] = true;
+        needConvert[i][j] = true;
+      }
+    }
+
+    // conversion from string variant to string variant
+    for (size_t i = 7; i <= 11; i++) {
+      for (size_t j = 7; j <= 11; j++) {
+        canConvert[i][j] = true;
+        needConvert[i][j] = (i != j);
+      }
+    }
+
     for (size_t i = 0; i < typesSize; i++) {
       for (size_t j = 0; j < typesSize; j++) {
         testConvertReader(types[i], types[j], canConvert[i][j], 
needConvert[i][j]);

Reply via email to