This is an automated email from the ASF dual-hosted git repository.

stevenwu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/iceberg.git

commit 21522697c42cee3bedd758aa399b9530b001e30a
Author: Steven Wu <stevenz...@gmail.com>
AuthorDate: Fri Dec 8 13:04:53 2023 -0800

    Flink: backport PR #9212 to 1.16 for switching to SortKey for data 
statistics
---
 .../iceberg/flink/sink/shuffle/DataStatistics.java |  12 +-
 .../sink/shuffle/DataStatisticsCoordinator.java    |   3 +-
 .../flink/sink/shuffle/DataStatisticsOperator.java |  25 +-
 .../flink/sink/shuffle/DataStatisticsUtil.java     |   1 -
 .../flink/sink/shuffle/MapDataStatistics.java      |  21 +-
 .../sink/shuffle/MapDataStatisticsSerializer.java  |  56 ++--
 .../flink/sink/shuffle/SortKeySerializer.java      | 353 +++++++++++++++++++++
 .../sink/shuffle/TestAggregatedStatistics.java     |  48 +--
 .../shuffle/TestAggregatedStatisticsTracker.java   |  97 +++---
 .../shuffle/TestDataStatisticsCoordinator.java     |  98 +++---
 .../TestDataStatisticsCoordinatorProvider.java     | 160 +++++-----
 .../sink/shuffle/TestDataStatisticsOperator.java   | 154 +++++----
 .../flink/sink/shuffle/TestMapDataStatistics.java  |  90 ++++++
 .../sink/shuffle/TestSortKeySerializerBase.java    |  65 ++++
 .../shuffle/TestSortKeySerializerNestedStruct.java |  55 ++++
 .../shuffle/TestSortKeySerializerPrimitives.java   |  57 ++++
 16 files changed, 980 insertions(+), 315 deletions(-)

diff --git 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java
 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java
index 28a05201c0..9d7cf179ab 100644
--- 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java
+++ 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatistics.java
@@ -19,7 +19,7 @@
 package org.apache.iceberg.flink.sink.shuffle;
 
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.table.data.RowData;
+import org.apache.iceberg.SortKey;
 
 /**
  * DataStatistics defines the interface to collect data distribution 
information.
@@ -29,7 +29,7 @@ import org.apache.flink.table.data.RowData;
  * (sketching) can be used.
  */
 @Internal
-interface DataStatistics<D extends DataStatistics, S> {
+interface DataStatistics<D extends DataStatistics<D, S>, S> {
 
   /**
    * Check if data statistics contains any statistics information.
@@ -38,12 +38,8 @@ interface DataStatistics<D extends DataStatistics, S> {
    */
   boolean isEmpty();
 
-  /**
-   * Add data key to data statistics.
-   *
-   * @param key generate from data by applying key selector
-   */
-  void add(RowData key);
+  /** Add row sortKey to data statistics. */
+  void add(SortKey sortKey);
 
   /**
    * Merge current statistics with other statistics.
diff --git 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java
 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java
index fcfd798842..c8ac79c61b 100644
--- 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java
+++ 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsCoordinator.java
@@ -172,6 +172,7 @@ class DataStatisticsCoordinator<D extends DataStatistics<D, 
S>, S> implements Op
     }
   }
 
+  @SuppressWarnings("FutureReturnValueIgnored")
   private void sendDataStatisticsToSubtasks(
       long checkpointId, DataStatistics<D, S> globalDataStatistics) {
     callInCoordinatorThread(
@@ -339,7 +340,7 @@ class DataStatisticsCoordinator<D extends DataStatistics<D, 
S>, S> implements Op
 
     private OperatorCoordinator.SubtaskGateway getSubtaskGateway(int 
subtaskIndex) {
       Preconditions.checkState(
-          gateways[subtaskIndex].size() > 0,
+          !gateways[subtaskIndex].isEmpty(),
           "Coordinator of %s subtask %d is not ready yet to receive events",
           operatorName,
           subtaskIndex);
diff --git 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java
 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java
index d00d5d2e5a..5157a37cf2 100644
--- 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java
+++ 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsOperator.java
@@ -22,7 +22,6 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.runtime.operators.coordination.OperatorEvent;
 import org.apache.flink.runtime.operators.coordination.OperatorEventGateway;
 import org.apache.flink.runtime.operators.coordination.OperatorEventHandler;
@@ -32,6 +31,12 @@ import 
org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.table.data.RowData;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.StructLike;
+import org.apache.iceberg.flink.FlinkSchemaUtil;
+import org.apache.iceberg.flink.RowDataWrapper;
 import 
org.apache.iceberg.relocated.com.google.common.annotations.VisibleForTesting;
 import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 
@@ -45,11 +50,12 @@ import 
org.apache.iceberg.relocated.com.google.common.base.Preconditions;
 class DataStatisticsOperator<D extends DataStatistics<D, S>, S>
     extends AbstractStreamOperator<DataStatisticsOrRecord<D, S>>
     implements OneInputStreamOperator<RowData, DataStatisticsOrRecord<D, S>>, 
OperatorEventHandler {
+
   private static final long serialVersionUID = 1L;
 
   private final String operatorName;
-  // keySelector will be used to generate key from data for collecting data 
statistics
-  private final KeySelector<RowData, RowData> keySelector;
+  private final RowDataWrapper rowDataWrapper;
+  private final SortKey sortKey;
   private final OperatorEventGateway operatorEventGateway;
   private final TypeSerializer<DataStatistics<D, S>> statisticsSerializer;
   private transient volatile DataStatistics<D, S> localStatistics;
@@ -58,11 +64,13 @@ class DataStatisticsOperator<D extends DataStatistics<D, 
S>, S>
 
   DataStatisticsOperator(
       String operatorName,
-      KeySelector<RowData, RowData> keySelector,
+      Schema schema,
+      SortOrder sortOrder,
       OperatorEventGateway operatorEventGateway,
       TypeSerializer<DataStatistics<D, S>> statisticsSerializer) {
     this.operatorName = operatorName;
-    this.keySelector = keySelector;
+    this.rowDataWrapper = new RowDataWrapper(FlinkSchemaUtil.convert(schema), 
schema.asStruct());
+    this.sortKey = new SortKey(schema, sortOrder);
     this.operatorEventGateway = operatorEventGateway;
     this.statisticsSerializer = statisticsSerializer;
   }
@@ -126,10 +134,11 @@ class DataStatisticsOperator<D extends DataStatistics<D, 
S>, S>
   }
 
   @Override
-  public void processElement(StreamRecord<RowData> streamRecord) throws 
Exception {
+  public void processElement(StreamRecord<RowData> streamRecord) {
     RowData record = streamRecord.getValue();
-    RowData key = keySelector.getKey(record);
-    localStatistics.add(key);
+    StructLike struct = rowDataWrapper.wrap(record);
+    sortKey.wrap(struct);
+    localStatistics.add(sortKey);
     output.collect(new 
StreamRecord<>(DataStatisticsOrRecord.fromRecord(record)));
   }
 
diff --git 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java
 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java
index 2737b1346f..8716cb872d 100644
--- 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java
+++ 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/DataStatisticsUtil.java
@@ -76,7 +76,6 @@ class DataStatisticsUtil {
     return bytes.toByteArray();
   }
 
-  @SuppressWarnings("unchecked")
   static <D extends DataStatistics<D, S>, S>
       AggregatedStatistics<D, S> deserializeAggregatedStatistics(
           byte[] bytes, TypeSerializer<DataStatistics<D, S>> 
statisticsSerializer)
diff --git 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java
 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java
index 246b56526f..0ffffd9cf4 100644
--- 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java
+++ 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatistics.java
@@ -20,20 +20,20 @@ package org.apache.iceberg.flink.sink.shuffle;
 
 import java.util.Map;
 import org.apache.flink.annotation.Internal;
-import org.apache.flink.table.data.RowData;
+import org.apache.iceberg.SortKey;
 import org.apache.iceberg.relocated.com.google.common.base.MoreObjects;
 import org.apache.iceberg.relocated.com.google.common.collect.Maps;
 
 /** MapDataStatistics uses map to count key frequency */
 @Internal
-class MapDataStatistics implements DataStatistics<MapDataStatistics, 
Map<RowData, Long>> {
-  private final Map<RowData, Long> statistics;
+class MapDataStatistics implements DataStatistics<MapDataStatistics, 
Map<SortKey, Long>> {
+  private final Map<SortKey, Long> statistics;
 
   MapDataStatistics() {
     this.statistics = Maps.newHashMap();
   }
 
-  MapDataStatistics(Map<RowData, Long> statistics) {
+  MapDataStatistics(Map<SortKey, Long> statistics) {
     this.statistics = statistics;
   }
 
@@ -43,9 +43,14 @@ class MapDataStatistics implements 
DataStatistics<MapDataStatistics, Map<RowData
   }
 
   @Override
-  public void add(RowData key) {
-    // increase count of occurrence by one in the dataStatistics map
-    statistics.merge(key, 1L, Long::sum);
+  public void add(SortKey sortKey) {
+    if (statistics.containsKey(sortKey)) {
+      statistics.merge(sortKey, 1L, Long::sum);
+    } else {
+      // clone the sort key before adding to map because input sortKey object 
can be reused
+      SortKey copiedKey = sortKey.copy();
+      statistics.put(copiedKey, 1L);
+    }
   }
 
   @Override
@@ -54,7 +59,7 @@ class MapDataStatistics implements 
DataStatistics<MapDataStatistics, Map<RowData
   }
 
   @Override
-  public Map<RowData, Long> statistics() {
+  public Map<SortKey, Long> statistics() {
     return statistics;
   }
 
diff --git 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsSerializer.java
 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsSerializer.java
index 6d07637b29..b6cccd0566 100644
--- 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsSerializer.java
+++ 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/MapDataStatisticsSerializer.java
@@ -29,22 +29,22 @@ import 
org.apache.flink.api.common.typeutils.base.LongSerializer;
 import org.apache.flink.api.common.typeutils.base.MapSerializer;
 import org.apache.flink.core.memory.DataInputView;
 import org.apache.flink.core.memory.DataOutputView;
-import org.apache.flink.table.data.RowData;
 import org.apache.flink.util.Preconditions;
+import org.apache.iceberg.SortKey;
 import org.apache.iceberg.relocated.com.google.common.collect.Maps;
 
 @Internal
 class MapDataStatisticsSerializer
-    extends TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, 
Long>>> {
-  private final MapSerializer<RowData, Long> mapSerializer;
+    extends TypeSerializer<DataStatistics<MapDataStatistics, Map<SortKey, 
Long>>> {
+  private final MapSerializer<SortKey, Long> mapSerializer;
 
-  static TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> 
fromKeySerializer(
-      TypeSerializer<RowData> keySerializer) {
+  static MapDataStatisticsSerializer fromSortKeySerializer(
+      TypeSerializer<SortKey> sortKeySerializer) {
     return new MapDataStatisticsSerializer(
-        new MapSerializer<>(keySerializer, LongSerializer.INSTANCE));
+        new MapSerializer<>(sortKeySerializer, LongSerializer.INSTANCE));
   }
 
-  MapDataStatisticsSerializer(MapSerializer<RowData, Long> mapSerializer) {
+  MapDataStatisticsSerializer(MapSerializer<SortKey, Long> mapSerializer) {
     this.mapSerializer = mapSerializer;
   }
 
@@ -55,28 +55,28 @@ class MapDataStatisticsSerializer
 
   @SuppressWarnings("ReferenceEquality")
   @Override
-  public TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>> 
duplicate() {
-    MapSerializer<RowData, Long> duplicateMapSerializer =
-        (MapSerializer<RowData, Long>) mapSerializer.duplicate();
+  public TypeSerializer<DataStatistics<MapDataStatistics, Map<SortKey, Long>>> 
duplicate() {
+    MapSerializer<SortKey, Long> duplicateMapSerializer =
+        (MapSerializer<SortKey, Long>) mapSerializer.duplicate();
     return (duplicateMapSerializer == mapSerializer)
         ? this
         : new MapDataStatisticsSerializer(duplicateMapSerializer);
   }
 
   @Override
-  public DataStatistics<MapDataStatistics, Map<RowData, Long>> 
createInstance() {
+  public MapDataStatistics createInstance() {
     return new MapDataStatistics();
   }
 
   @Override
-  public DataStatistics<MapDataStatistics, Map<RowData, Long>> 
copy(DataStatistics obj) {
+  public MapDataStatistics copy(DataStatistics<MapDataStatistics, Map<SortKey, 
Long>> obj) {
     Preconditions.checkArgument(
         obj instanceof MapDataStatistics, "Invalid data statistics type: " + 
obj.getClass());
     MapDataStatistics from = (MapDataStatistics) obj;
-    TypeSerializer<RowData> keySerializer = mapSerializer.getKeySerializer();
-    Map<RowData, Long> newMap = 
Maps.newHashMapWithExpectedSize(from.statistics().size());
-    for (Map.Entry<RowData, Long> entry : from.statistics().entrySet()) {
-      RowData newKey = keySerializer.copy(entry.getKey());
+    TypeSerializer<SortKey> keySerializer = mapSerializer.getKeySerializer();
+    Map<SortKey, Long> newMap = 
Maps.newHashMapWithExpectedSize(from.statistics().size());
+    for (Map.Entry<SortKey, Long> entry : from.statistics().entrySet()) {
+      SortKey newKey = keySerializer.copy(entry.getKey());
       // no need to copy value since it is just a Long
       newMap.put(newKey, entry.getValue());
     }
@@ -85,8 +85,9 @@ class MapDataStatisticsSerializer
   }
 
   @Override
-  public DataStatistics<MapDataStatistics, Map<RowData, Long>> copy(
-      DataStatistics from, DataStatistics reuse) {
+  public DataStatistics<MapDataStatistics, Map<SortKey, Long>> copy(
+      DataStatistics<MapDataStatistics, Map<SortKey, Long>> from,
+      DataStatistics<MapDataStatistics, Map<SortKey, Long>> reuse) {
     // not much benefit to reuse
     return copy(from);
   }
@@ -97,7 +98,9 @@ class MapDataStatisticsSerializer
   }
 
   @Override
-  public void serialize(DataStatistics obj, DataOutputView target) throws 
IOException {
+  public void serialize(
+      DataStatistics<MapDataStatistics, Map<SortKey, Long>> obj, 
DataOutputView target)
+      throws IOException {
     Preconditions.checkArgument(
         obj instanceof MapDataStatistics, "Invalid data statistics type: " + 
obj.getClass());
     MapDataStatistics mapStatistics = (MapDataStatistics) obj;
@@ -105,14 +108,15 @@ class MapDataStatisticsSerializer
   }
 
   @Override
-  public DataStatistics<MapDataStatistics, Map<RowData, Long>> 
deserialize(DataInputView source)
+  public DataStatistics<MapDataStatistics, Map<SortKey, Long>> 
deserialize(DataInputView source)
       throws IOException {
     return new MapDataStatistics(mapSerializer.deserialize(source));
   }
 
   @Override
-  public DataStatistics<MapDataStatistics, Map<RowData, Long>> deserialize(
-      DataStatistics reuse, DataInputView source) throws IOException {
+  public DataStatistics<MapDataStatistics, Map<SortKey, Long>> deserialize(
+      DataStatistics<MapDataStatistics, Map<SortKey, Long>> reuse, 
DataInputView source)
+      throws IOException {
     // not much benefit to reuse
     return deserialize(source);
   }
@@ -138,14 +142,14 @@ class MapDataStatisticsSerializer
   }
 
   @Override
-  public TypeSerializerSnapshot<DataStatistics<MapDataStatistics, Map<RowData, 
Long>>>
+  public TypeSerializerSnapshot<DataStatistics<MapDataStatistics, Map<SortKey, 
Long>>>
       snapshotConfiguration() {
     return new MapDataStatisticsSerializerSnapshot(this);
   }
 
   public static class MapDataStatisticsSerializerSnapshot
       extends CompositeTypeSerializerSnapshot<
-          DataStatistics<MapDataStatistics, Map<RowData, Long>>, 
MapDataStatisticsSerializer> {
+          DataStatistics<MapDataStatistics, Map<SortKey, Long>>, 
MapDataStatisticsSerializer> {
     private static final int CURRENT_VERSION = 1;
 
     // constructors need to public. Otherwise, Flink state restore would 
complain
@@ -175,8 +179,8 @@ class MapDataStatisticsSerializer
     protected MapDataStatisticsSerializer 
createOuterSerializerWithNestedSerializers(
         TypeSerializer<?>[] nestedSerializers) {
       @SuppressWarnings("unchecked")
-      MapSerializer<RowData, Long> mapSerializer =
-          (MapSerializer<RowData, Long>) nestedSerializers[0];
+      MapSerializer<SortKey, Long> mapSerializer =
+          (MapSerializer<SortKey, Long>) nestedSerializers[0];
       return new MapDataStatisticsSerializer(mapSerializer);
     }
   }
diff --git 
a/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySerializer.java
 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySerializer.java
new file mode 100644
index 0000000000..d03409f2a4
--- /dev/null
+++ 
b/flink/v1.16/flink/src/main/java/org/apache/iceberg/flink/sink/shuffle/SortKeySerializer.java
@@ -0,0 +1,353 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.io.IOException;
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.nio.ByteBuffer;
+import java.util.List;
+import java.util.Objects;
+import java.util.UUID;
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSchemaCompatibility;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.util.Preconditions;
+import org.apache.flink.util.StringUtils;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SchemaParser;
+import org.apache.iceberg.SortField;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.SortOrderParser;
+import org.apache.iceberg.types.CheckCompatibility;
+import org.apache.iceberg.types.Type;
+import org.apache.iceberg.types.Types;
+
+class SortKeySerializer extends TypeSerializer<SortKey> {
+  private final Schema schema;
+  private final SortOrder sortOrder;
+  private final int size;
+  private final Types.NestedField[] transformedFields;
+
+  private transient SortKey sortKey;
+
+  SortKeySerializer(Schema schema, SortOrder sortOrder) {
+    this.schema = schema;
+    this.sortOrder = sortOrder;
+    this.size = sortOrder.fields().size();
+
+    this.transformedFields = new Types.NestedField[size];
+    for (int i = 0; i < size; ++i) {
+      SortField sortField = sortOrder.fields().get(i);
+      Types.NestedField sourceField = schema.findField(sortField.sourceId());
+      Type resultType = 
sortField.transform().getResultType(sourceField.type());
+      Types.NestedField transformedField =
+          Types.NestedField.of(
+              sourceField.fieldId(),
+              sourceField.isOptional(),
+              sourceField.name(),
+              resultType,
+              sourceField.doc());
+      transformedFields[i] = transformedField;
+    }
+  }
+
+  private SortKey lazySortKey() {
+    if (sortKey == null) {
+      this.sortKey = new SortKey(schema, sortOrder);
+    }
+
+    return sortKey;
+  }
+
+  @Override
+  public boolean isImmutableType() {
+    return false;
+  }
+
+  @Override
+  public TypeSerializer<SortKey> duplicate() {
+    return new SortKeySerializer(schema, sortOrder);
+  }
+
+  @Override
+  public SortKey createInstance() {
+    return new SortKey(schema, sortOrder);
+  }
+
+  @Override
+  public SortKey copy(SortKey from) {
+    return from.copy();
+  }
+
+  @Override
+  public SortKey copy(SortKey from, SortKey reuse) {
+    // no benefit of reuse
+    return copy(from);
+  }
+
+  @Override
+  public int getLength() {
+    return -1;
+  }
+
+  @Override
+  public void serialize(SortKey record, DataOutputView target) throws 
IOException {
+    Preconditions.checkArgument(
+        record.size() == size,
+        "Invalid size of the sort key object: %s. Expected %s",
+        record.size(),
+        size);
+    for (int i = 0; i < size; ++i) {
+      int fieldId = transformedFields[i].fieldId();
+      Type.TypeID typeId = transformedFields[i].type().typeId();
+      switch (typeId) {
+        case BOOLEAN:
+          target.writeBoolean(record.get(i, Boolean.class));
+          break;
+        case INTEGER:
+        case DATE:
+          target.writeInt(record.get(i, Integer.class));
+          break;
+        case LONG:
+        case TIME:
+        case TIMESTAMP:
+          target.writeLong(record.get(i, Long.class));
+          break;
+        case FLOAT:
+          target.writeFloat(record.get(i, Float.class));
+          break;
+        case DOUBLE:
+          target.writeDouble(record.get(i, Double.class));
+          break;
+        case STRING:
+          target.writeUTF(record.get(i, CharSequence.class).toString());
+          break;
+        case UUID:
+          UUID uuid = record.get(i, UUID.class);
+          target.writeLong(uuid.getMostSignificantBits());
+          target.writeLong(uuid.getLeastSignificantBits());
+          break;
+        case FIXED:
+        case BINARY:
+          byte[] bytes = record.get(i, ByteBuffer.class).array();
+          target.writeInt(bytes.length);
+          target.write(bytes);
+          break;
+        case DECIMAL:
+          BigDecimal decimal = record.get(i, BigDecimal.class);
+          byte[] decimalBytes = decimal.unscaledValue().toByteArray();
+          target.writeInt(decimalBytes.length);
+          target.write(decimalBytes);
+          target.writeInt(decimal.scale());
+          break;
+        case STRUCT:
+        case MAP:
+        case LIST:
+        default:
+          // SortKey transformation is a flattened struct without list and map
+          throw new UnsupportedOperationException(
+              String.format("Field %d has unsupported field type: %s", 
fieldId, typeId));
+      }
+    }
+  }
+
+  @Override
+  public SortKey deserialize(DataInputView source) throws IOException {
+    // copying is a little faster than constructing a new SortKey object
+    SortKey deserialized = lazySortKey().copy();
+    deserialize(deserialized, source);
+    return deserialized;
+  }
+
+  @Override
+  public SortKey deserialize(SortKey reuse, DataInputView source) throws 
IOException {
+    Preconditions.checkArgument(
+        reuse.size() == size,
+        "Invalid size of the sort key object: %s. Expected %s",
+        reuse.size(),
+        size);
+    for (int i = 0; i < size; ++i) {
+      int fieldId = transformedFields[i].fieldId();
+      Type.TypeID typeId = transformedFields[i].type().typeId();
+      switch (typeId) {
+        case BOOLEAN:
+          reuse.set(i, source.readBoolean());
+          break;
+        case INTEGER:
+        case DATE:
+          reuse.set(i, source.readInt());
+          break;
+        case LONG:
+        case TIME:
+        case TIMESTAMP:
+          reuse.set(i, source.readLong());
+          break;
+        case FLOAT:
+          reuse.set(i, source.readFloat());
+          break;
+        case DOUBLE:
+          reuse.set(i, source.readDouble());
+          break;
+        case STRING:
+          reuse.set(i, source.readUTF());
+          break;
+        case UUID:
+          long mostSignificantBits = source.readLong();
+          long leastSignificantBits = source.readLong();
+          reuse.set(i, new UUID(mostSignificantBits, leastSignificantBits));
+          break;
+        case FIXED:
+        case BINARY:
+          byte[] bytes = new byte[source.readInt()];
+          source.read(bytes);
+          reuse.set(i, ByteBuffer.wrap(bytes));
+          break;
+        case DECIMAL:
+          byte[] unscaledBytes = new byte[source.readInt()];
+          source.read(unscaledBytes);
+          int scale = source.readInt();
+          BigDecimal decimal = new BigDecimal(new BigInteger(unscaledBytes), 
scale);
+          reuse.set(i, decimal);
+          break;
+        case STRUCT:
+        case MAP:
+        case LIST:
+        default:
+          // SortKey transformation is a flattened struct without list and map
+          throw new UnsupportedOperationException(
+              String.format("Field %d has unsupported field type: %s", 
fieldId, typeId));
+      }
+    }
+
+    return reuse;
+  }
+
+  @Override
+  public void copy(DataInputView source, DataOutputView target) throws 
IOException {
+    // no optimization here
+    serialize(deserialize(source), target);
+  }
+
+  @Override
+  public boolean equals(Object obj) {
+    if (!(obj instanceof SortKeySerializer)) {
+      return false;
+    }
+
+    SortKeySerializer other = (SortKeySerializer) obj;
+    return Objects.equals(schema.asStruct(), other.schema.asStruct())
+        && Objects.equals(sortOrder, other.sortOrder);
+  }
+
+  @Override
+  public int hashCode() {
+    return schema.asStruct().hashCode() * 31 + sortOrder.hashCode();
+  }
+
+  @Override
+  public TypeSerializerSnapshot<SortKey> snapshotConfiguration() {
+    return new SortKeySerializerSnapshot(schema, sortOrder);
+  }
+
+  public static class SortKeySerializerSnapshot implements 
TypeSerializerSnapshot<SortKey> {
+    private static final int CURRENT_VERSION = 1;
+
+    private Schema schema;
+    private SortOrder sortOrder;
+
+    @SuppressWarnings({"checkstyle:RedundantModifier", "WeakerAccess"})
+    public SortKeySerializerSnapshot() {
+      // this constructor is used when restoring from a checkpoint.
+    }
+
+    // constructors need to public. Otherwise, Flink state restore would 
complain
+    // "The class has no (implicit) public nullary constructor".
+    @SuppressWarnings("checkstyle:RedundantModifier")
+    public SortKeySerializerSnapshot(Schema schema, SortOrder sortOrder) {
+      this.schema = schema;
+      this.sortOrder = sortOrder;
+    }
+
+    @Override
+    public int getCurrentVersion() {
+      return CURRENT_VERSION;
+    }
+
+    @Override
+    public void writeSnapshot(DataOutputView out) throws IOException {
+      Preconditions.checkState(schema != null, "Invalid schema: null");
+      Preconditions.checkState(sortOrder != null, "Invalid sort order: null");
+
+      StringUtils.writeString(SchemaParser.toJson(schema), out);
+      StringUtils.writeString(SortOrderParser.toJson(sortOrder), out);
+    }
+
+    @Override
+    public void readSnapshot(int readVersion, DataInputView in, ClassLoader 
userCodeClassLoader)
+        throws IOException {
+      if (readVersion == 1) {
+        readV1(in);
+      } else {
+        throw new IllegalArgumentException("Unknown read version: " + 
readVersion);
+      }
+    }
+
+    @Override
+    public TypeSerializerSchemaCompatibility<SortKey> 
resolveSchemaCompatibility(
+        TypeSerializer<SortKey> newSerializer) {
+      if (!(newSerializer instanceof SortKeySerializer)) {
+        return TypeSerializerSchemaCompatibility.incompatible();
+      }
+
+      SortKeySerializer newAvroSerializer = (SortKeySerializer) newSerializer;
+      return resolveSchemaCompatibility(newAvroSerializer.schema, schema);
+    }
+
+    @Override
+    public TypeSerializer<SortKey> restoreSerializer() {
+      Preconditions.checkState(schema != null, "Invalid schema: null");
+      Preconditions.checkState(sortOrder != null, "Invalid sort order: null");
+      return new SortKeySerializer(schema, sortOrder);
+    }
+
+    private void readV1(DataInputView in) throws IOException {
+      String schemaJson = StringUtils.readString(in);
+      String sortOrderJson = StringUtils.readString(in);
+      this.schema = SchemaParser.fromJson(schemaJson);
+      this.sortOrder = SortOrderParser.fromJson(sortOrderJson).bind(schema);
+    }
+
+    @VisibleForTesting
+    static <T> TypeSerializerSchemaCompatibility<T> resolveSchemaCompatibility(
+        Schema readSchema, Schema writeSchema) {
+      List<String> compatibilityErrors =
+          CheckCompatibility.writeCompatibilityErrors(readSchema, writeSchema);
+      if (compatibilityErrors.isEmpty()) {
+        return TypeSerializerSchemaCompatibility.compatibleAsIs();
+      }
+
+      return TypeSerializerSchemaCompatibility.incompatible();
+    }
+  }
+}
diff --git 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java
 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java
index dd7fcafe53..890cc361b2 100644
--- 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java
+++ 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatistics.java
@@ -21,41 +21,43 @@ package org.apache.iceberg.flink.sink.shuffle;
 import static org.assertj.core.api.Assertions.assertThat;
 
 import java.util.Map;
-import org.apache.flink.table.data.GenericRowData;
-import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.data.StringData;
-import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
-import org.apache.flink.table.types.logical.RowType;
-import org.apache.flink.table.types.logical.VarCharType;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.types.Types;
 import org.junit.Test;
 
 public class TestAggregatedStatistics {
+  private final Schema schema =
+      new Schema(Types.NestedField.optional(1, "str", Types.StringType.get()));
+  private final SortOrder sortOrder = 
SortOrder.builderFor(schema).asc("str").build();
+  private final SortKey sortKey = new SortKey(schema, sortOrder);
+  private final MapDataStatisticsSerializer statisticsSerializer =
+      MapDataStatisticsSerializer.fromSortKeySerializer(new 
SortKeySerializer(schema, sortOrder));
 
   @Test
   public void mergeDataStatisticTest() {
-    GenericRowData rowDataA = GenericRowData.of(StringData.fromString("a"));
-    GenericRowData rowDataB = GenericRowData.of(StringData.fromString("b"));
+    SortKey keyA = sortKey.copy();
+    keyA.set(0, "a");
+    SortKey keyB = sortKey.copy();
+    keyB.set(0, "b");
 
-    AggregatedStatistics<MapDataStatistics, Map<RowData, Long>> 
aggregatedStatistics =
-        new AggregatedStatistics<>(
-            1,
-            MapDataStatisticsSerializer.fromKeySerializer(
-                new RowDataSerializer(RowType.of(new VarCharType()))));
+    AggregatedStatistics<MapDataStatistics, Map<SortKey, Long>> 
aggregatedStatistics =
+        new AggregatedStatistics<>(1, statisticsSerializer);
     MapDataStatistics mapDataStatistics1 = new MapDataStatistics();
-    mapDataStatistics1.add(rowDataA);
-    mapDataStatistics1.add(rowDataA);
-    mapDataStatistics1.add(rowDataB);
+    mapDataStatistics1.add(keyA);
+    mapDataStatistics1.add(keyA);
+    mapDataStatistics1.add(keyB);
     aggregatedStatistics.mergeDataStatistic("testOperator", 1, 
mapDataStatistics1);
     MapDataStatistics mapDataStatistics2 = new MapDataStatistics();
-    mapDataStatistics2.add(rowDataA);
+    mapDataStatistics2.add(keyA);
     aggregatedStatistics.mergeDataStatistic("testOperator", 1, 
mapDataStatistics2);
-    
assertThat(aggregatedStatistics.dataStatistics().statistics().get(rowDataA))
+    assertThat(aggregatedStatistics.dataStatistics().statistics().get(keyA))
         .isEqualTo(
-            mapDataStatistics1.statistics().get(rowDataA)
-                + mapDataStatistics2.statistics().get(rowDataA));
-    
assertThat(aggregatedStatistics.dataStatistics().statistics().get(rowDataB))
+            mapDataStatistics1.statistics().get(keyA) + 
mapDataStatistics2.statistics().get(keyA));
+    assertThat(aggregatedStatistics.dataStatistics().statistics().get(keyB))
         .isEqualTo(
-            mapDataStatistics1.statistics().get(rowDataB)
-                + mapDataStatistics2.statistics().getOrDefault(rowDataB, 0L));
+            mapDataStatistics1.statistics().get(keyB)
+                + mapDataStatistics2.statistics().getOrDefault(keyB, 0L));
   }
 }
diff --git 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java
 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java
index 48e4e4d8f9..4c64ce5222 100644
--- 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java
+++ 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestAggregatedStatisticsTracker.java
@@ -21,32 +21,33 @@ package org.apache.iceberg.flink.sink.shuffle;
 import static org.assertj.core.api.Assertions.assertThat;
 
 import java.util.Map;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.table.data.GenericRowData;
-import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.data.StringData;
-import org.apache.flink.table.data.binary.BinaryRowData;
-import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
-import org.apache.flink.table.types.logical.RowType;
-import org.apache.flink.table.types.logical.VarCharType;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.types.Types;
 import org.junit.Before;
 import org.junit.Test;
 
 public class TestAggregatedStatisticsTracker {
   private static final int NUM_SUBTASKS = 2;
-  private final RowType rowType = RowType.of(new VarCharType());
-  // When coordinator handles events from operator, 
DataStatisticsUtil#deserializeDataStatistics
-  // deserializes bytes into BinaryRowData
-  private final BinaryRowData binaryRowDataA =
-      new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a")));
-  private final BinaryRowData binaryRowDataB =
-      new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b")));
-  private final TypeSerializer<RowData> rowSerializer = new 
RowDataSerializer(rowType);
-  private final TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, 
Long>>>
-      statisticsSerializer = 
MapDataStatisticsSerializer.fromKeySerializer(rowSerializer);
-  private AggregatedStatisticsTracker<MapDataStatistics, Map<RowData, Long>>
+
+  private final Schema schema =
+      new Schema(Types.NestedField.optional(1, "str", Types.StringType.get()));
+  private final SortOrder sortOrder = 
SortOrder.builderFor(schema).asc("str").build();
+  private final SortKey sortKey = new SortKey(schema, sortOrder);
+  private final MapDataStatisticsSerializer statisticsSerializer =
+      MapDataStatisticsSerializer.fromSortKeySerializer(new 
SortKeySerializer(schema, sortOrder));
+  private final SortKey keyA = sortKey.copy();
+  private final SortKey keyB = sortKey.copy();
+
+  private AggregatedStatisticsTracker<MapDataStatistics, Map<SortKey, Long>>
       aggregatedStatisticsTracker;
 
+  public TestAggregatedStatisticsTracker() {
+    keyA.set(0, "a");
+    keyB.set(0, "b");
+  }
+
   @Before
   public void before() throws Exception {
     aggregatedStatisticsTracker =
@@ -56,8 +57,8 @@ public class TestAggregatedStatisticsTracker {
   @Test
   public void receiveNewerDataStatisticEvent() {
     MapDataStatistics checkpoint1Subtask0DataStatistic = new 
MapDataStatistics();
-    checkpoint1Subtask0DataStatistic.add(binaryRowDataA);
-    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+    checkpoint1Subtask0DataStatistic.add(keyA);
+    DataStatisticsEvent<MapDataStatistics, Map<SortKey, Long>>
         checkpoint1Subtask0DataStatisticEvent =
             DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, 
statisticsSerializer);
     assertThat(
@@ -67,8 +68,8 @@ public class TestAggregatedStatisticsTracker {
     
assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId()).isEqualTo(1);
 
     MapDataStatistics checkpoint2Subtask0DataStatistic = new 
MapDataStatistics();
-    checkpoint2Subtask0DataStatistic.add(binaryRowDataA);
-    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+    checkpoint2Subtask0DataStatistic.add(keyA);
+    DataStatisticsEvent<MapDataStatistics, Map<SortKey, Long>>
         checkpoint2Subtask0DataStatisticEvent =
             DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, 
statisticsSerializer);
     assertThat(
@@ -82,10 +83,10 @@ public class TestAggregatedStatisticsTracker {
   @Test
   public void receiveOlderDataStatisticEventTest() {
     MapDataStatistics checkpoint2Subtask0DataStatistic = new 
MapDataStatistics();
-    checkpoint2Subtask0DataStatistic.add(binaryRowDataA);
-    checkpoint2Subtask0DataStatistic.add(binaryRowDataB);
-    checkpoint2Subtask0DataStatistic.add(binaryRowDataB);
-    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+    checkpoint2Subtask0DataStatistic.add(keyA);
+    checkpoint2Subtask0DataStatistic.add(keyB);
+    checkpoint2Subtask0DataStatistic.add(keyB);
+    DataStatisticsEvent<MapDataStatistics, Map<SortKey, Long>>
         checkpoint3Subtask0DataStatisticEvent =
             DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, 
statisticsSerializer);
     assertThat(
@@ -94,8 +95,8 @@ public class TestAggregatedStatisticsTracker {
         .isNull();
 
     MapDataStatistics checkpoint1Subtask1DataStatistic = new 
MapDataStatistics();
-    checkpoint1Subtask1DataStatistic.add(binaryRowDataB);
-    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+    checkpoint1Subtask1DataStatistic.add(keyB);
+    DataStatisticsEvent<MapDataStatistics, Map<SortKey, Long>>
         checkpoint1Subtask1DataStatisticEvent =
             DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic, 
statisticsSerializer);
     // Receive event from old checkpoint, 
aggregatedStatisticsAggregatorTracker won't return
@@ -110,10 +111,10 @@ public class TestAggregatedStatisticsTracker {
   @Test
   public void receiveCompletedDataStatisticEvent() {
     MapDataStatistics checkpoint1Subtask0DataStatistic = new 
MapDataStatistics();
-    checkpoint1Subtask0DataStatistic.add(binaryRowDataA);
-    checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
-    checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
-    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+    checkpoint1Subtask0DataStatistic.add(keyA);
+    checkpoint1Subtask0DataStatistic.add(keyB);
+    checkpoint1Subtask0DataStatistic.add(keyB);
+    DataStatisticsEvent<MapDataStatistics, Map<SortKey, Long>>
         checkpoint1Subtask0DataStatisticEvent =
             DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, 
statisticsSerializer);
     assertThat(
@@ -122,14 +123,14 @@ public class TestAggregatedStatisticsTracker {
         .isNull();
 
     MapDataStatistics checkpoint1Subtask1DataStatistic = new 
MapDataStatistics();
-    checkpoint1Subtask1DataStatistic.add(binaryRowDataA);
-    checkpoint1Subtask1DataStatistic.add(binaryRowDataA);
-    checkpoint1Subtask1DataStatistic.add(binaryRowDataB);
-    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+    checkpoint1Subtask1DataStatistic.add(keyA);
+    checkpoint1Subtask1DataStatistic.add(keyA);
+    checkpoint1Subtask1DataStatistic.add(keyB);
+    DataStatisticsEvent<MapDataStatistics, Map<SortKey, Long>>
         checkpoint1Subtask1DataStatisticEvent =
             DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic, 
statisticsSerializer);
     // Receive data statistics from all subtasks at checkpoint 1
-    AggregatedStatistics<MapDataStatistics, Map<RowData, Long>> 
completedStatistics =
+    AggregatedStatistics<MapDataStatistics, Map<SortKey, Long>> 
completedStatistics =
         aggregatedStatisticsTracker.updateAndCheckCompletion(
             1, checkpoint1Subtask1DataStatisticEvent);
 
@@ -137,20 +138,20 @@ public class TestAggregatedStatisticsTracker {
     assertThat(completedStatistics.checkpointId()).isEqualTo(1);
     MapDataStatistics globalDataStatistics =
         (MapDataStatistics) completedStatistics.dataStatistics();
-    assertThat((long) globalDataStatistics.statistics().get(binaryRowDataA))
+    assertThat((long) globalDataStatistics.statistics().get(keyA))
         .isEqualTo(
-            checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataA)
-                + 
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataA));
-    assertThat((long) globalDataStatistics.statistics().get(binaryRowDataB))
+            checkpoint1Subtask0DataStatistic.statistics().get(keyA)
+                + checkpoint1Subtask1DataStatistic.statistics().get(keyA));
+    assertThat((long) globalDataStatistics.statistics().get(keyB))
         .isEqualTo(
-            checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataB)
-                + 
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataB));
+            checkpoint1Subtask0DataStatistic.statistics().get(keyB)
+                + checkpoint1Subtask1DataStatistic.statistics().get(keyB));
     
assertThat(aggregatedStatisticsTracker.inProgressStatistics().checkpointId())
         .isEqualTo(completedStatistics.checkpointId() + 1);
 
     MapDataStatistics checkpoint2Subtask0DataStatistic = new 
MapDataStatistics();
-    checkpoint2Subtask0DataStatistic.add(binaryRowDataA);
-    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+    checkpoint2Subtask0DataStatistic.add(keyA);
+    DataStatisticsEvent<MapDataStatistics, Map<SortKey, Long>>
         checkpoint2Subtask0DataStatisticEvent =
             DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, 
statisticsSerializer);
     assertThat(
@@ -160,8 +161,8 @@ public class TestAggregatedStatisticsTracker {
     assertThat(completedStatistics.checkpointId()).isEqualTo(1);
 
     MapDataStatistics checkpoint2Subtask1DataStatistic = new 
MapDataStatistics();
-    checkpoint2Subtask1DataStatistic.add(binaryRowDataB);
-    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+    checkpoint2Subtask1DataStatistic.add(keyB);
+    DataStatisticsEvent<MapDataStatistics, Map<SortKey, Long>>
         checkpoint2Subtask1DataStatisticEvent =
             DataStatisticsEvent.create(2, checkpoint2Subtask1DataStatistic, 
statisticsSerializer);
     // Receive data statistics from all subtasks at checkpoint 2
diff --git 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java
 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java
index 9ec2606e10..3df714059c 100644
--- 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java
+++ 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinator.java
@@ -24,19 +24,15 @@ import static 
org.assertj.core.api.Assertions.assertThatThrownBy;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.coordination.EventReceivingTasks;
 import 
org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext;
-import org.apache.flink.table.data.GenericRowData;
-import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.data.StringData;
-import org.apache.flink.table.data.binary.BinaryRowData;
-import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
-import org.apache.flink.table.types.logical.RowType;
-import org.apache.flink.table.types.logical.VarCharType;
 import org.apache.flink.util.ExceptionUtils;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.types.Types;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -44,20 +40,21 @@ public class TestDataStatisticsCoordinator {
   private static final String OPERATOR_NAME = "TestCoordinator";
   private static final OperatorID TEST_OPERATOR_ID = new OperatorID(1234L, 
5678L);
   private static final int NUM_SUBTASKS = 2;
-  private TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>>
-      statisticsSerializer;
+
+  private final Schema schema =
+      new Schema(Types.NestedField.optional(1, "str", Types.StringType.get()));
+  private final SortOrder sortOrder = 
SortOrder.builderFor(schema).asc("str").build();
+  private final SortKey sortKey = new SortKey(schema, sortOrder);
+  private final MapDataStatisticsSerializer statisticsSerializer =
+      MapDataStatisticsSerializer.fromSortKeySerializer(new 
SortKeySerializer(schema, sortOrder));
 
   private EventReceivingTasks receivingTasks;
-  private DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>
+  private DataStatisticsCoordinator<MapDataStatistics, Map<SortKey, Long>>
       dataStatisticsCoordinator;
 
   @Before
   public void before() throws Exception {
     receivingTasks = EventReceivingTasks.createForRunningTasks();
-    statisticsSerializer =
-        MapDataStatisticsSerializer.fromKeySerializer(
-            new RowDataSerializer(RowType.of(new VarCharType())));
-
     dataStatisticsCoordinator =
         new DataStatisticsCoordinator<>(
             OPERATOR_NAME,
@@ -93,59 +90,66 @@ public class TestDataStatisticsCoordinator {
   @Test
   public void testDataStatisticsEventHandling() throws Exception {
     tasksReady();
-    // When coordinator handles events from operator, 
DataStatisticsUtil#deserializeDataStatistics
-    // deserializes bytes into BinaryRowData
-    RowType rowType = RowType.of(new VarCharType());
-    BinaryRowData binaryRowDataA =
-        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a")));
-    BinaryRowData binaryRowDataB =
-        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b")));
-    BinaryRowData binaryRowDataC =
-        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("c")));
+    SortKey key = sortKey.copy();
 
     MapDataStatistics checkpoint1Subtask0DataStatistic = new 
MapDataStatistics();
-    checkpoint1Subtask0DataStatistic.add(binaryRowDataA);
-    checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
-    checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
-    checkpoint1Subtask0DataStatistic.add(binaryRowDataC);
-    checkpoint1Subtask0DataStatistic.add(binaryRowDataC);
-    checkpoint1Subtask0DataStatistic.add(binaryRowDataC);
-    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+    key.set(0, "a");
+    checkpoint1Subtask0DataStatistic.add(key);
+    key.set(0, "b");
+    checkpoint1Subtask0DataStatistic.add(key);
+    key.set(0, "b");
+    checkpoint1Subtask0DataStatistic.add(key);
+    key.set(0, "c");
+    checkpoint1Subtask0DataStatistic.add(key);
+    key.set(0, "c");
+    checkpoint1Subtask0DataStatistic.add(key);
+    key.set(0, "c");
+    checkpoint1Subtask0DataStatistic.add(key);
+
+    DataStatisticsEvent<MapDataStatistics, Map<SortKey, Long>>
         checkpoint1Subtask0DataStatisticEvent =
             DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, 
statisticsSerializer);
+
     MapDataStatistics checkpoint1Subtask1DataStatistic = new 
MapDataStatistics();
-    checkpoint1Subtask1DataStatistic.add(binaryRowDataA);
-    checkpoint1Subtask1DataStatistic.add(binaryRowDataB);
-    checkpoint1Subtask1DataStatistic.add(binaryRowDataC);
-    checkpoint1Subtask1DataStatistic.add(binaryRowDataC);
-    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
+    key.set(0, "a");
+    checkpoint1Subtask1DataStatistic.add(key);
+    key.set(0, "b");
+    checkpoint1Subtask1DataStatistic.add(key);
+    key.set(0, "c");
+    checkpoint1Subtask1DataStatistic.add(key);
+    key.set(0, "c");
+    checkpoint1Subtask1DataStatistic.add(key);
+
+    DataStatisticsEvent<MapDataStatistics, Map<SortKey, Long>>
         checkpoint1Subtask1DataStatisticEvent =
             DataStatisticsEvent.create(1, checkpoint1Subtask1DataStatistic, 
statisticsSerializer);
+
     // Handle events from operators for checkpoint 1
     dataStatisticsCoordinator.handleEventFromOperator(0, 0, 
checkpoint1Subtask0DataStatisticEvent);
     dataStatisticsCoordinator.handleEventFromOperator(1, 0, 
checkpoint1Subtask1DataStatisticEvent);
 
     waitForCoordinatorToProcessActions(dataStatisticsCoordinator);
+
     // Verify global data statistics is the aggregation of all subtasks data 
statistics
+    SortKey keyA = sortKey.copy();
+    keyA.set(0, "a");
+    SortKey keyB = sortKey.copy();
+    keyB.set(0, "b");
+    SortKey keyC = sortKey.copy();
+    keyC.set(0, "c");
     MapDataStatistics globalDataStatistics =
         (MapDataStatistics) 
dataStatisticsCoordinator.completedStatistics().dataStatistics();
     assertThat(globalDataStatistics.statistics())
         .containsExactlyInAnyOrderEntriesOf(
             ImmutableMap.of(
-                binaryRowDataA,
-                
checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataA)
-                    + (long) 
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataA),
-                binaryRowDataB,
-                
checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataB)
-                    + (long) 
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataB),
-                binaryRowDataC,
-                
checkpoint1Subtask0DataStatistic.statistics().get(binaryRowDataC)
-                    + (long) 
checkpoint1Subtask1DataStatistic.statistics().get(binaryRowDataC)));
+                keyA, 2L,
+                keyB, 3L,
+                keyC, 5L));
   }
 
   static void setAllTasksReady(
       int subtasks,
-      DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> 
dataStatisticsCoordinator,
+      DataStatisticsCoordinator<MapDataStatistics, Map<SortKey, Long>> 
dataStatisticsCoordinator,
       EventReceivingTasks receivingTasks) {
     for (int i = 0; i < subtasks; i++) {
       dataStatisticsCoordinator.executionAttemptReady(
@@ -154,7 +158,7 @@ public class TestDataStatisticsCoordinator {
   }
 
   static void waitForCoordinatorToProcessActions(
-      DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> 
coordinator) {
+      DataStatisticsCoordinator<MapDataStatistics, Map<SortKey, Long>> 
coordinator) {
     CompletableFuture<Void> future = new CompletableFuture<>();
     coordinator.callInCoordinatorThread(
         () -> {
diff --git 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java
 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java
index cb9d3f48ff..5e0a752be5 100644
--- 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java
+++ 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsCoordinatorProvider.java
@@ -23,18 +23,14 @@ import static org.assertj.core.api.Assertions.assertThat;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ExecutionException;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.coordination.EventReceivingTasks;
 import 
org.apache.flink.runtime.operators.coordination.MockOperatorCoordinatorContext;
 import 
org.apache.flink.runtime.operators.coordination.RecreateOnResetOperatorCoordinator;
-import org.apache.flink.table.data.GenericRowData;
-import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.data.StringData;
-import org.apache.flink.table.data.binary.BinaryRowData;
-import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
-import org.apache.flink.table.types.logical.RowType;
-import org.apache.flink.table.types.logical.VarCharType;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.types.Types;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -42,16 +38,18 @@ public class TestDataStatisticsCoordinatorProvider {
   private static final OperatorID OPERATOR_ID = new OperatorID();
   private static final int NUM_SUBTASKS = 1;
 
-  private DataStatisticsCoordinatorProvider<MapDataStatistics, Map<RowData, 
Long>> provider;
+  private final Schema schema =
+      new Schema(Types.NestedField.optional(1, "str", Types.StringType.get()));
+  private final SortOrder sortOrder = 
SortOrder.builderFor(schema).asc("str").build();
+  private final SortKey sortKey = new SortKey(schema, sortOrder);
+  private final MapDataStatisticsSerializer statisticsSerializer =
+      MapDataStatisticsSerializer.fromSortKeySerializer(new 
SortKeySerializer(schema, sortOrder));
+
+  private DataStatisticsCoordinatorProvider<MapDataStatistics, Map<SortKey, 
Long>> provider;
   private EventReceivingTasks receivingTasks;
-  private TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, Long>>>
-      statisticsSerializer;
 
   @Before
   public void before() {
-    statisticsSerializer =
-        MapDataStatisticsSerializer.fromKeySerializer(
-            new RowDataSerializer(RowType.of(new VarCharType())));
     provider =
         new DataStatisticsCoordinatorProvider<>(
             "DataStatisticsCoordinatorProvider", OPERATOR_ID, 
statisticsSerializer);
@@ -61,84 +59,82 @@ public class TestDataStatisticsCoordinatorProvider {
   @Test
   @SuppressWarnings("unchecked")
   public void testCheckpointAndReset() throws Exception {
-    RowType rowType = RowType.of(new VarCharType());
-    // When coordinator handles events from operator, 
DataStatisticsUtil#deserializeDataStatistics
-    // deserializes bytes into BinaryRowData
-    BinaryRowData binaryRowDataA =
-        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a")));
-    BinaryRowData binaryRowDataB =
-        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b")));
-    BinaryRowData binaryRowDataC =
-        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("c")));
-    BinaryRowData binaryRowDataD =
-        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("d")));
-    BinaryRowData binaryRowDataE =
-        new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("e")));
+    SortKey keyA = sortKey.copy();
+    keyA.set(0, "a");
+    SortKey keyB = sortKey.copy();
+    keyB.set(0, "b");
+    SortKey keyC = sortKey.copy();
+    keyC.set(0, "c");
+    SortKey keyD = sortKey.copy();
+    keyD.set(0, "c");
+    SortKey keyE = sortKey.copy();
+    keyE.set(0, "c");
 
-    RecreateOnResetOperatorCoordinator coordinator =
+    try (RecreateOnResetOperatorCoordinator coordinator =
         (RecreateOnResetOperatorCoordinator)
-            provider.create(new MockOperatorCoordinatorContext(OPERATOR_ID, 
NUM_SUBTASKS));
-    DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> 
dataStatisticsCoordinator =
-        (DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>)
-            coordinator.getInternalCoordinator();
+            provider.create(new MockOperatorCoordinatorContext(OPERATOR_ID, 
NUM_SUBTASKS))) {
+      DataStatisticsCoordinator<MapDataStatistics, Map<SortKey, Long>> 
dataStatisticsCoordinator =
+          (DataStatisticsCoordinator<MapDataStatistics, Map<SortKey, Long>>)
+              coordinator.getInternalCoordinator();
 
-    // Start the coordinator
-    coordinator.start();
-    TestDataStatisticsCoordinator.setAllTasksReady(
-        NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks);
-    MapDataStatistics checkpoint1Subtask0DataStatistic = new 
MapDataStatistics();
-    checkpoint1Subtask0DataStatistic.add(binaryRowDataA);
-    checkpoint1Subtask0DataStatistic.add(binaryRowDataB);
-    checkpoint1Subtask0DataStatistic.add(binaryRowDataC);
-    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
-        checkpoint1Subtask0DataStatisticEvent =
-            DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, 
statisticsSerializer);
+      // Start the coordinator
+      coordinator.start();
+      TestDataStatisticsCoordinator.setAllTasksReady(
+          NUM_SUBTASKS, dataStatisticsCoordinator, receivingTasks);
+      MapDataStatistics checkpoint1Subtask0DataStatistic = new 
MapDataStatistics();
+      checkpoint1Subtask0DataStatistic.add(keyA);
+      checkpoint1Subtask0DataStatistic.add(keyB);
+      checkpoint1Subtask0DataStatistic.add(keyC);
+      DataStatisticsEvent<MapDataStatistics, Map<SortKey, Long>>
+          checkpoint1Subtask0DataStatisticEvent =
+              DataStatisticsEvent.create(1, checkpoint1Subtask0DataStatistic, 
statisticsSerializer);
 
-    // Handle events from operators for checkpoint 1
-    coordinator.handleEventFromOperator(0, 0, 
checkpoint1Subtask0DataStatisticEvent);
-    
TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator);
-    // Verify checkpoint 1 global data statistics
-    MapDataStatistics checkpoint1GlobalDataStatistics =
-        (MapDataStatistics) 
dataStatisticsCoordinator.completedStatistics().dataStatistics();
-    assertThat(checkpoint1GlobalDataStatistics.statistics())
-        .isEqualTo(checkpoint1Subtask0DataStatistic.statistics());
-    byte[] checkpoint1Bytes = waitForCheckpoint(1L, dataStatisticsCoordinator);
+      // Handle events from operators for checkpoint 1
+      coordinator.handleEventFromOperator(0, 0, 
checkpoint1Subtask0DataStatisticEvent);
+      
TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator);
+      // Verify checkpoint 1 global data statistics
+      MapDataStatistics checkpoint1GlobalDataStatistics =
+          (MapDataStatistics) 
dataStatisticsCoordinator.completedStatistics().dataStatistics();
+      assertThat(checkpoint1GlobalDataStatistics.statistics())
+          .isEqualTo(checkpoint1Subtask0DataStatistic.statistics());
+      byte[] checkpoint1Bytes = waitForCheckpoint(1L, 
dataStatisticsCoordinator);
 
-    MapDataStatistics checkpoint2Subtask0DataStatistic = new 
MapDataStatistics();
-    checkpoint2Subtask0DataStatistic.add(binaryRowDataD);
-    checkpoint2Subtask0DataStatistic.add(binaryRowDataE);
-    checkpoint2Subtask0DataStatistic.add(binaryRowDataE);
-    DataStatisticsEvent<MapDataStatistics, Map<RowData, Long>>
-        checkpoint2Subtask0DataStatisticEvent =
-            DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, 
statisticsSerializer);
-    // Handle events from operators for checkpoint 2
-    coordinator.handleEventFromOperator(0, 0, 
checkpoint2Subtask0DataStatisticEvent);
-    
TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator);
-    // Verify checkpoint 2 global data statistics
-    MapDataStatistics checkpoint2GlobalDataStatistics =
-        (MapDataStatistics) 
dataStatisticsCoordinator.completedStatistics().dataStatistics();
-    assertThat(checkpoint2GlobalDataStatistics.statistics())
-        .isEqualTo(checkpoint2Subtask0DataStatistic.statistics());
-    waitForCheckpoint(2L, dataStatisticsCoordinator);
+      MapDataStatistics checkpoint2Subtask0DataStatistic = new 
MapDataStatistics();
+      checkpoint2Subtask0DataStatistic.add(keyD);
+      checkpoint2Subtask0DataStatistic.add(keyE);
+      checkpoint2Subtask0DataStatistic.add(keyE);
+      DataStatisticsEvent<MapDataStatistics, Map<SortKey, Long>>
+          checkpoint2Subtask0DataStatisticEvent =
+              DataStatisticsEvent.create(2, checkpoint2Subtask0DataStatistic, 
statisticsSerializer);
+      // Handle events from operators for checkpoint 2
+      coordinator.handleEventFromOperator(0, 0, 
checkpoint2Subtask0DataStatisticEvent);
+      
TestDataStatisticsCoordinator.waitForCoordinatorToProcessActions(dataStatisticsCoordinator);
+      // Verify checkpoint 2 global data statistics
+      MapDataStatistics checkpoint2GlobalDataStatistics =
+          (MapDataStatistics) 
dataStatisticsCoordinator.completedStatistics().dataStatistics();
+      assertThat(checkpoint2GlobalDataStatistics.statistics())
+          .isEqualTo(checkpoint2Subtask0DataStatistic.statistics());
+      waitForCheckpoint(2L, dataStatisticsCoordinator);
 
-    // Reset coordinator to checkpoint 1
-    coordinator.resetToCheckpoint(1L, checkpoint1Bytes);
-    DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>
-        restoredDataStatisticsCoordinator =
-            (DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>>)
-                coordinator.getInternalCoordinator();
-    
assertThat(dataStatisticsCoordinator).isNotEqualTo(restoredDataStatisticsCoordinator);
-    // Verify restored data statistics
-    MapDataStatistics restoredAggregateDataStatistics =
-        (MapDataStatistics)
-            
restoredDataStatisticsCoordinator.completedStatistics().dataStatistics();
-    assertThat(restoredAggregateDataStatistics.statistics())
-        .isEqualTo(checkpoint1GlobalDataStatistics.statistics());
+      // Reset coordinator to checkpoint 1
+      coordinator.resetToCheckpoint(1L, checkpoint1Bytes);
+      DataStatisticsCoordinator<MapDataStatistics, Map<SortKey, Long>>
+          restoredDataStatisticsCoordinator =
+              (DataStatisticsCoordinator<MapDataStatistics, Map<SortKey, 
Long>>)
+                  coordinator.getInternalCoordinator();
+      
assertThat(dataStatisticsCoordinator).isNotEqualTo(restoredDataStatisticsCoordinator);
+      // Verify restored data statistics
+      MapDataStatistics restoredAggregateDataStatistics =
+          (MapDataStatistics)
+              
restoredDataStatisticsCoordinator.completedStatistics().dataStatistics();
+      assertThat(restoredAggregateDataStatistics.statistics())
+          .isEqualTo(checkpoint1GlobalDataStatistics.statistics());
+    }
   }
 
   private byte[] waitForCheckpoint(
       long checkpointId,
-      DataStatisticsCoordinator<MapDataStatistics, Map<RowData, Long>> 
coordinator)
+      DataStatisticsCoordinator<MapDataStatistics, Map<SortKey, Long>> 
coordinator)
       throws InterruptedException, ExecutionException {
     CompletableFuture<byte[]> future = new CompletableFuture<>();
     coordinator.checkpointCoordinator(checkpointId, future);
diff --git 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java
 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java
index 880cb3d551..0e99a2d74c 100644
--- 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java
+++ 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestDataStatisticsOperator.java
@@ -27,7 +27,6 @@ import java.util.stream.Collectors;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.state.OperatorStateStore;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.fs.CloseableRegistry;
 import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
@@ -50,33 +49,37 @@ import 
org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.apache.flink.table.data.GenericRowData;
 import org.apache.flink.table.data.RowData;
 import org.apache.flink.table.data.StringData;
-import org.apache.flink.table.data.binary.BinaryRowData;
 import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
+import org.apache.flink.table.types.logical.IntType;
 import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.table.types.logical.VarCharType;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
 import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
 import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.relocated.com.google.common.collect.Maps;
+import org.apache.iceberg.types.Types;
 import org.junit.After;
 import org.junit.Before;
 import org.junit.Test;
 
 public class TestDataStatisticsOperator {
-  private final RowType rowType = RowType.of(new VarCharType());
+  private final Schema schema =
+      new Schema(
+          Types.NestedField.optional(1, "id", Types.StringType.get()),
+          Types.NestedField.optional(2, "number", Types.IntegerType.get()));
+  private final SortOrder sortOrder = 
SortOrder.builderFor(schema).asc("id").build();
+  private final SortKey sortKey = new SortKey(schema, sortOrder);
+  private final RowType rowType = RowType.of(new VarCharType(), new IntType());
   private final TypeSerializer<RowData> rowSerializer = new 
RowDataSerializer(rowType);
-  private final GenericRowData genericRowDataA = 
GenericRowData.of(StringData.fromString("a"));
-  private final GenericRowData genericRowDataB = 
GenericRowData.of(StringData.fromString("b"));
-  // When operator hands events from coordinator, 
DataStatisticsUtil#deserializeDataStatistics
-  // deserializes bytes into BinaryRowData
-  private final BinaryRowData binaryRowDataA =
-      new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("a")));
-  private final BinaryRowData binaryRowDataB =
-      new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("b")));
-  private final BinaryRowData binaryRowDataC =
-      new 
RowDataSerializer(rowType).toBinaryRow(GenericRowData.of(StringData.fromString("c")));
-  private final TypeSerializer<DataStatistics<MapDataStatistics, Map<RowData, 
Long>>>
-      statisticsSerializer = 
MapDataStatisticsSerializer.fromKeySerializer(rowSerializer);
-  private DataStatisticsOperator<MapDataStatistics, Map<RowData, Long>> 
operator;
+  private final TypeSerializer<DataStatistics<MapDataStatistics, Map<SortKey, 
Long>>>
+      statisticsSerializer =
+          MapDataStatisticsSerializer.fromSortKeySerializer(
+              new SortKeySerializer(schema, sortOrder));
+
+  private DataStatisticsOperator<MapDataStatistics, Map<SortKey, Long>> 
operator;
 
   private Environment getTestingEnvironment() {
     return new StreamMockEnvironment(
@@ -99,20 +102,10 @@ public class TestDataStatisticsOperator {
         new MockOutput<>(Lists.newArrayList()));
   }
 
-  private DataStatisticsOperator<MapDataStatistics, Map<RowData, Long>> 
createOperator() {
+  private DataStatisticsOperator<MapDataStatistics, Map<SortKey, Long>> 
createOperator() {
     MockOperatorEventGateway mockGateway = new MockOperatorEventGateway();
-    KeySelector<RowData, RowData> keySelector =
-        new KeySelector<RowData, RowData>() {
-          private static final long serialVersionUID = 7662520075515707428L;
-
-          @Override
-          public RowData getKey(RowData value) {
-            return value;
-          }
-        };
-
     return new DataStatisticsOperator<>(
-        "testOperator", keySelector, mockGateway, statisticsSerializer);
+        "testOperator", schema, sortOrder, mockGateway, statisticsSerializer);
   }
 
   @After
@@ -123,20 +116,26 @@ public class TestDataStatisticsOperator {
   @Test
   public void testProcessElement() throws Exception {
     try (OneInputStreamOperatorTestHarness<
-            RowData, DataStatisticsOrRecord<MapDataStatistics, Map<RowData, 
Long>>>
+            RowData, DataStatisticsOrRecord<MapDataStatistics, Map<SortKey, 
Long>>>
         testHarness = createHarness(this.operator)) {
       StateInitializationContext stateContext = getStateContext();
       operator.initializeState(stateContext);
-      operator.processElement(new StreamRecord<>(genericRowDataA));
-      operator.processElement(new StreamRecord<>(genericRowDataA));
-      operator.processElement(new StreamRecord<>(genericRowDataB));
+      operator.processElement(new 
StreamRecord<>(GenericRowData.of(StringData.fromString("a"), 5)));
+      operator.processElement(new 
StreamRecord<>(GenericRowData.of(StringData.fromString("a"), 3)));
+      operator.processElement(new 
StreamRecord<>(GenericRowData.of(StringData.fromString("b"), 1)));
       
assertThat(operator.localDataStatistics()).isInstanceOf(MapDataStatistics.class);
+
+      SortKey keyA = sortKey.copy();
+      keyA.set(0, "a");
+      SortKey keyB = sortKey.copy();
+      keyB.set(0, "b");
+      Map<SortKey, Long> expectedMap = ImmutableMap.of(keyA, 2L, keyB, 1L);
+
       MapDataStatistics mapDataStatistics = (MapDataStatistics) 
operator.localDataStatistics();
-      Map<RowData, Long> statsMap = mapDataStatistics.statistics();
+      Map<SortKey, Long> statsMap = mapDataStatistics.statistics();
       assertThat(statsMap).hasSize(2);
-      assertThat(statsMap)
-          .containsExactlyInAnyOrderEntriesOf(
-              ImmutableMap.of(genericRowDataA, 2L, genericRowDataB, 1L));
+      assertThat(statsMap).containsExactlyInAnyOrderEntriesOf(expectedMap);
+
       testHarness.endInput();
     }
   }
@@ -144,11 +143,14 @@ public class TestDataStatisticsOperator {
   @Test
   public void testOperatorOutput() throws Exception {
     try (OneInputStreamOperatorTestHarness<
-            RowData, DataStatisticsOrRecord<MapDataStatistics, Map<RowData, 
Long>>>
+            RowData, DataStatisticsOrRecord<MapDataStatistics, Map<SortKey, 
Long>>>
         testHarness = createHarness(this.operator)) {
-      testHarness.processElement(new StreamRecord<>(genericRowDataA));
-      testHarness.processElement(new StreamRecord<>(genericRowDataB));
-      testHarness.processElement(new StreamRecord<>(genericRowDataB));
+      testHarness.processElement(
+          new StreamRecord<>(GenericRowData.of(StringData.fromString("a"), 
2)));
+      testHarness.processElement(
+          new StreamRecord<>(GenericRowData.of(StringData.fromString("b"), 
3)));
+      testHarness.processElement(
+          new StreamRecord<>(GenericRowData.of(StringData.fromString("b"), 
1)));
 
       List<RowData> recordsOutput =
           testHarness.extractOutputValues().stream()
@@ -157,7 +159,10 @@ public class TestDataStatisticsOperator {
               .collect(Collectors.toList());
       assertThat(recordsOutput)
           .containsExactlyInAnyOrderElementsOf(
-              ImmutableList.of(genericRowDataA, genericRowDataB, 
genericRowDataB));
+              ImmutableList.of(
+                  GenericRowData.of(StringData.fromString("a"), 2),
+                  GenericRowData.of(StringData.fromString("b"), 3),
+                  GenericRowData.of(StringData.fromString("b"), 1)));
     }
   }
 
@@ -165,36 +170,61 @@ public class TestDataStatisticsOperator {
   public void testRestoreState() throws Exception {
     OperatorSubtaskState snapshot;
     try (OneInputStreamOperatorTestHarness<
-            RowData, DataStatisticsOrRecord<MapDataStatistics, Map<RowData, 
Long>>>
+            RowData, DataStatisticsOrRecord<MapDataStatistics, Map<SortKey, 
Long>>>
         testHarness1 = createHarness(this.operator)) {
-      DataStatistics<MapDataStatistics, Map<RowData, Long>> mapDataStatistics =
-          new MapDataStatistics();
-      mapDataStatistics.add(binaryRowDataA);
-      mapDataStatistics.add(binaryRowDataA);
-      mapDataStatistics.add(binaryRowDataB);
-      mapDataStatistics.add(binaryRowDataC);
-      operator.handleOperatorEvent(
-          DataStatisticsEvent.create(0, mapDataStatistics, 
statisticsSerializer));
+      MapDataStatistics mapDataStatistics = new MapDataStatistics();
+
+      SortKey key = sortKey.copy();
+      key.set(0, "a");
+      mapDataStatistics.add(key);
+      key.set(0, "a");
+      mapDataStatistics.add(key);
+      key.set(0, "b");
+      mapDataStatistics.add(key);
+      key.set(0, "c");
+      mapDataStatistics.add(key);
+
+      SortKey keyA = sortKey.copy();
+      keyA.set(0, "a");
+      SortKey keyB = sortKey.copy();
+      keyB.set(0, "b");
+      SortKey keyC = sortKey.copy();
+      keyC.set(0, "c");
+      Map<SortKey, Long> expectedMap = ImmutableMap.of(keyA, 2L, keyB, 1L, 
keyC, 1L);
+
+      DataStatisticsEvent<MapDataStatistics, Map<SortKey, Long>> event =
+          DataStatisticsEvent.create(0, mapDataStatistics, 
statisticsSerializer);
+      operator.handleOperatorEvent(event);
       
assertThat(operator.globalDataStatistics()).isInstanceOf(MapDataStatistics.class);
       assertThat(operator.globalDataStatistics().statistics())
-          .containsExactlyInAnyOrderEntriesOf(
-              ImmutableMap.of(binaryRowDataA, 2L, binaryRowDataB, 1L, 
binaryRowDataC, 1L));
+          .containsExactlyInAnyOrderEntriesOf(expectedMap);
       snapshot = testHarness1.snapshot(1L, 0);
     }
 
     // Use the snapshot to initialize state for another new operator and then 
verify that the global
     // statistics for the new operator is same as before
-    DataStatisticsOperator<MapDataStatistics, Map<RowData, Long>> 
restoredOperator =
+    DataStatisticsOperator<MapDataStatistics, Map<SortKey, Long>> 
restoredOperator =
         createOperator();
     try (OneInputStreamOperatorTestHarness<
-            RowData, DataStatisticsOrRecord<MapDataStatistics, Map<RowData, 
Long>>>
+            RowData, DataStatisticsOrRecord<MapDataStatistics, Map<SortKey, 
Long>>>
         testHarness2 = new 
OneInputStreamOperatorTestHarness<>(restoredOperator, 2, 2, 1)) {
       testHarness2.setup();
       testHarness2.initializeState(snapshot);
       
assertThat(restoredOperator.globalDataStatistics()).isInstanceOf(MapDataStatistics.class);
-      assertThat(restoredOperator.globalDataStatistics().statistics())
-          .containsExactlyInAnyOrderEntriesOf(
-              ImmutableMap.of(binaryRowDataA, 2L, binaryRowDataB, 1L, 
binaryRowDataC, 1L));
+
+      // restored RowData is BinaryRowData. convert to GenericRowData for 
comparison
+      Map<SortKey, Long> restoredStatistics = Maps.newHashMap();
+      
restoredStatistics.putAll(restoredOperator.globalDataStatistics().statistics());
+
+      SortKey keyA = sortKey.copy();
+      keyA.set(0, "a");
+      SortKey keyB = sortKey.copy();
+      keyB.set(0, "b");
+      SortKey keyC = sortKey.copy();
+      keyC.set(0, "c");
+      Map<SortKey, Long> expectedMap = ImmutableMap.of(keyA, 2L, keyB, 1L, 
keyC, 1L);
+
+      
assertThat(restoredStatistics).containsExactlyInAnyOrderEntriesOf(expectedMap);
     }
   }
 
@@ -209,18 +239,16 @@ public class TestDataStatisticsOperator {
   }
 
   private OneInputStreamOperatorTestHarness<
-          RowData, DataStatisticsOrRecord<MapDataStatistics, Map<RowData, 
Long>>>
+          RowData, DataStatisticsOrRecord<MapDataStatistics, Map<SortKey, 
Long>>>
       createHarness(
-          final DataStatisticsOperator<MapDataStatistics, Map<RowData, Long>>
+          final DataStatisticsOperator<MapDataStatistics, Map<SortKey, Long>>
               dataStatisticsOperator)
           throws Exception {
 
     OneInputStreamOperatorTestHarness<
-            RowData, DataStatisticsOrRecord<MapDataStatistics, Map<RowData, 
Long>>>
+            RowData, DataStatisticsOrRecord<MapDataStatistics, Map<SortKey, 
Long>>>
         harness = new 
OneInputStreamOperatorTestHarness<>(dataStatisticsOperator, 1, 1, 0);
-    harness.setup(
-        new DataStatisticsOrRecordSerializer<>(
-            MapDataStatisticsSerializer.fromKeySerializer(rowSerializer), 
rowSerializer));
+    harness.setup(new DataStatisticsOrRecordSerializer<>(statisticsSerializer, 
rowSerializer));
     harness.open();
     return harness;
   }
diff --git 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapDataStatistics.java
 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapDataStatistics.java
new file mode 100644
index 0000000000..a07808e935
--- /dev/null
+++ 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestMapDataStatistics.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import java.util.Map;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.flink.table.data.StringData;
+import org.apache.flink.table.types.logical.RowType;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.flink.FlinkSchemaUtil;
+import org.apache.iceberg.flink.RowDataWrapper;
+import org.apache.iceberg.flink.TestFixtures;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.assertj.core.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class TestMapDataStatistics {
+  private final SortOrder sortOrder = 
SortOrder.builderFor(TestFixtures.SCHEMA).asc("data").build();
+  private final SortKey sortKey = new SortKey(TestFixtures.SCHEMA, sortOrder);
+  private final RowType rowType = FlinkSchemaUtil.convert(TestFixtures.SCHEMA);
+  private final RowDataWrapper rowWrapper =
+      new RowDataWrapper(rowType, TestFixtures.SCHEMA.asStruct());
+
+  @Test
+  public void testAddsAndGet() {
+    MapDataStatistics dataStatistics = new MapDataStatistics();
+
+    GenericRowData reusedRow =
+        GenericRowData.of(StringData.fromString("a"), 1, 
StringData.fromString("2023-06-20"));
+    sortKey.wrap(rowWrapper.wrap(reusedRow));
+    dataStatistics.add(sortKey);
+
+    reusedRow.setField(0, StringData.fromString("b"));
+    sortKey.wrap(rowWrapper.wrap(reusedRow));
+    dataStatistics.add(sortKey);
+
+    reusedRow.setField(0, StringData.fromString("c"));
+    sortKey.wrap(rowWrapper.wrap(reusedRow));
+    dataStatistics.add(sortKey);
+
+    reusedRow.setField(0, StringData.fromString("b"));
+    sortKey.wrap(rowWrapper.wrap(reusedRow));
+    dataStatistics.add(sortKey);
+
+    reusedRow.setField(0, StringData.fromString("a"));
+    sortKey.wrap(rowWrapper.wrap(reusedRow));
+    dataStatistics.add(sortKey);
+
+    reusedRow.setField(0, StringData.fromString("b"));
+    sortKey.wrap(rowWrapper.wrap(reusedRow));
+    dataStatistics.add(sortKey);
+
+    Map<SortKey, Long> actual = dataStatistics.statistics();
+
+    rowWrapper.wrap(
+        GenericRowData.of(StringData.fromString("a"), 1, 
StringData.fromString("2023-06-20")));
+    sortKey.wrap(rowWrapper);
+    SortKey keyA = sortKey.copy();
+
+    rowWrapper.wrap(
+        GenericRowData.of(StringData.fromString("b"), 1, 
StringData.fromString("2023-06-20")));
+    sortKey.wrap(rowWrapper);
+    SortKey keyB = sortKey.copy();
+
+    rowWrapper.wrap(
+        GenericRowData.of(StringData.fromString("c"), 1, 
StringData.fromString("2023-06-20")));
+    sortKey.wrap(rowWrapper);
+    SortKey keyC = sortKey.copy();
+
+    Map<SortKey, Long> expected = ImmutableMap.of(keyA, 2L, keyB, 3L, keyC, 
1L);
+    Assertions.assertThat(actual).isEqualTo(expected);
+  }
+}
diff --git 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerBase.java
 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerBase.java
new file mode 100644
index 0000000000..c7fea01514
--- /dev/null
+++ 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerBase.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import org.apache.flink.api.common.typeutils.SerializerTestBase;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortKey;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.flink.FlinkSchemaUtil;
+import org.apache.iceberg.flink.RowDataWrapper;
+
+public abstract class TestSortKeySerializerBase extends 
SerializerTestBase<SortKey> {
+
+  protected abstract Schema schema();
+
+  protected abstract SortOrder sortOrder();
+
+  protected abstract GenericRowData rowData();
+
+  @Override
+  protected TypeSerializer<SortKey> createSerializer() {
+    return new SortKeySerializer(schema(), sortOrder());
+  }
+
+  @Override
+  protected int getLength() {
+    return -1;
+  }
+
+  @Override
+  protected Class<SortKey> getTypeClass() {
+    return SortKey.class;
+  }
+
+  @Override
+  protected SortKey[] getTestData() {
+    return new SortKey[] {sortKey()};
+  }
+
+  private SortKey sortKey() {
+    RowDataWrapper rowDataWrapper =
+        new RowDataWrapper(FlinkSchemaUtil.convert(schema()), 
schema().asStruct());
+    SortKey sortKey = new SortKey(schema(), sortOrder());
+    sortKey.wrap(rowDataWrapper.wrap(rowData()));
+    return sortKey;
+  }
+}
diff --git 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerNestedStruct.java
 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerNestedStruct.java
new file mode 100644
index 0000000000..0000688a8b
--- /dev/null
+++ 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerNestedStruct.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.iceberg.NullOrder;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortDirection;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.expressions.Expressions;
+import org.apache.iceberg.flink.DataGenerator;
+import org.apache.iceberg.flink.DataGenerators;
+
+public class TestSortKeySerializerNestedStruct extends 
TestSortKeySerializerBase {
+  private final DataGenerator generator = new DataGenerators.StructOfStruct();
+
+  @Override
+  protected Schema schema() {
+    return generator.icebergSchema();
+  }
+
+  @Override
+  protected SortOrder sortOrder() {
+    return SortOrder.builderFor(schema())
+        .asc("row_id")
+        .sortBy(
+            Expressions.bucket("struct_of_struct.id", 4), SortDirection.DESC, 
NullOrder.NULLS_LAST)
+        .sortBy(
+            Expressions.truncate("struct_of_struct.person_struct.name", 16),
+            SortDirection.ASC,
+            NullOrder.NULLS_FIRST)
+        .build();
+  }
+
+  @Override
+  protected GenericRowData rowData() {
+    return generator.generateFlinkRowData();
+  }
+}
diff --git 
a/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerPrimitives.java
 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerPrimitives.java
new file mode 100644
index 0000000000..291302aef4
--- /dev/null
+++ 
b/flink/v1.16/flink/src/test/java/org/apache/iceberg/flink/sink/shuffle/TestSortKeySerializerPrimitives.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.iceberg.flink.sink.shuffle;
+
+import org.apache.flink.table.data.GenericRowData;
+import org.apache.iceberg.NullOrder;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.SortDirection;
+import org.apache.iceberg.SortOrder;
+import org.apache.iceberg.expressions.Expressions;
+import org.apache.iceberg.flink.DataGenerator;
+import org.apache.iceberg.flink.DataGenerators;
+
+public class TestSortKeySerializerPrimitives extends TestSortKeySerializerBase 
{
+  private final DataGenerator generator = new DataGenerators.Primitives();
+
+  @Override
+  protected Schema schema() {
+    return generator.icebergSchema();
+  }
+
+  @Override
+  protected SortOrder sortOrder() {
+    return SortOrder.builderFor(schema())
+        .asc("boolean_field")
+        .sortBy(Expressions.bucket("int_field", 4), SortDirection.DESC, 
NullOrder.NULLS_LAST)
+        .sortBy(Expressions.truncate("string_field", 2), SortDirection.ASC, 
NullOrder.NULLS_FIRST)
+        .sortBy(Expressions.bucket("uuid_field", 16), SortDirection.ASC, 
NullOrder.NULLS_FIRST)
+        .sortBy(Expressions.hour("ts_with_zone_field"), SortDirection.ASC, 
NullOrder.NULLS_FIRST)
+        .sortBy(Expressions.day("ts_without_zone_field"), SortDirection.ASC, 
NullOrder.NULLS_FIRST)
+        // can not test HeapByteBuffer due to equality test inside 
SerializerTestBase
+        // .sortBy(Expressions.truncate("binary_field", 2), SortDirection.ASC,
+        // NullOrder.NULLS_FIRST)
+        .build();
+  }
+
+  @Override
+  protected GenericRowData rowData() {
+    return generator.generateFlinkRowData();
+  }
+}


Reply via email to