lindong28 commented on code in PR #90:
URL: https://github.com/apache/flink-ml/pull/90#discussion_r865976547


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/SGD.java:
##########
@@ -0,0 +1,402 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.optimizer;
+
+import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.IterationConfig;
+import org.apache.flink.iteration.IterationListener;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.ReplayableDataStreamList;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.feature.LabeledPointWithWeight;
+import org.apache.flink.ml.common.iteration.TerminateOnMaxIterOrTol;
+import org.apache.flink.ml.common.lossfunc.LossFunc;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.regression.linearregression.LinearRegression;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.OutputTag;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Stochastic Gradient Descent (SGD) is the mostly wide-used optimizer for 
optimizing machine
+ * learning models. It iteratively makes small adjustments to the machine 
learning model according
+ * to the gradient at each step, to decrease the error of the model.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Stochastic_gradient_descent.
+ */
+@Internal
+public class SGD implements Optimizer {
+    /** Params for SGD optimizer. */
+    private final SGDParams params;
+
+    public SGD(
+            int maxIter,
+            double learningRate,
+            int globalBatchSize,
+            double tol,
+            double reg,
+            double elasticNet) {
+        this.params = new SGDParams(maxIter, learningRate, globalBatchSize, 
tol, reg, elasticNet);
+    }
+
+    @Override
+    public DataStream<DenseVector> optimize(
+            DataStream<DenseVector> bcInitModel,
+            DataStream<LabeledPointWithWeight> trainData,
+            LossFunc lossFunc) {
+        DataStreamList resultList =
+                Iterations.iterateBoundedStreamsUntilTermination(
+                        DataStreamList.of(bcInitModel.map(modelVec -> 
modelVec.values)),
+                        
ReplayableDataStreamList.notReplay(trainData.rebalance()),
+                        IterationConfig.newBuilder().build(),
+                        new TrainIterationBody(lossFunc, params));
+        return resultList.get(0);
+    }
+
+    /** The iteration implementation for training process. */
+    private static class TrainIterationBody implements IterationBody {
+        private final LossFunc lossFunc;
+        private final SGDParams params;
+
+        public TrainIterationBody(LossFunc lossFunc, SGDParams params) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            // The variable stream at the first iteration is the initialized 
model data.
+            // In the following iterations, it contains: the model update, 
weightSum and
+            // lossSum.
+            DataStream<double[]> variableStream = variableStreams.get(0);
+            DataStream<LabeledPointWithWeight> trainData = dataStreams.get(0);
+            final OutputTag<DenseVector> modelDataOutputTag =
+                    new OutputTag<DenseVector>("MODEL_OUTPUT") {};
+
+            SingleOutputStreamOperator<double[]> modelUpdateAndWeightAndLoss =
+                    trainData
+                            .connect(variableStream)
+                            .transform(
+                                    "CacheDataAndDoTrain",
+                                    
PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO,
+                                    new CacheDataAndDoTrain(lossFunc, params, 
modelDataOutputTag));
+
+            DataStreamList feedbackVariableStream =
+                    IterationBody.forEachRound(
+                            DataStreamList.of(modelUpdateAndWeightAndLoss),
+                            input -> {
+                                DataStream<double[]> feedback =
+                                        
DataStreamUtils.allReduceSum(input.get(0));
+                                return DataStreamList.of(feedback);
+                            });
+
+            DataStream<Integer> terminationCriteria =
+                    feedbackVariableStream
+                            .get(0)
+                            .map(
+                                    reducedUpdateAndWeightAndLoss -> {
+                                        double[] value = (double[]) 
reducedUpdateAndWeightAndLoss;
+                                        return value[value.length - 1] / 
value[value.length - 2];
+                                    })
+                            .flatMap(new 
TerminateOnMaxIterOrTol(params.maxIter, params.tol));
+
+            return new IterationBodyResult(
+                    DataStreamList.of(feedbackVariableStream.get(0)),
+                    DataStreamList.of(
+                            
modelUpdateAndWeightAndLoss.getSideOutput(modelDataOutputTag)),
+                    terminationCriteria);
+        }
+    }
+
+    /**
+     * A stream operator that caches the training data in the first iteration 
and updates the model
+     * iteratively. The first input is the training data, and the second input 
is the initialized
+     * model data or feedback of model update, weight and loss.
+     */
+    private static class CacheDataAndDoTrain extends 
AbstractStreamOperator<double[]>
+            implements TwoInputStreamOperator<LabeledPointWithWeight, 
double[], double[]>,
+                    IterationListener<double[]> {
+        /** The cached training data. */
+        private List<LabeledPointWithWeight> trainData;
+
+        private ListState<LabeledPointWithWeight> trainDataState;
+
+        /** The start index (offset) of the next mini-batch data for training. 
*/
+        private int nextBatchOffset = 0;
+
+        private ListState<Integer> nextBatchOffsetState;
+
+        /** The model coefficient. */
+        private DenseVector coefficient;
+
+        private ListState<DenseVector> coefficientState;
+        /** The dimension of the coefficient. */
+        private int coeffiDim;
+
+        /**
+         * The double array to sync among all workers. For example, when 
training {@link
+         * LinearRegression}, this double array is consisted of {modelUpdate, 
weightSum, lossSum}.
+         */
+        private double[] feedbackArray;
+
+        private ListState<double[]> feedbackArrayState;
+        /** The outputTag to output the model data when iteration ends. */
+        private final OutputTag<DenseVector> modelDataOutputTag;
+        /** The batch size on this task. */
+        private int localBatchSize;
+        /** Optimizer-related parameters. */
+        private final SGDParams params;
+        /** The loss function to optimize. */
+        private final LossFunc lossFunc;
+
+        private CacheDataAndDoTrain(
+                LossFunc lossFunc, SGDParams params, OutputTag<DenseVector> 
modelDataOutputTag) {
+            this.lossFunc = lossFunc;
+            this.params = params;
+            this.modelDataOutputTag = modelDataOutputTag;
+        }
+
+        @Override
+        public void open() {
+            int numTasks = getRuntimeContext().getNumberOfParallelSubtasks();
+            int taskId = getRuntimeContext().getIndexOfThisSubtask();
+            localBatchSize = params.globalBatchSize / numTasks;
+            if (params.globalBatchSize % numTasks > taskId) {
+                localBatchSize++;
+            }
+        }
+
+        /**
+         * Gets the weight sum of the processed elements.
+         *
+         * @return The weight sum.
+         */
+        private double getWeightSum() {
+            return feedbackArray[coeffiDim];
+        }
+
+        /**
+         * Sets the weight sum of the processed elements.
+         *
+         * @param weightSum The weight sum.
+         */
+        private void setWeightSum(double weightSum) {
+            feedbackArray[coeffiDim] = weightSum;
+        }
+
+        /**
+         * Gets the loss sum of the processed elements.
+         *
+         * @return The loss sum.
+         */
+        private double getLoss() {
+            return feedbackArray[coeffiDim + 1];
+        }
+
+        /**
+         * Sets the loss sum of the processed elements.
+         *
+         * @param loss The loss sum.
+         */
+        private void setLoss(double loss) {
+            feedbackArray[coeffiDim + 1] = loss;
+        }
+
+        @Override
+        public void onEpochWatermarkIncremented(
+                int epochWatermark, Context context, Collector<double[]> 
collector)
+                throws Exception {
+            if (epochWatermark == 0) {
+                coefficient = new DenseVector(feedbackArray);
+                coeffiDim = coefficient.size();
+                feedbackArray = new double[coefficient.size() + 2];
+            } else {
+                if (getWeightSum() > 0) {
+                    BLAS.axpy(
+                            -params.learningRate / getWeightSum(),
+                            new DenseVector(feedbackArray),
+                            coefficient,
+                            coeffiDim);
+                    double regLoss =
+                            RegularizationUtils.regularize(
+                                    coefficient,
+                                    params.reg,
+                                    params.elasticNet,
+                                    params.learningRate);
+                    setLoss(getLoss() + regLoss);
+                }
+            }
+
+            if (trainData == null) {
+                trainData = 
IteratorUtils.toList(trainDataState.get().iterator());
+            }
+
+            // TODO: supports efficient shuffle of training set on each 
partition.
+            if (trainData.size() > 0) {
+                List<LabeledPointWithWeight> miniBatchData =
+                        trainData.subList(
+                                nextBatchOffset,
+                                Math.min(nextBatchOffset + localBatchSize, 
trainData.size()));

Review Comment:
   Sounds good. Thanks for the explanation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to