manolama commented on code in PR #38423:
URL: https://github.com/apache/arrow/pull/38423#discussion_r1385850436
##########
java/vector/src/test/java/org/apache/arrow/vector/ipc/BaseFileTest.java:
##########
@@ -846,4 +862,293 @@ protected void validateListAsMapData(VectorSchemaRoot
root) {
}
}
}
+
+ /**
+ * Utility to write permutations of dictionary encoding.
+ *
+ * state == 1, one delta dictionary.
+ * state == 2, one standalone dictionary.
+ * state == 3, one of each
+ * state == 4, delta with nothing at start and end
+ * state == 5, both deltas
+ * state == 6, both deltas and standalone
+ * state == 7, replacement dictionary
+ */
+ protected void writeDataMultiBatchWithDictionaries(OutputStream stream, int
state) throws IOException {
+ DictionaryProvider.MapDictionaryProvider provider = new
DictionaryProvider.MapDictionaryProvider();
+ DictionaryEncoding deltaEncoding =
+ new DictionaryEncoding(42, false, new ArrowType.Int(16, false), true);
+ DictionaryEncoding replacementEncoding =
+ new DictionaryEncoding(24, false, new ArrowType.Int(16, false), false);
+ DictionaryEncoding deltaCEncoding =
+ new DictionaryEncoding(1, false, new ArrowType.Int(16, false), true);
+ DictionaryEncoding replacementEncodingUpdated =
+ new DictionaryEncoding(2, false, new ArrowType.Int(16, false), false);
+
+ boolean isFile = stream instanceof FileOutputStream;
+ try (BatchedDictionary vectorA = newDictionary("vectorA", deltaEncoding,
isFile);
+ BatchedDictionary vectorB = newDictionary("vectorB",
replacementEncoding, isFile);
+ BatchedDictionary vectorC = newDictionary("vectorC", deltaCEncoding,
isFile);
+ BatchedDictionary vectorD = newDictionary("vectorD",
replacementEncodingUpdated, isFile);) {
+ switch (state) {
+ case 1:
+ provider.put(vectorA);
+ break;
+ case 2:
+ provider.put(vectorB);
+ break;
+ case 3:
+ provider.put(vectorA);
+ provider.put(vectorB);
+ break;
+ case 4:
+ provider.put(vectorC);
+ break;
+ case 5:
+ provider.put(vectorA);
+ provider.put(vectorC);
+ break;
+ case 6:
+ provider.put(vectorA);
+ provider.put(vectorB);
+ provider.put(vectorC);
+ break;
+ case 7:
+ provider.put(vectorD);
+ break;
+ default:
+ throw new IllegalStateException("Unsupported state: " + state);
+ }
+
+ VectorSchemaRoot root = null;
+ switch (state) {
+ case 1:
+ root = VectorSchemaRoot.of(vectorA.getIndexVector());
+ break;
+ case 2:
+ root = VectorSchemaRoot.of(vectorB.getIndexVector());
+ break;
+ case 3:
+ root = VectorSchemaRoot.of(vectorA.getIndexVector(),
vectorB.getIndexVector());
+ break;
+ case 4:
+ root = VectorSchemaRoot.of(vectorC.getIndexVector());
+ break;
+ case 5:
+ root = VectorSchemaRoot.of(vectorA.getIndexVector(),
vectorC.getIndexVector());
+ break;
+ case 6:
+ root = VectorSchemaRoot.of(vectorA.getIndexVector(),
vectorB.getIndexVector(), vectorC.getIndexVector());
+ break;
+ case 7:
+ root = VectorSchemaRoot.of(vectorD.getIndexVector());
+ break;
+ default:
+ throw new IllegalStateException("Unsupported state: " + state);
+ }
+
+ ArrowWriter arrowWriter = null;
+ try {
+ if (stream instanceof FileOutputStream) {
+ FileChannel channel = ((FileOutputStream) stream).getChannel();
+ arrowWriter = new ArrowFileWriter(root, provider, channel);
+ } else {
+ arrowWriter = new ArrowStreamWriter(root, provider, stream);
+ }
+
+ vectorA.setSafe(0, "foo".getBytes(StandardCharsets.UTF_8));
+ vectorA.setSafe(1, "bar".getBytes(StandardCharsets.UTF_8));
+ vectorB.setSafe(0, "lorem".getBytes(StandardCharsets.UTF_8));
+ vectorB.setSafe(1, "ipsum".getBytes(StandardCharsets.UTF_8));
+ vectorC.setNull(0);
+ vectorC.setNull(1);
+ vectorD.setSafe(0, "porro".getBytes(StandardCharsets.UTF_8));
+ vectorD.setSafe(1, "amet".getBytes(StandardCharsets.UTF_8));
+
+ // batch 1
+ arrowWriter.start();
+ root.setRowCount(2);
+ arrowWriter.writeBatch();
+
+ // batch 2
+ vectorA.setSafe(0, "meep".getBytes(StandardCharsets.UTF_8));
+ vectorA.setSafe(1, "bar".getBytes(StandardCharsets.UTF_8));
+ vectorB.setSafe(0, "ipsum".getBytes(StandardCharsets.UTF_8));
+ vectorB.setSafe(1, "lorem".getBytes(StandardCharsets.UTF_8));
+ vectorC.setSafe(0, "qui".getBytes(StandardCharsets.UTF_8));
+ vectorC.setSafe(1, "dolor".getBytes(StandardCharsets.UTF_8));
+ vectorD.setSafe(0, "amet".getBytes(StandardCharsets.UTF_8));
+ if (state == 7) {
+ vectorD.setSafe(1, "quia".getBytes(StandardCharsets.UTF_8));
+ }
+
+ root.setRowCount(2);
+ arrowWriter.writeBatch();
+
+ // batch 3
+ vectorA.setNull(0);
+ vectorA.setNull(1);
+ vectorB.setSafe(0, "ipsum".getBytes(StandardCharsets.UTF_8));
+ vectorB.setNull(1);
+ vectorC.setNull(0);
+ vectorC.setSafe(1, "qui".getBytes(StandardCharsets.UTF_8));
+ vectorD.setNull(0);
+ if (state == 7) {
+ vectorD.setSafe(1, "quia".getBytes(StandardCharsets.UTF_8));
+ }
+
+ root.setRowCount(2);
+ arrowWriter.writeBatch();
+
+ // batch 4
+ vectorA.setSafe(0, "bar".getBytes(StandardCharsets.UTF_8));
+ vectorA.setSafe(1, "zap".getBytes(StandardCharsets.UTF_8));
+ vectorB.setNull(0);
+ vectorB.setSafe(1, "lorem".getBytes(StandardCharsets.UTF_8));
+ vectorC.setNull(0);
+ vectorC.setNull(1);
+ if (state == 7) {
+ vectorD.setSafe(0, "quia".getBytes(StandardCharsets.UTF_8));
+ }
+ vectorD.setNull(1);
+
+ root.setRowCount(2);
+ arrowWriter.writeBatch();
+
+ arrowWriter.end();
+ } catch (Exception e) {
+ if (arrowWriter != null) {
+ arrowWriter.close();
+ }
+ throw e;
+ }
+ }
+ }
+
+ Map<Integer, String[][]> valuesPerBlock = new HashMap<Integer, String[][]>();
+
+ {
+ valuesPerBlock.put(0, new String[][]{
+ {"foo", "bar"},
+ {"lorem", "ipsum"},
+ {null, null},
+ {"porro", "amet"}
+ });
+ valuesPerBlock.put(1, new String[][]{
+ {"meep", "bar"},
+ {"ipsum", "lorem"},
+ {"qui", "dolor"},
+ {"amet", "quia"}
+ });
+ valuesPerBlock.put(2, new String[][]{
+ {null, null},
+ {"ipsum", null},
+ {null, "qui"},
+ {null, "quia"}
+ });
+ valuesPerBlock.put(3, new String[][]{
+ {"bar", "zap"},
+ {null, "lorem"},
+ {null, null},
+ {"quia", null}
+ });
+ }
+
+ protected void assertDictionary(FieldVector encoded, ArrowReader reader,
String... expected) throws Exception {
+ DictionaryEncoding dictionaryEncoding = encoded.getField().getDictionary();
+ BaseDictionary dictionary =
reader.getDictionaryVectors().get(dictionaryEncoding.getId());
+ try (ValueVector decoded = DictionaryEncoder.decode(encoded, dictionary)) {
+ Assertions.assertEquals(expected.length, encoded.getValueCount());
+ for (int i = 0; i < expected.length; i++) {
+ if (expected[i] == null) {
+ Assertions.assertNull(decoded.getObject(i));
Review Comment:
ack
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]