This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new d7c9c8b5 [FLINK-31753] Support DataStream CoGroup in stream mode with
similar performance as DataSet CoGroup
d7c9c8b5 is described below
commit d7c9c8b5242a3c161d430a03fc4e4c3b0d1d78ff
Author: Dong Lin <[email protected]>
AuthorDate: Wed Apr 12 16:43:12 2023 +0800
[FLINK-31753] Support DataStream CoGroup in stream mode with similar
performance as DataSet CoGroup
This closes #230.
---
.../ml/common/datastream/DataStreamUtils.java | 65 ++++-
.../datastream/sort/BytesKeyNormalizationUtil.java | 84 ++++++
.../ml/common/datastream/sort/CoGroupOperator.java | 314 +++++++++++++++++++++
.../sort/FixedLengthByteKeyComparator.java | 188 ++++++++++++
.../datastream/sort/KeyAndValueSerializer.java | 189 +++++++++++++
.../sort/VariableLengthByteKeyComparator.java | 193 +++++++++++++
.../ml/common/datastream/DataStreamUtilsTest.java | 94 ++++++
.../apache/flink/ml/clustering/kmeans/KMeans.java | 2 +-
8 files changed, 1123 insertions(+), 6 deletions(-)
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
index eb4ec6ca..e4cbcd52 100644
---
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -20,6 +20,7 @@ package org.apache.flink.ml.common.datastream;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
@@ -32,12 +33,13 @@ import org.apache.flink.api.common.time.Time;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
-import org.apache.flink.api.dag.Transformation;
+import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.TypeExtractor;
import org.apache.flink.core.memory.ManagedMemoryUseCase;
import org.apache.flink.iteration.datacache.nonkeyed.ListStateWithCache;
import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.sort.CoGroupOperator;
import org.apache.flink.ml.common.window.CountTumblingWindows;
import org.apache.flink.ml.common.window.EventTimeSessionWindows;
import org.apache.flink.ml.common.window.EventTimeTumblingWindows;
@@ -75,6 +77,7 @@ import org.apache.flink.util.Collector;
import org.apache.commons.collections.IteratorUtils;
+import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
@@ -133,6 +136,7 @@ public class DataStreamUtils {
DataStream<IN> input,
MapPartitionFunction<IN, OUT> func,
TypeInformation<OUT> outType) {
+ func = input.getExecutionEnvironment().clean(func);
return input.transform("mapPartition", outType, new
MapPartitionOperator<>(func))
.setParallelism(input.getParallelism());
}
@@ -162,6 +166,7 @@ public class DataStreamUtils {
*/
public static <T> DataStream<T> reduce(
DataStream<T> input, ReduceFunction<T> func, TypeInformation<T>
outType) {
+ func = input.getExecutionEnvironment().clean(func);
DataStream<T> partialReducedStream =
input.transform("reduce", outType, new ReduceOperator<>(func))
.setParallelism(input.getParallelism());
@@ -201,6 +206,7 @@ public class DataStreamUtils {
*/
public static <T, K> DataStream<T> reduce(
KeyedStream<T, K> input, ReduceFunction<T> func,
TypeInformation<T> outType) {
+ func = input.getExecutionEnvironment().clean(func);
return input.transform(
"Keyed Reduce",
outType,
@@ -263,6 +269,7 @@ public class DataStreamUtils {
AggregateFunction<IN, ACC, OUT> func,
TypeInformation<ACC> accType,
TypeInformation<OUT> outType) {
+ func = input.getExecutionEnvironment().clean(func);
DataStream<ACC> partialAggregatedStream =
input.transform(
"partialAggregate", accType, new
PartialAggregateOperator<>(func, accType));
@@ -319,13 +326,14 @@ public class DataStreamUtils {
* bytes should be in the same scale as existing usage in Flink, for
example,
* StreamExecWindowAggregate.WINDOW_AGG_MEMORY_RATIO.
*/
- public static <T> void setManagedMemoryWeight(
- Transformation<T> transformation, long memoryBytes) {
+ public static <T> void setManagedMemoryWeight(DataStream<T> dataStream,
long memoryBytes) {
if (memoryBytes > 0) {
final int weightInMebibyte = Math.max(1, (int) (memoryBytes >>
20));
final Optional<Integer> previousWeight =
- transformation.declareManagedMemoryUseCaseAtOperatorScope(
- ManagedMemoryUseCase.OPERATOR, weightInMebibyte);
+ dataStream
+ .getTransformation()
+ .declareManagedMemoryUseCaseAtOperatorScope(
+ ManagedMemoryUseCase.OPERATOR,
weightInMebibyte);
if (previousWeight.isPresent()) {
throw new TableException(
"Managed memory weight has been set, this should not
happen.");
@@ -345,6 +353,7 @@ public class DataStreamUtils {
*/
public static <IN, OUT, W extends Window> SingleOutputStreamOperator<OUT>
windowAllAndProcess(
DataStream<IN> input, Windows windows,
ProcessAllWindowFunction<IN, OUT, W> function) {
+ function = input.getExecutionEnvironment().clean(function);
AllWindowedStream<IN, W> allWindowedStream =
getAllWindowedStream(input, windows);
return allWindowedStream.process(function);
}
@@ -365,10 +374,56 @@ public class DataStreamUtils {
Windows windows,
ProcessAllWindowFunction<IN, OUT, W> function,
TypeInformation<OUT> outType) {
+ function = input.getExecutionEnvironment().clean(function);
AllWindowedStream<IN, W> allWindowedStream =
getAllWindowedStream(input, windows);
return allWindowedStream.process(function, outType);
}
+ /**
+ * A CoGroup transformation combines the elements of two {@link DataStream
DataStreams} into one
+ * DataStream. It groups each DataStream individually on a key and gives
groups of both
+ * DataStreams with equal keys together into a {@link
+ * org.apache.flink.api.common.functions.CoGroupFunction}. If a DataStream
has a group with no
+ * matching key in the other DataStream, the CoGroupFunction is called
with an empty group for
+ * the non-existing group.
+ *
+ * <p>The CoGroupFunction can iterate over the elements of both groups and
return any number of
+ * elements including none.
+ *
+ * <p>NOTE: This method assumes both inputs are bounded.
+ *
+ * @param input1 The first data stream.
+ * @param input2 The second data stream.
+ * @param keySelector1 The KeySelector to be used for extracting the first
input's key for
+ * partitioning.
+ * @param keySelector2 The KeySelector to be used for extracting the
second input's key for
+ * partitioning.
+ * @param outTypeInformation The type information describing the output
type.
+ * @param func The user-defined co-group function.
+ * @param <IN1> The class type of the first input.
+ * @param <IN2> The class type of the second input.
+ * @param <KEY> The class type of the key.
+ * @param <OUT> The class type of the output values.
+ * @return The result data stream.
+ */
+ public static <IN1, IN2, KEY extends Serializable, OUT> DataStream<OUT>
coGroup(
+ DataStream<IN1> input1,
+ DataStream<IN2> input2,
+ KeySelector<IN1, KEY> keySelector1,
+ KeySelector<IN2, KEY> keySelector2,
+ TypeInformation<OUT> outTypeInformation,
+ CoGroupFunction<IN1, IN2, OUT> func) {
+ func = input1.getExecutionEnvironment().clean(func);
+ DataStream<OUT> result =
+ input1.connect(input2)
+ .keyBy(keySelector1, keySelector2)
+ .transform(
+ "CoGroupOperator", outTypeInformation, new
CoGroupOperator<>(func))
+ .setParallelism(Math.max(input1.getParallelism(),
input2.getParallelism()));
+ setManagedMemoryWeight(result, 100);
+ return result;
+ }
+
@SuppressWarnings({"rawtypes", "unchecked"})
private static <IN, W extends Window> AllWindowedStream<IN, W>
getAllWindowedStream(
DataStream<IN> input, Windows windows) {
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/BytesKeyNormalizationUtil.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/BytesKeyNormalizationUtil.java
new file mode 100644
index 00000000..90be2616
--- /dev/null
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/BytesKeyNormalizationUtil.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream.sort;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+/**
+ * Utility class for common key normalization used both in {@link
VariableLengthByteKeyComparator}
+ * and {@link FixedLengthByteKeyComparator}.
+ *
+ * <p>TODO: remove this class after making the corresponding class in Flink
public.
+ */
+final class BytesKeyNormalizationUtil {
+ /**
+ * Writes the normalized key of given record. The normalized key consists
of the key serialized
+ * as bytes and the timestamp of the record.
+ *
+ * <p>NOTE: The key does not represent a logical order. It can be used
only for grouping keys!
+ */
+ static <IN> void putNormalizedKey(
+ Tuple2<byte[], StreamRecord<IN>> record,
+ int dataLength,
+ MemorySegment target,
+ int offset,
+ int numBytes) {
+ byte[] data = record.f0;
+
+ if (dataLength >= numBytes) {
+ putBytesArray(target, offset, numBytes, data);
+ } else {
+ // whole key fits into the normalized key
+ putBytesArray(target, offset, dataLength, data);
+ int lastOffset = offset + numBytes;
+ offset += dataLength;
+ long valueOfTimestamp = record.f1.asRecord().getTimestamp() -
Long.MIN_VALUE;
+ if (dataLength + FixedLengthByteKeyComparator.TIMESTAMP_BYTE_SIZE
<= numBytes) {
+ // whole timestamp fits into the normalized key
+ target.putLong(offset, valueOfTimestamp);
+ offset += FixedLengthByteKeyComparator.TIMESTAMP_BYTE_SIZE;
+ // fill in the remaining space with zeros
+ while (offset < lastOffset) {
+ target.put(offset++, (byte) 0);
+ }
+ } else {
+ // only part of the timestamp fits into normalized key
+ for (int i = 0; offset < lastOffset; offset++, i++) {
+ target.put(offset, (byte) (valueOfTimestamp >>> ((7 - i)
<< 3)));
+ }
+ }
+ }
+ }
+
+ private static void putBytesArray(MemorySegment target, int offset, int
numBytes, byte[] data) {
+ for (int i = 0; i < numBytes; i++) {
+ // We're converting the signed byte in data into an unsigned
representation.
+ // A Java byte goes from -128 to 127, i.e. is signed. By
subtracting -128 (MIN_VALUE)
+ // here we're shifting the number to be from 0 to 255. The
normalized key sorter sorts
+ // bytes as "unsigned", so we need to convert here to maintain a
correct ordering.
+ int highByte = data[i] & 0xff;
+ highByte -= Byte.MIN_VALUE;
+ target.put(offset + i, (byte) highByte);
+ }
+ }
+
+ private BytesKeyNormalizationUtil() {}
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/CoGroupOperator.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/CoGroupOperator.java
new file mode 100644
index 00000000..8c7e2f1f
--- /dev/null
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/CoGroupOperator.java
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream.sort;
+
+import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.functions.CoGroupFunction;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.common.typeutils.TypePairComparator;
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import
org.apache.flink.api.java.typeutils.runtime.RuntimePairComparatorFactory;
+import org.apache.flink.configuration.AlgorithmOptions;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.core.memory.DataOutputSerializer;
+import org.apache.flink.core.memory.ManagedMemoryUseCase;
+import org.apache.flink.runtime.io.disk.iomanager.IOManager;
+import org.apache.flink.runtime.memory.MemoryAllocationException;
+import org.apache.flink.runtime.memory.MemoryManager;
+import org.apache.flink.runtime.operators.sort.ExternalSorter;
+import
org.apache.flink.runtime.operators.sort.NonReusingSortMergeCoGroupIterator;
+import org.apache.flink.runtime.operators.sort.PushSorter;
+import org.apache.flink.runtime.operators.sort.ReusingSortMergeCoGroupIterator;
+import org.apache.flink.runtime.operators.util.CoGroupTaskIterator;
+import org.apache.flink.streaming.api.graph.StreamConfig;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedMultiInput;
+import org.apache.flink.streaming.api.operators.Output;
+import org.apache.flink.streaming.api.operators.TimestampedCollector;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.streaming.runtime.tasks.StreamTask;
+import org.apache.flink.util.MutableObjectIterator;
+import org.apache.flink.util.TraversableOnceException;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * An operator that implements the co-group logic.
+ *
+ * @param <IN1> The class type of the first input.
+ * @param <IN2> The class type of the second input.
+ * @param <KEY> The class type of the key.
+ * @param <OUT> The class type of the output values.
+ */
+public class CoGroupOperator<IN1, IN2, KEY extends Serializable, OUT>
+ extends AbstractUdfStreamOperator<OUT, CoGroupFunction<IN1, IN2, OUT>>
+ implements TwoInputStreamOperator<IN1, IN2, OUT>, BoundedMultiInput {
+
+ private PushSorter<Tuple2<byte[], StreamRecord<IN1>>> sorterA;
+ private PushSorter<Tuple2<byte[], StreamRecord<IN2>>> sorterB;
+ private TypeComparator<Tuple2<byte[], StreamRecord<IN1>>> comparatorA;
+ private TypeComparator<Tuple2<byte[], StreamRecord<IN2>>> comparatorB;
+ private KeySelector<IN1, KEY> keySelectorA;
+ private KeySelector<IN2, KEY> keySelectorB;
+ private TypeSerializer<Tuple2<byte[], StreamRecord<IN1>>>
keyAndValueSerializerA;
+ private TypeSerializer<Tuple2<byte[], StreamRecord<IN2>>>
keyAndValueSerializerB;
+ private TypeSerializer<KEY> keySerializer;
+ private DataOutputSerializer dataOutputSerializer;
+ private long lastWatermarkTimestamp = Long.MIN_VALUE;
+ private int remainingInputNum = 2;
+
+ public CoGroupOperator(CoGroupFunction<IN1, IN2, OUT> function) {
+ super(function);
+ }
+
+ @Override
+ public void setup(
+ StreamTask<?, ?> containingTask,
+ StreamConfig config,
+ Output<StreamRecord<OUT>> output) {
+ super.setup(containingTask, config, output);
+ ClassLoader userCodeClassLoader =
containingTask.getUserCodeClassLoader();
+ MemoryManager memoryManager =
containingTask.getEnvironment().getMemoryManager();
+ IOManager ioManager = containingTask.getEnvironment().getIOManager();
+
+ keySelectorA = config.getStatePartitioner(0, userCodeClassLoader);
+ keySelectorB = config.getStatePartitioner(1, userCodeClassLoader);
+ keySerializer = config.getStateKeySerializer(userCodeClassLoader);
+ int keyLength = keySerializer.getLength();
+
+ TypeSerializer<IN1> typeSerializerA = config.getTypeSerializerIn(0,
userCodeClassLoader);
+ TypeSerializer<IN2> typeSerializerB = config.getTypeSerializerIn(1,
userCodeClassLoader);
+ keyAndValueSerializerA = new KeyAndValueSerializer<>(typeSerializerA,
keyLength);
+ keyAndValueSerializerB = new KeyAndValueSerializer<>(typeSerializerB,
keyLength);
+
+ if (keyLength > 0) {
+ dataOutputSerializer = new DataOutputSerializer(keyLength);
+ comparatorA = new FixedLengthByteKeyComparator<>(keyLength);
+ comparatorB = new FixedLengthByteKeyComparator<>(keyLength);
+ } else {
+ dataOutputSerializer = new DataOutputSerializer(64);
+ comparatorA = new VariableLengthByteKeyComparator<>();
+ comparatorB = new VariableLengthByteKeyComparator<>();
+ }
+
+ ExecutionConfig executionConfig =
containingTask.getEnvironment().getExecutionConfig();
+ double managedMemoryFraction =
+ config.getManagedMemoryFractionOperatorUseCaseOfSlot(
+ ManagedMemoryUseCase.OPERATOR,
+
containingTask.getEnvironment().getTaskConfiguration(),
+ userCodeClassLoader)
+ / 2;
+ Configuration jobConfiguration =
containingTask.getEnvironment().getJobConfiguration();
+
+ try {
+ sorterA =
+ ExternalSorter.newBuilder(
+ memoryManager,
+ containingTask,
+ keyAndValueSerializerA,
+ comparatorA,
+ executionConfig)
+ .memoryFraction(managedMemoryFraction)
+ .enableSpilling(
+ ioManager,
+
jobConfiguration.get(AlgorithmOptions.SORT_SPILLING_THRESHOLD))
+ .maxNumFileHandles(
+
jobConfiguration.get(AlgorithmOptions.SPILLING_MAX_FAN))
+
.objectReuse(executionConfig.isObjectReuseEnabled())
+ .largeRecords(
+ jobConfiguration.get(
+
AlgorithmOptions.USE_LARGE_RECORDS_HANDLER))
+ .build();
+ sorterB =
+ ExternalSorter.newBuilder(
+ memoryManager,
+ containingTask,
+ keyAndValueSerializerB,
+ comparatorB,
+ executionConfig)
+ .memoryFraction(managedMemoryFraction)
+ .enableSpilling(
+ ioManager,
+
jobConfiguration.get(AlgorithmOptions.SORT_SPILLING_THRESHOLD))
+ .maxNumFileHandles(
+
jobConfiguration.get(AlgorithmOptions.SPILLING_MAX_FAN))
+
.objectReuse(executionConfig.isObjectReuseEnabled())
+ .largeRecords(
+ jobConfiguration.get(
+
AlgorithmOptions.USE_LARGE_RECORDS_HANDLER))
+ .build();
+ } catch (MemoryAllocationException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public void endInput(int inputId) throws Exception {
+ if (inputId == 1) {
+ sorterA.finishReading();
+ remainingInputNum--;
+ } else if (inputId == 2) {
+ sorterB.finishReading();
+ remainingInputNum--;
+ } else {
+ throw new RuntimeException("Unknown inputId " + inputId);
+ }
+
+ if (remainingInputNum > 0) {
+ return;
+ }
+
+ MutableObjectIterator<Tuple2<byte[], StreamRecord<IN1>>> iteratorA =
sorterA.getIterator();
+ MutableObjectIterator<Tuple2<byte[], StreamRecord<IN2>>> iteratorB =
sorterB.getIterator();
+ TypePairComparator<Tuple2<byte[], StreamRecord<IN1>>, Tuple2<byte[],
StreamRecord<IN2>>>
+ pairComparator =
+ (new RuntimePairComparatorFactory<
+ Tuple2<byte[], StreamRecord<IN1>>,
+ Tuple2<byte[], StreamRecord<IN2>>>())
+ .createComparator12(comparatorA, comparatorB);
+
+ CoGroupTaskIterator<Tuple2<byte[], StreamRecord<IN1>>, Tuple2<byte[],
StreamRecord<IN2>>>
+ coGroupIterator;
+ if (getExecutionConfig().isObjectReuseEnabled()) {
+ coGroupIterator =
+ new ReusingSortMergeCoGroupIterator<
+ Tuple2<byte[], StreamRecord<IN1>>, Tuple2<byte[],
StreamRecord<IN2>>>(
+ iteratorA,
+ iteratorB,
+ keyAndValueSerializerA,
+ comparatorA,
+ keyAndValueSerializerB,
+ comparatorB,
+ pairComparator);
+ } else {
+ coGroupIterator =
+ new NonReusingSortMergeCoGroupIterator<
+ Tuple2<byte[], StreamRecord<IN1>>, Tuple2<byte[],
StreamRecord<IN2>>>(
+ iteratorA,
+ iteratorB,
+ keyAndValueSerializerA,
+ comparatorA,
+ keyAndValueSerializerB,
+ comparatorB,
+ pairComparator);
+ }
+
+ coGroupIterator.open();
+ TupleUnwrappingIterator<IN1, byte[]> unWrappediteratorA = new
TupleUnwrappingIterator<>();
+ TupleUnwrappingIterator<IN2, byte[]> unWrappediteratorB = new
TupleUnwrappingIterator<>();
+
+ Output<OUT> timestampedCollector = new TimestampedCollector<>(output);
+ while (coGroupIterator.next()) {
+ unWrappediteratorA.set(coGroupIterator.getValues1().iterator());
+ unWrappediteratorB.set(coGroupIterator.getValues2().iterator());
+ userFunction.coGroup(unWrappediteratorA, unWrappediteratorB,
timestampedCollector);
+ }
+ coGroupIterator.close();
+
+ Watermark watermark = new Watermark(lastWatermarkTimestamp);
+ if (getTimeServiceManager().isPresent()) {
+ getTimeServiceManager().get().advanceWatermark(watermark);
+ }
+ output.emitWatermark(watermark);
+ }
+
+ @Override
+ public void processWatermark(Watermark watermark) throws Exception {
+ if (lastWatermarkTimestamp > watermark.getTimestamp()) {
+ throw new RuntimeException("Invalid watermark");
+ }
+ lastWatermarkTimestamp = watermark.getTimestamp();
+ }
+
+ @Override
+ public void close() throws Exception {
+ super.close();
+ sorterA.close();
+ sorterB.close();
+ }
+
+ @Override
+ public void processElement1(StreamRecord<IN1> streamRecord) throws
Exception {
+ KEY key = keySelectorA.getKey(streamRecord.getValue());
+ keySerializer.serialize(key, dataOutputSerializer);
+ byte[] serializedKey = dataOutputSerializer.getCopyOfBuffer();
+ dataOutputSerializer.clear();
+ sorterA.writeRecord(Tuple2.of(serializedKey, streamRecord));
+ }
+
+ @Override
+ public void processElement2(StreamRecord<IN2> streamRecord) throws
Exception {
+ KEY key = keySelectorB.getKey(streamRecord.getValue());
+ keySerializer.serialize(key, dataOutputSerializer);
+ byte[] serializedKey = dataOutputSerializer.getCopyOfBuffer();
+ dataOutputSerializer.clear();
+ sorterB.writeRecord(Tuple2.of(serializedKey, streamRecord));
+ }
+
+ private static class TupleUnwrappingIterator<T, K>
+ implements Iterator<T>, Iterable<T>, java.io.Serializable {
+
+ private static final long serialVersionUID = 1L;
+
+ private K lastKey;
+ private Iterator<Tuple2<K, StreamRecord<T>>> iterator;
+ private boolean iteratorAvailable;
+
+ public void set(Iterator<Tuple2<K, StreamRecord<T>>> iterator) {
+ this.iterator = iterator;
+ this.iteratorAvailable = true;
+ }
+
+ public K getLastKey() {
+ return lastKey;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return iterator.hasNext();
+ }
+
+ @Override
+ public T next() {
+ Tuple2<K, StreamRecord<T>> t = iterator.next();
+ this.lastKey = t.f0;
+ return t.f1.getValue();
+ }
+
+ @Override
+ public void remove() {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Iterator<T> iterator() {
+ if (iteratorAvailable) {
+ iteratorAvailable = false;
+ return this;
+ } else {
+ throw new TraversableOnceException();
+ }
+ }
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/FixedLengthByteKeyComparator.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/FixedLengthByteKeyComparator.java
new file mode 100644
index 00000000..859e193f
--- /dev/null
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/FixedLengthByteKeyComparator.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream.sort;
+
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.streaming.api.operators.sort.SortingDataInput;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * A comparator used in {@link SortingDataInput} which compares records keys
and timestamps. It uses
+ * binary format produced by the {@link KeyAndValueSerializer}.
+ *
+ * <p>It assumes keys are always of a fixed length and thus the length of the
record is not
+ * serialized.
+ *
+ * <p>TODO: remove this class after making the corresponding class in Flink
public.
+ */
+public final class FixedLengthByteKeyComparator<IN>
+ extends TypeComparator<Tuple2<byte[], StreamRecord<IN>>> {
+ static final int TIMESTAMP_BYTE_SIZE = 8;
+ private final int keyLength;
+ private byte[] keyReference;
+ private long timestampReference;
+
+ public FixedLengthByteKeyComparator(int keyLength) {
+ this.keyLength = keyLength;
+ }
+
+ @Override
+ public int hash(Tuple2<byte[], StreamRecord<IN>> record) {
+ return record.hashCode();
+ }
+
+ @Override
+ public void setReference(Tuple2<byte[], StreamRecord<IN>> toCompare) {
+ this.keyReference = toCompare.f0;
+ this.timestampReference = toCompare.f1.asRecord().getTimestamp();
+ }
+
+ @Override
+ public boolean equalToReference(Tuple2<byte[], StreamRecord<IN>>
candidate) {
+ return Arrays.equals(keyReference, candidate.f0)
+ && timestampReference ==
candidate.f1.asRecord().getTimestamp();
+ }
+
+ @Override
+ public int compareToReference(
+ TypeComparator<Tuple2<byte[], StreamRecord<IN>>>
referencedComparator) {
+ byte[] otherKey = ((FixedLengthByteKeyComparator<IN>)
referencedComparator).keyReference;
+ long otherTimestamp =
+ ((FixedLengthByteKeyComparator<IN>)
referencedComparator).timestampReference;
+
+ int keyCmp = compare(otherKey, this.keyReference);
+ if (keyCmp != 0) {
+ return keyCmp;
+ }
+ return Long.compare(otherTimestamp, this.timestampReference);
+ }
+
+ @Override
+ public int compare(
+ Tuple2<byte[], StreamRecord<IN>> first, Tuple2<byte[],
StreamRecord<IN>> second) {
+ int keyCmp = compare(first.f0, second.f0);
+ if (keyCmp != 0) {
+ return keyCmp;
+ }
+ return Long.compare(
+ first.f1.asRecord().getTimestamp(),
second.f1.asRecord().getTimestamp());
+ }
+
+ private int compare(byte[] first, byte[] second) {
+ for (int i = 0; i < keyLength; i++) {
+ int cmp = Byte.compare(first[i], second[i]);
+
+ if (cmp != 0) {
+ return cmp < 0 ? -1 : 1;
+ }
+ }
+
+ return 0;
+ }
+
+ @Override
+ public int compareSerialized(DataInputView firstSource, DataInputView
secondSource)
+ throws IOException {
+ int minCount = keyLength;
+ while (minCount-- > 0) {
+ byte firstValue = firstSource.readByte();
+ byte secondValue = secondSource.readByte();
+
+ int cmp = Byte.compare(firstValue, secondValue);
+ if (cmp != 0) {
+ return cmp < 0 ? -1 : 1;
+ }
+ }
+
+ return Long.compare(firstSource.readLong(), secondSource.readLong());
+ }
+
+ @Override
+ public boolean supportsNormalizedKey() {
+ return true;
+ }
+
+ @Override
+ public int getNormalizeKeyLen() {
+ return keyLength + TIMESTAMP_BYTE_SIZE;
+ }
+
+ @Override
+ public boolean isNormalizedKeyPrefixOnly(int keyBytes) {
+ return keyBytes < getNormalizeKeyLen();
+ }
+
+ @Override
+ public void putNormalizedKey(
+ Tuple2<byte[], StreamRecord<IN>> record,
+ MemorySegment target,
+ int offset,
+ int numBytes) {
+ BytesKeyNormalizationUtil.putNormalizedKey(record, keyLength, target,
offset, numBytes);
+ }
+
+ @Override
+ public boolean invertNormalizedKey() {
+ return false;
+ }
+
+ @Override
+ public TypeComparator<Tuple2<byte[], StreamRecord<IN>>> duplicate() {
+ return new FixedLengthByteKeyComparator<>(this.keyLength);
+ }
+
+ @Override
+ public int extractKeys(Object record, Object[] target, int index) {
+ target[index] = record;
+ return 1;
+ }
+
+ @Override
+ public TypeComparator<?>[] getFlatComparators() {
+ return new TypeComparator[] {this};
+ }
+
+ //
--------------------------------------------------------------------------------------------
+ // unsupported normalization
+ //
--------------------------------------------------------------------------------------------
+
+ @Override
+ public boolean supportsSerializationWithKeyNormalization() {
+ return false;
+ }
+
+ @Override
+ public void writeWithKeyNormalization(
+ Tuple2<byte[], StreamRecord<IN>> record, DataOutputView target)
throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Tuple2<byte[], StreamRecord<IN>> readWithKeyDenormalization(
+ Tuple2<byte[], StreamRecord<IN>> reuse, DataInputView source)
throws IOException {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/KeyAndValueSerializer.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/KeyAndValueSerializer.java
new file mode 100644
index 00000000..f2d82be8
--- /dev/null
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/KeyAndValueSerializer.java
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream.sort;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.common.typeutils.TypeSerializerSnapshot;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.streaming.api.operators.sort.SortingDataInput;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Objects;
+
+/**
+ * A serializer used in {@link SortingDataInput} for serializing elements
alongside their key and
+ * timestamp. It serializes the record in a format known by the {@link
FixedLengthByteKeyComparator}
+ * and {@link VariableLengthByteKeyComparator}.
+ *
+ * <p>If the key is of known constant length, the length is not serialized
with the data. Therefore
+ * the serialized data is as follows:
+ *
+ * <pre>
+ * [key-length] | <key> | <timestamp> | <record>
+ * </pre>
+ *
+ * <p>TODO: remove this class after making the corresponding class in Flink
public.
+ */
+public final class KeyAndValueSerializer<IN>
+ extends TypeSerializer<Tuple2<byte[], StreamRecord<IN>>> {
+ private static final int TIMESTAMP_LENGTH = 8;
+ private final TypeSerializer<IN> valueSerializer;
+
+ // This represents either a variable length (-1) or a fixed one (>= 0).
+ private final int serializedKeyLength;
+
+ public KeyAndValueSerializer(TypeSerializer<IN> valueSerializer, int
serializedKeyLength) {
+ this.valueSerializer = valueSerializer;
+ this.serializedKeyLength = serializedKeyLength;
+ }
+
+ @Override
+ public boolean isImmutableType() {
+ return false;
+ }
+
+ @Override
+ public TypeSerializer<Tuple2<byte[], StreamRecord<IN>>> duplicate() {
+ return new KeyAndValueSerializer<>(valueSerializer.duplicate(),
this.serializedKeyLength);
+ }
+
+ @Override
+ public Tuple2<byte[], StreamRecord<IN>> copy(Tuple2<byte[],
StreamRecord<IN>> from) {
+ StreamRecord<IN> fromRecord = from.f1;
+ return Tuple2.of(
+ Arrays.copyOf(from.f0, from.f0.length),
+ fromRecord.copy(valueSerializer.copy(fromRecord.getValue())));
+ }
+
+ @Override
+ public Tuple2<byte[], StreamRecord<IN>> createInstance() {
+ return Tuple2.of(new byte[0], new
StreamRecord<>(valueSerializer.createInstance()));
+ }
+
+ @Override
+ public Tuple2<byte[], StreamRecord<IN>> copy(
+ Tuple2<byte[], StreamRecord<IN>> from, Tuple2<byte[],
StreamRecord<IN>> reuse) {
+ StreamRecord<IN> fromRecord = from.f1;
+ StreamRecord<IN> reuseRecord = reuse.f1;
+
+ IN valueCopy = valueSerializer.copy(fromRecord.getValue(),
reuseRecord.getValue());
+ fromRecord.copyTo(valueCopy, reuseRecord);
+ reuse.f0 = Arrays.copyOf(from.f0, from.f0.length);
+ reuse.f1 = reuseRecord;
+ return reuse;
+ }
+
+ @Override
+ public int getLength() {
+ if (valueSerializer.getLength() < 0 || serializedKeyLength < 0) {
+ return -1;
+ }
+ return valueSerializer.getLength() + serializedKeyLength +
TIMESTAMP_LENGTH;
+ }
+
+ @Override
+ public void serialize(Tuple2<byte[], StreamRecord<IN>> record,
DataOutputView target)
+ throws IOException {
+ if (serializedKeyLength < 0) {
+ target.writeInt(record.f0.length);
+ }
+ target.write(record.f0);
+ StreamRecord<IN> toSerialize = record.f1;
+ target.writeLong(toSerialize.getTimestamp());
+ valueSerializer.serialize(toSerialize.getValue(), target);
+ }
+
+ @Override
+ public Tuple2<byte[], StreamRecord<IN>> deserialize(DataInputView source)
throws IOException {
+ final int length = getKeyLength(source);
+ byte[] bytes = new byte[length];
+ source.read(bytes);
+ long timestamp = source.readLong();
+ IN value = valueSerializer.deserialize(source);
+ return Tuple2.of(bytes, new StreamRecord<>(value, timestamp));
+ }
+
+ @Override
+ public Tuple2<byte[], StreamRecord<IN>> deserialize(
+ Tuple2<byte[], StreamRecord<IN>> reuse, DataInputView source)
throws IOException {
+ final int length = getKeyLength(source);
+ byte[] bytes = new byte[length];
+ source.read(bytes);
+ long timestamp = source.readLong();
+ IN value = valueSerializer.deserialize(source);
+ StreamRecord<IN> reuseRecord = reuse.f1;
+ reuseRecord.replace(value, timestamp);
+ reuse.f0 = bytes;
+ reuse.f1 = reuseRecord;
+ return reuse;
+ }
+
+ private int getKeyLength(DataInputView source) throws IOException {
+ final int length;
+ if (serializedKeyLength < 0) {
+ length = source.readInt();
+ } else {
+ length = serializedKeyLength;
+ }
+ return length;
+ }
+
+ @Override
+ public void copy(DataInputView source, DataOutputView target) throws
IOException {
+ final int length;
+ if (serializedKeyLength < 0) {
+ length = source.readInt();
+ target.writeInt(length);
+ } else {
+ length = serializedKeyLength;
+ }
+ for (int i = 0; i < length; i++) {
+ target.writeByte(source.readByte());
+ }
+ target.writeLong(source.readLong());
+ valueSerializer.copy(source, target);
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ KeyAndValueSerializer<?> that = (KeyAndValueSerializer<?>) o;
+ return Objects.equals(valueSerializer, that.valueSerializer);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(valueSerializer);
+ }
+
+ @Override
+ public TypeSerializerSnapshot<Tuple2<byte[], StreamRecord<IN>>>
snapshotConfiguration() {
+ throw new UnsupportedOperationException(
+ "The KeyAndValueSerializer should not be used for persisting
into State!");
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/VariableLengthByteKeyComparator.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/VariableLengthByteKeyComparator.java
new file mode 100644
index 00000000..7ae89caa
--- /dev/null
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/sort/VariableLengthByteKeyComparator.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream.sort;
+
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.MemorySegment;
+import org.apache.flink.streaming.api.operators.sort.SortingDataInput;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import java.io.IOException;
+import java.util.Arrays;
+
+/**
+ * A comparator used in {@link SortingDataInput} which compares records keys
and timestamps,. It
+ * uses binary format produced by the {@link KeyAndValueSerializer}.
+ *
+ * <p>It assumes keys are of a variable length and thus expects the length of
the record to be
+ * serialized.
+ *
+ * <p>TODO: remove this class after making the corresponding class in Flink
public.
+ */
+public final class VariableLengthByteKeyComparator<IN>
+ extends TypeComparator<Tuple2<byte[], StreamRecord<IN>>> {
+ private byte[] keyReference;
+ private long timestampReference;
+
+ @Override
+ public int hash(Tuple2<byte[], StreamRecord<IN>> record) {
+ return record.hashCode();
+ }
+
+ @Override
+ public void setReference(Tuple2<byte[], StreamRecord<IN>> toCompare) {
+ this.keyReference = Arrays.copyOf(toCompare.f0, toCompare.f0.length);
+ this.timestampReference = toCompare.f1.asRecord().getTimestamp();
+ }
+
+ @Override
+ public boolean equalToReference(Tuple2<byte[], StreamRecord<IN>>
candidate) {
+ return Arrays.equals(keyReference, candidate.f0)
+ && timestampReference ==
candidate.f1.asRecord().getTimestamp();
+ }
+
+ @Override
+ public int compareToReference(
+ TypeComparator<Tuple2<byte[], StreamRecord<IN>>>
referencedComparator) {
+ byte[] otherKey = ((VariableLengthByteKeyComparator<IN>)
referencedComparator).keyReference;
+ long otherTimestamp =
+ ((VariableLengthByteKeyComparator<IN>)
referencedComparator).timestampReference;
+
+ int keyCmp = compare(otherKey, this.keyReference);
+ if (keyCmp != 0) {
+ return keyCmp;
+ }
+ return Long.compare(otherTimestamp, this.timestampReference);
+ }
+
+ @Override
+ public int compare(
+ Tuple2<byte[], StreamRecord<IN>> first, Tuple2<byte[],
StreamRecord<IN>> second) {
+ int keyCmp = compare(first.f0, second.f0);
+ if (keyCmp != 0) {
+ return keyCmp;
+ }
+ return Long.compare(
+ first.f1.asRecord().getTimestamp(),
second.f1.asRecord().getTimestamp());
+ }
+
+ private int compare(byte[] first, byte[] second) {
+ int firstLength = first.length;
+ int secondLength = second.length;
+ int minLength = Math.min(firstLength, secondLength);
+ for (int i = 0; i < minLength; i++) {
+ int cmp = Byte.compare(first[i], second[i]);
+
+ if (cmp != 0) {
+ return cmp;
+ }
+ }
+
+ return Integer.compare(firstLength, secondLength);
+ }
+
+ @Override
+ public int compareSerialized(DataInputView firstSource, DataInputView
secondSource)
+ throws IOException {
+ int firstLength = firstSource.readInt();
+ int secondLength = secondSource.readInt();
+ int minLength = Math.min(firstLength, secondLength);
+ while (minLength-- > 0) {
+ byte firstValue = firstSource.readByte();
+ byte secondValue = secondSource.readByte();
+
+ int cmp = Byte.compare(firstValue, secondValue);
+ if (cmp != 0) {
+ return cmp;
+ }
+ }
+
+ int lengthCompare = Integer.compare(firstLength, secondLength);
+ if (lengthCompare != 0) {
+ return lengthCompare;
+ } else {
+ return Long.compare(firstSource.readLong(),
secondSource.readLong());
+ }
+ }
+
+ @Override
+ public boolean supportsNormalizedKey() {
+ return true;
+ }
+
+ @Override
+ public int getNormalizeKeyLen() {
+ return Integer.MAX_VALUE;
+ }
+
+ @Override
+ public boolean isNormalizedKeyPrefixOnly(int keyBytes) {
+ return true;
+ }
+
+ @Override
+ public void putNormalizedKey(
+ Tuple2<byte[], StreamRecord<IN>> record,
+ MemorySegment target,
+ int offset,
+ int numBytes) {
+ BytesKeyNormalizationUtil.putNormalizedKey(
+ record, record.f0.length, target, offset, numBytes);
+ }
+
+ @Override
+ public boolean invertNormalizedKey() {
+ return false;
+ }
+
+ @Override
+ public TypeComparator<Tuple2<byte[], StreamRecord<IN>>> duplicate() {
+ return new VariableLengthByteKeyComparator<>();
+ }
+
+ @Override
+ public int extractKeys(Object record, Object[] target, int index) {
+ target[index] = record;
+ return 1;
+ }
+
+ @Override
+ public TypeComparator<?>[] getFlatComparators() {
+ return new TypeComparator[] {this};
+ }
+
+ //
--------------------------------------------------------------------------------------------
+ // unsupported normalization
+ //
--------------------------------------------------------------------------------------------
+
+ @Override
+ public boolean supportsSerializationWithKeyNormalization() {
+ return false;
+ }
+
+ @Override
+ public void writeWithKeyNormalization(
+ Tuple2<byte[], StreamRecord<IN>> record, DataOutputView target)
throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Tuple2<byte[], StreamRecord<IN>> readWithKeyDenormalization(
+ Tuple2<byte[], StreamRecord<IN>> reuse, DataInputView source)
throws IOException {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
index 7b3e8b3a..e72482f5 100644
---
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
+++
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
@@ -19,10 +19,14 @@
package org.apache.flink.ml.common.datastream;
import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.CoGroupFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -33,6 +37,7 @@ import org.apache.commons.collections.IteratorUtils;
import org.junit.Before;
import org.junit.Test;
+import java.util.Arrays;
import java.util.List;
import static org.junit.Assert.assertArrayEquals;
@@ -48,6 +53,95 @@ public class DataStreamUtilsTest {
env = TestUtils.getExecutionEnvironment();
}
+ @Test
+ public void testCoGroupWithSingleParallelism() throws Exception {
+ DataStream<Tuple2<Integer, Integer>> data1 =
+ env.fromCollection(
+ Arrays.asList(Tuple2.of(1, 1), Tuple2.of(2, 2),
Tuple2.of(3, 3)));
+ DataStream<Tuple2<Integer, Double>> data2 =
+ env.fromCollection(
+ Arrays.asList(
+ Tuple2.of(1, 1.5),
+ Tuple2.of(5, 5.5),
+ Tuple2.of(3, 3.5),
+ Tuple2.of(1, 2.5)));
+ DataStream<Double> result =
+ DataStreamUtils.coGroup(
+ data1,
+ data2,
+ (KeySelector<Tuple2<Integer, Integer>, Integer>) tuple
-> tuple.f0,
+ (KeySelector<Tuple2<Integer, Double>, Integer>) tuple
-> tuple.f0,
+ BasicTypeInfo.DOUBLE_TYPE_INFO,
+ new CoGroupFunction<
+ Tuple2<Integer, Integer>, Tuple2<Integer,
Double>, Double>() {
+ @Override
+ public void coGroup(
+ Iterable<Tuple2<Integer, Integer>>
iterableA,
+ Iterable<Tuple2<Integer, Double>>
iterableB,
+ Collector<Double> collector) {
+ List<Tuple2<Integer, Integer>> valuesA =
+
IteratorUtils.toList(iterableA.iterator());
+ List<Tuple2<Integer, Double>> valuesB =
+
IteratorUtils.toList(iterableB.iterator());
+
+ double sum = 0;
+ for (Tuple2<Integer, Integer> value : valuesA)
{
+ sum += value.f1;
+ }
+ for (Tuple2<Integer, Double> value : valuesB) {
+ sum += value.f1;
+ }
+ collector.collect(sum);
+ }
+ });
+
+ List<Double> resultValues =
IteratorUtils.toList(result.executeAndCollect());
+ double[] resultPrimitiveValues =
+
resultValues.stream().mapToDouble(Double::doubleValue).toArray();
+ double[] expectedResult = new double[] {5.0, 2.0, 6.5, 5.5};
+ assertArrayEquals(expectedResult, resultPrimitiveValues, 1e-5);
+ }
+
+ @Test
+ public void testCoGroupWithMultiParallelism() throws Exception {
+ DataStream<Long> data1 =
+ env.fromParallelCollection(new NumberSequenceIterator(0L,
10L), Types.LONG);
+ DataStream<Long> data2 =
+ env.fromParallelCollection(new NumberSequenceIterator(6L,
16L), Types.LONG);
+
+ DataStream<Long> result =
+ DataStreamUtils.coGroup(
+ data1,
+ data2,
+ (KeySelector<Long, Long>) v -> v / 2,
+ (KeySelector<Long, Long>) v -> v / 2,
+ BasicTypeInfo.LONG_TYPE_INFO,
+ new CoGroupFunction<Long, Long, Long>() {
+ @Override
+ public void coGroup(
+ Iterable<Long> iterableA,
+ Iterable<Long> iterableB,
+ Collector<Long> collector) {
+ List<Long> valuesA =
IteratorUtils.toList(iterableA.iterator());
+ List<Long> valuesB =
IteratorUtils.toList(iterableB.iterator());
+ long sum = 0;
+ for (Long value : valuesA) {
+ sum += value;
+ }
+ for (Long value : valuesB) {
+ sum += value;
+ }
+ collector.collect(sum);
+ }
+ });
+
+ List<Long> resultValues =
IteratorUtils.toList(result.executeAndCollect());
+ long[] resultPrimitiveValues =
resultValues.stream().mapToLong(Long::longValue).toArray();
+ Arrays.sort(resultPrimitiveValues);
+ long[] expectedResult = new long[] {1, 5, 9, 16, 25, 26, 29, 31, 34};
+ assertArrayEquals(expectedResult, resultPrimitiveValues);
+ }
+
@Test
public void testMapPartition() throws Exception {
DataStream<Long> dataStream =
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
index 56330742..6fbf39d5 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
@@ -160,7 +160,7 @@ public class KMeans implements Estimator<KMeans,
KMeansModel>, KMeansParams<KMea
DenseVectorTypeInfo.INSTANCE)),
new
CentroidsUpdateAccumulator(distanceMeasure));
-
DataStreamUtils.setManagedMemoryWeight(centroidIdAndPoints.getTransformation(),
100);
+ DataStreamUtils.setManagedMemoryWeight(centroidIdAndPoints, 100);
int parallelism = centroidIdAndPoints.getParallelism();
DataStream<KMeansModelData> newModelData =