This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new eeccb82  [FLINK-24845] Add allreduce utility function in FlinkML
eeccb82 is described below

commit eeccb82129dd666070a9c1c220f2f8de3b0e5aec
Author: zhangzp <[email protected]>
AuthorDate: Wed Nov 17 14:31:27 2021 +0800

    [FLINK-24845] Add allreduce utility function in FlinkML
    
    This closes #30.
---
 .../flink/ml/common/datastream/AllReduceImpl.java  | 300 +++++++++++++++++++++
 .../ml/common/datastream/DataStreamUtils.java      |  41 +++
 .../ml/common/datastream/AllReduceImplTest.java    | 165 ++++++++++++
 3 files changed, 506 insertions(+)

diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
new file mode 100644
index 0000000..b8571a0
--- /dev/null
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/AllReduceImpl.java
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.api.java.typeutils.TupleTypeInfo;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * Applies all-reduce on a data stream where each partition contains only one 
double array.
+ *
+ * <p>AllReduce is a communication primitive widely used in MPI. In this 
implementation, all workers
+ * do reduce on a partition of the whole data and they all get the final 
reduce result. In detail,
+ * we split each double array into chunks of fixed size buffer (32KB by 
default) and let each
+ * subtask handle several chunks.
+ *
+ * <p>There're mainly three stages:
+ * <li>All workers send their partial data to other workers for reduce.
+ * <li>All workers do reduce on all data it received and then broadcast 
partial results to others.
+ * <li>All workers merge partial results into final result.
+ */
+class AllReduceImpl {
+
+    @VisibleForTesting static final int CHUNK_SIZE = 1024 * 4;
+
+    /**
+     * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+     * one double array in each worker. The result data stream has the same 
parallelism as the
+     * input, where each worker contains one double array that sums all of the 
double arrays in the
+     * input data stream.
+     *
+     * <p>We throw exception when one of the following two cases happen:
+     * <li>There exists one worker that contains more than one double array.
+     * <li>The length of double array is not consistent among all workers.
+     *
+     * @param input The input data stream.
+     * @return The result data stream.
+     */
+    static DataStream<double[]> allReduceSum(DataStream<double[]> input) {
+        // chunkId, originalArrayLength, arrayChunk
+        DataStream<Tuple3<Integer, Integer, double[]>> allReduceSend =
+                input.flatMap(new AllReduceSend())
+                        .setParallelism(input.getParallelism())
+                        .name("all-reduce-send");
+
+        // taskId, chunkId, originalArrayLength, arrayChunk
+        DataStream<Tuple4<Integer, Integer, Integer, double[]>> allReduceSum =
+                allReduceSend
+                        .partitionCustom(
+                                (chunkId, numPartitions) -> chunkId % 
numPartitions, x -> x.f0)
+                        .transform(
+                                "all-reduce-sum",
+                                new TupleTypeInfo<>(
+                                        BasicTypeInfo.INT_TYPE_INFO,
+                                        BasicTypeInfo.INT_TYPE_INFO,
+                                        BasicTypeInfo.INT_TYPE_INFO,
+                                        
PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO),
+                                new AllReduceSum())
+                        .setParallelism(input.getParallelism())
+                        .name("all-reduce-sum");
+
+        return allReduceSum
+                .partitionCustom((taskIdx, numPartitions) -> taskIdx % 
numPartitions, x -> x.f0)
+                .transform(
+                        "all-reduce-recv",
+                        
PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                        new AllReduceRecv())
+                .setParallelism(input.getParallelism())
+                .name("all-reduce-recv");
+    }
+
+    /**
+     * Splits each double array into multiple chunks and sends each chunk to 
the corresponding
+     * worker.
+     */
+    private static class AllReduceSend
+            extends RichFlatMapFunction<double[], Tuple3<Integer, Integer, 
double[]>> {
+
+        private boolean hasReceivedOneRecord = false;
+
+        private double[] transferBuffer = new double[CHUNK_SIZE];
+
+        @Override
+        public void flatMap(
+                double[] inputArray, Collector<Tuple3<Integer, Integer, 
double[]>> out) {
+            if (hasReceivedOneRecord) {
+                throw new RuntimeException("The input cannot contain more than 
one double array.");
+            }
+            hasReceivedOneRecord = true;
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+
+            for (int taskId = 0; taskId < numTasks; taskId++) {
+                int startChunkId = getStartChunkId(taskId, numTasks, 
inputArray.length);
+                int numChunksToHandle = getNumChunksByTaskId(taskId, numTasks, 
inputArray.length);
+                for (int chunkId = startChunkId;
+                        chunkId < numChunksToHandle + startChunkId;
+                        chunkId++) {
+                    System.arraycopy(
+                            inputArray,
+                            chunkId * CHUNK_SIZE,
+                            transferBuffer,
+                            0,
+                            getLengthOfChunk(chunkId, inputArray.length));
+                    out.collect(Tuple3.of(chunkId, inputArray.length, 
transferBuffer));
+                }
+            }
+        }
+    }
+
+    /**
+     * Aggregates partitioned array chunks from other workers and broadcast 
the aggregated array
+     * chunk to each worker.
+     */
+    private static class AllReduceSum
+            extends AbstractStreamOperator<Tuple4<Integer, Integer, Integer, 
double[]>>
+            implements OneInputStreamOperator<
+                            Tuple3<Integer, Integer, double[]>,
+                            Tuple4<Integer, Integer, Integer, double[]>>,
+                    BoundedOneInput {
+
+        /**
+         * A map that aggregates the received array chunks. The key is 
chunkId, the value is
+         * (originalArrayLength, aggregatedArrayChunk).
+         */
+        private Map<Integer, Tuple2<Integer, double[]>> 
aggregatedArrayChunkByChunkId =
+                new HashMap<>();
+
+        @Override
+        public void endInput() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            for (Map.Entry<Integer, Tuple2<Integer, double[]>> entry :
+                    aggregatedArrayChunkByChunkId.entrySet()) {
+                for (int taskId = 0; taskId < numTasks; taskId++) {
+                    int chunkId = entry.getKey();
+                    int originalArrayLength = entry.getValue().f0;
+                    double[] aggregatedArrayChunk = entry.getValue().f1;
+                    output.collect(
+                            new StreamRecord<>(
+                                    Tuple4.of(
+                                            taskId,
+                                            chunkId,
+                                            originalArrayLength,
+                                            aggregatedArrayChunk)));
+                }
+            }
+        }
+
+        @Override
+        public void processElement(StreamRecord<Tuple3<Integer, Integer, 
double[]>> streamRecord) {
+            Tuple3<Integer, Integer, double[]> record = 
streamRecord.getValue();
+            int chunkId = record.f0;
+            int originalArrayLength = record.f1;
+            double[] arrayChunk = record.f2;
+            if (aggregatedArrayChunkByChunkId.containsKey(chunkId)) {
+                if (aggregatedArrayChunkByChunkId.get(chunkId).f0 != 
originalArrayLength) {
+                    throw new RuntimeException("The input double array must 
have same length.");
+                }
+                double[] curAggregatedArrayChunk = 
aggregatedArrayChunkByChunkId.get(chunkId).f1;
+                for (int i = 0; i < curAggregatedArrayChunk.length; i++) {
+                    curAggregatedArrayChunk[i] += arrayChunk[i];
+                }
+            } else {
+                aggregatedArrayChunkByChunkId.put(
+                        chunkId, Tuple2.of(originalArrayLength, arrayChunk));
+            }
+        }
+    }
+
+    /** Organizes the received chunks into the result array. */
+    private static class AllReduceRecv extends AbstractStreamOperator<double[]>
+            implements OneInputStreamOperator<
+                            Tuple4<Integer, Integer, Integer, double[]>, 
double[]>,
+                    BoundedOneInput {
+
+        /** Stores the reduced results. */
+        double[] resultArray;
+
+        @Override
+        public void endInput() {
+            if (null != resultArray) {
+                output.collect(new StreamRecord<>(resultArray));
+            }
+        }
+
+        @Override
+        public void processElement(
+                StreamRecord<Tuple4<Integer, Integer, Integer, double[]>> 
streamRecord) {
+            Tuple4<Integer, Integer, Integer, double[]> ele = 
streamRecord.getValue();
+            int chunkId = ele.f1;
+            int originalArrayLength = ele.f2;
+            double[] aggregatedArrayChunk = ele.f3;
+            if (null == resultArray) {
+                resultArray = new double[originalArrayLength];
+            }
+            System.arraycopy(
+                    aggregatedArrayChunk,
+                    0,
+                    resultArray,
+                    chunkId * CHUNK_SIZE,
+                    getLengthOfChunk(chunkId, resultArray.length));
+        }
+    }
+
+    /**
+     * Computes how many chunks is an array with length ${len} going to be 
split into.
+     *
+     * @param len Length of the array.
+     * @return Number of chunks the array is split into.
+     */
+    private static int getNumChunks(int len) {
+        int div = len / CHUNK_SIZE;
+        int mod = len % CHUNK_SIZE;
+        return mod == 0 ? div : div + 1;
+    }
+
+    /**
+     * Computes the length of the last chunk of an array with length ${len}.
+     *
+     * @param len Length of the array.
+     * @return Length of the last chunk.
+     */
+    private static int getLengthOfChunk(int chunkId, int len) {
+        if (chunkId == getNumChunks(len) - 1) {
+            int mod = len % CHUNK_SIZE;
+            return mod == 0 ? CHUNK_SIZE : mod;
+        } else {
+            return CHUNK_SIZE;
+        }
+    }
+
+    /**
+     * Computes the index of the first chunk that one task needs to handle.
+     *
+     * @param taskId Index of the current task.
+     * @param numTasks Number of parallel tasks.
+     * @param len Length of the array to be reduced.
+     * @return Start position of this task.
+     */
+    private static int getStartChunkId(int taskId, int numTasks, int len) {
+        int numChunks = getNumChunks(len);
+        int div = numChunks / numTasks;
+        int mod = numChunks % numTasks;
+
+        if (taskId >= mod) {
+            return div * taskId + mod;
+        } else {
+            return div * taskId + taskId;
+        }
+    }
+
+    /**
+     * Computes the number of chunks that one task needs to handle.
+     *
+     * @param taskId Index of the current task.
+     * @param parallelism Number of parallel tasks.
+     * @param len Length of the array to be reduced.
+     * @return Number of chunks this task needs to handle.
+     */
+    private static int getNumChunksByTaskId(int taskId, int parallelism, int 
len) {
+        int numChunks = getNumChunks(len);
+        int div = numChunks / parallelism;
+        int mod = numChunks % parallelism;
+
+        if (taskId >= mod) {
+            return div;
+        } else {
+            return div + 1;
+        }
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
new file mode 100644
index 0000000..99a9c2f
--- /dev/null
+++ 
b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.streaming.api.datastream.DataStream;
+
+/** Provides utility functions for {@link DataStream}. */
+public class DataStreamUtils {
+    /**
+     * Applies allReduceSum on the input data stream. The input data stream is 
supposed to contain
+     * one double array in each partition. The result data stream has the same 
parallelism as the
+     * input, where each partition contains one double array that sums all of 
the double arrays in
+     * the input data stream.
+     *
+     * <p>Note that we throw exception when one of the following two cases 
happen:
+     * <li>There exists one partition that contains more than one double array.
+     * <li>The length of the double array is not consistent among all 
partitions.
+     *
+     * @param input The input data stream.
+     * @return The result data stream.
+     */
+    public static DataStream<double[]> allReduceSum(DataStream<double[]> 
input) {
+        return AllReduceImpl.allReduceSum(input);
+    }
+}
diff --git 
a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java
 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java
new file mode 100644
index 0000000..1ee0201
--- /dev/null
+++ 
b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/AllReduceImplTest.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.NumberSequenceIterator;
+
+import org.junit.Test;
+import org.junit.experimental.runners.Enclosed;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.util.Arrays;
+import java.util.Collection;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests the {@link AllReduceImpl}. */
+@RunWith(Enclosed.class)
+public class AllReduceImplTest {
+
+    private static final int parallelism = 4;
+
+    private static final int chunkSize = AllReduceImpl.CHUNK_SIZE;
+
+    private static final double TOLERANCE = 1e-7;
+
+    /**
+     * Parameterized test for {@link AllReduceImpl}. The test cases include:
+     * <li>when there are no chunks.
+     * <li>when the data is not enough for one chunk.
+     * <li>when not every worker has one chunk to handle.
+     * <li>when each worker needs to handle at least one chunk.
+     */
+    @RunWith(Parameterized.class)
+    public static class ParameterizedTest {
+
+        private static int numElements;
+
+        @Parameterized.Parameters
+        public static Collection<Object[]> params() {
+            return Arrays.asList(
+                    new Object[][] {
+                        {0},
+                        {(int) (chunkSize * 0.5)},
+                        {(int) (chunkSize * parallelism * 0.5)},
+                        {(int) (chunkSize * parallelism * 1.5)}
+                    });
+        }
+
+        public ParameterizedTest(int numElements) {
+            ParameterizedTest.numElements = numElements;
+        }
+
+        @Test
+        public void testAllReduce() throws Exception {
+            StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+            env.setParallelism(parallelism);
+            DataStream<double[]> elements =
+                    env.fromParallelCollection(
+                                    new NumberSequenceIterator(1L, 
parallelism),
+                                    BasicTypeInfo.LONG_TYPE_INFO)
+                            .map(
+                                    x -> {
+                                        double[] res = new double[numElements];
+                                        for (int i = 0; i < res.length; i++) {
+                                            res[i] = i;
+                                        }
+                                        return res;
+                                    });
+
+            DataStreamUtils.allReduceSum(elements)
+                    .addSink(
+                            new SinkFunction<double[]>() {
+                                @Override
+                                public void invoke(double[] value, Context 
context) {
+                                    assertEquals(numElements, value.length);
+                                    for (int i = 0; i < value.length; i++) {
+                                        assertEquals(i * parallelism, 
value[i], TOLERANCE);
+                                    }
+                                }
+                            });
+
+            env.execute();
+        }
+    }
+
+    /** Non-parameterized test for {@link AllReduceImpl}. */
+    public static class NonParameterizedTest {
+
+        @Test
+        public void testAllReduceWithMoreThanOneArray() {
+            try {
+                StreamExecutionEnvironment env =
+                        StreamExecutionEnvironment.getExecutionEnvironment();
+                env.setParallelism(parallelism);
+                DataStream<double[]> elements =
+                        env.fromParallelCollection(
+                                        new NumberSequenceIterator(1L, 
parallelism),
+                                        BasicTypeInfo.LONG_TYPE_INFO)
+                                .flatMap(
+                                        new FlatMapFunction<Long, double[]>() {
+                                            @Override
+                                            public void flatMap(
+                                                    Long value, 
Collector<double[]> out) {
+                                                out.collect(new double[100]);
+                                                out.collect(new double[100]);
+                                            }
+                                        });
+
+                DataStreamUtils.allReduceSum(elements).addSink(new 
SinkFunction<double[]>() {});
+                env.execute();
+                fail();
+            } catch (Exception e) {
+                assertEquals(
+                        "The input cannot contain more than one double array.",
+                        e.getCause().getCause().getMessage());
+            }
+        }
+
+        @Test
+        public void testAllReduceWithDifferentLength() {
+            try {
+                StreamExecutionEnvironment env =
+                        StreamExecutionEnvironment.getExecutionEnvironment();
+                env.setParallelism(parallelism);
+                DataStream<double[]> elements =
+                        env.fromParallelCollection(
+                                        new NumberSequenceIterator(1L, 
parallelism),
+                                        BasicTypeInfo.LONG_TYPE_INFO)
+                                .map(x -> new double[x.intValue()]);
+
+                DataStreamUtils.allReduceSum(elements).addSink(new 
SinkFunction<double[]>() {});
+                env.execute();
+                fail();
+            } catch (Exception e) {
+                assertEquals(
+                        "The input double array must have same length.",
+                        e.getCause().getCause().getMessage());
+            }
+        }
+    }
+}

Reply via email to