This is an automated email from the ASF dual-hosted git repository.

mdeepak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/orc.git


The following commit(s) were added to refs/heads/master by this push:
     new d3d00f8  ORC-412: [C++] Fix Char(n) and Varchar(n) writers with UTF-8 
(#317)
d3d00f8 is described below

commit d3d00f81e91bc0b978afd770f7962a8112a7f12e
Author: Gang Wu <[email protected]>
AuthorDate: Tue Oct 16 04:39:03 2018 -0700

    ORC-412: [C++] Fix Char(n) and Varchar(n) writers with UTF-8 (#317)
    
    * ORC-412: [C++] Fix Char(n) and Varchar(n) writers with UTF-8
    
    * remove useless offset param in Utf8Utils
    
    * fix comment
---
 c++/src/ColumnWriter.cc | 114 ++++++++++++++++++++++++++++++++++++++++--------
 c++/test/TestWriter.cc  |  81 ++++++++++++++++++++++++++++++++++
 2 files changed, 178 insertions(+), 17 deletions(-)

diff --git a/c++/src/ColumnWriter.cc b/c++/src/ColumnWriter.cc
index eb2fc40..8feb077 100644
--- a/c++/src/ColumnWriter.cc
+++ b/c++/src/ColumnWriter.cc
@@ -940,16 +940,88 @@ namespace orc {
     lengthEncoder->recordPosition(rowIndexPosition.get());
   }
 
+  struct Utf8Utils {
+    /**
+     * Counts how many utf-8 chars of the input data
+     */
+    static uint64_t charLength(const char * data, uint64_t length) {
+      uint64_t chars = 0;
+      for (uint64_t i = 0; i < length; i++) {
+        if (isUtfStartByte(data[i])) {
+          chars++;
+        }
+      }
+      return chars;
+    }
+
+    /**
+     * Return the number of bytes required to read at most maxCharLength
+     * characters in full from a utf-8 encoded byte array provided
+     * by data. This does not validate utf-8 data, but
+     * operates correctly on already valid utf-8 data.
+     *
+     * @param maxCharLength number of characters required
+     * @param data the bytes of UTF-8
+     * @param length the length of data to truncate
+     */
+    static uint64_t truncateBytesTo(uint64_t maxCharLength,
+                                    const char * data,
+                                    uint64_t length) {
+      uint64_t chars = 0;
+      if (length <= maxCharLength) {
+        return length;
+      }
+      for (uint64_t i = 0; i < length; i++) {
+        if (isUtfStartByte(data[i])) {
+          chars++;
+        }
+        if (chars > maxCharLength) {
+          return i;
+        }
+      }
+      // everything fits
+      return length;
+    }
+
+    /**
+     * Checks if b is the first byte of a UTF-8 character.
+     */
+    inline static bool isUtfStartByte(char b) {
+      return (b & 0xC0) != 0x80;
+    }
+
+    /**
+     * Find the start of the last character that ends in the current string.
+     * @param text the bytes of the utf-8
+     * @param from the first byte location
+     * @param until the last byte location
+     * @return the index of the last character
+    */
+    static uint64_t findLastCharacter(const char * text, uint64_t from, 
uint64_t until) {
+      uint64_t posn = until;
+      /* we don't expect characters more than 5 bytes */
+      while (posn >= from) {
+        if (isUtfStartByte(text[posn])) {
+          return posn;
+        }
+        posn -= 1;
+      }
+      /* beginning of a valid char not found */
+      throw std::logic_error(
+        "Could not truncate string, beginning of a valid char not found");
+    }
+  };
+
   class CharColumnWriter : public StringColumnWriter {
   public:
     CharColumnWriter(const Type& type,
                      const StreamsFactory& factory,
                      const WriterOptions& options) :
                          StringColumnWriter(type, factory, options),
-                         fixedLength(type.getMaximumLength()),
-                         padBuffer(*options.getMemoryPool(),
-                                   type.getMaximumLength()) {
-      // PASS
+                         maxLength(type.getMaximumLength()),
+                         padBuffer(*options.getMemoryPool()) {
+      // utf-8 is currently 4 bytes long, but it could be up to 6
+      padBuffer.resize(maxLength * 6);
     }
 
     virtual void add(ColumnVectorBatch& rowBatch,
@@ -957,7 +1029,7 @@ namespace orc {
                      uint64_t numValues) override;
 
   private:
-    uint64_t fixedLength;
+    uint64_t maxLength;
     DataBuffer<char> padBuffer;
   };
 
@@ -984,17 +1056,24 @@ namespace orc {
 
     for (uint64_t i = 0; i < numValues; ++i) {
       if (!notNull || notNull[i]) {
-        char *charData = data[i];
-        uint64_t oriLength = static_cast<uint64_t>(length[i]);
-        if (oriLength < fixedLength) {
-          memcpy(padBuffer.data(), data[i], oriLength);
-          memset(padBuffer.data() + oriLength, ' ', fixedLength - oriLength);
+        const char * charData = nullptr;
+        uint64_t originLength = static_cast<uint64_t>(length[i]);
+        uint64_t charLength = Utf8Utils::charLength(data[i], originLength);
+        if (charLength >= maxLength) {
+          charData = data[i];
+          length[i] = static_cast<int64_t>(
+            Utf8Utils::truncateBytesTo(maxLength, data[i], originLength));
+        } else {
           charData = padBuffer.data();
+          // the padding is exactly 1 byte per char
+          length[i] = length[i] + static_cast<int64_t>(maxLength - charLength);
+          memcpy(padBuffer.data(), data[i], originLength);
+          memset(padBuffer.data() + originLength,
+                 ' ',
+                 static_cast<size_t>(length[i]) - originLength);
         }
-        length[i] = static_cast<int64_t>(fixedLength);
-        dataStream->write(charData, fixedLength);
-
-        strStats->update(charData, fixedLength);
+        dataStream->write(charData, static_cast<size_t>(length[i]));
+        strStats->update(charData, static_cast<size_t>(length[i]));
         strStats->increase(1);
       } else if (!hasNull) {
         hasNull = true;
@@ -1045,9 +1124,10 @@ namespace orc {
 
     for (uint64_t i = 0; i < numValues; ++i) {
       if (!notNull || notNull[i]) {
-        if (length[i] > static_cast<int64_t>(maxLength)) {
-          length[i] = static_cast<int64_t>(maxLength);
-        }
+        uint64_t itemLength = Utf8Utils::truncateBytesTo(
+          maxLength, data[i], static_cast<uint64_t>(length[i]));
+
+        length[i] = static_cast<int64_t>(itemLength);
         dataStream->write(data[i], static_cast<size_t>(length[i]));
 
         strStats->update(data[i], static_cast<size_t>(length[i]));
diff --git a/c++/test/TestWriter.cc b/c++/test/TestWriter.cc
index f7597e8..330d0d7 100644
--- a/c++/test/TestWriter.cc
+++ b/c++/test/TestWriter.cc
@@ -1181,5 +1181,86 @@ namespace orc {
     }
   }
 
+  TEST_P(WriterTest, writeUTF8CharAndVarcharColumn) {
+    MemoryOutputStream memStream(DEFAULT_MEM_STREAM_SIZE);
+    MemoryPool * pool = getDefaultPool();
+    std::unique_ptr<Type> type(Type::buildTypeFromString(
+      "struct<col1:char(2),col2:varchar(2)>"));
+
+    uint64_t stripeSize = 1024;
+    uint64_t compressionBlockSize = 1024;
+    uint64_t rowCount = 3;
+    std::unique_ptr<Writer> writer = createWriter(stripeSize,
+                                                  compressionBlockSize,
+                                                  CompressionKind_ZLIB,
+                                                  *type,
+                                                  pool,
+                                                  &memStream,
+                                                  fileVersion);
+    std::unique_ptr<ColumnVectorBatch> batch = 
writer->createRowBatch(rowCount);
+    StructVectorBatch * structBatch =
+      dynamic_cast<StructVectorBatch *>(batch.get());
+    StringVectorBatch * charBatch =
+      dynamic_cast<StringVectorBatch *>(structBatch->fields[0]);
+    StringVectorBatch * varcharBatch =
+      dynamic_cast<StringVectorBatch *>(structBatch->fields[1]);
+    std::vector<std::vector<char>> strs;
+
+    // input character is 'à' (0xC3, 0xA0)
+    // in total 3 rows, each has 1, 2, and 3 'à' respectively
+    std::vector<char> vec;
+    for (uint64_t i = 0; i != rowCount; ++i) {
+      vec.push_back(static_cast<char>(0xC3));
+      vec.push_back(static_cast<char>(0xA0));
+      strs.push_back(vec);
+      charBatch->data[i] = varcharBatch->data[i] = strs.back().data();
+      charBatch->length[i] = varcharBatch->length[i] = 
static_cast<int64_t>(strs.back().size());
+    }
+
+    structBatch->numElements = rowCount;
+    charBatch->numElements = rowCount;
+    varcharBatch->numElements = rowCount;
+
+    writer->add(*batch);
+    writer->close();
+
+    // read and verify data
+    std::unique_ptr<InputStream> inStream(
+      new MemoryInputStream (memStream.getData(), memStream.getLength()));
+    std::unique_ptr<Reader> reader = createReader(pool, std::move(inStream));
+    std::unique_ptr<RowReader> rowReader = createRowReader(reader.get());
+    EXPECT_EQ(rowCount, reader->getNumberOfRows());
+
+    batch = rowReader->createRowBatch(rowCount);
+    structBatch = dynamic_cast<StructVectorBatch *>(batch.get());
+    charBatch = dynamic_cast<StringVectorBatch *>(structBatch->fields[0]);
+    varcharBatch = dynamic_cast<StringVectorBatch *>(structBatch->fields[1]);
+
+    EXPECT_EQ(true, rowReader->next(*batch));
+    EXPECT_EQ(rowCount, batch->numElements);
+
+    char expectedPadded[3] = {static_cast<char>(0xC3), 
static_cast<char>(0xA0), ' '};
+    char expectedOneChar[2] = {static_cast<char>(0xC3), 
static_cast<char>(0xA0)};
+    char expectedTwoChars[4] = {static_cast<char>(0xC3), 
static_cast<char>(0xA0),
+                                static_cast<char>(0xC3), 
static_cast<char>(0xA0)};
+
+    EXPECT_EQ(3, charBatch->length[0]);
+    EXPECT_EQ(4, charBatch->length[1]);
+    EXPECT_EQ(4, charBatch->length[2]);
+    EXPECT_TRUE(memcmp(charBatch->data[0], expectedPadded, 3) == 0);
+    EXPECT_TRUE(memcmp(charBatch->data[1], expectedTwoChars, 4) == 0);
+    EXPECT_TRUE(memcmp(charBatch->data[2], expectedTwoChars, 4) == 0);
+
+    EXPECT_EQ(2, varcharBatch->length[0]);
+    EXPECT_EQ(4, varcharBatch->length[1]);
+    EXPECT_EQ(4, varcharBatch->length[2]);
+    EXPECT_TRUE(memcmp(varcharBatch->data[0], expectedOneChar, 2) == 0);
+    EXPECT_TRUE(memcmp(varcharBatch->data[1], expectedTwoChars, 4) == 0);
+    EXPECT_TRUE(memcmp(varcharBatch->data[2], expectedTwoChars, 4) == 0);
+
+    EXPECT_FALSE(rowReader->next(*batch));
+  }
+
+
   INSTANTIATE_TEST_CASE_P(OrcTest, WriterTest, Values(FileVersion::v_0_11(), 
FileVersion::v_0_12()));
 }

Reply via email to