[ 
https://issues.apache.org/jira/browse/ORC-412?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16651542#comment-16651542
 ] 

ASF GitHub Bot commented on ORC-412:
------------------------------------

majetideepak closed pull request #317: ORC-412: [C++] Fix Char(n) and 
Varchar(n) writers with UTF-8
URL: https://github.com/apache/orc/pull/317
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/c++/src/ColumnWriter.cc b/c++/src/ColumnWriter.cc
index eb2fc40de7..8feb077130 100644
--- a/c++/src/ColumnWriter.cc
+++ b/c++/src/ColumnWriter.cc
@@ -940,16 +940,88 @@ namespace orc {
     lengthEncoder->recordPosition(rowIndexPosition.get());
   }
 
+  struct Utf8Utils {
+    /**
+     * Counts how many utf-8 chars of the input data
+     */
+    static uint64_t charLength(const char * data, uint64_t length) {
+      uint64_t chars = 0;
+      for (uint64_t i = 0; i < length; i++) {
+        if (isUtfStartByte(data[i])) {
+          chars++;
+        }
+      }
+      return chars;
+    }
+
+    /**
+     * Return the number of bytes required to read at most maxCharLength
+     * characters in full from a utf-8 encoded byte array provided
+     * by data. This does not validate utf-8 data, but
+     * operates correctly on already valid utf-8 data.
+     *
+     * @param maxCharLength number of characters required
+     * @param data the bytes of UTF-8
+     * @param length the length of data to truncate
+     */
+    static uint64_t truncateBytesTo(uint64_t maxCharLength,
+                                    const char * data,
+                                    uint64_t length) {
+      uint64_t chars = 0;
+      if (length <= maxCharLength) {
+        return length;
+      }
+      for (uint64_t i = 0; i < length; i++) {
+        if (isUtfStartByte(data[i])) {
+          chars++;
+        }
+        if (chars > maxCharLength) {
+          return i;
+        }
+      }
+      // everything fits
+      return length;
+    }
+
+    /**
+     * Checks if b is the first byte of a UTF-8 character.
+     */
+    inline static bool isUtfStartByte(char b) {
+      return (b & 0xC0) != 0x80;
+    }
+
+    /**
+     * Find the start of the last character that ends in the current string.
+     * @param text the bytes of the utf-8
+     * @param from the first byte location
+     * @param until the last byte location
+     * @return the index of the last character
+    */
+    static uint64_t findLastCharacter(const char * text, uint64_t from, 
uint64_t until) {
+      uint64_t posn = until;
+      /* we don't expect characters more than 5 bytes */
+      while (posn >= from) {
+        if (isUtfStartByte(text[posn])) {
+          return posn;
+        }
+        posn -= 1;
+      }
+      /* beginning of a valid char not found */
+      throw std::logic_error(
+        "Could not truncate string, beginning of a valid char not found");
+    }
+  };
+
   class CharColumnWriter : public StringColumnWriter {
   public:
     CharColumnWriter(const Type& type,
                      const StreamsFactory& factory,
                      const WriterOptions& options) :
                          StringColumnWriter(type, factory, options),
-                         fixedLength(type.getMaximumLength()),
-                         padBuffer(*options.getMemoryPool(),
-                                   type.getMaximumLength()) {
-      // PASS
+                         maxLength(type.getMaximumLength()),
+                         padBuffer(*options.getMemoryPool()) {
+      // utf-8 is currently 4 bytes long, but it could be up to 6
+      padBuffer.resize(maxLength * 6);
     }
 
     virtual void add(ColumnVectorBatch& rowBatch,
@@ -957,7 +1029,7 @@ namespace orc {
                      uint64_t numValues) override;
 
   private:
-    uint64_t fixedLength;
+    uint64_t maxLength;
     DataBuffer<char> padBuffer;
   };
 
@@ -984,17 +1056,24 @@ namespace orc {
 
     for (uint64_t i = 0; i < numValues; ++i) {
       if (!notNull || notNull[i]) {
-        char *charData = data[i];
-        uint64_t oriLength = static_cast<uint64_t>(length[i]);
-        if (oriLength < fixedLength) {
-          memcpy(padBuffer.data(), data[i], oriLength);
-          memset(padBuffer.data() + oriLength, ' ', fixedLength - oriLength);
+        const char * charData = nullptr;
+        uint64_t originLength = static_cast<uint64_t>(length[i]);
+        uint64_t charLength = Utf8Utils::charLength(data[i], originLength);
+        if (charLength >= maxLength) {
+          charData = data[i];
+          length[i] = static_cast<int64_t>(
+            Utf8Utils::truncateBytesTo(maxLength, data[i], originLength));
+        } else {
           charData = padBuffer.data();
+          // the padding is exactly 1 byte per char
+          length[i] = length[i] + static_cast<int64_t>(maxLength - charLength);
+          memcpy(padBuffer.data(), data[i], originLength);
+          memset(padBuffer.data() + originLength,
+                 ' ',
+                 static_cast<size_t>(length[i]) - originLength);
         }
-        length[i] = static_cast<int64_t>(fixedLength);
-        dataStream->write(charData, fixedLength);
-
-        strStats->update(charData, fixedLength);
+        dataStream->write(charData, static_cast<size_t>(length[i]));
+        strStats->update(charData, static_cast<size_t>(length[i]));
         strStats->increase(1);
       } else if (!hasNull) {
         hasNull = true;
@@ -1045,9 +1124,10 @@ namespace orc {
 
     for (uint64_t i = 0; i < numValues; ++i) {
       if (!notNull || notNull[i]) {
-        if (length[i] > static_cast<int64_t>(maxLength)) {
-          length[i] = static_cast<int64_t>(maxLength);
-        }
+        uint64_t itemLength = Utf8Utils::truncateBytesTo(
+          maxLength, data[i], static_cast<uint64_t>(length[i]));
+
+        length[i] = static_cast<int64_t>(itemLength);
         dataStream->write(data[i], static_cast<size_t>(length[i]));
 
         strStats->update(data[i], static_cast<size_t>(length[i]));
diff --git a/c++/test/TestWriter.cc b/c++/test/TestWriter.cc
index f7597e8682..330d0d7130 100644
--- a/c++/test/TestWriter.cc
+++ b/c++/test/TestWriter.cc
@@ -1181,5 +1181,86 @@ namespace orc {
     }
   }
 
+  TEST_P(WriterTest, writeUTF8CharAndVarcharColumn) {
+    MemoryOutputStream memStream(DEFAULT_MEM_STREAM_SIZE);
+    MemoryPool * pool = getDefaultPool();
+    std::unique_ptr<Type> type(Type::buildTypeFromString(
+      "struct<col1:char(2),col2:varchar(2)>"));
+
+    uint64_t stripeSize = 1024;
+    uint64_t compressionBlockSize = 1024;
+    uint64_t rowCount = 3;
+    std::unique_ptr<Writer> writer = createWriter(stripeSize,
+                                                  compressionBlockSize,
+                                                  CompressionKind_ZLIB,
+                                                  *type,
+                                                  pool,
+                                                  &memStream,
+                                                  fileVersion);
+    std::unique_ptr<ColumnVectorBatch> batch = 
writer->createRowBatch(rowCount);
+    StructVectorBatch * structBatch =
+      dynamic_cast<StructVectorBatch *>(batch.get());
+    StringVectorBatch * charBatch =
+      dynamic_cast<StringVectorBatch *>(structBatch->fields[0]);
+    StringVectorBatch * varcharBatch =
+      dynamic_cast<StringVectorBatch *>(structBatch->fields[1]);
+    std::vector<std::vector<char>> strs;
+
+    // input character is 'à' (0xC3, 0xA0)
+    // in total 3 rows, each has 1, 2, and 3 'à' respectively
+    std::vector<char> vec;
+    for (uint64_t i = 0; i != rowCount; ++i) {
+      vec.push_back(static_cast<char>(0xC3));
+      vec.push_back(static_cast<char>(0xA0));
+      strs.push_back(vec);
+      charBatch->data[i] = varcharBatch->data[i] = strs.back().data();
+      charBatch->length[i] = varcharBatch->length[i] = 
static_cast<int64_t>(strs.back().size());
+    }
+
+    structBatch->numElements = rowCount;
+    charBatch->numElements = rowCount;
+    varcharBatch->numElements = rowCount;
+
+    writer->add(*batch);
+    writer->close();
+
+    // read and verify data
+    std::unique_ptr<InputStream> inStream(
+      new MemoryInputStream (memStream.getData(), memStream.getLength()));
+    std::unique_ptr<Reader> reader = createReader(pool, std::move(inStream));
+    std::unique_ptr<RowReader> rowReader = createRowReader(reader.get());
+    EXPECT_EQ(rowCount, reader->getNumberOfRows());
+
+    batch = rowReader->createRowBatch(rowCount);
+    structBatch = dynamic_cast<StructVectorBatch *>(batch.get());
+    charBatch = dynamic_cast<StringVectorBatch *>(structBatch->fields[0]);
+    varcharBatch = dynamic_cast<StringVectorBatch *>(structBatch->fields[1]);
+
+    EXPECT_EQ(true, rowReader->next(*batch));
+    EXPECT_EQ(rowCount, batch->numElements);
+
+    char expectedPadded[3] = {static_cast<char>(0xC3), 
static_cast<char>(0xA0), ' '};
+    char expectedOneChar[2] = {static_cast<char>(0xC3), 
static_cast<char>(0xA0)};
+    char expectedTwoChars[4] = {static_cast<char>(0xC3), 
static_cast<char>(0xA0),
+                                static_cast<char>(0xC3), 
static_cast<char>(0xA0)};
+
+    EXPECT_EQ(3, charBatch->length[0]);
+    EXPECT_EQ(4, charBatch->length[1]);
+    EXPECT_EQ(4, charBatch->length[2]);
+    EXPECT_TRUE(memcmp(charBatch->data[0], expectedPadded, 3) == 0);
+    EXPECT_TRUE(memcmp(charBatch->data[1], expectedTwoChars, 4) == 0);
+    EXPECT_TRUE(memcmp(charBatch->data[2], expectedTwoChars, 4) == 0);
+
+    EXPECT_EQ(2, varcharBatch->length[0]);
+    EXPECT_EQ(4, varcharBatch->length[1]);
+    EXPECT_EQ(4, varcharBatch->length[2]);
+    EXPECT_TRUE(memcmp(varcharBatch->data[0], expectedOneChar, 2) == 0);
+    EXPECT_TRUE(memcmp(varcharBatch->data[1], expectedTwoChars, 4) == 0);
+    EXPECT_TRUE(memcmp(varcharBatch->data[2], expectedTwoChars, 4) == 0);
+
+    EXPECT_FALSE(rowReader->next(*batch));
+  }
+
+
   INSTANTIATE_TEST_CASE_P(OrcTest, WriterTest, Values(FileVersion::v_0_11(), 
FileVersion::v_0_12()));
 }


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


> [C++] ORC: Char(n) and Varchar(n) writers truncate to n bytes & corrupts 
> multi-byte data
> ----------------------------------------------------------------------------------------
>
>                 Key: ORC-412
>                 URL: https://issues.apache.org/jira/browse/ORC-412
>             Project: ORC
>          Issue Type: Bug
>    Affects Versions: 1.5.2
>            Reporter: Gang Wu
>            Assignee: Gang Wu
>            Priority: Major
>
> https://github.com/apache/orc/blob/master/java/core/src/java/org/apache/orc/impl/writer/CharTreeWriter.java#L41
> {code}
>     itemLength = schema.getMaxLength();
>     padding = new byte[itemLength];
>   }
> {code}
> https://github.com/apache/orc/blob/master/java/core/src/java/org/apache/orc/impl/writer/VarcharTreeWriter.java#L48
> {code}
>       if (vector.noNulls || !vector.isNull[0]) {
>         int itemLength = Math.min(vec.length[0], maxLength);
> {code}



--
This message was sent by Atlassian JIRA
(v7.6.3#76005)

Reply via email to