This is an automated email from the ASF dual-hosted git repository.
mdeepak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/orc.git
The following commit(s) were added to refs/heads/master by this push:
new d3d00f8 ORC-412: [C++] Fix Char(n) and Varchar(n) writers with UTF-8
(#317)
d3d00f8 is described below
commit d3d00f81e91bc0b978afd770f7962a8112a7f12e
Author: Gang Wu <[email protected]>
AuthorDate: Tue Oct 16 04:39:03 2018 -0700
ORC-412: [C++] Fix Char(n) and Varchar(n) writers with UTF-8 (#317)
* ORC-412: [C++] Fix Char(n) and Varchar(n) writers with UTF-8
* remove useless offset param in Utf8Utils
* fix comment
---
c++/src/ColumnWriter.cc | 114 ++++++++++++++++++++++++++++++++++++++++--------
c++/test/TestWriter.cc | 81 ++++++++++++++++++++++++++++++++++
2 files changed, 178 insertions(+), 17 deletions(-)
diff --git a/c++/src/ColumnWriter.cc b/c++/src/ColumnWriter.cc
index eb2fc40..8feb077 100644
--- a/c++/src/ColumnWriter.cc
+++ b/c++/src/ColumnWriter.cc
@@ -940,16 +940,88 @@ namespace orc {
lengthEncoder->recordPosition(rowIndexPosition.get());
}
+ struct Utf8Utils {
+ /**
+ * Counts how many utf-8 chars of the input data
+ */
+ static uint64_t charLength(const char * data, uint64_t length) {
+ uint64_t chars = 0;
+ for (uint64_t i = 0; i < length; i++) {
+ if (isUtfStartByte(data[i])) {
+ chars++;
+ }
+ }
+ return chars;
+ }
+
+ /**
+ * Return the number of bytes required to read at most maxCharLength
+ * characters in full from a utf-8 encoded byte array provided
+ * by data. This does not validate utf-8 data, but
+ * operates correctly on already valid utf-8 data.
+ *
+ * @param maxCharLength number of characters required
+ * @param data the bytes of UTF-8
+ * @param length the length of data to truncate
+ */
+ static uint64_t truncateBytesTo(uint64_t maxCharLength,
+ const char * data,
+ uint64_t length) {
+ uint64_t chars = 0;
+ if (length <= maxCharLength) {
+ return length;
+ }
+ for (uint64_t i = 0; i < length; i++) {
+ if (isUtfStartByte(data[i])) {
+ chars++;
+ }
+ if (chars > maxCharLength) {
+ return i;
+ }
+ }
+ // everything fits
+ return length;
+ }
+
+ /**
+ * Checks if b is the first byte of a UTF-8 character.
+ */
+ inline static bool isUtfStartByte(char b) {
+ return (b & 0xC0) != 0x80;
+ }
+
+ /**
+ * Find the start of the last character that ends in the current string.
+ * @param text the bytes of the utf-8
+ * @param from the first byte location
+ * @param until the last byte location
+ * @return the index of the last character
+ */
+ static uint64_t findLastCharacter(const char * text, uint64_t from,
uint64_t until) {
+ uint64_t posn = until;
+ /* we don't expect characters more than 5 bytes */
+ while (posn >= from) {
+ if (isUtfStartByte(text[posn])) {
+ return posn;
+ }
+ posn -= 1;
+ }
+ /* beginning of a valid char not found */
+ throw std::logic_error(
+ "Could not truncate string, beginning of a valid char not found");
+ }
+ };
+
class CharColumnWriter : public StringColumnWriter {
public:
CharColumnWriter(const Type& type,
const StreamsFactory& factory,
const WriterOptions& options) :
StringColumnWriter(type, factory, options),
- fixedLength(type.getMaximumLength()),
- padBuffer(*options.getMemoryPool(),
- type.getMaximumLength()) {
- // PASS
+ maxLength(type.getMaximumLength()),
+ padBuffer(*options.getMemoryPool()) {
+ // utf-8 is currently 4 bytes long, but it could be up to 6
+ padBuffer.resize(maxLength * 6);
}
virtual void add(ColumnVectorBatch& rowBatch,
@@ -957,7 +1029,7 @@ namespace orc {
uint64_t numValues) override;
private:
- uint64_t fixedLength;
+ uint64_t maxLength;
DataBuffer<char> padBuffer;
};
@@ -984,17 +1056,24 @@ namespace orc {
for (uint64_t i = 0; i < numValues; ++i) {
if (!notNull || notNull[i]) {
- char *charData = data[i];
- uint64_t oriLength = static_cast<uint64_t>(length[i]);
- if (oriLength < fixedLength) {
- memcpy(padBuffer.data(), data[i], oriLength);
- memset(padBuffer.data() + oriLength, ' ', fixedLength - oriLength);
+ const char * charData = nullptr;
+ uint64_t originLength = static_cast<uint64_t>(length[i]);
+ uint64_t charLength = Utf8Utils::charLength(data[i], originLength);
+ if (charLength >= maxLength) {
+ charData = data[i];
+ length[i] = static_cast<int64_t>(
+ Utf8Utils::truncateBytesTo(maxLength, data[i], originLength));
+ } else {
charData = padBuffer.data();
+ // the padding is exactly 1 byte per char
+ length[i] = length[i] + static_cast<int64_t>(maxLength - charLength);
+ memcpy(padBuffer.data(), data[i], originLength);
+ memset(padBuffer.data() + originLength,
+ ' ',
+ static_cast<size_t>(length[i]) - originLength);
}
- length[i] = static_cast<int64_t>(fixedLength);
- dataStream->write(charData, fixedLength);
-
- strStats->update(charData, fixedLength);
+ dataStream->write(charData, static_cast<size_t>(length[i]));
+ strStats->update(charData, static_cast<size_t>(length[i]));
strStats->increase(1);
} else if (!hasNull) {
hasNull = true;
@@ -1045,9 +1124,10 @@ namespace orc {
for (uint64_t i = 0; i < numValues; ++i) {
if (!notNull || notNull[i]) {
- if (length[i] > static_cast<int64_t>(maxLength)) {
- length[i] = static_cast<int64_t>(maxLength);
- }
+ uint64_t itemLength = Utf8Utils::truncateBytesTo(
+ maxLength, data[i], static_cast<uint64_t>(length[i]));
+
+ length[i] = static_cast<int64_t>(itemLength);
dataStream->write(data[i], static_cast<size_t>(length[i]));
strStats->update(data[i], static_cast<size_t>(length[i]));
diff --git a/c++/test/TestWriter.cc b/c++/test/TestWriter.cc
index f7597e8..330d0d7 100644
--- a/c++/test/TestWriter.cc
+++ b/c++/test/TestWriter.cc
@@ -1181,5 +1181,86 @@ namespace orc {
}
}
+ TEST_P(WriterTest, writeUTF8CharAndVarcharColumn) {
+ MemoryOutputStream memStream(DEFAULT_MEM_STREAM_SIZE);
+ MemoryPool * pool = getDefaultPool();
+ std::unique_ptr<Type> type(Type::buildTypeFromString(
+ "struct<col1:char(2),col2:varchar(2)>"));
+
+ uint64_t stripeSize = 1024;
+ uint64_t compressionBlockSize = 1024;
+ uint64_t rowCount = 3;
+ std::unique_ptr<Writer> writer = createWriter(stripeSize,
+ compressionBlockSize,
+ CompressionKind_ZLIB,
+ *type,
+ pool,
+ &memStream,
+ fileVersion);
+ std::unique_ptr<ColumnVectorBatch> batch =
writer->createRowBatch(rowCount);
+ StructVectorBatch * structBatch =
+ dynamic_cast<StructVectorBatch *>(batch.get());
+ StringVectorBatch * charBatch =
+ dynamic_cast<StringVectorBatch *>(structBatch->fields[0]);
+ StringVectorBatch * varcharBatch =
+ dynamic_cast<StringVectorBatch *>(structBatch->fields[1]);
+ std::vector<std::vector<char>> strs;
+
+ // input character is 'à' (0xC3, 0xA0)
+ // in total 3 rows, each has 1, 2, and 3 'à' respectively
+ std::vector<char> vec;
+ for (uint64_t i = 0; i != rowCount; ++i) {
+ vec.push_back(static_cast<char>(0xC3));
+ vec.push_back(static_cast<char>(0xA0));
+ strs.push_back(vec);
+ charBatch->data[i] = varcharBatch->data[i] = strs.back().data();
+ charBatch->length[i] = varcharBatch->length[i] =
static_cast<int64_t>(strs.back().size());
+ }
+
+ structBatch->numElements = rowCount;
+ charBatch->numElements = rowCount;
+ varcharBatch->numElements = rowCount;
+
+ writer->add(*batch);
+ writer->close();
+
+ // read and verify data
+ std::unique_ptr<InputStream> inStream(
+ new MemoryInputStream (memStream.getData(), memStream.getLength()));
+ std::unique_ptr<Reader> reader = createReader(pool, std::move(inStream));
+ std::unique_ptr<RowReader> rowReader = createRowReader(reader.get());
+ EXPECT_EQ(rowCount, reader->getNumberOfRows());
+
+ batch = rowReader->createRowBatch(rowCount);
+ structBatch = dynamic_cast<StructVectorBatch *>(batch.get());
+ charBatch = dynamic_cast<StringVectorBatch *>(structBatch->fields[0]);
+ varcharBatch = dynamic_cast<StringVectorBatch *>(structBatch->fields[1]);
+
+ EXPECT_EQ(true, rowReader->next(*batch));
+ EXPECT_EQ(rowCount, batch->numElements);
+
+ char expectedPadded[3] = {static_cast<char>(0xC3),
static_cast<char>(0xA0), ' '};
+ char expectedOneChar[2] = {static_cast<char>(0xC3),
static_cast<char>(0xA0)};
+ char expectedTwoChars[4] = {static_cast<char>(0xC3),
static_cast<char>(0xA0),
+ static_cast<char>(0xC3),
static_cast<char>(0xA0)};
+
+ EXPECT_EQ(3, charBatch->length[0]);
+ EXPECT_EQ(4, charBatch->length[1]);
+ EXPECT_EQ(4, charBatch->length[2]);
+ EXPECT_TRUE(memcmp(charBatch->data[0], expectedPadded, 3) == 0);
+ EXPECT_TRUE(memcmp(charBatch->data[1], expectedTwoChars, 4) == 0);
+ EXPECT_TRUE(memcmp(charBatch->data[2], expectedTwoChars, 4) == 0);
+
+ EXPECT_EQ(2, varcharBatch->length[0]);
+ EXPECT_EQ(4, varcharBatch->length[1]);
+ EXPECT_EQ(4, varcharBatch->length[2]);
+ EXPECT_TRUE(memcmp(varcharBatch->data[0], expectedOneChar, 2) == 0);
+ EXPECT_TRUE(memcmp(varcharBatch->data[1], expectedTwoChars, 4) == 0);
+ EXPECT_TRUE(memcmp(varcharBatch->data[2], expectedTwoChars, 4) == 0);
+
+ EXPECT_FALSE(rowReader->next(*batch));
+ }
+
+
INSTANTIATE_TEST_CASE_P(OrcTest, WriterTest, Values(FileVersion::v_0_11(),
FileVersion::v_0_12()));
}