Github user fhueske commented on a diff in the pull request:
https://github.com/apache/flink/pull/3511#discussion_r139928885
--- Diff:
flink-runtime/src/main/java/org/apache/flink/runtime/codegeneration/SorterTemplateModel.java
---
@@ -0,0 +1,424 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.codegeneration;
+
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.runtime.operators.sort.NormalizedKeySorter;
+
+import java.nio.ByteOrder;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * {@link SorterTemplateModel} is a class that implements code generation
logic for a given
+ * {@link TypeComparator}.
+ *
+ * <p>The swap and compare methods in {@link NormalizedKeySorter} work on
a sequence of bytes.
+ * We speed up these operations by splitting this sequence of bytes into
chunks that can
+ * be handled by primitive operations such as Integer and Long
operations.</p>
+ */
+class SorterTemplateModel {
+
+ //
------------------------------------------------------------------------
+ // Constants
+ //
------------------------------------------------------------------------
+
+ static final String TEMPLATE_NAME = "sorter.ftlh";
+
+ /** We don't split to chunks above this size. */
+ private static final int SPLITTING_THRESHOLD = 32;
+
+ /**
+ * POSSIBLE_CHUNK_SIZES must be in descending order,
+ * because methods that using it are using greedy approach.
+ */
+ private static final Integer[] POSSIBLE_CHUNK_SIZES = {8, 4, 2, 1};
+
+ /** Mapping from chunk sizes to primitive operators. */
+ private static final HashMap<Integer, String> byteOperatorMapping = new
HashMap<Integer, String>(){
+ {
+ put(8, "Long");
+ put(4, "Int");
+ put(2, "Short");
+ put(1, "Byte");
+ }
+ };
+
+ //
------------------------------------------------------------------------
+ // Attributes
+ //
------------------------------------------------------------------------
+
+ private final TypeComparator typeComparator;
+
+ /**
+ * Sizes of the chunks. Empty, if we are not splitting to chunks. (See
calculateChunks())
+ */
+ private final ArrayList<Integer> primitiveChunks;
+
+ private final String sorterName;
+
+ /**
+ * Shows whether the order of records can be completely determined by
the normalized
+ * sorting key, or the sorter has to also deserialize records if their
keys are equal to
+ * really confirm the order.
+ */
+ private final boolean normalizedKeyFullyDetermines;
+
+ /**
+ * Constructor.
+ * @param typeComparator
+ * The type information of underlying data
+ */
+ SorterTemplateModel(TypeComparator typeComparator){
+ this.typeComparator = typeComparator;
+
+ // number of bytes of the sorting key
+ int numKeyBytes;
+
+ // compute no. bytes for sorting records and check whether
these bytes are just a prefix or not.
+ if (this.typeComparator.supportsNormalizedKey()) {
+ // compute the max normalized key length
+ int numPartialKeys;
+ try {
+ numPartialKeys =
this.typeComparator.getFlatComparators().length;
+ } catch (Throwable t) {
+ numPartialKeys = 1;
+ }
+
+ int maxLen =
Math.min(NormalizedKeySorter.DEFAULT_MAX_NORMALIZED_KEY_LEN,
NormalizedKeySorter.MAX_NORMALIZED_KEY_LEN_PER_ELEMENT * numPartialKeys);
+
+ numKeyBytes =
Math.min(this.typeComparator.getNormalizeKeyLen(), maxLen);
+ this.normalizedKeyFullyDetermines =
!this.typeComparator.isNormalizedKeyPrefixOnly(numKeyBytes);
+ }
+ else {
+ numKeyBytes = 0;
+ this.normalizedKeyFullyDetermines = false;
+ }
+
+ this.primitiveChunks = calculateChunks(numKeyBytes);
+
+ this.sorterName = generateCodeFilename(this.primitiveChunks,
this.normalizedKeyFullyDetermines);
+ }
+
+ //
------------------------------------------------------------------------
+ // Public Methods
+ //
------------------------------------------------------------------------
+
+ /**
+ * Generate suitable sequence of operators for creating custom
NormalizedKeySorter.
+ * @return map of procedures and corresponding code
+ */
+ Map<String, String> getTemplateVariables() {
+
+ Map<String, String> templateVariables = new HashMap<>();
+
+ templateVariables.put("name", this.sorterName);
+
+ String swapProcedures = generateSwapProcedures();
+ String writeProcedures = generateWriteProcedures();
+ String compareProcedures = generateCompareProcedures();
+
+ templateVariables.put("writeProcedures", writeProcedures);
+ templateVariables.put("swapProcedures", swapProcedures);
+ templateVariables.put("compareProcedures", compareProcedures);
+
+ return templateVariables;
+ }
+
+ /**
+ * Getter for sorterName (generated in the constructor).
+ * @return name of the sorter
+ */
+ String getSorterName(){
+ return this.sorterName;
+ }
+
+ //
------------------------------------------------------------------------
+ // Protected Methods
+ //
------------------------------------------------------------------------
+
+ /**
+ * Getter for primitiveChunks.
+ * this method is for testing purposes
+ */
+ ArrayList<Integer> getPrimitiveChunks(){
+ return primitiveChunks;
+ }
+
+ //
------------------------------------------------------------------------
+ // Private Methods
+ //
------------------------------------------------------------------------
+
+ /**
+ * Given no. of bytes, break it into chunks that can be handled by
+ * primitive operations (e.g., integer or long operations)
+ * @return ArrayList of chunk sizes
+ */
+ private ArrayList<Integer> calculateChunks(int numKeyBytes){
+ ArrayList<Integer> chunks = new ArrayList<>();
+
+ // if no. of bytes is too large, we don't split
+ if (numKeyBytes > SPLITTING_THRESHOLD) {
+ return chunks;
+ }
+
+ // also include the offset because of the pointer
+ numKeyBytes += NormalizedKeySorter.OFFSET_LEN;
+
+ // greedy finding of chunk sizes
+ int i = 0;
+ while (numKeyBytes > 0) {
+ int bytes = POSSIBLE_CHUNK_SIZES[i];
+ if (bytes <= numKeyBytes) {
+ chunks.add(bytes);
+ numKeyBytes -= bytes;
+ } else {
+ i++;
+ }
+ }
+
+ // generateCompareProcedures and generateWriteProcedures skip
the
+ // first 8 bytes, because it contains the pointer.
+ // They do this by skipping the first entry of primitiveChunks,
because that
+ // should always be 8 in this case.
+ if (!(NormalizedKeySorter.OFFSET_LEN == 8 &&
chunks.get(0).equals(8))) {
+ throw new RuntimeException("Bug: Incorrect OFFSET_LEN
or primitiveChunks");
+ }
+
+ return chunks;
+ }
+
+ /**
+ * Based on primitiveChunks variable, generate the most suitable
operators
+ * for swapping function.
+ *
+ * @return code used in the swap method
+ */
+ private String generateSwapProcedures(){
+ /* Example generated code, for 20 bytes (8+8+4):
+
+ long temp1 = segI.getLong(segmentOffsetI);
+ long temp2 = segI.getLong(segmentOffsetI+8);
+ int temp3 = segI.getInt(segmentOffsetI+16);
+
+ segI.putLong(segmentOffsetI, segJ.getLong(segmentOffsetJ));
+ segI.putLong(segmentOffsetI+8, segJ.getLong(segmentOffsetJ+8));
+ segI.putInt(segmentOffsetI+16, segJ.getInt(segmentOffsetJ+16));
+
+ segJ.putLong(segmentOffsetJ, temp1);
+ segJ.putLong(segmentOffsetJ+8, temp2);
+ segJ.putInt(segmentOffsetJ+16, temp3);
+ */
+
+ String procedures = "";
+
+ if (this.primitiveChunks.size() > 0) {
+ StringBuilder temporaryString = new StringBuilder();
+ StringBuilder firstSegmentString = new StringBuilder();
+ StringBuilder secondSegmentString = new StringBuilder();
+
+ int accOffset = 0;
+ for (int i = 0; i < primitiveChunks.size(); i++){
+ int numberByte = primitiveChunks.get(i);
+ int varIndex = i + 1;
+
+ String primitiveClass =
byteOperatorMapping.get(numberByte);
+ String primitiveType =
primitiveClass.toLowerCase();
+
+ String offsetString = "";
+ if (i > 0) {
+ accOffset += primitiveChunks.get(i - 1);
+ offsetString = "+" + accOffset;
+ }
+
+ temporaryString.append(String.format("%s temp%d
= segI.get%s(segmentOffsetI%s);\n",
+ primitiveType, varIndex,
primitiveClass, offsetString));
+
+
firstSegmentString.append(String.format("segI.put%s(segmentOffsetI%s,
segJ.get%s(segmentOffsetJ%s));\n",
+ primitiveClass, offsetString,
primitiveClass, offsetString));
+
+
secondSegmentString.append(String.format("segJ.put%s(segmentOffsetJ%s,
temp%d);\n",
+ primitiveClass, offsetString,
varIndex));
+
+ }
+
+ procedures = temporaryString.toString()
+ + "\n" + firstSegmentString.toString()
+ + "\n" + secondSegmentString.toString();
+ } else {
+ procedures = "segI.swapBytes(this.swapBuffer, segJ,
segmentOffsetI, segmentOffsetJ, this.indexEntrySize);";
+ }
+
+ return procedures;
+ }
+
+ /**
+ * Based on primitiveChunks variable, generate reverse byte operators
for little endian machine
+ * for writing a record to MemorySegment, such that later during
comparison
+ * we can directly use native byte order to do unsigned comparison.
+ *
+ * @return code used in the write method
+ */
+ private String generateWriteProcedures(){
+ /* Example generated code, for 12 bytes (8+4):
+
+ long temp1 =
Long.reverseBytes(this.currentSortIndexSegment.getLong(this.currentSortIndexOffset+8));
+
this.currentSortIndexSegment.putLong(this.currentSortIndexOffset + 8, temp1);
+ int temp2 =
Integer.reverseBytes(this.currentSortIndexSegment.getInt(this.currentSortIndexOffset+16));
+ this.currentSortIndexSegment.putInt(this.currentSortIndexOffset
+ 16, temp2);
+ */
+
+ StringBuilder procedures = new StringBuilder();
+ // skip the first chunk, which is the pointer before the key
+ if (primitiveChunks.size() > 1 && ByteOrder.nativeOrder() ==
ByteOrder.LITTLE_ENDIAN) {
+ int offset = 0;
+ // starts from 1 because of skipping the first chunk
+ for (int i = 1; i < primitiveChunks.size(); i++){
+ int noBytes = primitiveChunks.get(i);
+ if (noBytes == 1){
--- End diff --
add space between `){`
---