This is an automated email from the ASF dual-hosted git repository.

junhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new f4d052133 [Flink] Introduce range strategy for sort compaction. (#2749)
f4d052133 is described below

commit f4d05213395ced891dc5081401bde794cb7353b7
Author: wgcn <[email protected]>
AuthorDate: Mon Mar 18 14:52:05 2024 +0800

    [Flink] Introduce range strategy for sort compaction. (#2749)
---
 .../shortcodes/generated/core_configuration.html   |   7 +
 .../main/java/org/apache/paimon/CoreOptions.java   |  18 ++
 .../paimon/types/InternalRowToSizeVisitor.java     | 299 +++++++++++++++++++++
 .../paimon/types/InternalRowToSizeVisitorTest.java | 195 ++++++++++++++
 .../apache/paimon/flink/shuffle/RangeShuffle.java  | 167 +++++++++---
 .../org/apache/paimon/flink/sorter/SortUtils.java  |   4 +-
 .../SortCompactActionForUnawareBucketITCase.java   |  34 ++-
 .../paimon/flink/shuffle/RangeShuffleTest.java     |  75 ++++++
 8 files changed, 750 insertions(+), 49 deletions(-)

diff --git a/docs/layouts/shortcodes/generated/core_configuration.html 
b/docs/layouts/shortcodes/generated/core_configuration.html
index 52645cf8c..ecf92c2f1 100644
--- a/docs/layouts/shortcodes/generated/core_configuration.html
+++ b/docs/layouts/shortcodes/generated/core_configuration.html
@@ -593,6 +593,13 @@ This config option does not affect the default filesystem 
metastore.</td>
             <td>Duration</td>
             <td>In watermarking, if a source remains idle beyond the specified 
timeout duration, it triggers snapshot advancement and facilitates tag 
creation.</td>
         </tr>
+        <tr>
+            <td><h5>sort-compaction.range-strategy</h5></td>
+            <td style="word-wrap: break-word;">QUANTITY</td>
+            <td><p>Enum</p></td>
+            <td>The range strategy of sort compaction, the default value is 
quantity.
+If the data size allocated for the sorting task is uneven,which may lead to 
performance bottlenecks, the config can be set to size.<br /><br />Possible 
values:<ul><li>"SIZE"</li><li>"QUANTITY"</li></ul></td>
+        </tr>
         <tr>
             <td><h5>sort-engine</h5></td>
             <td style="word-wrap: break-word;">loser-tree</td>
diff --git a/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java 
b/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java
index 19554296c..9deeb6324 100644
--- a/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java
+++ b/paimon-common/src/main/java/org/apache/paimon/CoreOptions.java
@@ -1083,6 +1083,14 @@ public class CoreOptions implements Serializable {
                             "Whether to enable deletion vectors mode. In this 
mode, index files containing deletion"
                                     + " vectors are generated when data is 
written, which marks the data for deletion."
                                     + " During read operations, by applying 
these index files, merging can be avoided.");
+    public static final ConfigOption<RangeStrategy> SORT_RANG_STRATEGY =
+            key("sort-compaction.range-strategy")
+                    .enumType(RangeStrategy.class)
+                    .defaultValue(RangeStrategy.QUANTITY)
+                    .withDescription(
+                            "The range strategy of sort compaction, the 
default value is quantity.\n"
+                                    + "If the data size allocated for the 
sorting task is uneven,which may lead to performance bottlenecks, "
+                                    + "the config can be set to size.");
 
     private final Options options;
 
@@ -1150,6 +1158,10 @@ public class CoreOptions implements Serializable {
         return options.get(PARTITION_DEFAULT_NAME);
     }
 
+    public boolean sortBySize() {
+        return options.get(SORT_RANG_STRATEGY) == RangeStrategy.SIZE;
+    }
+
     public static FileFormat createFileFormat(
             Options options, ConfigOption<FileFormatType> formatOption) {
         String formatIdentifier = options.get(formatOption).toString();
@@ -2210,6 +2222,12 @@ public class CoreOptions implements Serializable {
         }
     }
 
+    /** Specifies range strategy. */
+    public enum RangeStrategy {
+        SIZE,
+        QUANTITY
+    }
+
     /** Specifies the log consistency mode for table. */
     public enum ConsumerMode implements DescribedEnum {
         EXACTLY_ONCE(
diff --git 
a/paimon-common/src/main/java/org/apache/paimon/types/InternalRowToSizeVisitor.java
 
b/paimon-common/src/main/java/org/apache/paimon/types/InternalRowToSizeVisitor.java
new file mode 100644
index 000000000..4e1ad782c
--- /dev/null
+++ 
b/paimon-common/src/main/java/org/apache/paimon/types/InternalRowToSizeVisitor.java
@@ -0,0 +1,299 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.types;
+
+import org.apache.paimon.data.DataGetters;
+import org.apache.paimon.data.InternalArray;
+import org.apache.paimon.data.InternalMap;
+import org.apache.paimon.data.InternalRow;
+
+import java.util.List;
+import java.util.function.BiFunction;
+
+/** The class is to calculate the occupied space size based on Datatype. */
+public class InternalRowToSizeVisitor
+        implements DataTypeVisitor<BiFunction<DataGetters, Integer, Integer>> {
+
+    public static final int NULL_SIZE = 0;
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(CharType charType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return row.getString(index).toBytes().length;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(VarCharType 
varCharType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return row.getString(index).toBytes().length;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(BooleanType 
booleanType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return 1;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(BinaryType 
binaryType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return row.getBinary(index).length;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(VarBinaryType 
varBinaryType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return row.getBinary(index).length;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(DecimalType 
decimalType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return row.getDecimal(index, decimalType.getPrecision(), 
decimalType.getScale())
+                        .toUnscaledBytes()
+                        .length;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(TinyIntType 
tinyIntType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return 1;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(SmallIntType 
smallIntType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return 2;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(IntType intType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return 4;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(BigIntType 
bigIntType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return 8;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(FloatType 
floatType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return 4;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(DoubleType 
doubleType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return 8;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(DateType dateType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return 4;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(TimeType timeType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return 4;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(TimestampType 
timestampType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return 8;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(
+            LocalZonedTimestampType localZonedTimestampType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                return 8;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(ArrayType 
arrayType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                BiFunction<DataGetters, Integer, Integer> function =
+                        arrayType.getElementType().accept(this);
+                InternalArray internalArray = row.getArray(index);
+
+                int size = 0;
+                for (int i = 0; i < internalArray.size(); i++) {
+                    size += function.apply(internalArray, i);
+                }
+
+                return size;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(MultisetType 
multisetType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                BiFunction<DataGetters, Integer, Integer> function =
+                        multisetType.getElementType().accept(this);
+                InternalMap map = row.getMap(index);
+
+                int size = 0;
+                for (int i = 0; i < map.size(); i++) {
+                    size += function.apply(map.keyArray(), i);
+                }
+
+                return size;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(MapType mapType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+
+                BiFunction<DataGetters, Integer, Integer> keyFunction =
+                        mapType.getKeyType().accept(this);
+                BiFunction<DataGetters, Integer, Integer> valueFunction =
+                        mapType.getValueType().accept(this);
+
+                InternalMap map = row.getMap(index);
+
+                int size = 0;
+                for (int i = 0; i < map.size(); i++) {
+                    size += keyFunction.apply(map.keyArray(), i);
+                }
+
+                for (int i = 0; i < map.size(); i++) {
+                    size += valueFunction.apply(map.valueArray(), i);
+                }
+
+                return size;
+            }
+        };
+    }
+
+    @Override
+    public BiFunction<DataGetters, Integer, Integer> visit(RowType rowType) {
+        return (row, index) -> {
+            if (row.isNullAt(index)) {
+                return NULL_SIZE;
+            } else {
+                int size = 0;
+                List<DataType> fieldTypes = rowType.getFieldTypes();
+                InternalRow nestRow = row.getRow(index, 
rowType.getFieldCount());
+                for (int i = 0; i < fieldTypes.size(); i++) {
+                    DataType dataType = fieldTypes.get(i);
+                    size += dataType.accept(this).apply(nestRow, i);
+                }
+                return size;
+            }
+        };
+    }
+}
diff --git 
a/paimon-common/src/test/java/org/apache/paimon/types/InternalRowToSizeVisitorTest.java
 
b/paimon-common/src/test/java/org/apache/paimon/types/InternalRowToSizeVisitorTest.java
new file mode 100644
index 000000000..cfdae649c
--- /dev/null
+++ 
b/paimon-common/src/test/java/org/apache/paimon/types/InternalRowToSizeVisitorTest.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.types;
+
+import org.apache.paimon.data.BinaryString;
+import org.apache.paimon.data.DataGetters;
+import org.apache.paimon.data.Decimal;
+import org.apache.paimon.data.GenericArray;
+import org.apache.paimon.data.GenericMap;
+import org.apache.paimon.data.GenericRow;
+import org.apache.paimon.data.Timestamp;
+
+import org.assertj.core.api.Assertions;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+import java.math.BigDecimal;
+import java.util.Collections;
+import java.util.List;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
+
+/** Test for InternalRowToSizeVisitor. */
+public class InternalRowToSizeVisitorTest {
+
+    private List<BiFunction<DataGetters, Integer, Integer>> 
feildSizeCalculator;
+
+    @BeforeEach
+    void setUp() {
+        RowType rowType =
+                RowType.builder()
+                        .field("a0", DataTypes.INT())
+                        .field("a1", DataTypes.TINYINT())
+                        .field("a2", DataTypes.SMALLINT())
+                        .field("a3", DataTypes.BIGINT())
+                        .field("a4", DataTypes.STRING())
+                        .field("a5", DataTypes.DOUBLE())
+                        .field("a6", DataTypes.ARRAY(DataTypes.STRING()))
+                        .field("a7", DataTypes.CHAR(100))
+                        .field("a8", DataTypes.VARCHAR(100))
+                        .field("a9", DataTypes.BOOLEAN())
+                        .field("a10", DataTypes.DATE())
+                        .field("a11", DataTypes.TIME())
+                        .field("a12", DataTypes.TIMESTAMP())
+                        .field("a13", DataTypes.TIMESTAMP_MILLIS())
+                        .field("a14", DataTypes.DECIMAL(3, 3))
+                        .field("a15", DataTypes.BYTES())
+                        .field("a16", DataTypes.FLOAT())
+                        .field("a17", DataTypes.MAP(DataTypes.STRING(), 
DataTypes.STRING()))
+                        .field("a18", DataTypes.ROW(DataTypes.FIELD(100, "b1", 
DataTypes.STRING())))
+                        .field("a19", DataTypes.BINARY(100))
+                        .field("a20", DataTypes.VARBINARY(100))
+                        .field("a21", DataTypes.MULTISET(DataTypes.STRING()))
+                        .field(
+                                "a22",
+                                DataTypes.ROW(
+                                        DataTypes.FIELD(
+                                                101,
+                                                "b2",
+                                                DataTypes.ROW(
+                                                        DataTypes.FIELD(
+                                                                102,
+                                                                "b3",
+                                                                DataTypes.MAP(
+                                                                        
DataTypes.STRING(),
+                                                                        
DataTypes.STRING())),
+                                                        DataTypes.FIELD(
+                                                                103,
+                                                                "b4",
+                                                                
DataTypes.ARRAY(
+                                                                        
DataTypes.STRING())),
+                                                        DataTypes.FIELD(
+                                                                104,
+                                                                "b5",
+                                                                
DataTypes.MULTISET(
+                                                                        
DataTypes.STRING()))))))
+                        .field("a23", DataTypes.MULTISET(DataTypes.STRING()))
+                        .build();
+
+        InternalRowToSizeVisitor internalRowToSizeVisitor = new 
InternalRowToSizeVisitor();
+        feildSizeCalculator =
+                rowType.getFieldTypes().stream()
+                        .map(dataType -> 
dataType.accept(internalRowToSizeVisitor))
+                        .collect(Collectors.toList());
+    }
+
+    @Test
+    void testCalculatorSize() {
+        GenericRow row = new GenericRow(24);
+
+        row.setField(0, 1);
+        Assertions.assertThat(feildSizeCalculator.get(0).apply(row, 
0)).isEqualTo(4);
+
+        row.setField(1, (byte) 1);
+        Assertions.assertThat(feildSizeCalculator.get(1).apply(row, 
1)).isEqualTo(1);
+
+        row.setField(2, (short) 1);
+        Assertions.assertThat(feildSizeCalculator.get(2).apply(row, 
2)).isEqualTo(2);
+
+        row.setField(3, 1L);
+        Assertions.assertThat(feildSizeCalculator.get(3).apply(row, 
3)).isEqualTo(8);
+
+        row.setField(4, BinaryString.fromString("a"));
+        Assertions.assertThat(feildSizeCalculator.get(4).apply(row, 
4)).isEqualTo(1);
+
+        row.setField(5, 0.5D);
+        Assertions.assertThat(feildSizeCalculator.get(5).apply(row, 
5)).isEqualTo(8);
+
+        row.setField(6, new GenericArray(new Object[] 
{BinaryString.fromString("1")}));
+        Assertions.assertThat(feildSizeCalculator.get(6).apply(row, 
6)).isEqualTo(1);
+
+        row.setField(7, BinaryString.fromString("3"));
+        Assertions.assertThat(feildSizeCalculator.get(7).apply(row, 
7)).isEqualTo(1);
+
+        row.setField(8, BinaryString.fromString("3"));
+        Assertions.assertThat(feildSizeCalculator.get(8).apply(row, 
8)).isEqualTo(1);
+
+        row.setField(9, true);
+        Assertions.assertThat(feildSizeCalculator.get(9).apply(row, 
9)).isEqualTo(1);
+
+        row.setField(10, 375);
+        Assertions.assertThat(feildSizeCalculator.get(10).apply(row, 
10)).isEqualTo(4);
+
+        row.setField(11, 100);
+        Assertions.assertThat(feildSizeCalculator.get(11).apply(row, 
11)).isEqualTo(4);
+
+        row.setField(12, Timestamp.fromEpochMillis(1685548953000L));
+        Assertions.assertThat(feildSizeCalculator.get(12).apply(row, 
12)).isEqualTo(8);
+
+        row.setField(13, Timestamp.fromEpochMillis(1685548953000L));
+        Assertions.assertThat(feildSizeCalculator.get(13).apply(row, 
13)).isEqualTo(8);
+
+        row.setField(14, Decimal.fromBigDecimal(new BigDecimal("0.22"), 3, 3));
+        Assertions.assertThat(feildSizeCalculator.get(14).apply(row, 
14)).isEqualTo(2);
+
+        row.setField(15, new byte[] {1, 5, 2});
+        Assertions.assertThat(feildSizeCalculator.get(15).apply(row, 
15)).isEqualTo(3);
+
+        row.setField(16, 0.26F);
+        Assertions.assertThat(feildSizeCalculator.get(16).apply(row, 
16)).isEqualTo(4);
+
+        row.setField(
+                17,
+                new GenericMap(
+                        Collections.singletonMap(
+                                BinaryString.fromString("k"), 
BinaryString.fromString("v"))));
+        Assertions.assertThat(feildSizeCalculator.get(17).apply(row, 
17)).isEqualTo(2);
+
+        row.setField(18, GenericRow.of(BinaryString.fromString("cc")));
+        Assertions.assertThat(feildSizeCalculator.get(18).apply(row, 
18)).isEqualTo(2);
+
+        row.setField(19, "bb".getBytes());
+        Assertions.assertThat(feildSizeCalculator.get(19).apply(row, 
19)).isEqualTo(2);
+
+        row.setField(20, "aa".getBytes());
+        Assertions.assertThat(feildSizeCalculator.get(20).apply(row, 
20)).isEqualTo(2);
+
+        row.setField(
+                21, new 
GenericMap(Collections.singletonMap(BinaryString.fromString("set"), 1)));
+
+        Assertions.assertThat(feildSizeCalculator.get(21).apply(row, 
21)).isEqualTo(3);
+
+        row.setField(
+                22,
+                GenericRow.of(
+                        GenericRow.of(
+                                new GenericMap(
+                                        Collections.singletonMap(
+                                                BinaryString.fromString("k"),
+                                                BinaryString.fromString("v"))),
+                                new GenericArray(new Object[] 
{BinaryString.fromString("1")}),
+                                new GenericMap(
+                                        Collections.singletonMap(
+                                                
BinaryString.fromString("set"), 1)))));
+        Assertions.assertThat(feildSizeCalculator.get(22).apply(row, 
22)).isEqualTo(6);
+
+        Assertions.assertThat(feildSizeCalculator.get(23).apply(row, 
23)).isEqualTo(0);
+    }
+}
diff --git 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/shuffle/RangeShuffle.java
 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/shuffle/RangeShuffle.java
index a81cc7710..9c67e8855 100644
--- 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/shuffle/RangeShuffle.java
+++ 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/shuffle/RangeShuffle.java
@@ -18,18 +18,26 @@
 
 package org.apache.paimon.flink.shuffle;
 
+import org.apache.paimon.annotation.VisibleForTesting;
+import org.apache.paimon.data.DataGetters;
+import org.apache.paimon.flink.FlinkRowWrapper;
+import org.apache.paimon.types.InternalRowToSizeVisitor;
+import org.apache.paimon.types.RowType;
 import org.apache.paimon.utils.Pair;
 import org.apache.paimon.utils.SerializableSupplier;
 
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.functions.Partitioner;
+import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.dag.Transformation;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
 import org.apache.flink.api.java.typeutils.ListTypeInfo;
 import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.operators.BoundedOneInput;
 import org.apache.flink.streaming.api.operators.InputSelectable;
@@ -54,12 +62,13 @@ import org.apache.flink.util.XORShiftRandom;
 
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Collections;
 import java.util.Comparator;
 import java.util.Iterator;
 import java.util.List;
 import java.util.PriorityQueue;
 import java.util.Random;
+import java.util.function.BiFunction;
+import java.util.stream.Collectors;
 
 /**
  * RangeShuffle Util to shuffle the input stream by the sampling range. See 
`rangeShuffleBykey`
@@ -89,28 +98,33 @@ public class RangeShuffle {
             TypeInformation<T> keyTypeInformation,
             int sampleSize,
             int rangeNum,
-            int outParallelism) {
+            int outParallelism,
+            RowType valueRowType,
+            boolean isSortBySize) {
         Transformation<Tuple2<T, RowData>> input = 
inputDataStream.getTransformation();
 
-        OneInputTransformation<Tuple2<T, RowData>, T> keyInput =
+        OneInputTransformation<Tuple2<T, RowData>, Tuple2<T, Integer>> 
keyInput =
                 new OneInputTransformation<>(
                         input,
-                        "ABSTRACT KEY",
-                        new StreamMap<>(a -> a.f0),
-                        keyTypeInformation,
+                        "ABSTRACT KEY AND SIZE",
+                        new StreamMap<>(new 
KeyAndSizeExtractor<>(valueRowType, isSortBySize)),
+                        new TupleTypeInfo<>(keyTypeInformation, 
BasicTypeInfo.INT_TYPE_INFO),
                         input.getParallelism());
 
         // 1. Fixed size sample in each partitions.
-        OneInputTransformation<T, Tuple2<Double, T>> localSample =
+        OneInputTransformation<Tuple2<T, Integer>, Tuple3<Double, T, Integer>> 
localSample =
                 new OneInputTransformation<>(
                         keyInput,
                         "LOCAL SAMPLE",
                         new LocalSampleOperator<>(sampleSize),
-                        new TupleTypeInfo<>(BasicTypeInfo.DOUBLE_TYPE_INFO, 
keyTypeInformation),
+                        new TupleTypeInfo<>(
+                                BasicTypeInfo.DOUBLE_TYPE_INFO,
+                                keyTypeInformation,
+                                BasicTypeInfo.INT_TYPE_INFO),
                         keyInput.getParallelism());
 
         // 2. Collect all the samples and gather them into a sorted key range.
-        OneInputTransformation<Tuple2<Double, T>, List<T>> sampleAndHistogram =
+        OneInputTransformation<Tuple3<Double, T, Integer>, List<T>> 
sampleAndHistogram =
                 new OneInputTransformation<>(
                         localSample,
                         "GLOBAL SAMPLE",
@@ -155,6 +169,47 @@ public class RangeShuffle {
                         outParallelism));
     }
 
+    /** KeyAndSizeExtractor is responsible for extracting the sort key and row 
size. */
+    public static class KeyAndSizeExtractor<T>
+            extends RichMapFunction<Tuple2<T, RowData>, Tuple2<T, Integer>> {
+        private final RowType rowType;
+        private final boolean isSortBySize;
+        private transient List<BiFunction<DataGetters, Integer, Integer>> 
fieldSizeCalculator;
+
+        public KeyAndSizeExtractor(RowType rowType, boolean isSortBySize) {
+            this.rowType = rowType;
+            this.isSortBySize = isSortBySize;
+        }
+
+        @Override
+        public void open(Configuration parameters) throws Exception {
+            super.open(parameters);
+            InternalRowToSizeVisitor internalRowToSizeVisitor = new 
InternalRowToSizeVisitor();
+            fieldSizeCalculator =
+                    rowType.getFieldTypes().stream()
+                            .map(dataType -> 
dataType.accept(internalRowToSizeVisitor))
+                            .collect(Collectors.toList());
+        }
+
+        @Override
+        public Tuple2<T, Integer> map(Tuple2<T, RowData> keyAndRowData) throws 
Exception {
+            if (isSortBySize) {
+                int size = 0;
+                for (int i = 0; i < fieldSizeCalculator.size(); i++) {
+                    size +=
+                            fieldSizeCalculator
+                                    .get(i)
+                                    .apply(new 
FlinkRowWrapper(keyAndRowData.f1), i);
+                }
+                return new Tuple2<>(keyAndRowData.f0, size);
+            } else {
+                // when basing on quantity, we don't need the size of the 
data, so setting it to a
+                // constant of 1 would be sufficient.
+                return new Tuple2<>(keyAndRowData.f0, 1);
+            }
+        }
+    }
+
     /**
      * LocalSampleOperator wraps the sample logic on the partition side (the 
first phase of
      * distributed sample algorithm). Outputs sampled weight with record.
@@ -162,15 +217,17 @@ public class RangeShuffle {
      * <p>See {@link Sampler}.
      */
     @Internal
-    public static class LocalSampleOperator<T> extends 
TableStreamOperator<Tuple2<Double, T>>
-            implements OneInputStreamOperator<T, Tuple2<Double, T>>, 
BoundedOneInput {
+    public static class LocalSampleOperator<T>
+            extends TableStreamOperator<Tuple3<Double, T, Integer>>
+            implements OneInputStreamOperator<Tuple2<T, Integer>, 
Tuple3<Double, T, Integer>>,
+                    BoundedOneInput {
 
         private static final long serialVersionUID = 1L;
 
         private final int numSample;
 
-        private transient Collector<Tuple2<Double, T>> collector;
-        private transient Sampler<T> sampler;
+        private transient Collector<Tuple3<Double, T, Integer>> collector;
+        private transient Sampler<Tuple2<T, Integer>> sampler;
 
         public LocalSampleOperator(int numSample) {
             this.numSample = numSample;
@@ -184,15 +241,17 @@ public class RangeShuffle {
         }
 
         @Override
-        public void processElement(StreamRecord<T> streamRecord) throws 
Exception {
+        public void processElement(StreamRecord<Tuple2<T, Integer>> 
streamRecord) throws Exception {
             sampler.collect(streamRecord.getValue());
         }
 
         @Override
-        public void endInput() throws Exception {
-            Iterator<Tuple2<Double, T>> sampled = sampler.sample();
+        public void endInput() {
+            Iterator<Tuple2<Double, Tuple2<T, Integer>>> sampled = 
sampler.sample();
             while (sampled.hasNext()) {
-                collector.collect(sampled.next());
+                Tuple2<Double, Tuple2<T, Integer>> next = sampled.next();
+
+                collector.collect(new Tuple3<>(next.f0, next.f1.f0, 
next.f1.f1));
             }
         }
     }
@@ -203,7 +262,8 @@ public class RangeShuffle {
      * <p>See {@link Sampler}.
      */
     private static class GlobalSampleOperator<T> extends 
TableStreamOperator<List<T>>
-            implements OneInputStreamOperator<Tuple2<Double, T>, List<T>>, 
BoundedOneInput {
+            implements OneInputStreamOperator<Tuple3<Double, T, Integer>, 
List<T>>,
+                    BoundedOneInput {
 
         private static final long serialVersionUID = 1L;
 
@@ -213,7 +273,7 @@ public class RangeShuffle {
 
         private transient Comparator<T> keyComparator;
         private transient Collector<List<T>> collector;
-        private transient Sampler<T> sampler;
+        private transient Sampler<Tuple2<T, Integer>> sampler;
 
         public GlobalSampleOperator(
                 int numSample,
@@ -233,35 +293,32 @@ public class RangeShuffle {
         }
 
         @Override
-        public void processElement(StreamRecord<Tuple2<Double, T>> record) 
throws Exception {
-            Tuple2<Double, T> tuple = record.getValue();
-            sampler.collect(tuple.f0, tuple.f1);
+        public void processElement(StreamRecord<Tuple3<Double, T, Integer>> 
record)
+                throws Exception {
+            Tuple3<Double, T, Integer> tuple = record.getValue();
+            sampler.collect(tuple.f0, new Tuple2<>(tuple.f1, tuple.f2));
         }
 
         @Override
-        public void endInput() throws Exception {
-            Iterator<Tuple2<Double, T>> sampled = sampler.sample();
+        public void endInput() {
+            Iterator<Tuple2<Double, Tuple2<T, Integer>>> sampled = 
sampler.sample();
+
+            List<Tuple2<T, Integer>> sampledData = new ArrayList<>();
 
-            List<T> sampledData = new ArrayList<>();
             while (sampled.hasNext()) {
                 sampledData.add(sampled.next().f1);
             }
 
-            sampledData.sort(keyComparator);
+            sampledData.sort((o1, o2) -> keyComparator.compare(o1.f0, o2.f0));
 
-            int boundarySize = rangesNum - 1;
-            @SuppressWarnings("unchecked")
-            T[] boundaries = (T[]) new Object[boundarySize];
-            if (sampledData.size() > 0) {
-                double avgRange = sampledData.size() / (double) rangesNum;
-                for (int i = 1; i < rangesNum; i++) {
-                    T record = sampledData.get((int) (i * avgRange));
-                    boundaries[i - 1] = record;
-                }
-                collector.collect(Arrays.asList(boundaries));
+            List<T> range;
+            if (sampledData.isEmpty()) {
+                range = new ArrayList<>();
             } else {
-                collector.collect(Collections.emptyList());
+                range = Arrays.asList(allocateRangeBaseSize(sampledData, 
rangesNum));
             }
+
+            collector.collect(range);
         }
     }
 
@@ -488,4 +545,40 @@ public class RangeShuffle {
             return list.get(RANDOM.nextInt(list.size()));
         }
     }
+
+    @VisibleForTesting
+    static <T> T[] allocateRangeBaseSize(List<Tuple2<T, Integer>> sampledData, 
int rangesNum) {
+        int sampeNum = sampledData.size();
+        int boundarySize = rangesNum - 1;
+        @SuppressWarnings("unchecked")
+        T[] boundaries = (T[]) new Object[boundarySize];
+
+        if (!sampledData.isEmpty()) {
+            long restSize = sampledData.stream().mapToLong(t -> (long) 
t.f1).sum();
+            double stepRange = restSize / (double) rangesNum;
+
+            int currentWeight = 0;
+            int index = 0;
+
+            for (int i = 0; i < boundarySize; i++) {
+                while (currentWeight < stepRange && index < sampeNum) {
+                    boundaries[i] = sampledData.get(Math.min(index, sampeNum - 
1)).f0;
+                    int sampleWeight = sampledData.get(index++).f1;
+                    currentWeight += sampleWeight;
+                    restSize -= sampleWeight;
+                }
+
+                currentWeight = 0;
+                stepRange = restSize / (double) (rangesNum - i - 1);
+            }
+        }
+
+        for (int i = 0; i < boundarySize; i++) {
+            if (boundaries[i] == null) {
+                boundaries[i] = sampledData.get(sampeNum - 1).f0;
+            }
+        }
+
+        return boundaries;
+    }
 }
diff --git 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sorter/SortUtils.java
 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sorter/SortUtils.java
index 4d049ff72..c31aaa1a4 100644
--- 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sorter/SortUtils.java
+++ 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sorter/SortUtils.java
@@ -146,7 +146,9 @@ public class SortUtils {
                         keyTypeInformation,
                         sampleSize,
                         rangeNum,
-                        sinkParallelism)
+                        sinkParallelism,
+                        valueRowType,
+                        options.sortBySize())
                 .map(
                         a -> new JoinedRow(convertor.apply(a.f0), new 
FlinkRowWrapper(a.f1)),
                         internalRowType)
diff --git 
a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/action/SortCompactActionForUnawareBucketITCase.java
 
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/action/SortCompactActionForUnawareBucketITCase.java
index ee43ca58e..272b3516a 100644
--- 
a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/action/SortCompactActionForUnawareBucketITCase.java
+++ 
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/action/SortCompactActionForUnawareBucketITCase.java
@@ -331,30 +331,35 @@ public class SortCompactActionForUnawareBucketITCase 
extends ActionITCaseBase {
     }
 
     private void zorder(List<String> columns) throws Exception {
+        String rangeStrategy = RANDOM.nextBoolean() ? "size" : "quantity";
         if (RANDOM.nextBoolean()) {
-            createAction("zorder", columns).run();
+            createAction("zorder", rangeStrategy, columns).run();
         } else {
-            callProcedure("zorder", columns);
+            callProcedure("zorder", rangeStrategy, columns);
         }
     }
 
     private void hilbert(List<String> columns) throws Exception {
+        String rangeStrategy = RANDOM.nextBoolean() ? "size" : "quantity";
         if (RANDOM.nextBoolean()) {
-            createAction("hilbert", columns).run();
+            createAction("hilbert", rangeStrategy, columns).run();
         } else {
-            callProcedure("hilbert", columns);
+            callProcedure("hilbert", rangeStrategy, columns);
         }
     }
 
     private void order(List<String> columns) throws Exception {
+        String rangeStrategy = RANDOM.nextBoolean() ? "size" : "quantity";
         if (RANDOM.nextBoolean()) {
-            createAction("order", columns).run();
+            createAction("order", rangeStrategy, columns).run();
         } else {
-            callProcedure("order", columns);
+            callProcedure("order", rangeStrategy, columns);
         }
     }
 
-    private SortCompactAction createAction(String orderStrategy, List<String> 
columns) {
+    private SortCompactAction createAction(
+            String orderStrategy, String rangeStrategy, List<String> columns) {
+
         return createAction(
                 SortCompactAction.class,
                 "compact",
@@ -367,14 +372,21 @@ public class SortCompactActionForUnawareBucketITCase 
extends ActionITCaseBase {
                 "--order_strategy",
                 orderStrategy,
                 "--order_by",
-                String.join(",", columns));
+                String.join(",", columns),
+                "--table_conf sort-compaction.range-strategy=" + rangeStrategy,
+                rangeStrategy);
     }
 
-    private void callProcedure(String orderStrategy, List<String> 
orderByColumns) {
+    private void callProcedure(
+            String orderStrategy, String rangeStrategy, List<String> 
orderByColumns) {
         callProcedure(
                 String.format(
-                        "CALL sys.compact('%s.%s', 'ALL', '%s', '%s')",
-                        database, tableName, orderStrategy, String.join(",", 
orderByColumns)),
+                        "CALL sys.compact('%s.%s', 'ALL', '%s', 
'%s','sort-compaction.range-strategy=%s')",
+                        database,
+                        tableName,
+                        orderStrategy,
+                        String.join(",", orderByColumns),
+                        rangeStrategy),
                 false,
                 true);
     }
diff --git 
a/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/shuffle/RangeShuffleTest.java
 
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/shuffle/RangeShuffleTest.java
new file mode 100644
index 000000000..cb79a8628
--- /dev/null
+++ 
b/paimon-flink/paimon-flink-common/src/test/java/org/apache/paimon/flink/shuffle/RangeShuffleTest.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.paimon.flink.shuffle;
+
+import org.apache.paimon.shade.guava30.com.google.common.collect.Lists;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+/** Test for {@link RangeShuffle}. */
+class RangeShuffleTest {
+
+    @Test
+    void testAllocateRange() {
+
+        // the size of test data is even
+        List<Tuple2<Integer, Integer>> test0 =
+                Lists.newArrayList(
+                        // key and size
+                        new Tuple2<>(1, 1),
+                        new Tuple2<>(2, 1),
+                        new Tuple2<>(3, 1),
+                        new Tuple2<>(4, 1),
+                        new Tuple2<>(5, 1),
+                        new Tuple2<>(6, 1));
+        Assertions.assertEquals(
+                "[2, 4]", 
Arrays.deepToString(RangeShuffle.allocateRangeBaseSize(test0, 3)));
+
+        // the size of test data is uneven,but can be evenly split based size
+        List<Tuple2<Integer, Integer>> test2 =
+                Lists.newArrayList(
+                        new Tuple2<>(1, 1),
+                        new Tuple2<>(2, 1),
+                        new Tuple2<>(3, 1),
+                        new Tuple2<>(4, 1),
+                        new Tuple2<>(5, 4),
+                        new Tuple2<>(6, 4),
+                        new Tuple2<>(7, 4));
+        Assertions.assertEquals(
+                "[4, 5, 6]", 
Arrays.deepToString(RangeShuffle.allocateRangeBaseSize(test2, 4)));
+
+        // the size of test data is uneven,and can not be evenly split
+        List<Tuple2<Integer, Integer>> test1 =
+                Lists.newArrayList(
+                        new Tuple2<>(1, 1),
+                        new Tuple2<>(2, 2),
+                        new Tuple2<>(3, 3),
+                        new Tuple2<>(4, 1),
+                        new Tuple2<>(5, 2),
+                        new Tuple2<>(6, 3));
+
+        Assertions.assertEquals(
+                "[3, 5]", 
Arrays.deepToString(RangeShuffle.allocateRangeBaseSize(test1, 3)));
+    }
+}


Reply via email to