This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-paimon.git


The following commit(s) were added to refs/heads/master by this push:
     new 2e0e2366b [core] Introduce userDefineSeqComparator for MergeSorter 
(#2936)
2e0e2366b is described below

commit 2e0e2366bb49ccb7f6d11a5a8036d600c09c2546
Author: Jingsong Lee <[email protected]>
AuthorDate: Tue Mar 5 10:40:35 2024 +0800

    [core] Introduce userDefineSeqComparator for MergeSorter (#2936)
---
 .../apache/paimon/codegen/CodeGeneratorImpl.java   |  18 +--
 .../org/apache/paimon/codegen/CodeGenerator.java   |  16 +--
 .../org/apache/paimon/utils/FieldsComparator.java  |  23 +---
 .../org/apache/paimon/data/BinaryStringTest.java   |  78 ++++++-----
 .../org/apache/paimon/codegen/CodeGenUtils.java    |  21 ++-
 .../paimon/crosspartition/GlobalIndexAssigner.java |   3 +-
 .../org/apache/paimon/lookup/RocksDBState.java     |   2 +-
 .../org/apache/paimon/mergetree/MergeSorter.java   |  33 +++--
 .../apache/paimon/mergetree/MergeTreeReaders.java  |   8 +-
 .../paimon/mergetree/SortBufferWriteBuffer.java    |   7 +-
 .../compact/ChangelogMergeTreeRewriter.java        |   1 +
 .../paimon/mergetree/compact/SortMergeReader.java  |   8 +-
 .../compact/SortMergeReaderWithLoserTree.java      |  19 ++-
 .../compact/SortMergeReaderWithMinHeap.java        |  10 ++
 .../org/apache/paimon/operation/DiffReader.java    |   3 +
 .../paimon/operation/KeyValueFileStoreRead.java    |   2 +
 .../paimon/sort/BinaryExternalSortBuffer.java      |  10 +-
 .../apache/paimon/utils/KeyComparatorSupplier.java |   6 +-
 .../apache/paimon/mergetree/MergeSorterTest.java   | 144 +++++++++++++++------
 .../mergetree/compact/SortMergeReaderTestBase.java |   1 +
 .../apache/paimon/flink/sorter/SortOperator.java   |   4 +-
 21 files changed, 268 insertions(+), 149 deletions(-)

diff --git 
a/paimon-codegen/src/main/java/org/apache/paimon/codegen/CodeGeneratorImpl.java 
b/paimon-codegen/src/main/java/org/apache/paimon/codegen/CodeGeneratorImpl.java
index 8cb446664..c41895422 100644
--- 
a/paimon-codegen/src/main/java/org/apache/paimon/codegen/CodeGeneratorImpl.java
+++ 
b/paimon-codegen/src/main/java/org/apache/paimon/codegen/CodeGeneratorImpl.java
@@ -37,20 +37,20 @@ public class CodeGeneratorImpl implements CodeGenerator {
 
     @Override
     public GeneratedClass<NormalizedKeyComputer> generateNormalizedKeyComputer(
-            List<DataType> fieldTypes, String name) {
+            List<DataType> inputTypes, int[] sortFields, String name) {
         return new SortCodeGenerator(
-                        RowType.builder().fields(fieldTypes).build(),
-                        getAscendingSortSpec(fieldTypes.size()))
+                        RowType.builder().fields(inputTypes).build(),
+                        getAscendingSortSpec(sortFields))
                 .generateNormalizedKeyComputer(name);
     }
 
     @Override
     public GeneratedClass<RecordComparator> generateRecordComparator(
-            List<DataType> fieldTypes, String name) {
+            List<DataType> inputTypes, int[] sortFields, String name) {
         return ComparatorCodeGenerator.gen(
                 name,
-                RowType.builder().fields(fieldTypes).build(),
-                getAscendingSortSpec(fieldTypes.size()));
+                RowType.builder().fields(inputTypes).build(),
+                getAscendingSortSpec(sortFields));
     }
 
     /** Generate a {@link RecordEqualiser}. */
@@ -61,10 +61,10 @@ public class CodeGeneratorImpl implements CodeGenerator {
                 .generateRecordEqualiser(name);
     }
 
-    private SortSpec getAscendingSortSpec(int numFields) {
+    private SortSpec getAscendingSortSpec(int[] sortFields) {
         SortSpec.SortSpecBuilder builder = SortSpec.builder();
-        for (int i = 0; i < numFields; i++) {
-            builder.addField(i, true, false);
+        for (int sortField : sortFields) {
+            builder.addField(sortField, true, false);
         }
         return builder.build();
     }
diff --git 
a/paimon-common/src/main/java/org/apache/paimon/codegen/CodeGenerator.java 
b/paimon-common/src/main/java/org/apache/paimon/codegen/CodeGenerator.java
index 5ee64a0ec..842098dbc 100644
--- a/paimon-common/src/main/java/org/apache/paimon/codegen/CodeGenerator.java
+++ b/paimon-common/src/main/java/org/apache/paimon/codegen/CodeGenerator.java
@@ -32,22 +32,22 @@ public interface CodeGenerator {
     /**
      * Generate a {@link NormalizedKeyComputer}.
      *
-     * @param fieldTypes Both the input row field types and the sort key field 
types. Records are
-     *     compared by the first field, then the second field, then the third 
field and so on. All
-     *     fields are compared in ascending order.
+     * @param inputTypes input types.
+     * @param sortFields the sort key fields. Records are compared by the 
first field, then the
+     *     second field, then the third field and so on. All fields are 
compared in ascending order.
      */
     GeneratedClass<NormalizedKeyComputer> generateNormalizedKeyComputer(
-            List<DataType> fieldTypes, String name);
+            List<DataType> inputTypes, int[] sortFields, String name);
 
     /**
      * Generate a {@link RecordComparator}.
      *
-     * @param fieldTypes Both the input row field types and the sort key field 
types. Records are *
-     *     compared by the first field, then the second field, then the third 
field and so on. All *
-     *     fields are compared in ascending order.
+     * @param inputTypes input types.
+     * @param sortFields the sort key fields. Records are compared by the 
first field, then the
+     *     second field, then the third field and so on. All fields are 
compared in ascending order.
      */
     GeneratedClass<RecordComparator> generateRecordComparator(
-            List<DataType> fieldTypes, String name);
+            List<DataType> inputTypes, int[] sortFields, String name);
 
     /**
      * Generate a {@link RecordEqualiser}.
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/utils/KeyComparatorSupplier.java 
b/paimon-common/src/main/java/org/apache/paimon/utils/FieldsComparator.java
similarity index 51%
copy from 
paimon-core/src/main/java/org/apache/paimon/utils/KeyComparatorSupplier.java
copy to 
paimon-common/src/main/java/org/apache/paimon/utils/FieldsComparator.java
index 88589cc81..85140d26b 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/utils/KeyComparatorSupplier.java
+++ b/paimon-common/src/main/java/org/apache/paimon/utils/FieldsComparator.java
@@ -18,29 +18,12 @@
 
 package org.apache.paimon.utils;
 
-import org.apache.paimon.codegen.CodeGenUtils;
-import org.apache.paimon.codegen.GeneratedClass;
-import org.apache.paimon.codegen.RecordComparator;
 import org.apache.paimon.data.InternalRow;
-import org.apache.paimon.types.RowType;
 
 import java.util.Comparator;
-import java.util.function.Supplier;
 
-/** A {@link Supplier} that returns the comparator for the file store key. */
-public class KeyComparatorSupplier implements 
SerializableSupplier<Comparator<InternalRow>> {
+/** A {@link Comparator} to compare fields for {@link InternalRow}. */
+public interface FieldsComparator extends Comparator<InternalRow> {
 
-    private static final long serialVersionUID = 1L;
-
-    private final GeneratedClass<RecordComparator> genRecordComparator;
-
-    public KeyComparatorSupplier(RowType keyType) {
-        genRecordComparator =
-                CodeGenUtils.generateRecordComparator(keyType.getFieldTypes(), 
"KeyComparator");
-    }
-
-    @Override
-    public RecordComparator get() {
-        return 
genRecordComparator.newInstance(KeyComparatorSupplier.class.getClassLoader());
-    }
+    int[] compareFields();
 }
diff --git 
a/paimon-common/src/test/java/org/apache/paimon/data/BinaryStringTest.java 
b/paimon-common/src/test/java/org/apache/paimon/data/BinaryStringTest.java
index a5052c730..c574ae8be 100644
--- a/paimon-common/src/test/java/org/apache/paimon/data/BinaryStringTest.java
+++ b/paimon-common/src/test/java/org/apache/paimon/data/BinaryStringTest.java
@@ -26,7 +26,6 @@ import org.apache.paimon.utils.SortUtil;
 import org.junit.jupiter.api.TestTemplate;
 import org.junit.jupiter.api.extension.ExtendWith;
 
-import java.io.UnsupportedEncodingException;
 import java.math.BigDecimal;
 import java.nio.charset.StandardCharsets;
 import java.util.Arrays;
@@ -34,6 +33,7 @@ import java.util.List;
 import java.util.Random;
 
 import static java.nio.charset.StandardCharsets.UTF_8;
+import static org.apache.paimon.data.BinaryString.EMPTY_UTF8;
 import static org.apache.paimon.data.BinaryString.blankString;
 import static org.apache.paimon.data.BinaryString.fromBytes;
 import static org.apache.paimon.utils.DecimalUtils.castFrom;
@@ -47,14 +47,13 @@ import static org.assertj.core.api.Assertions.assertThat;
 @ExtendWith(ParameterizedTestExtension.class)
 public class BinaryStringTest {
 
-    private BinaryString empty = fromString("");
-
     private final Mode mode;
 
     public BinaryStringTest(Mode mode) {
         this.mode = mode;
     }
 
+    @SuppressWarnings("unused")
     @Parameters(name = "{0}")
     public static List<Mode> getVarSeg() {
         return Arrays.asList(Mode.ONE_SEG, Mode.MULTI_SEGS, Mode.STRING, 
Mode.RANDOM);
@@ -78,7 +77,7 @@ public class BinaryStringTest {
                 mode = Mode.ONE_SEG;
             } else if (rnd == 1) {
                 mode = Mode.MULTI_SEGS;
-            } else if (rnd == 2) {
+            } else {
                 mode = Mode.STRING;
             }
         }
@@ -143,10 +142,10 @@ public class BinaryStringTest {
 
     @TestTemplate
     public void emptyStringTest() {
-        assertThat(fromString("")).isEqualTo(empty);
-        assertThat(fromBytes(new byte[0])).isEqualTo(empty);
-        assertThat(empty.numChars()).isEqualTo(0);
-        assertThat(empty.getSizeInBytes()).isEqualTo(0);
+        assertThat(fromString("")).isEqualTo(EMPTY_UTF8);
+        assertThat(fromBytes(new byte[0])).isEqualTo(EMPTY_UTF8);
+        assertThat(EMPTY_UTF8.numChars()).isEqualTo(0);
+        assertThat(EMPTY_UTF8.getSizeInBytes()).isEqualTo(0);
     }
 
     @TestTemplate
@@ -224,7 +223,7 @@ public class BinaryStringTest {
 
     @TestTemplate
     public void contains() {
-        assertThat(empty.contains(empty)).isTrue();
+        assertThat(EMPTY_UTF8.contains(EMPTY_UTF8)).isTrue();
         assertThat(fromString("hello").contains(fromString("ello"))).isTrue();
         
assertThat(fromString("hello").contains(fromString("vello"))).isFalse();
         
assertThat(fromString("hello").contains(fromString("hellooo"))).isFalse();
@@ -235,7 +234,7 @@ public class BinaryStringTest {
 
     @TestTemplate
     public void startsWith() {
-        assertThat(empty.startsWith(empty)).isTrue();
+        assertThat(EMPTY_UTF8.startsWith(EMPTY_UTF8)).isTrue();
         
assertThat(fromString("hello").startsWith(fromString("hell"))).isTrue();
         
assertThat(fromString("hello").startsWith(fromString("ell"))).isFalse();
         
assertThat(fromString("hello").startsWith(fromString("hellooo"))).isFalse();
@@ -246,7 +245,7 @@ public class BinaryStringTest {
 
     @TestTemplate
     public void endsWith() {
-        assertThat(empty.endsWith(empty)).isTrue();
+        assertThat(EMPTY_UTF8.endsWith(EMPTY_UTF8)).isTrue();
         assertThat(fromString("hello").endsWith(fromString("ello"))).isTrue();
         
assertThat(fromString("hello").endsWith(fromString("ellov"))).isFalse();
         
assertThat(fromString("hello").endsWith(fromString("hhhello"))).isFalse();
@@ -257,7 +256,7 @@ public class BinaryStringTest {
 
     @TestTemplate
     public void substring() {
-        assertThat(fromString("hello").substring(0, 0)).isEqualTo(empty);
+        assertThat(fromString("hello").substring(0, 0)).isEqualTo(EMPTY_UTF8);
         assertThat(fromString("hello").substring(1, 
3)).isEqualTo(fromString("el"));
         assertThat(fromString("数据砖头").substring(0, 
1)).isEqualTo(fromString("数"));
         assertThat(fromString("数据砖头").substring(1, 
3)).isEqualTo(fromString("据砖"));
@@ -267,9 +266,9 @@ public class BinaryStringTest {
 
     @TestTemplate
     public void indexOf() {
-        assertThat(empty.indexOf(empty, 0)).isEqualTo(0);
-        assertThat(empty.indexOf(fromString("l"), 0)).isEqualTo(-1);
-        assertThat(fromString("hello").indexOf(empty, 0)).isEqualTo(0);
+        assertThat(EMPTY_UTF8.indexOf(EMPTY_UTF8, 0)).isEqualTo(0);
+        assertThat(EMPTY_UTF8.indexOf(fromString("l"), 0)).isEqualTo(-1);
+        assertThat(fromString("hello").indexOf(EMPTY_UTF8, 0)).isEqualTo(0);
         assertThat(fromString("hello").indexOf(fromString("l"), 
0)).isEqualTo(2);
         assertThat(fromString("hello").indexOf(fromString("l"), 
3)).isEqualTo(3);
         assertThat(fromString("hello").indexOf(fromString("a"), 
0)).isEqualTo(-1);
@@ -300,24 +299,21 @@ public class BinaryStringTest {
         writer.writeString(5, BinaryString.fromString("!@#$%^*"));
         writer.complete();
 
-        assertThat(((BinaryString) 
row.getString(0)).toUpperCase()).isEqualTo(fromString("A"));
-        assertThat(((BinaryString) 
row.getString(1)).toUpperCase()).isEqualTo(fromString("我是中国人"));
-        assertThat(((BinaryString) 
row.getString(1)).toLowerCase()).isEqualTo(fromString("我是中国人"));
-        assertThat(((BinaryString) row.getString(3)).toUpperCase())
-                .isEqualTo(fromString("ABCDEFG"));
-        assertThat(((BinaryString) row.getString(3)).toLowerCase())
-                .isEqualTo(fromString("abcdefg"));
-        assertThat(((BinaryString) row.getString(5)).toUpperCase())
-                .isEqualTo(fromString("!@#$%^*"));
-        assertThat(((BinaryString) row.getString(5)).toLowerCase())
-                .isEqualTo(fromString("!@#$%^*"));
+        assertThat(row.getString(0).toUpperCase()).isEqualTo(fromString("A"));
+        
assertThat(row.getString(1).toUpperCase()).isEqualTo(fromString("我是中国人"));
+        
assertThat(row.getString(1).toLowerCase()).isEqualTo(fromString("我是中国人"));
+        
assertThat(row.getString(3).toUpperCase()).isEqualTo(fromString("ABCDEFG"));
+        
assertThat(row.getString(3).toLowerCase()).isEqualTo(fromString("abcdefg"));
+        
assertThat(row.getString(5).toUpperCase()).isEqualTo(fromString("!@#$%^*"));
+        
assertThat(row.getString(5).toLowerCase()).isEqualTo(fromString("!@#$%^*"));
     }
 
     @TestTemplate
-    public void testcastFrom() {
+    public void testCastFrom() {
         class DecimalTestData {
-            private String str;
-            private int precision, scale;
+            private final String str;
+            private final int precision;
+            private final int scale;
 
             private DecimalTestData(String str, int precision, int scale) {
                 this.str = str;
@@ -391,7 +387,7 @@ public class BinaryStringTest {
         writer.complete();
         for (int i = 0; i < data.length; i++) {
             DecimalTestData d = data[i];
-            assertThat(castFrom((BinaryString) row.getString(i), d.precision, 
d.scale))
+            assertThat(castFrom(row.getString(i), d.precision, d.scale))
                     .isEqualTo(Decimal.fromBigDecimal(new BigDecimal(d.str), 
d.precision, d.scale));
         }
     }
@@ -407,21 +403,21 @@ public class BinaryStringTest {
             str3 = BinaryString.fromAddress(segments, 15, 0);
         }
 
-        assertThat(BinaryString.EMPTY_UTF8.compareTo(str2)).isLessThan(0);
-        assertThat(str2.compareTo(BinaryString.EMPTY_UTF8)).isGreaterThan(0);
+        assertThat(EMPTY_UTF8.compareTo(str2)).isLessThan(0);
+        assertThat(str2.compareTo(EMPTY_UTF8)).isGreaterThan(0);
 
-        assertThat(BinaryString.EMPTY_UTF8.compareTo(str3)).isEqualTo(0);
-        assertThat(str3.compareTo(BinaryString.EMPTY_UTF8)).isEqualTo(0);
+        assertThat(EMPTY_UTF8.compareTo(str3)).isEqualTo(0);
+        assertThat(str3.compareTo(EMPTY_UTF8)).isEqualTo(0);
 
-        assertThat(str2).isNotEqualTo(BinaryString.EMPTY_UTF8);
-        assertThat(BinaryString.EMPTY_UTF8).isNotEqualTo(str2);
+        assertThat(str2).isNotEqualTo(EMPTY_UTF8);
+        assertThat(EMPTY_UTF8).isNotEqualTo(str2);
 
-        assertThat(str3).isEqualTo(BinaryString.EMPTY_UTF8);
-        assertThat(BinaryString.EMPTY_UTF8).isEqualTo(str3);
+        assertThat(str3).isEqualTo(EMPTY_UTF8);
+        assertThat(EMPTY_UTF8).isEqualTo(str3);
     }
 
     @TestTemplate
-    public void testEncodeWithIllegalCharacter() throws 
UnsupportedEncodingException {
+    public void testEncodeWithIllegalCharacter() {
 
         // Tis char array has some illegal character, such as 55357
         // the jdk ignores theses character and cast them to '?'
@@ -434,11 +430,11 @@ public class BinaryStringTest {
 
         String str = new String(chars);
 
-        
assertThat(BinaryString.encodeUTF8(str)).isEqualTo(str.getBytes("UTF-8"));
+        
assertThat(BinaryString.encodeUTF8(str)).isEqualTo(str.getBytes(UTF_8));
     }
 
     @TestTemplate
-    public void testDecodeWithIllegalUtf8Bytes() throws 
UnsupportedEncodingException {
+    public void testDecodeWithIllegalUtf8Bytes() {
 
         // illegal utf-8 bytes
         byte[] bytes =
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/codegen/CodeGenUtils.java 
b/paimon-core/src/main/java/org/apache/paimon/codegen/CodeGenUtils.java
index 868734b0f..57abd915a 100644
--- a/paimon-core/src/main/java/org/apache/paimon/codegen/CodeGenUtils.java
+++ b/paimon-core/src/main/java/org/apache/paimon/codegen/CodeGenUtils.java
@@ -23,6 +23,7 @@ import org.apache.paimon.types.DataType;
 import org.apache.paimon.types.RowType;
 
 import java.util.List;
+import java.util.stream.IntStream;
 
 /** Utils for code generations. */
 public class CodeGenUtils {
@@ -46,15 +47,16 @@ public class CodeGenUtils {
     }
 
     public static NormalizedKeyComputer newNormalizedKeyComputer(
-            List<DataType> fieldTypes, String name) {
+            List<DataType> inputTypes, int[] sortFields, String name) {
         return CodeGenLoader.getCodeGenerator()
-                .generateNormalizedKeyComputer(fieldTypes, name)
+                .generateNormalizedKeyComputer(inputTypes, sortFields, name)
                 .newInstance(CodeGenUtils.class.getClassLoader());
     }
 
     public static GeneratedClass<RecordComparator> generateRecordComparator(
-            List<DataType> fieldTypes, String name) {
-        return 
CodeGenLoader.getCodeGenerator().generateRecordComparator(fieldTypes, name);
+            List<DataType> inputTypes, int[] sortFields, String name) {
+        return CodeGenLoader.getCodeGenerator()
+                .generateRecordComparator(inputTypes, sortFields, name);
     }
 
     public static GeneratedClass<RecordEqualiser> generateRecordEqualiser(
@@ -62,8 +64,15 @@ public class CodeGenUtils {
         return 
CodeGenLoader.getCodeGenerator().generateRecordEqualiser(fieldTypes, name);
     }
 
-    public static RecordComparator newRecordComparator(List<DataType> 
fieldTypes, String name) {
-        return generateRecordComparator(fieldTypes, name)
+    public static RecordComparator newRecordComparator(
+            List<DataType> inputTypes, int[] sortFields, String name) {
+        return generateRecordComparator(inputTypes, sortFields, name)
+                .newInstance(CodeGenUtils.class.getClassLoader());
+    }
+
+    public static RecordComparator newRecordComparator(List<DataType> 
inputTypes, String name) {
+        return generateRecordComparator(
+                        inputTypes, IntStream.range(0, 
inputTypes.size()).toArray(), name)
                 .newInstance(CodeGenUtils.class.getClassLoader());
     }
 }
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/crosspartition/GlobalIndexAssigner.java
 
b/paimon-core/src/main/java/org/apache/paimon/crosspartition/GlobalIndexAssigner.java
index d4383d4c2..787fc9e0f 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/crosspartition/GlobalIndexAssigner.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/crosspartition/GlobalIndexAssigner.java
@@ -67,6 +67,7 @@ import java.util.UUID;
 import java.util.concurrent.ThreadLocalRandom;
 import java.util.function.BiConsumer;
 import java.util.function.Function;
+import java.util.stream.IntStream;
 
 import static org.apache.paimon.lookup.RocksDBOptions.BLOCK_CACHE_SIZE;
 import static org.apache.paimon.utils.Preconditions.checkArgument;
@@ -285,8 +286,8 @@ public class GlobalIndexAssigner implements Serializable, 
Closeable {
         BinaryExternalSortBuffer keyIdBuffer =
                 BinaryExternalSortBuffer.create(
                         ioManager,
-                        keyWithIdType,
                         keyWithRowType,
+                        IntStream.range(0, 
keyWithIdType.getFieldCount()).toArray(),
                         coreOptions.writeBufferSize() / 2,
                         coreOptions.pageSize(),
                         coreOptions.localSortMaxNumFileHandles(),
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/lookup/RocksDBState.java 
b/paimon-core/src/main/java/org/apache/paimon/lookup/RocksDBState.java
index 5ffcaf146..2e10acb0e 100644
--- a/paimon-core/src/main/java/org/apache/paimon/lookup/RocksDBState.java
+++ b/paimon-core/src/main/java/org/apache/paimon/lookup/RocksDBState.java
@@ -107,8 +107,8 @@ public abstract class RocksDBState<K, V, CacheV> {
             IOManager ioManager, CoreOptions options) {
         return BinaryExternalSortBuffer.create(
                 ioManager,
-                RowType.of(DataTypes.BYTES()),
                 RowType.of(DataTypes.BYTES(), DataTypes.BYTES()),
+                new int[] {0},
                 options.writeBufferSize() / 2,
                 options.pageSize(),
                 options.localSortMaxNumFileHandles(),
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeSorter.java 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeSorter.java
index 26a6d7fc6..5e30d16aa 100644
--- a/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeSorter.java
+++ b/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeSorter.java
@@ -37,12 +37,11 @@ import org.apache.paimon.sort.BinaryExternalSortBuffer;
 import org.apache.paimon.sort.SortBuffer;
 import org.apache.paimon.types.BigIntType;
 import org.apache.paimon.types.DataField;
-import org.apache.paimon.types.DataType;
-import org.apache.paimon.types.DataTypes;
 import org.apache.paimon.types.IntType;
 import org.apache.paimon.types.RowKind;
 import org.apache.paimon.types.RowType;
 import org.apache.paimon.types.TinyIntType;
+import org.apache.paimon.utils.FieldsComparator;
 import org.apache.paimon.utils.IOUtils;
 import org.apache.paimon.utils.MutableObjectIterator;
 import org.apache.paimon.utils.OffsetRow;
@@ -53,6 +52,7 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Comparator;
 import java.util.List;
+import java.util.stream.IntStream;
 
 import static org.apache.paimon.schema.SystemColumns.SEQUENCE_NUMBER;
 import static org.apache.paimon.schema.SystemColumns.VALUE_KIND;
@@ -103,10 +103,12 @@ public class MergeSorter {
     public <T> RecordReader<T> mergeSort(
             List<ReaderSupplier<KeyValue>> lazyReaders,
             Comparator<InternalRow> keyComparator,
+            @Nullable FieldsComparator userDefinedSeqComparator,
             MergeFunctionWrapper<T> mergeFunction)
             throws IOException {
         if (ioManager != null && lazyReaders.size() > spillThreshold) {
-            return spillMergeSort(lazyReaders, keyComparator, mergeFunction);
+            return spillMergeSort(
+                    lazyReaders, keyComparator, userDefinedSeqComparator, 
mergeFunction);
         }
 
         List<RecordReader<KeyValue>> readers = new 
ArrayList<>(lazyReaders.size());
@@ -121,15 +123,16 @@ public class MergeSorter {
         }
 
         return SortMergeReader.createSortMergeReader(
-                readers, keyComparator, mergeFunction, sortEngine);
+                readers, keyComparator, userDefinedSeqComparator, 
mergeFunction, sortEngine);
     }
 
     private <T> RecordReader<T> spillMergeSort(
             List<ReaderSupplier<KeyValue>> readers,
             Comparator<InternalRow> keyComparator,
+            @Nullable FieldsComparator userDefinedSeqComparator,
             MergeFunctionWrapper<T> mergeFunction)
             throws IOException {
-        ExternalSorterWithLevel sorter = new ExternalSorterWithLevel();
+        ExternalSorterWithLevel sorter = new 
ExternalSorterWithLevel(userDefinedSeqComparator);
         ConcatRecordReader.create(readers).forIOEachRemaining(sorter::put);
         sorter.flushMemory();
 
@@ -176,15 +179,25 @@ public class MergeSorter {
 
         private final SortBuffer buffer;
 
-        public ExternalSorterWithLevel() {
+        public ExternalSorterWithLevel(@Nullable FieldsComparator 
userDefinedSeqComparator) {
             if (memoryPool.freePages() < 3) {
                 throw new IllegalArgumentException(
                         "Write buffer requires a minimum of 3 page memory, 
please increase write buffer memory size.");
             }
 
-            // user key + sequenceNumber
-            List<DataType> sortKeyTypes = new 
ArrayList<>(keyType.getFieldTypes());
-            sortKeyTypes.add(new BigIntType(false));
+            // key fields
+            IntStream sortFields = IntStream.range(0, keyType.getFieldCount());
+
+            // user define sequence fields
+            if (userDefinedSeqComparator != null) {
+                IntStream udsFields =
+                        IntStream.of(userDefinedSeqComparator.compareFields())
+                                .map(operand -> operand + 
keyType.getFieldCount() + 3);
+                sortFields = IntStream.concat(sortFields, udsFields);
+            }
+
+            // sequence field
+            sortFields = IntStream.concat(sortFields, 
IntStream.of(keyType.getFieldCount()));
 
             // row type
             List<DataField> fields = new ArrayList<>(keyType.getFields());
@@ -196,8 +209,8 @@ public class MergeSorter {
             this.buffer =
                     BinaryExternalSortBuffer.create(
                             ioManager,
-                            DataTypes.ROW(sortKeyTypes.toArray(new 
DataType[0])),
                             new RowType(fields),
+                            sortFields.toArray(),
                             memoryPool,
                             spillSortMaxNumFiles,
                             compression);
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeTreeReaders.java 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeTreeReaders.java
index 3fe28c073..186371b81 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeTreeReaders.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/MergeTreeReaders.java
@@ -28,6 +28,9 @@ import org.apache.paimon.mergetree.compact.MergeFunction;
 import org.apache.paimon.mergetree.compact.MergeFunctionWrapper;
 import org.apache.paimon.mergetree.compact.ReducerMergeFunctionWrapper;
 import org.apache.paimon.reader.RecordReader;
+import org.apache.paimon.utils.FieldsComparator;
+
+import javax.annotation.Nullable;
 
 import java.io.IOException;
 import java.util.ArrayList;
@@ -55,6 +58,7 @@ public class MergeTreeReaders {
                                     section,
                                     readerFactory,
                                     userKeyComparator,
+                                    null,
                                     new 
ReducerMergeFunctionWrapper(mergeFunction),
                                     mergeSorter));
         }
@@ -69,6 +73,7 @@ public class MergeTreeReaders {
             List<SortedRun> section,
             KeyValueFileReaderFactory readerFactory,
             Comparator<InternalRow> userKeyComparator,
+            @Nullable FieldsComparator userDefinedSeqComparator,
             MergeFunctionWrapper<T> mergeFunctionWrapper,
             MergeSorter mergeSorter)
             throws IOException {
@@ -76,7 +81,8 @@ public class MergeTreeReaders {
         for (SortedRun run : section) {
             readers.add(() -> readerForRun(run, readerFactory));
         }
-        return mergeSorter.mergeSort(readers, userKeyComparator, 
mergeFunctionWrapper);
+        return mergeSorter.mergeSort(
+                readers, userKeyComparator, userDefinedSeqComparator, 
mergeFunctionWrapper);
     }
 
     public static RecordReader<KeyValue> readerForRun(
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/SortBufferWriteBuffer.java
 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/SortBufferWriteBuffer.java
index 744c4f31a..ae877f5cb 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/SortBufferWriteBuffer.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/SortBufferWriteBuffer.java
@@ -48,6 +48,7 @@ import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Comparator;
 import java.util.List;
+import java.util.stream.IntStream;
 
 /** A {@link WriteBuffer} which stores records in {@link 
BinaryInMemorySortBuffer}. */
 public class SortBufferWriteBuffer implements WriteBuffer {
@@ -74,10 +75,12 @@ public class SortBufferWriteBuffer implements WriteBuffer {
         sortKeyTypes.add(new BigIntType(false));
 
         // for sort binary buffer
+        int[] sortFields = IntStream.range(0, sortKeyTypes.size()).toArray();
         NormalizedKeyComputer normalizedKeyComputer =
-                CodeGenUtils.newNormalizedKeyComputer(sortKeyTypes, 
"MemTableKeyComputer");
+                CodeGenUtils.newNormalizedKeyComputer(
+                        sortKeyTypes, sortFields, "MemTableKeyComputer");
         RecordComparator keyComparator =
-                CodeGenUtils.newRecordComparator(sortKeyTypes, 
"MemTableComparator");
+                CodeGenUtils.newRecordComparator(sortKeyTypes, sortFields, 
"MemTableComparator");
 
         if (memoryPool.freePages() < 3) {
             throw new IllegalArgumentException(
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/ChangelogMergeTreeRewriter.java
 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/ChangelogMergeTreeRewriter.java
index f056fafaa..99f5f052d 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/ChangelogMergeTreeRewriter.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/ChangelogMergeTreeRewriter.java
@@ -112,6 +112,7 @@ public abstract class ChangelogMergeTreeRewriter extends 
MergeTreeCompactRewrite
                                     section,
                                     readerFactory,
                                     keyComparator,
+                                    null,
                                     createMergeWrapper(outputLevel),
                                     mergeSorter));
         }
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReader.java
 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReader.java
index d545efd35..598ca2aa6 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReader.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReader.java
@@ -22,6 +22,9 @@ import org.apache.paimon.CoreOptions.SortEngine;
 import org.apache.paimon.KeyValue;
 import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.reader.RecordReader;
+import org.apache.paimon.utils.FieldsComparator;
+
+import javax.annotation.Nullable;
 
 import java.util.Comparator;
 import java.util.List;
@@ -38,15 +41,16 @@ public interface SortMergeReader<T> extends RecordReader<T> 
{
     static <T> SortMergeReader<T> createSortMergeReader(
             List<RecordReader<KeyValue>> readers,
             Comparator<InternalRow> userKeyComparator,
+            @Nullable FieldsComparator userDefinedSeqComparator,
             MergeFunctionWrapper<T> mergeFunctionWrapper,
             SortEngine sortEngine) {
         switch (sortEngine) {
             case MIN_HEAP:
                 return new SortMergeReaderWithMinHeap<>(
-                        readers, userKeyComparator, mergeFunctionWrapper);
+                        readers, userKeyComparator, userDefinedSeqComparator, 
mergeFunctionWrapper);
             case LOSER_TREE:
                 return new SortMergeReaderWithLoserTree<>(
-                        readers, userKeyComparator, mergeFunctionWrapper);
+                        readers, userKeyComparator, userDefinedSeqComparator, 
mergeFunctionWrapper);
             default:
                 throw new UnsupportedOperationException("Unsupported sort 
engine: " + sortEngine);
         }
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithLoserTree.java
 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithLoserTree.java
index f7a20ee67..3ca3d288e 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithLoserTree.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithLoserTree.java
@@ -21,6 +21,7 @@ package org.apache.paimon.mergetree.compact;
 import org.apache.paimon.KeyValue;
 import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.reader.RecordReader;
+import org.apache.paimon.utils.FieldsComparator;
 import org.apache.paimon.utils.Preconditions;
 
 import javax.annotation.Nullable;
@@ -38,13 +39,29 @@ public class SortMergeReaderWithLoserTree<T> implements 
SortMergeReader<T> {
     public SortMergeReaderWithLoserTree(
             List<RecordReader<KeyValue>> readers,
             Comparator<InternalRow> userKeyComparator,
+            @Nullable FieldsComparator userDefinedSeqComparator,
             MergeFunctionWrapper<T> mergeFunctionWrapper) {
         this.mergeFunctionWrapper = mergeFunctionWrapper;
         this.loserTree =
                 new LoserTree<>(
                         readers,
                         (e1, e2) -> userKeyComparator.compare(e2.key(), 
e1.key()),
-                        (e1, e2) -> Long.compare(e2.sequenceNumber(), 
e1.sequenceNumber()));
+                        createSequenceComparator(userDefinedSeqComparator));
+    }
+
+    private Comparator<KeyValue> createSequenceComparator(
+            @Nullable FieldsComparator userDefinedSeqComparator) {
+        if (userDefinedSeqComparator == null) {
+            return (e1, e2) -> Long.compare(e2.sequenceNumber(), 
e1.sequenceNumber());
+        }
+
+        return (o1, o2) -> {
+            int result = userDefinedSeqComparator.compare(o2.value(), 
o1.value());
+            if (result != 0) {
+                return result;
+            }
+            return Long.compare(o2.sequenceNumber(), o1.sequenceNumber());
+        };
     }
 
     /** Compared with heapsort, {@link LoserTree} will only produce one batch. 
*/
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithMinHeap.java
 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithMinHeap.java
index adca53fdb..a78ef334f 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithMinHeap.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/mergetree/compact/SortMergeReaderWithMinHeap.java
@@ -21,6 +21,7 @@ package org.apache.paimon.mergetree.compact;
 import org.apache.paimon.KeyValue;
 import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.reader.RecordReader;
+import org.apache.paimon.utils.FieldsComparator;
 import org.apache.paimon.utils.Preconditions;
 
 import javax.annotation.Nullable;
@@ -44,6 +45,7 @@ public class SortMergeReaderWithMinHeap<T> implements 
SortMergeReader<T> {
     public SortMergeReaderWithMinHeap(
             List<RecordReader<KeyValue>> readers,
             Comparator<InternalRow> userKeyComparator,
+            @Nullable FieldsComparator userDefinedSeqComparator,
             MergeFunctionWrapper<T> mergeFunctionWrapper) {
         this.nextBatchReaders = new ArrayList<>(readers);
         this.userKeyComparator = userKeyComparator;
@@ -56,6 +58,14 @@ public class SortMergeReaderWithMinHeap<T> implements 
SortMergeReader<T> {
                             if (result != 0) {
                                 return result;
                             }
+                            if (userDefinedSeqComparator != null) {
+                                result =
+                                        userDefinedSeqComparator.compare(
+                                                e1.kv.value(), e2.kv.value());
+                                if (result != 0) {
+                                    return result;
+                                }
+                            }
                             return Long.compare(e1.kv.sequenceNumber(), 
e2.kv.sequenceNumber());
                         });
         this.polled = new ArrayList<>();
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/operation/DiffReader.java 
b/paimon-core/src/main/java/org/apache/paimon/operation/DiffReader.java
index 324bff935..bc5153600 100644
--- a/paimon-core/src/main/java/org/apache/paimon/operation/DiffReader.java
+++ b/paimon-core/src/main/java/org/apache/paimon/operation/DiffReader.java
@@ -24,6 +24,7 @@ import org.apache.paimon.mergetree.MergeSorter;
 import org.apache.paimon.mergetree.compact.MergeFunctionWrapper;
 import org.apache.paimon.reader.RecordReader;
 import org.apache.paimon.types.RowKind;
+import org.apache.paimon.utils.FieldsComparator;
 
 import javax.annotation.Nullable;
 
@@ -43,6 +44,7 @@ public class DiffReader {
             RecordReader<KeyValue> beforeReader,
             RecordReader<KeyValue> afterReader,
             Comparator<InternalRow> keyComparator,
+            @Nullable FieldsComparator userDefinedSeqComparator,
             MergeSorter sorter,
             boolean keepDelete)
             throws IOException {
@@ -51,6 +53,7 @@ public class DiffReader {
                         () -> wrapLevelToReader(beforeReader, BEFORE_LEVEL),
                         () -> wrapLevelToReader(afterReader, AFTER_LEVEL)),
                 keyComparator,
+                userDefinedSeqComparator,
                 new DiffMerger(keepDelete));
     }
 
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/operation/KeyValueFileStoreRead.java
 
b/paimon-core/src/main/java/org/apache/paimon/operation/KeyValueFileStoreRead.java
index 09b9cb81b..e8accda5b 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/operation/KeyValueFileStoreRead.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/operation/KeyValueFileStoreRead.java
@@ -227,6 +227,7 @@ public class KeyValueFileStoreRead implements 
FileStoreRead<KeyValue> {
                             batchMergeRead(
                                     split.partition(), split.bucket(), 
split.dataFiles(), false),
                             keyComparator,
+                            null,
                             mergeSorter,
                             forceKeepDelete);
         }
@@ -255,6 +256,7 @@ public class KeyValueFileStoreRead implements 
FileStoreRead<KeyValue> {
                                             ? overlappedSectionFactory
                                             : nonOverlappedSectionFactory,
                                     keyComparator,
+                                    null,
                                     mergeFuncWrapper,
                                     mergeSorter));
         }
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/sort/BinaryExternalSortBuffer.java
 
b/paimon-core/src/main/java/org/apache/paimon/sort/BinaryExternalSortBuffer.java
index 7d9522f61..4de25a4ea 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/sort/BinaryExternalSortBuffer.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/sort/BinaryExternalSortBuffer.java
@@ -91,16 +91,16 @@ public class BinaryExternalSortBuffer implements SortBuffer 
{
 
     public static BinaryExternalSortBuffer create(
             IOManager ioManager,
-            RowType keyType,
             RowType rowType,
+            int[] keyFields,
             long bufferSize,
             int pageSize,
             int maxNumFileHandles,
             String compression) {
         return create(
                 ioManager,
-                keyType,
                 rowType,
+                keyFields,
                 new HeapMemorySegmentPool(bufferSize, pageSize),
                 maxNumFileHandles,
                 compression);
@@ -108,17 +108,17 @@ public class BinaryExternalSortBuffer implements 
SortBuffer {
 
     public static BinaryExternalSortBuffer create(
             IOManager ioManager,
-            RowType keyType,
             RowType rowType,
+            int[] keyFields,
             MemorySegmentPool pool,
             int maxNumFileHandles,
             String compression) {
         RecordComparator comparator =
-                newRecordComparator(keyType.getFieldTypes(), 
"ExternalSort_comparator");
+                newRecordComparator(rowType.getFieldTypes(), keyFields, 
"ExternalSort_comparator");
         BinaryInMemorySortBuffer sortBuffer =
                 BinaryInMemorySortBuffer.createBuffer(
                         newNormalizedKeyComputer(
-                                keyType.getFieldTypes(), 
"ExternalSort_normalized_key"),
+                                rowType.getFieldTypes(), keyFields, 
"ExternalSort_normalized_key"),
                         new InternalRowSerializer(rowType),
                         comparator,
                         pool);
diff --git 
a/paimon-core/src/main/java/org/apache/paimon/utils/KeyComparatorSupplier.java 
b/paimon-core/src/main/java/org/apache/paimon/utils/KeyComparatorSupplier.java
index 88589cc81..30d3ab328 100644
--- 
a/paimon-core/src/main/java/org/apache/paimon/utils/KeyComparatorSupplier.java
+++ 
b/paimon-core/src/main/java/org/apache/paimon/utils/KeyComparatorSupplier.java
@@ -26,6 +26,7 @@ import org.apache.paimon.types.RowType;
 
 import java.util.Comparator;
 import java.util.function.Supplier;
+import java.util.stream.IntStream;
 
 /** A {@link Supplier} that returns the comparator for the file store key. */
 public class KeyComparatorSupplier implements 
SerializableSupplier<Comparator<InternalRow>> {
@@ -36,7 +37,10 @@ public class KeyComparatorSupplier implements 
SerializableSupplier<Comparator<In
 
     public KeyComparatorSupplier(RowType keyType) {
         genRecordComparator =
-                CodeGenUtils.generateRecordComparator(keyType.getFieldTypes(), 
"KeyComparator");
+                CodeGenUtils.generateRecordComparator(
+                        keyType.getFieldTypes(),
+                        IntStream.range(0, keyType.getFieldCount()).toArray(),
+                        "KeyComparator");
     }
 
     @Override
diff --git 
a/paimon-core/src/test/java/org/apache/paimon/mergetree/MergeSorterTest.java 
b/paimon-core/src/test/java/org/apache/paimon/mergetree/MergeSorterTest.java
index b13eea843..d61caa3e2 100644
--- a/paimon-core/src/test/java/org/apache/paimon/mergetree/MergeSorterTest.java
+++ b/paimon-core/src/test/java/org/apache/paimon/mergetree/MergeSorterTest.java
@@ -19,28 +19,35 @@
 package org.apache.paimon.mergetree;
 
 import org.apache.paimon.CoreOptions;
+import org.apache.paimon.CoreOptions.SortEngine;
 import org.apache.paimon.KeyValue;
 import org.apache.paimon.data.GenericRow;
+import org.apache.paimon.data.InternalRow;
 import org.apache.paimon.disk.IOManager;
 import org.apache.paimon.mergetree.compact.ConcatRecordReader.ReaderSupplier;
 import org.apache.paimon.mergetree.compact.MergeFunctionWrapper;
 import org.apache.paimon.options.MemorySize;
 import org.apache.paimon.options.Options;
+import 
org.apache.paimon.testutils.junit.parameterized.ParameterizedTestExtension;
+import org.apache.paimon.testutils.junit.parameterized.Parameters;
 import org.apache.paimon.types.DataTypes;
 import org.apache.paimon.types.RowKind;
 import org.apache.paimon.types.RowType;
+import org.apache.paimon.utils.FieldsComparator;
 import org.apache.paimon.utils.IteratorRecordReader;
 
 import org.jetbrains.annotations.Nullable;
 import org.junit.jupiter.api.AfterEach;
 import org.junit.jupiter.api.BeforeEach;
-import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestTemplate;
+import org.junit.jupiter.api.extension.ExtendWith;
 import org.junit.jupiter.api.io.TempDir;
 
 import java.io.File;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Comparator;
 import java.util.HashSet;
 import java.util.List;
@@ -51,6 +58,7 @@ import java.util.stream.Collectors;
 import static org.assertj.core.api.Assertions.assertThat;
 
 /** Test for {@link MergeSorter}. */
+@ExtendWith(ParameterizedTestExtension.class)
 public class MergeSorterTest {
 
     private static final int MEMORY_SIZE = 1024 * 1024 * 32;
@@ -59,17 +67,30 @@ public class MergeSorterTest {
 
     private final RowType valueType = RowType.builder().field("v", 
DataTypes.INT()).build();
 
+    private final SortEngine sortEngine;
+
     @TempDir Path tempDir;
 
     private IOManager ioManager;
     private MergeSorter sorter;
     private int totalPages;
 
+    public MergeSorterTest(SortEngine sortEngine) {
+        this.sortEngine = sortEngine;
+    }
+
+    @SuppressWarnings("unused")
+    @Parameters(name = "{0}")
+    public static List<SortEngine> getVarSeg() {
+        return Arrays.asList(SortEngine.LOSER_TREE, SortEngine.MIN_HEAP);
+    }
+
     @BeforeEach
     public void beforeTest() {
         ioManager = IOManager.create(tempDir.toString());
         Options options = new Options();
         options.set(CoreOptions.SORT_SPILL_BUFFER_SIZE, new 
MemorySize(MEMORY_SIZE));
+        options.set(CoreOptions.SORT_ENGINE, sortEngine);
         sorter = new MergeSorter(new CoreOptions(options), keyType, valueType, 
ioManager);
         totalPages = sorter.memoryPool().freePages();
     }
@@ -86,60 +107,82 @@ public class MergeSorterTest {
         this.ioManager.close();
     }
 
-    @Test
+    @TestTemplate
     public void testSortAndMerge() throws Exception {
+        innerTest(null);
+    }
+
+    @TestTemplate
+    public void testWithUserDefineSequence() throws Exception {
+        innerTest(
+                new FieldsComparator() {
+                    @Override
+                    public int[] compareFields() {
+                        return new int[] {0};
+                    }
+
+                    @Override
+                    public int compare(InternalRow o1, InternalRow o2) {
+                        return Integer.compare(o1.getInt(0), o2.getInt(0));
+                    }
+                });
+    }
+
+    private void innerTest(FieldsComparator userDefinedSeqComparator) throws 
Exception {
+        Comparator<KeyValue> comparator =
+                Comparator.comparingInt((KeyValue o) -> o.key().getInt(0));
+        if (userDefinedSeqComparator != null) {
+            comparator =
+                    comparator.thenComparing(
+                            (o1, o2) -> 
userDefinedSeqComparator.compare(o1.value(), o2.value()));
+        }
+        comparator = comparator.thenComparingLong(KeyValue::sequenceNumber);
+
         List<ReaderSupplier<KeyValue>> readers = new ArrayList<>();
         Random rnd = new Random();
         List<KeyValue> expectedKvs = new ArrayList<>();
         Set<Long> distinctSeq = new HashSet<>();
-        for (int i = 0; i < 10; i++) {
+        for (int i = 0; i < rnd.nextInt(10) + 3; i++) {
             List<KeyValue> kvs = new ArrayList<>();
+            Set<Integer> distinctKeys = new HashSet<>();
             for (int j = 0; j < 100; j++) {
-                long seq = rnd.nextLong();
-                while (distinctSeq.contains(seq)) {
-                    rnd.nextLong();
+                while (true) {
+                    int key = rnd.nextInt(1000);
+                    if (distinctKeys.contains(key)) {
+                        continue;
+                    }
+
+                    long seq = rnd.nextLong();
+                    while (distinctSeq.contains(seq)) {
+                        seq = rnd.nextLong();
+                    }
+                    distinctSeq.add(seq);
+                    kvs.add(
+                            new KeyValue()
+                                    .replace(
+                                            GenericRow.of(key),
+                                            seq,
+                                            RowKind.fromByteValue((byte) 
rnd.nextInt(4)),
+                                            GenericRow.of(rnd.nextInt(1000)))
+                                    .setLevel(rnd.nextInt(100)));
+                    distinctKeys.add(key);
+                    break;
                 }
-                distinctSeq.add(seq);
-                kvs.add(
-                        new KeyValue()
-                                .replace(
-                                        GenericRow.of(rnd.nextInt(100)),
-                                        seq,
-                                        RowKind.fromByteValue((byte) 
rnd.nextInt(4)),
-                                        GenericRow.of(rnd.nextInt()))
-                                .setLevel(rnd.nextInt(100)));
             }
             expectedKvs.addAll(kvs);
+            kvs.sort(comparator);
             readers.add(() -> new IteratorRecordReader<>(kvs.iterator()));
         }
 
-        expectedKvs.sort(
-                Comparator.comparingInt((KeyValue o) -> o.key().getInt(0))
-                        .thenComparingLong(KeyValue::sequenceNumber));
+        expectedKvs.sort(comparator);
 
-        MergeFunctionWrapper<List<KeyValue>> collectFunc =
-                new MergeFunctionWrapper<List<KeyValue>>() {
-
-                    private List<KeyValue> result;
-
-                    @Override
-                    public void reset() {
-                        result = new ArrayList<>();
-                    }
-
-                    @Override
-                    public void add(KeyValue kv) {
-                        result.add(kv);
-                    }
-
-                    @Nullable
-                    @Override
-                    public List<KeyValue> getResult() {
-                        return result;
-                    }
-                };
+        TestMergeFunctionWrapper collectFunc = new TestMergeFunctionWrapper();
         List<KeyValue> all = new ArrayList<>();
-        sorter.mergeSort(readers, Comparator.comparingInt(o -> o.getInt(0)), 
collectFunc)
+        sorter.mergeSort(
+                        readers,
+                        Comparator.comparingInt(o -> o.getInt(0)),
+                        userDefinedSeqComparator,
+                        collectFunc)
                 .forEachRemaining(all::addAll);
 
         
assertThat(toString(all)).containsExactlyElementsOf(toString(expectedKvs));
@@ -148,4 +191,25 @@ public class MergeSorterTest {
     private List<String> toString(List<KeyValue> kvs) {
         return kvs.stream().map(kv -> kv.toString(keyType, 
valueType)).collect(Collectors.toList());
     }
+
+    private static class TestMergeFunctionWrapper implements 
MergeFunctionWrapper<List<KeyValue>> {
+
+        private List<KeyValue> result;
+
+        @Override
+        public void reset() {
+            result = new ArrayList<>();
+        }
+
+        @Override
+        public void add(KeyValue kv) {
+            result.add(kv);
+        }
+
+        @Nullable
+        @Override
+        public List<KeyValue> getResult() {
+            return result;
+        }
+    }
 }
diff --git 
a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/SortMergeReaderTestBase.java
 
b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/SortMergeReaderTestBase.java
index 6895815cd..81e49b81d 100644
--- 
a/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/SortMergeReaderTestBase.java
+++ 
b/paimon-core/src/test/java/org/apache/paimon/mergetree/compact/SortMergeReaderTestBase.java
@@ -48,6 +48,7 @@ public abstract class SortMergeReaderTestBase extends 
CombiningRecordReaderTestB
         return SortMergeReader.createSortMergeReader(
                 new ArrayList<>(readers),
                 KEY_COMPARATOR,
+                null,
                 new ReducerMergeFunctionWrapper(createMergeFunction()),
                 sortEngine);
     }
diff --git 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sorter/SortOperator.java
 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sorter/SortOperator.java
index 4272dc53e..1292c7ce3 100644
--- 
a/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sorter/SortOperator.java
+++ 
b/paimon-flink/paimon-flink-common/src/main/java/org/apache/paimon/flink/sorter/SortOperator.java
@@ -31,6 +31,8 @@ import 
org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.table.runtime.operators.TableStreamOperator;
 
+import java.util.stream.IntStream;
+
 /** SortOperator to sort the `InternalRow`s by the `KeyType`. */
 public class SortOperator extends TableStreamOperator<InternalRow>
         implements OneInputStreamOperator<InternalRow, InternalRow>, 
BoundedOneInput {
@@ -87,8 +89,8 @@ public class SortOperator extends 
TableStreamOperator<InternalRow>
         buffer =
                 BinaryExternalSortBuffer.create(
                         ioManager,
-                        keyType,
                         rowType,
+                        IntStream.range(0, keyType.getFieldCount()).toArray(),
                         maxMemory,
                         pageSize,
                         spillSortMaxNumFiles,

Reply via email to