This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new d7c9c8b5 [FLINK-31753] Support DataStream CoGroup in stream mode with 
similar performance as DataSet CoGroup
d7c9c8b5 is described below

commit d7c9c8b5242a3c161d430a03fc4e4c3b0d1d78ff
Author: Dong Lin <[email protected]>
AuthorDate: Wed Apr 12 16:43:12 2023 +0800

    [FLINK-31753] Support DataStream CoGroup in stream mode with similar 
performance as DataSet CoGroup
    
    This closes #230.
---
 .../ml/common/datastream/DataStreamUtils.java      |  65 ++++-
 .../datastream/sort/BytesKeyNormalizationUtil.java |  84 ++++++
 .../ml/common/datastream/sort/CoGroupOperator.java | 314 +++++++++++++++++++++
 .../sort/FixedLengthByteKeyComparator.java         | 188 ++++++++++++
 .../datastream/sort/KeyAndValueSerializer.java     | 189 +++++++++++++
 .../sort/VariableLengthByteKeyComparator.java      | 193 +++++++++++++
 .../ml/common/datastream/DataStreamUtilsTest.java  |  94 ++++++
 .../apache/flink/ml/clustering/kmeans/KMeans.java  |   2 +-
 8 files changed, 1123 insertions(+), 6 deletions(-)

diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
index eb4ec6ca..e4cbcd52 100644
--- 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -20,6 +20,7 @@ package org.apache.flink.ml.common.datastream;
 
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.CoGroupFunction;
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.MapPartitionFunction;
@@ -32,12 +33,13 @@ import org.apache.flink.api.common.time.Time;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
-import org.apache.flink.api.dag.Transformation;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.core.memory.ManagedMemoryUseCase;
 import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache;
 import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.sort.CoGroupOperator;
 import org.apache.flink.ml.common.window.CountTumblingWindows;
 import org.apache.flink.ml.common.window.EventTimeSessionWindows;
 import org.apache.flink.ml.common.window.EventTimeTumblingWindows;
@@ -75,6 +77,7 @@ import org.apache.flink.util.Collector;
 
 import org.apache.commons.collections.IteratorUtils;
 
+import java.io.Serializable;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
@@ -133,6 +136,7 @@ public class DataStreamUtils {
             DataStream<IN> input,
             MapPartitionFunction<IN, OUT> func,
             TypeInformation<OUT> outType) {
+        func = input.getExecutionEnvironment().clean(func);
         return input.transform("mapPartition", outType, new 
MapPartitionOperator<>(func))
                 .setParallelism(input.getParallelism());
     }
@@ -162,6 +166,7 @@ public class DataStreamUtils {
      */
     public static <T> DataStream<T> reduce(
             DataStream<T> input, ReduceFunction<T> func, TypeInformation<T> 
outType) {
+        func = input.getExecutionEnvironment().clean(func);
         DataStream<T> partialReducedStream =
                 input.transform("reduce", outType, new ReduceOperator<>(func))
                         .setParallelism(input.getParallelism());
@@ -201,6 +206,7 @@ public class DataStreamUtils {
      */
     public static <T, K> DataStream<T> reduce(
             KeyedStream<T, K> input, ReduceFunction<T> func, 
TypeInformation<T> outType) {
+        func = input.getExecutionEnvironment().clean(func);
         return input.transform(
                         "Keyed Reduce",
                         outType,
@@ -263,6 +269,7 @@ public class DataStreamUtils {
             AggregateFunction<IN, ACC, OUT> func,
             TypeInformation<ACC> accType,
             TypeInformation<OUT> outType) {
+        func = input.getExecutionEnvironment().clean(func);
         DataStream<ACC> partialAggregatedStream =
                 input.transform(
                         "partialAggregate", accType, new 
PartialAggregateOperator<>(func, accType));
@@ -319,13 +326,14 @@ public class DataStreamUtils {
      * bytes should be in the same scale as existing usage in Flink, for 
example,
      * StreamExecWindowAggregate.WINDOW_AGG_MEMORY_RATIO.
      */
-    public static <T> void setManagedMemoryWeight(
-            Transformation<T> transformation, long memoryBytes) {
+    public static <T> void setManagedMemoryWeight(DataStream<T> dataStream, 
long memoryBytes) {
         if (memoryBytes > 0) {
             final int weightInMebibyte = Math.max(1, (int) (memoryBytes >> 
20));
             final Optional<Integer> previousWeight =
-                    transformation.declareManagedMemoryUseCaseAtOperatorScope(
-                            ManagedMemoryUseCase.OPERATOR, weightInMebibyte);
+                    dataStream
+                            .getTransformation()
+                            .declareManagedMemoryUseCaseAtOperatorScope(
+                                    ManagedMemoryUseCase.OPERATOR, 
weightInMebibyte);
             if (previousWeight.isPresent()) {
                 throw new TableException(
                         "Managed memory weight has been set, this should not 
happen.");
@@ -345,6 +353,7 @@ public class DataStreamUtils {
      */
     public static <IN, OUT, W extends Window> SingleOutputStreamOperator<OUT> 
windowAllAndProcess(
             DataStream<IN> input, Windows windows, 
ProcessAllWindowFunction<IN, OUT, W> function) {
+        function = input.getExecutionEnvironment().clean(function);
         AllWindowedStream<IN, W> allWindowedStream = 
getAllWindowedStream(input, windows);
         return allWindowedStream.process(function);
     }
@@ -365,10 +374,56 @@ public class DataStreamUtils {
             Windows windows,
             ProcessAllWindowFunction<IN, OUT, W> function,
             TypeInformation<OUT> outType) {
+        function = input.getExecutionEnvironment().clean(function);
         AllWindowedStream<IN, W> allWindowedStream = 
getAllWindowedStream(input, windows);
         return allWindowedStream.process(function, outType);
     }
 
+    /**
+     * A CoGroup transformation combines the elements of two {@link DataStream 
DataStreams} into one
+     * DataStream. It groups each DataStream individually on a key and gives 
groups of both
+     * DataStreams with equal keys together into a {@link
+     * org.apache.flink.api.common.functions.CoGroupFunction}. If a DataStream 
has a group with no
+     * matching key in the other DataStream, the CoGroupFunction is called 
with an empty group for
+     * the non-existing group.
+     *
+     * <p>The CoGroupFunction can iterate over the elements of both groups and 
return any number of
+     * elements including none.
+     *
+     * <p>NOTE: This method assumes both inputs are bounded.
+     *
+     * @param input1 The first data stream.
+     * @param input2 The second data stream.
+     * @param keySelector1 The KeySelector to be used for extracting the first 
input's key for
+     *     partitioning.
+     * @param keySelector2 The KeySelector to be used for extracting the 
second input's key for
+     *     partitioning.
+     * @param outTypeInformation The type information describing the output 
type.
+     * @param func The user-defined co-group function.
+     * @param <IN1> The class type of the first input.
+     * @param <IN2> The class type of the second input.
+     * @param <KEY> The class type of the key.
+     * @param <OUT> The class type of the output values.
+     * @return The result data stream.
+     */
+    public static <IN1, IN2, KEY extends Serializable, OUT> DataStream<OUT> 
coGroup(
+            DataStream<IN1> input1,
+            DataStream<IN2> input2,
+            KeySelector<IN1, KEY> keySelector1,
+            KeySelector<IN2, KEY> keySelector2,
+            TypeInformation<OUT> outTypeInformation,
+            CoGroupFunction<IN1, IN2, OUT> func) {
+        func = input1.getExecutionEnvironment().clean(func);
+        DataStream<OUT> result =
+                input1.connect(input2)
+                        .keyBy(keySelector1, keySelector2)
+                        .transform(
+                                "CoGroupOperator", outTypeInformation, new 
CoGroupOperator<>(func))
+                        .setParallelism(Math.max(input1.getParallelism(), 
input2.getParallelism()));
+        setManagedMemoryWeight(result, 100);
+        return result;
+    }
+
     @SuppressWarnings({"rawtypes", "unchecked"})
     private static <IN, W extends Window> AllWindowedStream<IN, W> 
getAllWindowedStream(
             DataStream<IN> input, Windows windows) {
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/BytesKeyNormalizationUtil.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/BytesKeyNormalizationUtil.java
new file mode 100644
index 00000000..90be2616
--- /dev/null
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/BytesKeyNormalizationUtil.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream.sort;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+/**
+ * Utility class for common key normalization used both in {@link 
VariableLengthByteKeyComparator}
+ * and {@link FixedLengthByteKeyComparator}.
+ *
+ * <p>TODO: remove this class after making the corresponding class in Flink 
public.
+ */
+final class BytesKeyNormalizationUtil {
+    /**
+     * Writes the normalized key of given record. The normalized key consists 
of the key serialized
+     * as bytes and the timestamp of the record.
+     *
+     * <p>NOTE: The key does not represent a logical order. It can be used 
only for grouping keys!
+     */
+    static <IN> void putNormalizedKey(
+            Tuple2<byte[], StreamRecord<IN>> record,
+            int dataLength,
+            MemorySegment target,
+            int offset,
+            int numBytes) {
+        byte[] data = record.f0;
+
+        if (dataLength >= numBytes) {
+            putBytesArray(target, offset, numBytes, data);
+        } else {
+            // whole key fits into the normalized key
+            putBytesArray(target, offset, dataLength, data);
+            int lastOffset = offset + numBytes;
+            offset += dataLength;
+            long valueOfTimestamp = record.f1.asRecord().getTimestamp() - 
Long.MIN_VALUE;
+            if (dataLength + FixedLengthByteKeyComparator.TIMESTAMP_BYTE_SIZE 
<= numBytes) {
+                // whole timestamp fits into the normalized key
+                target.putLong(offset, valueOfTimestamp);
+                offset += FixedLengthByteKeyComparator.TIMESTAMP_BYTE_SIZE;
+                // fill in the remaining space with zeros
+                while (offset < lastOffset) {
+                    target.put(offset++, (byte) 0);
+                }
+            } else {
+                // only part of the timestamp fits into normalized key
+                for (int i = 0; offset < lastOffset; offset++, i++) {
+                    target.put(offset, (byte) (valueOfTimestamp >>> ((7 - i) 
<< 3)));
+                }
+            }
+        }
+    }
+
+    private static void putBytesArray(MemorySegment target, int offset, int 
numBytes, byte[] data) {
+        for (int i = 0; i < numBytes; i++) {
+            // We're converting the signed byte in data into an unsigned 
representation.
+            // A Java byte goes from -128 to 127, i.e. is signed. By 
subtracting -128 (MIN_VALUE)
+            // here we're shifting the number to be from 0 to 255. The 
normalized key sorter sorts
+            // bytes as "unsigned", so we need to convert here to maintain a 
correct ordering.
+            int highByte = data[i] & 0xff;
+            highByte -= Byte.MIN_VALUE;
+            target.put(offset + i, (byte) highByte);
+        }
+    }
+
+    private BytesKeyNormalizationUtil() {}
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/CoGroupOperator.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/CoGroupOperator.java
new file mode 100644
index 00000000..8c7e2f1f
--- /dev/null
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/CoGroupOperator.java
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream.sort;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.functions.CoGroupFunction;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.common.typeutils.TypePairComparator;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import 
org.apache.flink.api.java.typeutils.runtime.RuntimePairComparatorFactory;
+import org.apache.flink.configuration.AlgorithmOptions;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.operators.sort.ExternalSorter;
+import 
org.apache.flink.runtime.operators.sort.NonReusingSortMergeCoGroupIterator;
+import org.apache.flink.runtime.operators.sort.PushSorter;
+import org.apache.flink.runtime.operators.sort.ReusingSortMergeCoGroupIterator;
+import org.apache.flink.runtime.operators.util.CoGroupTaskIterator;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.util.MutableObjectIterator;
+import org.apache.flink.util.TraversableOnceException;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * An operator that implements the co-group logic.
+ *
+ * @param <IN1> The class type of the first input.
+ * @param <IN2> The class type of the second input.
+ * @param <KEY> The class type of the key.
+ * @param <OUT> The class type of the output values.
+ */
+public class CoGroupOperator<IN1, IN2, KEY extends Serializable, OUT>
+        extends AbstractUdfStreamOperator<OUT, CoGroupFunction<IN1, IN2, OUT>>
+        implements TwoInputStreamOperator<IN1, IN2, OUT>, BoundedMultiInput {
+
+    private PushSorter<Tuple2<byte[], StreamRecord<IN1>>> sorterA;
+    private PushSorter<Tuple2<byte[], StreamRecord<IN2>>> sorterB;
+    private TypeComparator<Tuple2<byte[], StreamRecord<IN1>>> comparatorA;
+    private TypeComparator<Tuple2<byte[], StreamRecord<IN2>>> comparatorB;
+    private KeySelector<IN1, KEY> keySelectorA;
+    private KeySelector<IN2, KEY> keySelectorB;
+    private TypeSerializer<Tuple2<byte[], StreamRecord<IN1>>> 
keyAndValueSerializerA;
+    private TypeSerializer<Tuple2<byte[], StreamRecord<IN2>>> 
keyAndValueSerializerB;
+    private TypeSerializer<KEY> keySerializer;
+    private DataOutputSerializer dataOutputSerializer;
+    private long lastWatermarkTimestamp = Long.MIN_VALUE;
+    private int remainingInputNum = 2;
+
+    public CoGroupOperator(CoGroupFunction<IN1, IN2, OUT> function) {
+        super(function);
+    }
+
+    @Override
+    public void setup(
+            StreamTask<?, ?> containingTask,
+            StreamConfig config,
+            Output<StreamRecord<OUT>> output) {
+        super.setup(containingTask, config, output);
+        ClassLoader userCodeClassLoader = 
containingTask.getUserCodeClassLoader();
+        MemoryManager memoryManager = 
containingTask.getEnvironment().getMemoryManager();
+        IOManager ioManager = containingTask.getEnvironment().getIOManager();
+
+        keySelectorA = config.getStatePartitioner(0, userCodeClassLoader);
+        keySelectorB = config.getStatePartitioner(1, userCodeClassLoader);
+        keySerializer = config.getStateKeySerializer(userCodeClassLoader);
+        int keyLength = keySerializer.getLength();
+
+        TypeSerializer<IN1> typeSerializerA = config.getTypeSerializerIn(0, 
userCodeClassLoader);
+        TypeSerializer<IN2> typeSerializerB = config.getTypeSerializerIn(1, 
userCodeClassLoader);
+        keyAndValueSerializerA = new KeyAndValueSerializer<>(typeSerializerA, 
keyLength);
+        keyAndValueSerializerB = new KeyAndValueSerializer<>(typeSerializerB, 
keyLength);
+
+        if (keyLength > 0) {
+            dataOutputSerializer = new DataOutputSerializer(keyLength);
+            comparatorA = new FixedLengthByteKeyComparator<>(keyLength);
+            comparatorB = new FixedLengthByteKeyComparator<>(keyLength);
+        } else {
+            dataOutputSerializer = new DataOutputSerializer(64);
+            comparatorA = new VariableLengthByteKeyComparator<>();
+            comparatorB = new VariableLengthByteKeyComparator<>();
+        }
+
+        ExecutionConfig executionConfig = 
containingTask.getEnvironment().getExecutionConfig();
+        double managedMemoryFraction =
+                config.getManagedMemoryFractionOperatorUseCaseOfSlot(
+                                ManagedMemoryUseCase.OPERATOR,
+                                
containingTask.getEnvironment().getTaskConfiguration(),
+                                userCodeClassLoader)
+                        / 2;
+        Configuration jobConfiguration = 
containingTask.getEnvironment().getJobConfiguration();
+
+        try {
+            sorterA =
+                    ExternalSorter.newBuilder(
+                                    memoryManager,
+                                    containingTask,
+                                    keyAndValueSerializerA,
+                                    comparatorA,
+                                    executionConfig)
+                            .memoryFraction(managedMemoryFraction)
+                            .enableSpilling(
+                                    ioManager,
+                                    
jobConfiguration.get(AlgorithmOptions.SORT_SPILLING_THRESHOLD))
+                            .maxNumFileHandles(
+                                    
jobConfiguration.get(AlgorithmOptions.SPILLING_MAX_FAN))
+                            
.objectReuse(executionConfig.isObjectReuseEnabled())
+                            .largeRecords(
+                                    jobConfiguration.get(
+                                            
AlgorithmOptions.USE_LARGE_RECORDS_HANDLER))
+                            .build();
+            sorterB =
+                    ExternalSorter.newBuilder(
+                                    memoryManager,
+                                    containingTask,
+                                    keyAndValueSerializerB,
+                                    comparatorB,
+                                    executionConfig)
+                            .memoryFraction(managedMemoryFraction)
+                            .enableSpilling(
+                                    ioManager,
+                                    
jobConfiguration.get(AlgorithmOptions.SORT_SPILLING_THRESHOLD))
+                            .maxNumFileHandles(
+                                    
jobConfiguration.get(AlgorithmOptions.SPILLING_MAX_FAN))
+                            
.objectReuse(executionConfig.isObjectReuseEnabled())
+                            .largeRecords(
+                                    jobConfiguration.get(
+                                            
AlgorithmOptions.USE_LARGE_RECORDS_HANDLER))
+                            .build();
+        } catch (MemoryAllocationException e) {
+            throw new RuntimeException(e);
+        }
+    }
+
+    @Override
+    public void endInput(int inputId) throws Exception {
+        if (inputId == 1) {
+            sorterA.finishReading();
+            remainingInputNum--;
+        } else if (inputId == 2) {
+            sorterB.finishReading();
+            remainingInputNum--;
+        } else {
+            throw new RuntimeException("Unknown inputId " + inputId);
+        }
+
+        if (remainingInputNum > 0) {
+            return;
+        }
+
+        MutableObjectIterator<Tuple2<byte[], StreamRecord<IN1>>> iteratorA = 
sorterA.getIterator();
+        MutableObjectIterator<Tuple2<byte[], StreamRecord<IN2>>> iteratorB = 
sorterB.getIterator();
+        TypePairComparator<Tuple2<byte[], StreamRecord<IN1>>, Tuple2<byte[], 
StreamRecord<IN2>>>
+                pairComparator =
+                        (new RuntimePairComparatorFactory<
+                                        Tuple2<byte[], StreamRecord<IN1>>,
+                                        Tuple2<byte[], StreamRecord<IN2>>>())
+                                .createComparator12(comparatorA, comparatorB);
+
+        CoGroupTaskIterator<Tuple2<byte[], StreamRecord<IN1>>, Tuple2<byte[], 
StreamRecord<IN2>>>
+                coGroupIterator;
+        if (getExecutionConfig().isObjectReuseEnabled()) {
+            coGroupIterator =
+                    new ReusingSortMergeCoGroupIterator<
+                            Tuple2<byte[], StreamRecord<IN1>>, Tuple2<byte[], 
StreamRecord<IN2>>>(
+                            iteratorA,
+                            iteratorB,
+                            keyAndValueSerializerA,
+                            comparatorA,
+                            keyAndValueSerializerB,
+                            comparatorB,
+                            pairComparator);
+        } else {
+            coGroupIterator =
+                    new NonReusingSortMergeCoGroupIterator<
+                            Tuple2<byte[], StreamRecord<IN1>>, Tuple2<byte[], 
StreamRecord<IN2>>>(
+                            iteratorA,
+                            iteratorB,
+                            keyAndValueSerializerA,
+                            comparatorA,
+                            keyAndValueSerializerB,
+                            comparatorB,
+                            pairComparator);
+        }
+
+        coGroupIterator.open();
+        TupleUnwrappingIterator<IN1, byte[]> unWrappediteratorA = new 
TupleUnwrappingIterator<>();
+        TupleUnwrappingIterator<IN2, byte[]> unWrappediteratorB = new 
TupleUnwrappingIterator<>();
+
+        Output<OUT> timestampedCollector = new TimestampedCollector<>(output);
+        while (coGroupIterator.next()) {
+            unWrappediteratorA.set(coGroupIterator.getValues1().iterator());
+            unWrappediteratorB.set(coGroupIterator.getValues2().iterator());
+            userFunction.coGroup(unWrappediteratorA, unWrappediteratorB, 
timestampedCollector);
+        }
+        coGroupIterator.close();
+
+        Watermark watermark = new Watermark(lastWatermarkTimestamp);
+        if (getTimeServiceManager().isPresent()) {
+            getTimeServiceManager().get().advanceWatermark(watermark);
+        }
+        output.emitWatermark(watermark);
+    }
+
+    @Override
+    public void processWatermark(Watermark watermark) throws Exception {
+        if (lastWatermarkTimestamp > watermark.getTimestamp()) {
+            throw new RuntimeException("Invalid watermark");
+        }
+        lastWatermarkTimestamp = watermark.getTimestamp();
+    }
+
+    @Override
+    public void close() throws Exception {
+        super.close();
+        sorterA.close();
+        sorterB.close();
+    }
+
+    @Override
+    public void processElement1(StreamRecord<IN1> streamRecord) throws 
Exception {
+        KEY key = keySelectorA.getKey(streamRecord.getValue());
+        keySerializer.serialize(key, dataOutputSerializer);
+        byte[] serializedKey = dataOutputSerializer.getCopyOfBuffer();
+        dataOutputSerializer.clear();
+        sorterA.writeRecord(Tuple2.of(serializedKey, streamRecord));
+    }
+
+    @Override
+    public void processElement2(StreamRecord<IN2> streamRecord) throws 
Exception {
+        KEY key = keySelectorB.getKey(streamRecord.getValue());
+        keySerializer.serialize(key, dataOutputSerializer);
+        byte[] serializedKey = dataOutputSerializer.getCopyOfBuffer();
+        dataOutputSerializer.clear();
+        sorterB.writeRecord(Tuple2.of(serializedKey, streamRecord));
+    }
+
+    private static class TupleUnwrappingIterator<T, K>
+            implements Iterator<T>, Iterable<T>, java.io.Serializable {
+
+        private static final long serialVersionUID = 1L;
+
+        private K lastKey;
+        private Iterator<Tuple2<K, StreamRecord<T>>> iterator;
+        private boolean iteratorAvailable;
+
+        public void set(Iterator<Tuple2<K, StreamRecord<T>>> iterator) {
+            this.iterator = iterator;
+            this.iteratorAvailable = true;
+        }
+
+        public K getLastKey() {
+            return lastKey;
+        }
+
+        @Override
+        public boolean hasNext() {
+            return iterator.hasNext();
+        }
+
+        @Override
+        public T next() {
+            Tuple2<K, StreamRecord<T>> t = iterator.next();
+            this.lastKey = t.f0;
+            return t.f1.getValue();
+        }
+
+        @Override
+        public void remove() {
+            throw new UnsupportedOperationException();
+        }
+
+        @Override
+        public Iterator<T> iterator() {
+            if (iteratorAvailable) {
+                iteratorAvailable = false;
+                return this;
+            } else {
+                throw new TraversableOnceException();
+            }
+        }
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/FixedLengthByteKeyComparator.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/FixedLengthByteKeyComparator.java
new file mode 100644
index 00000000..859e193f
--- /dev/null
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/FixedLengthByteKeyComparator.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream.sort;
+
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.streaming.api.operators.sort.SortingDataInput;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * A comparator used in {@link SortingDataInput} which compares records keys 
and timestamps. It uses
+ * binary format produced by the {@link KeyAndValueSerializer}.
+ *
+ * <p>It assumes keys are always of a fixed length and thus the length of the 
record is not
+ * serialized.
+ *
+ * <p>TODO: remove this class after making the corresponding class in Flink 
public.
+ */
+public final class FixedLengthByteKeyComparator<IN>
+        extends TypeComparator<Tuple2<byte[], StreamRecord<IN>>> {
+    static final int TIMESTAMP_BYTE_SIZE = 8;
+    private final int keyLength;
+    private byte[] keyReference;
+    private long timestampReference;
+
+    public FixedLengthByteKeyComparator(int keyLength) {
+        this.keyLength = keyLength;
+    }
+
+    @Override
+    public int hash(Tuple2<byte[], StreamRecord<IN>> record) {
+        return record.hashCode();
+    }
+
+    @Override
+    public void setReference(Tuple2<byte[], StreamRecord<IN>> toCompare) {
+        this.keyReference = toCompare.f0;
+        this.timestampReference = toCompare.f1.asRecord().getTimestamp();
+    }
+
+    @Override
+    public boolean equalToReference(Tuple2<byte[], StreamRecord<IN>> 
candidate) {
+        return Arrays.equals(keyReference, candidate.f0)
+                && timestampReference == 
candidate.f1.asRecord().getTimestamp();
+    }
+
+    @Override
+    public int compareToReference(
+            TypeComparator<Tuple2<byte[], StreamRecord<IN>>> 
referencedComparator) {
+        byte[] otherKey = ((FixedLengthByteKeyComparator<IN>) 
referencedComparator).keyReference;
+        long otherTimestamp =
+                ((FixedLengthByteKeyComparator<IN>) 
referencedComparator).timestampReference;
+
+        int keyCmp = compare(otherKey, this.keyReference);
+        if (keyCmp != 0) {
+            return keyCmp;
+        }
+        return Long.compare(otherTimestamp, this.timestampReference);
+    }
+
+    @Override
+    public int compare(
+            Tuple2<byte[], StreamRecord<IN>> first, Tuple2<byte[], 
StreamRecord<IN>> second) {
+        int keyCmp = compare(first.f0, second.f0);
+        if (keyCmp != 0) {
+            return keyCmp;
+        }
+        return Long.compare(
+                first.f1.asRecord().getTimestamp(), 
second.f1.asRecord().getTimestamp());
+    }
+
+    private int compare(byte[] first, byte[] second) {
+        for (int i = 0; i < keyLength; i++) {
+            int cmp = Byte.compare(first[i], second[i]);
+
+            if (cmp != 0) {
+                return cmp < 0 ? -1 : 1;
+            }
+        }
+
+        return 0;
+    }
+
+    @Override
+    public int compareSerialized(DataInputView firstSource, DataInputView 
secondSource)
+            throws IOException {
+        int minCount = keyLength;
+        while (minCount-- > 0) {
+            byte firstValue = firstSource.readByte();
+            byte secondValue = secondSource.readByte();
+
+            int cmp = Byte.compare(firstValue, secondValue);
+            if (cmp != 0) {
+                return cmp < 0 ? -1 : 1;
+            }
+        }
+
+        return Long.compare(firstSource.readLong(), secondSource.readLong());
+    }
+
+    @Override
+    public boolean supportsNormalizedKey() {
+        return true;
+    }
+
+    @Override
+    public int getNormalizeKeyLen() {
+        return keyLength + TIMESTAMP_BYTE_SIZE;
+    }
+
+    @Override
+    public boolean isNormalizedKeyPrefixOnly(int keyBytes) {
+        return keyBytes < getNormalizeKeyLen();
+    }
+
+    @Override
+    public void putNormalizedKey(
+            Tuple2<byte[], StreamRecord<IN>> record,
+            MemorySegment target,
+            int offset,
+            int numBytes) {
+        BytesKeyNormalizationUtil.putNormalizedKey(record, keyLength, target, 
offset, numBytes);
+    }
+
+    @Override
+    public boolean invertNormalizedKey() {
+        return false;
+    }
+
+    @Override
+    public TypeComparator<Tuple2<byte[], StreamRecord<IN>>> duplicate() {
+        return new FixedLengthByteKeyComparator<>(this.keyLength);
+    }
+
+    @Override
+    public int extractKeys(Object record, Object[] target, int index) {
+        target[index] = record;
+        return 1;
+    }
+
+    @Override
+    public TypeComparator<?>[] getFlatComparators() {
+        return new TypeComparator[] {this};
+    }
+
+    // 
--------------------------------------------------------------------------------------------
+    // unsupported normalization
+    // 
--------------------------------------------------------------------------------------------
+
+    @Override
+    public boolean supportsSerializationWithKeyNormalization() {
+        return false;
+    }
+
+    @Override
+    public void writeWithKeyNormalization(
+            Tuple2<byte[], StreamRecord<IN>> record, DataOutputView target) 
throws IOException {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public Tuple2<byte[], StreamRecord<IN>> readWithKeyDenormalization(
+            Tuple2<byte[], StreamRecord<IN>> reuse, DataInputView source) 
throws IOException {
+        throw new UnsupportedOperationException();
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/KeyAndValueSerializer.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/KeyAndValueSerializer.java
new file mode 100644
index 00000000..f2d82be8
--- /dev/null
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/KeyAndValueSerializer.java
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream.sort;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.streaming.api.operators.sort.SortingDataInput;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Objects;
+
+/**
+ * A serializer used in {@link SortingDataInput} for serializing elements 
alongside their key and
+ * timestamp. It serializes the record in a format known by the {@link 
FixedLengthByteKeyComparator}
+ * and {@link VariableLengthByteKeyComparator}.
+ *
+ * <p>If the key is of known constant length, the length is not serialized 
with the data. Therefore
+ * the serialized data is as follows:
+ *
+ * <pre>
+ *      [key-length] | &lt;key&gt; | &lt;timestamp&gt; | &lt;record&gt;
+ * </pre>
+ *
+ * <p>TODO: remove this class after making the corresponding class in Flink 
public.
+ */
+public final class KeyAndValueSerializer<IN>
+        extends TypeSerializer<Tuple2<byte[], StreamRecord<IN>>> {
+    private static final int TIMESTAMP_LENGTH = 8;
+    private final TypeSerializer<IN> valueSerializer;
+
+    // This represents either a variable length (-1) or a fixed one (>= 0).
+    private final int serializedKeyLength;
+
+    public KeyAndValueSerializer(TypeSerializer<IN> valueSerializer, int 
serializedKeyLength) {
+        this.valueSerializer = valueSerializer;
+        this.serializedKeyLength = serializedKeyLength;
+    }
+
+    @Override
+    public boolean isImmutableType() {
+        return false;
+    }
+
+    @Override
+    public TypeSerializer<Tuple2<byte[], StreamRecord<IN>>> duplicate() {
+        return new KeyAndValueSerializer<>(valueSerializer.duplicate(), 
this.serializedKeyLength);
+    }
+
+    @Override
+    public Tuple2<byte[], StreamRecord<IN>> copy(Tuple2<byte[], 
StreamRecord<IN>> from) {
+        StreamRecord<IN> fromRecord = from.f1;
+        return Tuple2.of(
+                Arrays.copyOf(from.f0, from.f0.length),
+                fromRecord.copy(valueSerializer.copy(fromRecord.getValue())));
+    }
+
+    @Override
+    public Tuple2<byte[], StreamRecord<IN>> createInstance() {
+        return Tuple2.of(new byte[0], new 
StreamRecord<>(valueSerializer.createInstance()));
+    }
+
+    @Override
+    public Tuple2<byte[], StreamRecord<IN>> copy(
+            Tuple2<byte[], StreamRecord<IN>> from, Tuple2<byte[], 
StreamRecord<IN>> reuse) {
+        StreamRecord<IN> fromRecord = from.f1;
+        StreamRecord<IN> reuseRecord = reuse.f1;
+
+        IN valueCopy = valueSerializer.copy(fromRecord.getValue(), 
reuseRecord.getValue());
+        fromRecord.copyTo(valueCopy, reuseRecord);
+        reuse.f0 = Arrays.copyOf(from.f0, from.f0.length);
+        reuse.f1 = reuseRecord;
+        return reuse;
+    }
+
+    @Override
+    public int getLength() {
+        if (valueSerializer.getLength() < 0 || serializedKeyLength < 0) {
+            return -1;
+        }
+        return valueSerializer.getLength() + serializedKeyLength + 
TIMESTAMP_LENGTH;
+    }
+
+    @Override
+    public void serialize(Tuple2<byte[], StreamRecord<IN>> record, 
DataOutputView target)
+            throws IOException {
+        if (serializedKeyLength < 0) {
+            target.writeInt(record.f0.length);
+        }
+        target.write(record.f0);
+        StreamRecord<IN> toSerialize = record.f1;
+        target.writeLong(toSerialize.getTimestamp());
+        valueSerializer.serialize(toSerialize.getValue(), target);
+    }
+
+    @Override
+    public Tuple2<byte[], StreamRecord<IN>> deserialize(DataInputView source) 
throws IOException {
+        final int length = getKeyLength(source);
+        byte[] bytes = new byte[length];
+        source.read(bytes);
+        long timestamp = source.readLong();
+        IN value = valueSerializer.deserialize(source);
+        return Tuple2.of(bytes, new StreamRecord<>(value, timestamp));
+    }
+
+    @Override
+    public Tuple2<byte[], StreamRecord<IN>> deserialize(
+            Tuple2<byte[], StreamRecord<IN>> reuse, DataInputView source) 
throws IOException {
+        final int length = getKeyLength(source);
+        byte[] bytes = new byte[length];
+        source.read(bytes);
+        long timestamp = source.readLong();
+        IN value = valueSerializer.deserialize(source);
+        StreamRecord<IN> reuseRecord = reuse.f1;
+        reuseRecord.replace(value, timestamp);
+        reuse.f0 = bytes;
+        reuse.f1 = reuseRecord;
+        return reuse;
+    }
+
+    private int getKeyLength(DataInputView source) throws IOException {
+        final int length;
+        if (serializedKeyLength < 0) {
+            length = source.readInt();
+        } else {
+            length = serializedKeyLength;
+        }
+        return length;
+    }
+
+    @Override
+    public void copy(DataInputView source, DataOutputView target) throws 
IOException {
+        final int length;
+        if (serializedKeyLength < 0) {
+            length = source.readInt();
+            target.writeInt(length);
+        } else {
+            length = serializedKeyLength;
+        }
+        for (int i = 0; i < length; i++) {
+            target.writeByte(source.readByte());
+        }
+        target.writeLong(source.readLong());
+        valueSerializer.copy(source, target);
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (o == null || getClass() != o.getClass()) {
+            return false;
+        }
+        KeyAndValueSerializer<?> that = (KeyAndValueSerializer<?>) o;
+        return Objects.equals(valueSerializer, that.valueSerializer);
+    }
+
+    @Override
+    public int hashCode() {
+        return Objects.hash(valueSerializer);
+    }
+
+    @Override
+    public TypeSerializerSnapshot<Tuple2<byte[], StreamRecord<IN>>> 
snapshotConfiguration() {
+        throw new UnsupportedOperationException(
+                "The KeyAndValueSerializer should not be used for persisting 
into State!");
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/VariableLengthByteKeyComparator.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/VariableLengthByteKeyComparator.java
new file mode 100644
index 00000000..7ae89caa
--- /dev/null
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/VariableLengthByteKeyComparator.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream.sort;
+
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.streaming.api.operators.sort.SortingDataInput;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * A comparator used in {@link SortingDataInput} which compares records keys 
and timestamps,. It
+ * uses binary format produced by the {@link KeyAndValueSerializer}.
+ *
+ * <p>It assumes keys are of a variable length and thus expects the length of 
the record to be
+ * serialized.
+ *
+ * <p>TODO: remove this class after making the corresponding class in Flink 
public.
+ */
+public final class VariableLengthByteKeyComparator<IN>
+        extends TypeComparator<Tuple2<byte[], StreamRecord<IN>>> {
+    private byte[] keyReference;
+    private long timestampReference;
+
+    @Override
+    public int hash(Tuple2<byte[], StreamRecord<IN>> record) {
+        return record.hashCode();
+    }
+
+    @Override
+    public void setReference(Tuple2<byte[], StreamRecord<IN>> toCompare) {
+        this.keyReference = Arrays.copyOf(toCompare.f0, toCompare.f0.length);
+        this.timestampReference = toCompare.f1.asRecord().getTimestamp();
+    }
+
+    @Override
+    public boolean equalToReference(Tuple2<byte[], StreamRecord<IN>> 
candidate) {
+        return Arrays.equals(keyReference, candidate.f0)
+                && timestampReference == 
candidate.f1.asRecord().getTimestamp();
+    }
+
+    @Override
+    public int compareToReference(
+            TypeComparator<Tuple2<byte[], StreamRecord<IN>>> 
referencedComparator) {
+        byte[] otherKey = ((VariableLengthByteKeyComparator<IN>) 
referencedComparator).keyReference;
+        long otherTimestamp =
+                ((VariableLengthByteKeyComparator<IN>) 
referencedComparator).timestampReference;
+
+        int keyCmp = compare(otherKey, this.keyReference);
+        if (keyCmp != 0) {
+            return keyCmp;
+        }
+        return Long.compare(otherTimestamp, this.timestampReference);
+    }
+
+    @Override
+    public int compare(
+            Tuple2<byte[], StreamRecord<IN>> first, Tuple2<byte[], 
StreamRecord<IN>> second) {
+        int keyCmp = compare(first.f0, second.f0);
+        if (keyCmp != 0) {
+            return keyCmp;
+        }
+        return Long.compare(
+                first.f1.asRecord().getTimestamp(), 
second.f1.asRecord().getTimestamp());
+    }
+
+    private int compare(byte[] first, byte[] second) {
+        int firstLength = first.length;
+        int secondLength = second.length;
+        int minLength = Math.min(firstLength, secondLength);
+        for (int i = 0; i < minLength; i++) {
+            int cmp = Byte.compare(first[i], second[i]);
+
+            if (cmp != 0) {
+                return cmp;
+            }
+        }
+
+        return Integer.compare(firstLength, secondLength);
+    }
+
+    @Override
+    public int compareSerialized(DataInputView firstSource, DataInputView 
secondSource)
+            throws IOException {
+        int firstLength = firstSource.readInt();
+        int secondLength = secondSource.readInt();
+        int minLength = Math.min(firstLength, secondLength);
+        while (minLength-- > 0) {
+            byte firstValue = firstSource.readByte();
+            byte secondValue = secondSource.readByte();
+
+            int cmp = Byte.compare(firstValue, secondValue);
+            if (cmp != 0) {
+                return cmp;
+            }
+        }
+
+        int lengthCompare = Integer.compare(firstLength, secondLength);
+        if (lengthCompare != 0) {
+            return lengthCompare;
+        } else {
+            return Long.compare(firstSource.readLong(), 
secondSource.readLong());
+        }
+    }
+
+    @Override
+    public boolean supportsNormalizedKey() {
+        return true;
+    }
+
+    @Override
+    public int getNormalizeKeyLen() {
+        return Integer.MAX_VALUE;
+    }
+
+    @Override
+    public boolean isNormalizedKeyPrefixOnly(int keyBytes) {
+        return true;
+    }
+
+    @Override
+    public void putNormalizedKey(
+            Tuple2<byte[], StreamRecord<IN>> record,
+            MemorySegment target,
+            int offset,
+            int numBytes) {
+        BytesKeyNormalizationUtil.putNormalizedKey(
+                record, record.f0.length, target, offset, numBytes);
+    }
+
+    @Override
+    public boolean invertNormalizedKey() {
+        return false;
+    }
+
+    @Override
+    public TypeComparator<Tuple2<byte[], StreamRecord<IN>>> duplicate() {
+        return new VariableLengthByteKeyComparator<>();
+    }
+
+    @Override
+    public int extractKeys(Object record, Object[] target, int index) {
+        target[index] = record;
+        return 1;
+    }
+
+    @Override
+    public TypeComparator<?>[] getFlatComparators() {
+        return new TypeComparator[] {this};
+    }
+
+    // 
--------------------------------------------------------------------------------------------
+    // unsupported normalization
+    // 
--------------------------------------------------------------------------------------------
+
+    @Override
+    public boolean supportsSerializationWithKeyNormalization() {
+        return false;
+    }
+
+    @Override
+    public void writeWithKeyNormalization(
+            Tuple2<byte[], StreamRecord<IN>> record, DataOutputView target) 
throws IOException {
+        throw new UnsupportedOperationException();
+    }
+
+    @Override
+    public Tuple2<byte[], StreamRecord<IN>> readWithKeyDenormalization(
+            Tuple2<byte[], StreamRecord<IN>> reuse, DataInputView source) 
throws IOException {
+        throw new UnsupportedOperationException();
+    }
+}
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
index 7b3e8b3a..e72482f5 100644
--- 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
+++ 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
@@ -19,10 +19,14 @@
 package org.apache.flink.ml.common.datastream;
 
 import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.CoGroupFunction;
 import org.apache.flink.api.common.functions.MapPartitionFunction;
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.ml.util.TestUtils;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -33,6 +37,7 @@ import org.apache.commons.collections.IteratorUtils;
 import org.junit.Before;
 import org.junit.Test;
 
+import java.util.Arrays;
 import java.util.List;
 
 import static org.junit.Assert.assertArrayEquals;
@@ -48,6 +53,95 @@ public class DataStreamUtilsTest {
         env = TestUtils.getExecutionEnvironment();
     }
 
+    @Test
+    public void testCoGroupWithSingleParallelism() throws Exception {
+        DataStream<Tuple2<Integer, Integer>> data1 =
+                env.fromCollection(
+                        Arrays.asList(Tuple2.of(1, 1), Tuple2.of(2, 2), 
Tuple2.of(3, 3)));
+        DataStream<Tuple2<Integer, Double>> data2 =
+                env.fromCollection(
+                        Arrays.asList(
+                                Tuple2.of(1, 1.5),
+                                Tuple2.of(5, 5.5),
+                                Tuple2.of(3, 3.5),
+                                Tuple2.of(1, 2.5)));
+        DataStream<Double> result =
+                DataStreamUtils.coGroup(
+                        data1,
+                        data2,
+                        (KeySelector<Tuple2<Integer, Integer>, Integer>) tuple 
-> tuple.f0,
+                        (KeySelector<Tuple2<Integer, Double>, Integer>) tuple 
-> tuple.f0,
+                        BasicTypeInfo.DOUBLE_TYPE_INFO,
+                        new CoGroupFunction<
+                                Tuple2<Integer, Integer>, Tuple2<Integer, 
Double>, Double>() {
+                            @Override
+                            public void coGroup(
+                                    Iterable<Tuple2<Integer, Integer>> 
iterableA,
+                                    Iterable<Tuple2<Integer, Double>> 
iterableB,
+                                    Collector<Double> collector) {
+                                List<Tuple2<Integer, Integer>> valuesA =
+                                        
IteratorUtils.toList(iterableA.iterator());
+                                List<Tuple2<Integer, Double>> valuesB =
+                                        
IteratorUtils.toList(iterableB.iterator());
+
+                                double sum = 0;
+                                for (Tuple2<Integer, Integer> value : valuesA) 
{
+                                    sum += value.f1;
+                                }
+                                for (Tuple2<Integer, Double> value : valuesB) {
+                                    sum += value.f1;
+                                }
+                                collector.collect(sum);
+                            }
+                        });
+
+        List<Double> resultValues = 
IteratorUtils.toList(result.executeAndCollect());
+        double[] resultPrimitiveValues =
+                
resultValues.stream().mapToDouble(Double::doubleValue).toArray();
+        double[] expectedResult = new double[] {5.0, 2.0, 6.5, 5.5};
+        assertArrayEquals(expectedResult, resultPrimitiveValues, 1e-5);
+    }
+
+    @Test
+    public void testCoGroupWithMultiParallelism() throws Exception {
+        DataStream<Long> data1 =
+                env.fromParallelCollection(new NumberSequenceIterator(0L, 
10L), Types.LONG);
+        DataStream<Long> data2 =
+                env.fromParallelCollection(new NumberSequenceIterator(6L, 
16L), Types.LONG);
+
+        DataStream<Long> result =
+                DataStreamUtils.coGroup(
+                        data1,
+                        data2,
+                        (KeySelector<Long, Long>) v -> v / 2,
+                        (KeySelector<Long, Long>) v -> v / 2,
+                        BasicTypeInfo.LONG_TYPE_INFO,
+                        new CoGroupFunction<Long, Long, Long>() {
+                            @Override
+                            public void coGroup(
+                                    Iterable<Long> iterableA,
+                                    Iterable<Long> iterableB,
+                                    Collector<Long> collector) {
+                                List<Long> valuesA = 
IteratorUtils.toList(iterableA.iterator());
+                                List<Long> valuesB = 
IteratorUtils.toList(iterableB.iterator());
+                                long sum = 0;
+                                for (Long value : valuesA) {
+                                    sum += value;
+                                }
+                                for (Long value : valuesB) {
+                                    sum += value;
+                                }
+                                collector.collect(sum);
+                            }
+                        });
+
+        List<Long> resultValues = 
IteratorUtils.toList(result.executeAndCollect());
+        long[] resultPrimitiveValues = 
resultValues.stream().mapToLong(Long::longValue).toArray();
+        Arrays.sort(resultPrimitiveValues);
+        long[] expectedResult = new long[] {1, 5, 9, 16, 25, 26, 29, 31, 34};
+        assertArrayEquals(expectedResult, resultPrimitiveValues);
+    }
+
     @Test
     public void testMapPartition() throws Exception {
         DataStream<Long> dataStream =
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
index 56330742..6fbf39d5 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
@@ -160,7 +160,7 @@ public class KMeans implements Estimator<KMeans, 
KMeansModel>, KMeansParams<KMea
                                                     
DenseVectorTypeInfo.INSTANCE)),
                                     new 
CentroidsUpdateAccumulator(distanceMeasure));
 
-            
DataStreamUtils.setManagedMemoryWeight(centroidIdAndPoints.getTransformation(), 
100);
+            DataStreamUtils.setManagedMemoryWeight(centroidIdAndPoints, 100);
 
             int parallelism = centroidIdAndPoints.getParallelism();
             DataStream<KMeansModelData> newModelData =

Reply via email to