This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new eeccb82 [FLINK-24845] Add allreduce utility function in FlinkML
eeccb82 is described below
commit eeccb82129dd666070a9c1c220f2f8de3b0e5aec
Author: zhangzp <[email protected]>
AuthorDate: Wed Nov 17 14:31:27 2021 +0800
[FLINK-24845] Add allreduce utility function in FlinkML
This closes #30.
---
.../flink/ml/common/datastream/AllReduceImpl.java | 300 +++++++++++++++++++++
.../ml/common/datastream/DataStreamUtils.java | 41 +++
.../ml/common/datastream/AllReduceImplTest.java | 165 ++++++++++++
3 files changed, 506 insertions(+)
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
new file mode 100644
index 0000000..b8571a0
--- /dev/null
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Applies all-reduce on a data stream where each partition contains only one
double array.
+ *
+ * <p>AllReduce is a communication primitive widely used in MPI. In this
implementation, all workers
+ * do reduce on a partition of the whole data and they all get the final
reduce result. In detail,
+ * we split each double array into chunks of fixed size buffer (32KB by
default) and let each
+ * subtask handle several chunks.
+ *
+ * <p>There're mainly three stages:
+ * <li>All workers send their partial data to other workers for reduce.
+ * <li>All workers do reduce on all data it received and then broadcast
partial results to others.
+ * <li>All workers merge partial results into final result.
+ */
+class AllReduceImpl {
+
+ @VisibleForTesting static final int CHUNK_SIZE = 1024 * 4;
+
+ /**
+ * Applies allReduceSum on the input data stream. The input data stream is
supposed to contain
+ * one double array in each worker. The result data stream has the same
parallelism as the
+ * input, where each worker contains one double array that sums all of the
double arrays in the
+ * input data stream.
+ *
+ * <p>We throw exception when one of the following two cases happen:
+ * <li>There exists one worker that contains more than one double array.
+ * <li>The length of double array is not consistent among all workers.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+ static DataStream<double[]> allReduceSum(DataStream<double[]> input) {
+ // chunkId, originalArrayLength, arrayChunk
+ DataStream<Tuple3<Integer, Integer, double[]>> allReduceSend =
+ input.flatMap(new AllReduceSend())
+ .setParallelism(input.getParallelism())
+ .name("all-reduce-send");
+
+ // taskId, chunkId, originalArrayLength, arrayChunk
+ DataStream<Tuple4<Integer, Integer, Integer, double[]>> allReduceSum =
+ allReduceSend
+ .partitionCustom(
+ (chunkId, numPartitions) -> chunkId %
numPartitions, x -> x.f0)
+ .transform(
+ "all-reduce-sum",
+ new TupleTypeInfo<>(
+ BasicTypeInfo.INT_TYPE_INFO,
+ BasicTypeInfo.INT_TYPE_INFO,
+ BasicTypeInfo.INT_TYPE_INFO,
+
PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO),
+ new AllReduceSum())
+ .setParallelism(input.getParallelism())
+ .name("all-reduce-sum");
+
+ return allReduceSum
+ .partitionCustom((taskIdx, numPartitions) -> taskIdx %
numPartitions, x -> x.f0)
+ .transform(
+ "all-reduce-recv",
+
PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+ new AllReduceRecv())
+ .setParallelism(input.getParallelism())
+ .name("all-reduce-recv");
+ }
+
+ /**
+ * Splits each double array into multiple chunks and sends each chunk to
the corresponding
+ * worker.
+ */
+ private static class AllReduceSend
+ extends RichFlatMapFunction<double[], Tuple3<Integer, Integer,
double[]>> {
+
+ private boolean hasReceivedOneRecord = false;
+
+ private double[] transferBuffer = new double[CHUNK_SIZE];
+
+ @Override
+ public void flatMap(
+ double[] inputArray, Collector<Tuple3<Integer, Integer,
double[]>> out) {
+ if (hasReceivedOneRecord) {
+ throw new RuntimeException("The input cannot contain more than
one double array.");
+ }
+ hasReceivedOneRecord = true;
+ int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+
+ for (int taskId = 0; taskId < numTasks; taskId++) {
+ int startChunkId = getStartChunkId(taskId, numTasks,
inputArray.length);
+ int numChunksToHandle = getNumChunksByTaskId(taskId, numTasks,
inputArray.length);
+ for (int chunkId = startChunkId;
+ chunkId < numChunksToHandle + startChunkId;
+ chunkId++) {
+ System.arraycopy(
+ inputArray,
+ chunkId * CHUNK_SIZE,
+ transferBuffer,
+ 0,
+ getLengthOfChunk(chunkId, inputArray.length));
+ out.collect(Tuple3.of(chunkId, inputArray.length,
transferBuffer));
+ }
+ }
+ }
+ }
+
+ /**
+ * Aggregates partitioned array chunks from other workers and broadcast
the aggregated array
+ * chunk to each worker.
+ */
+ private static class AllReduceSum
+ extends AbstractStreamOperator<Tuple4<Integer, Integer, Integer,
double[]>>
+ implements OneInputStreamOperator<
+ Tuple3<Integer, Integer, double[]>,
+ Tuple4<Integer, Integer, Integer, double[]>>,
+ BoundedOneInput {
+
+ /**
+ * A map that aggregates the received array chunks. The key is
chunkId, the value is
+ * (originalArrayLength, aggregatedArrayChunk).
+ */
+ private Map<Integer, Tuple2<Integer, double[]>>
aggregatedArrayChunkByChunkId =
+ new HashMap<>();
+
+ @Override
+ public void endInput() {
+ int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+ for (Map.Entry<Integer, Tuple2<Integer, double[]>> entry :
+ aggregatedArrayChunkByChunkId.entrySet()) {
+ for (int taskId = 0; taskId < numTasks; taskId++) {
+ int chunkId = entry.getKey();
+ int originalArrayLength = entry.getValue().f0;
+ double[] aggregatedArrayChunk = entry.getValue().f1;
+ output.collect(
+ new StreamRecord<>(
+ Tuple4.of(
+ taskId,
+ chunkId,
+ originalArrayLength,
+ aggregatedArrayChunk)));
+ }
+ }
+ }
+
+ @Override
+ public void processElement(StreamRecord<Tuple3<Integer, Integer,
double[]>> streamRecord) {
+ Tuple3<Integer, Integer, double[]> record =
streamRecord.getValue();
+ int chunkId = record.f0;
+ int originalArrayLength = record.f1;
+ double[] arrayChunk = record.f2;
+ if (aggregatedArrayChunkByChunkId.containsKey(chunkId)) {
+ if (aggregatedArrayChunkByChunkId.get(chunkId).f0 !=
originalArrayLength) {
+ throw new RuntimeException("The input double array must
have same length.");
+ }
+ double[] curAggregatedArrayChunk =
aggregatedArrayChunkByChunkId.get(chunkId).f1;
+ for (int i = 0; i < curAggregatedArrayChunk.length; i++) {
+ curAggregatedArrayChunk[i] += arrayChunk[i];
+ }
+ } else {
+ aggregatedArrayChunkByChunkId.put(
+ chunkId, Tuple2.of(originalArrayLength, arrayChunk));
+ }
+ }
+ }
+
+ /** Organizes the received chunks into the result array. */
+ private static class AllReduceRecv extends AbstractStreamOperator<double[]>
+ implements OneInputStreamOperator<
+ Tuple4<Integer, Integer, Integer, double[]>,
double[]>,
+ BoundedOneInput {
+
+ /** Stores the reduced results. */
+ double[] resultArray;
+
+ @Override
+ public void endInput() {
+ if (null != resultArray) {
+ output.collect(new StreamRecord<>(resultArray));
+ }
+ }
+
+ @Override
+ public void processElement(
+ StreamRecord<Tuple4<Integer, Integer, Integer, double[]>>
streamRecord) {
+ Tuple4<Integer, Integer, Integer, double[]> ele =
streamRecord.getValue();
+ int chunkId = ele.f1;
+ int originalArrayLength = ele.f2;
+ double[] aggregatedArrayChunk = ele.f3;
+ if (null == resultArray) {
+ resultArray = new double[originalArrayLength];
+ }
+ System.arraycopy(
+ aggregatedArrayChunk,
+ 0,
+ resultArray,
+ chunkId * CHUNK_SIZE,
+ getLengthOfChunk(chunkId, resultArray.length));
+ }
+ }
+
+ /**
+ * Computes how many chunks is an array with length ${len} going to be
split into.
+ *
+ * @param len Length of the array.
+ * @return Number of chunks the array is split into.
+ */
+ private static int getNumChunks(int len) {
+ int div = len / CHUNK_SIZE;
+ int mod = len % CHUNK_SIZE;
+ return mod == 0 ? div : div + 1;
+ }
+
+ /**
+ * Computes the length of the last chunk of an array with length ${len}.
+ *
+ * @param len Length of the array.
+ * @return Length of the last chunk.
+ */
+ private static int getLengthOfChunk(int chunkId, int len) {
+ if (chunkId == getNumChunks(len) - 1) {
+ int mod = len % CHUNK_SIZE;
+ return mod == 0 ? CHUNK_SIZE : mod;
+ } else {
+ return CHUNK_SIZE;
+ }
+ }
+
+ /**
+ * Computes the index of the first chunk that one task needs to handle.
+ *
+ * @param taskId Index of the current task.
+ * @param numTasks Number of parallel tasks.
+ * @param len Length of the array to be reduced.
+ * @return Start position of this task.
+ */
+ private static int getStartChunkId(int taskId, int numTasks, int len) {
+ int numChunks = getNumChunks(len);
+ int div = numChunks / numTasks;
+ int mod = numChunks % numTasks;
+
+ if (taskId >= mod) {
+ return div * taskId + mod;
+ } else {
+ return div * taskId + taskId;
+ }
+ }
+
+ /**
+ * Computes the number of chunks that one task needs to handle.
+ *
+ * @param taskId Index of the current task.
+ * @param parallelism Number of parallel tasks.
+ * @param len Length of the array to be reduced.
+ * @return Number of chunks this task needs to handle.
+ */
+ private static int getNumChunksByTaskId(int taskId, int parallelism, int
len) {
+ int numChunks = getNumChunks(len);
+ int div = numChunks / parallelism;
+ int mod = numChunks % parallelism;
+
+ if (taskId >= mod) {
+ return div;
+ } else {
+ return div + 1;
+ }
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
new file mode 100644
index 0000000..99a9c2f
--- /dev/null
+++
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+ /**
+ * Applies allReduceSum on the input data stream. The input data stream is
supposed to contain
+ * one double array in each partition. The result data stream has the same
parallelism as the
+ * input, where each partition contains one double array that sums all of
the double arrays in
+ * the input data stream.
+ *
+ * <p>Note that we throw exception when one of the following two cases
happen:
+ * <li>There exists one partition that contains more than one double array.
+ * <li>The length of the double array is not consistent among all
partitions.
+ *
+ * @param input The input data stream.
+ * @return The result data stream.
+ */
+ public static DataStream<double[]> allReduceSum(DataStream<double[]>
input) {
+ return AllReduceImpl.allReduceSum(input);
+ }
+}
diff --git
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java
new file mode 100644
index 0000000..1ee0201
--- /dev/null
+++
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.NumberSequenceIterator;
+
+import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests the {@link AllReduceImpl}. */
+@RunWith(Enclosed.class)
+public class AllReduceImplTest {
+
+ private static final int parallelism = 4;
+
+ private static final int chunkSize = AllReduceImpl.CHUNK_SIZE;
+
+ private static final double TOLERANCE = 1e-7;
+
+ /**
+ * Parameterized test for {@link AllReduceImpl}. The test cases include:
+ * <li>when there are no chunks.
+ * <li>when the data is not enough for one chunk.
+ * <li>when not every worker has one chunk to handle.
+ * <li>when each worker needs to handle at least one chunk.
+ */
+ @RunWith(Parameterized.class)
+ public static class ParameterizedTest {
+
+ private static int numElements;
+
+ @Parameterized.Parameters
+ public static Collection<Object[]> params() {
+ return Arrays.asList(
+ new Object[][] {
+ {0},
+ {(int) (chunkSize * 0.5)},
+ {(int) (chunkSize * parallelism * 0.5)},
+ {(int) (chunkSize * parallelism * 1.5)}
+ });
+ }
+
+ public ParameterizedTest(int numElements) {
+ ParameterizedTest.numElements = numElements;
+ }
+
+ @Test
+ public void testAllReduce() throws Exception {
+ StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(parallelism);
+ DataStream<double[]> elements =
+ env.fromParallelCollection(
+ new NumberSequenceIterator(1L,
parallelism),
+ BasicTypeInfo.LONG_TYPE_INFO)
+ .map(
+ x -> {
+ double[] res = new double[numElements];
+ for (int i = 0; i < res.length; i++) {
+ res[i] = i;
+ }
+ return res;
+ });
+
+ DataStreamUtils.allReduceSum(elements)
+ .addSink(
+ new SinkFunction<double[]>() {
+ @Override
+ public void invoke(double[] value, Context
context) {
+ assertEquals(numElements, value.length);
+ for (int i = 0; i < value.length; i++) {
+ assertEquals(i * parallelism,
value[i], TOLERANCE);
+ }
+ }
+ });
+
+ env.execute();
+ }
+ }
+
+ /** Non-parameterized test for {@link AllReduceImpl}. */
+ public static class NonParameterizedTest {
+
+ @Test
+ public void testAllReduceWithMoreThanOneArray() {
+ try {
+ StreamExecutionEnvironment env =
+ StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(parallelism);
+ DataStream<double[]> elements =
+ env.fromParallelCollection(
+ new NumberSequenceIterator(1L,
parallelism),
+ BasicTypeInfo.LONG_TYPE_INFO)
+ .flatMap(
+ new FlatMapFunction<Long, double[]>() {
+ @Override
+ public void flatMap(
+ Long value,
Collector<double[]> out) {
+ out.collect(new double[100]);
+ out.collect(new double[100]);
+ }
+ });
+
+ DataStreamUtils.allReduceSum(elements).addSink(new
SinkFunction<double[]>() {});
+ env.execute();
+ fail();
+ } catch (Exception e) {
+ assertEquals(
+ "The input cannot contain more than one double array.",
+ e.getCause().getCause().getMessage());
+ }
+ }
+
+ @Test
+ public void testAllReduceWithDifferentLength() {
+ try {
+ StreamExecutionEnvironment env =
+ StreamExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(parallelism);
+ DataStream<double[]> elements =
+ env.fromParallelCollection(
+ new NumberSequenceIterator(1L,
parallelism),
+ BasicTypeInfo.LONG_TYPE_INFO)
+ .map(x -> new double[x.intValue()]);
+
+ DataStreamUtils.allReduceSum(elements).addSink(new
SinkFunction<double[]>() {});
+ env.execute();
+ fail();
+ } catch (Exception e) {
+ assertEquals(
+ "The input double array must have same length.",
+ e.getCause().getCause().getMessage());
+ }
+ }
+ }
+}