Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/21570#discussion_r195636548
--- Diff:
sql/core/src/test/java/test/org/apache/spark/sql/execution/sort/RecordBinaryComparatorSuite.java
---
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.execution.sort;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.TestMemoryConsumer;
+import org.apache.spark.memory.TestMemoryManager;
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
+import org.apache.spark.sql.execution.RecordBinaryComparator;
+import org.apache.spark.unsafe.Platform;
+import org.apache.spark.unsafe.UnsafeAlignedOffset;
+import org.apache.spark.unsafe.array.LongArray;
+import org.apache.spark.unsafe.memory.MemoryBlock;
+import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.util.collection.unsafe.sort.*;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Test the RecordBinaryComparator, which compares two UnsafeRows by their
binary form.
+ */
+public class RecordBinaryComparatorSuite {
+
+ private final TaskMemoryManager memoryManager = new TaskMemoryManager(
+ new TestMemoryManager(new
SparkConf().set("spark.memory.offHeap.enabled", "false")), 0);
+ private final TestMemoryConsumer consumer = new
TestMemoryConsumer(memoryManager);
+
+ private final int uaoSize = UnsafeAlignedOffset.getUaoSize();
+
+ private MemoryBlock dataPage;
+ private long pageCursor;
+
+ private LongArray array;
+ private int pos;
+
+ @Before
+ public void beforeEach() {
+ // Only compare between two input rows.
+ array = consumer.allocateArray(2);
+ pos = 0;
+
+ dataPage = memoryManager.allocatePage(4096, consumer);
+ pageCursor = dataPage.getBaseOffset();
+ }
+
+ @After
+ public void afterEach() {
+ consumer.freePage(dataPage);
+ dataPage = null;
+ pageCursor = 0;
+
+ consumer.freeArray(array);
+ array = null;
+ pos = 0;
+ }
+
+ private void insertRow(UnsafeRow row) {
+ Object recordBase = row.getBaseObject();
+ long recordOffset = row.getBaseOffset();
+ int recordLength = row.getSizeInBytes();
+
+ Object baseObject = dataPage.getBaseObject();
+ assert(pageCursor + recordLength <= dataPage.getBaseOffset() +
dataPage.size());
+ long recordAddress = memoryManager.encodePageNumberAndOffset(dataPage,
pageCursor);
+ UnsafeAlignedOffset.putSize(baseObject, pageCursor, recordLength);
+ pageCursor += uaoSize;
+ Platform.copyMemory(recordBase, recordOffset, baseObject, pageCursor,
recordLength);
+ pageCursor += recordLength;
+
+ assert(pos < 2);
+ array.set(pos, recordAddress);
+ pos++;
+ }
+
+ private int compare(int index1, int index2) {
+ Object baseObject = dataPage.getBaseObject();
+
+ long recordAddress1 = array.get(index1);
+ long baseOffset1 = memoryManager.getOffsetInPage(recordAddress1) +
uaoSize;
+ int recordLength1 = UnsafeAlignedOffset.getSize(baseObject,
baseOffset1 - uaoSize);
+
+ long recordAddress2 = array.get(index2);
+ long baseOffset2 = memoryManager.getOffsetInPage(recordAddress2) +
uaoSize;
+ int recordLength2 = UnsafeAlignedOffset.getSize(baseObject,
baseOffset2 - uaoSize);
+
+ return binaryComparator.compare(baseObject, baseOffset1,
recordLength1, baseObject,
+ baseOffset2, recordLength2);
+ }
+
+ private final RecordComparator binaryComparator = new
RecordBinaryComparator();
+
+ // Compute the most compact size for UnsafeRow's backing data.
+ private int computeSizeInBytes(int originalSize) {
+ // All the UnsafeRows in this suite contains less than 64 columns, so
the bitSetSize shall
+ // always be 8.
+ return 8 + (originalSize + 7) / 8 * 8;
+ }
+
+ // Compute the relative offset of variable-length values.
+ private long relativeOffset(int numFields) {
+ // All the UnsafeRows in this suite contains less than 64 columns, so
the bitSetSize shall
+ // always be 8.
+ return 8 + numFields * Long.BYTES;
+ }
+
+ @Test
+ public void testBinaryComparatorForSingleColumnRow() throws Exception {
+ int numFields = 1;
+
+ UnsafeRow row1 = new UnsafeRow(numFields);
+ byte[] data1 = new byte[100];
+ row1.pointTo(data1, computeSizeInBytes(numFields * Long.BYTES));
+ row1.setInt(0, 11);
+
+ UnsafeRow row2 = new UnsafeRow(numFields);
+ byte[] data2 = new byte[100];
+ row2.pointTo(data2, computeSizeInBytes(numFields * Long.BYTES));
+ row2.setInt(0, 42);
+
+ insertRow(row1);
+ insertRow(row2);
+
+ assert(compare(0, 0) == 0);
+ assert(compare(0, 1) < 0);
+ }
+
+ @Test
+ public void testBinaryComparatorForMultipleColumnRow() throws Exception {
+ int numFields = 5;
+
+ UnsafeRow row1 = new UnsafeRow(numFields);
+ byte[] data1 = new byte[100];
+ row1.pointTo(data1, computeSizeInBytes(numFields * Double.BYTES));
--- End diff --
nit: ditto. Regardless of data type, width is always 8.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]