yunfengzhou-hub commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r847844186


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*

Review Comment:
   Jira tickets for algorithms that are supposed to be added have all been 
created in advance. You can find the ticket for FTRL on 
https://issues.apache.org/jira/secure/RapidBoard.jspa?rapidView=541. The ticket 
for FTRL is FLINK-20790.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import 
org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {

Review Comment:
   The test cases are arranged differently from existing practice. Let's add 
tests that each covers the following situations.
   - tests getting/setting parameters
   - tests the most common fit/transform process.
   - tests save/load.
   - tests getting/setting model data.
   - tests invalid inputs/corner cases.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LinearModelData.java:
##########
@@ -37,60 +38,68 @@
 import java.io.OutputStream;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel}, {@link FtrlModel}.
  *
  * <p>This class also provides methods to convert model data from Table to 
Datastream, and classes
  * to save/load model data.
  */
-public class LogisticRegressionModelData {
+public class LinearModelData {

Review Comment:
   Could you please illustrate the relationship between FTRL and 
LogisticRegression, and other algorithms like LinearRegression? I'm not sure 
why we would like to rename `LogisticRegressionModelData` as `LinearModelData`.
   
   If after discussion we still agree that this renaming is reasonable, it 
would mean that the model data class neither belongs `logisticregresson` or 
`ftrl` package. We would need to place classes like this to a neutral package, 
like something named `common`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LinearModelData.java:
##########
@@ -37,60 +38,68 @@
 import java.io.OutputStream;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel}, {@link FtrlModel}.
  *
  * <p>This class also provides methods to convert model data from Table to 
Datastream, and classes
  * to save/load model data.
  */
-public class LogisticRegressionModelData {
+public class LinearModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;

Review Comment:
   The versioning mechanism is different from that in `OnlineKMeans`. Shall we 
adopt the same practice across both online algorithms?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlParams.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Ftrl}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlParams<T>
+        extends HasLabelCol<T>, HasBatchStrategy<T>, HasGlobalBatchSize<T>, 
HasFeaturesCol<T> {

Review Comment:
   An `Estimator`'s param class should inherit the corresponding `Model`'s 
param class. In this current implementation, `Ftrl` would not be able to set 
the `rawPredictionCol` and `predictionCol` of the generated `FtrlModel`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new 
ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new 
GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), 
getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, 
double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] 
{value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], 
LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData 
map(DenseVector[] value) {
+                                                    return new 
LinearModelData(value[0], iter++);
+                                                }
+                                            })

Review Comment:
   Shall we move logics like this to a separated class or method? That could 
make the code look prettier.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import 
org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private Table trainDenseTable;
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),

Review Comment:
   Shall we make `new double[]{1.0, 1.0, 1.0}` a variable? It might make the 
code look more pretty.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import 
org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private Table trainDenseTable;
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {0, 2, 3}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 3, 4}, new double[] 
{1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 8}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 8, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 8}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 2, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 2, 4}, new double[] 
{1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {7, 8, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 7, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.of(DenseVector.class))
+                        .column("f1", DataTypes.DOUBLE())
+                        .build();
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream, 
schema).as(FEATURE_COL, LABEL_COL);
+    }
+
+    @Test
+    public void testFtrlWithInitLrModel() throws Exception {
+        Table initModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, 
TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            
TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrl() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] 
{0.0, 0.0}), 0L)));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, 
TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            
TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrlModel() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] 
{0.0, 0.0}), 0L)));
+
+        Table onlineTrainTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, 
TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), 
Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, 
PREDICT_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), 
Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+
+        FtrlModel model =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = 
model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable);
+    }
+
+    @Test
+    public void testFtrlModelSparse() throws Exception {
+        Table initModelSparse =
+                tEnv.fromDataStream(
+                        env.fromElements(
+                                Row.of(
+                                        new DenseVector(
+                                                new double[] {
+                                                    0.1, 0.1, 0.1, 0.1, 0.1, 
0.1, 0.1, 0.1, 0.1, 0.1
+                                                }),
+                                        0L)));
+
+        Table onlineTrainTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, 
TRAIN_SPARSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), 
Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, 
PREDICT_SPARSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), 
Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+
+        FtrlModel model =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModelSparse)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = 
model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        tEnv.toDataStream(model.getModelData()[0]).print();
+        verifyPredictionResult(resultTable);
+    }
+
+    private static void verifyPredictionResult(Table output) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+        DataStream<Row> stream = tEnv.toDataStream(output);
+        List<Row> result = IteratorUtils.toList(stream.executeAndCollect());
+        Map<Long, Tuple2<Double, Double>> correctRatio = new HashMap<>();
+
+        for (Row row : result) {
+            long modelVersion = row.getFieldAs(MODEL_VERSION_COL);
+            Double pred = row.getFieldAs(PREDICT_COL);
+            Double label = row.getFieldAs(LABEL_COL);
+            if (correctRatio.containsKey(modelVersion)) {
+                Tuple2<Double, Double> t2 = correctRatio.get(modelVersion);
+                if (pred.equals(label)) {
+                    t2.f0 += 1.0;
+                }
+                t2.f1 += 1.0;
+            } else {
+                correctRatio.put(modelVersion, Tuple2.of(pred.equals(label) ? 
1.0 : 0.0, 1.0));
+            }
+        }
+        for (Long id : correctRatio.keySet()) {
+            System.out.println(
+                    id
+                            + " : "
+                            + correctRatio.get(id).f0 / correctRatio.get(id).f1
+                            + " total sample num : "
+                            + correctRatio.get(id).f1);
+            if (id > 0L) {
+                assertEquals(1.0, correctRatio.get(id).f0 / 
correctRatio.get(id).f1, 1.0e-5);
+            }
+        }
+    }
+
+    /** Generates random data for ftrl train and predict. */
+    public static class RandomSourceFunction implements SourceFunction<Row> {
+        private volatile boolean isRunning = true;
+        private final long timeInterval;
+        private final long maxSize;
+        private final List<Row> data;
+
+        public RandomSourceFunction(long timeInterval, long maxSize, List<Row> 
data)
+                throws InterruptedException {
+            this.timeInterval = timeInterval;
+            this.maxSize = maxSize;
+            this.data = data;
+        }
+
+        @Override
+        public void run(SourceContext<Row> ctx) throws Exception {
+            int size = data.size();
+            for (int i = 0; i < maxSize; ++i) {
+                if (i == 0) {
+                    Thread.sleep(5000);

Review Comment:
   Shall we avoid using `Thread.sleep()` in test cases? If every unit test 
adopt this practice, the total time for `mvn install` would be very long.
   
   What we actually want to do here is just to make sure the job has been 
initialized before any input is provided. `OnlineKMeansTest` has established a 
best practice for such kind of problem, using `InMemorySourceFunction`, you can 
refer to these classes for how to solve this problem.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import 
org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private Table trainDenseTable;
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {0, 2, 3}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 3, 4}, new double[] 
{1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 8}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 8, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 8}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 2, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 2, 4}, new double[] 
{1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {7, 8, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 7, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.of(DenseVector.class))
+                        .column("f1", DataTypes.DOUBLE())
+                        .build();
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream, 
schema).as(FEATURE_COL, LABEL_COL);
+    }
+
+    @Test
+    public void testFtrlWithInitLrModel() throws Exception {
+        Table initModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, 
TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            
TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrl() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] 
{0.0, 0.0}), 0L)));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, 
TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            
TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();

Review Comment:
   Let's examine the correctness of the output of all test cases. 
`env.execute()` only guarantees that the job does not throw exception during 
its execution, while it does not mean the calculation result is correct.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlModelParams.java:
##########
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
+import org.apache.flink.ml.common.param.HasRawPredictionCol;
+
+/**
+ * Params for {@link FtrlModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlModelParams<T>
+        extends HasFeaturesCol<T>, HasPredictionCol<T>, HasRawPredictionCol<T> 
{}

Review Comment:
   This interface is identical to `LogisticRegressionModelParams`. If FTRL is 
an online version of LogisticRegression, shall we consider reorganizing code to 
avoid creating duplicate classes like this?



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/classification/FtrlTest.java:
##########
@@ -0,0 +1,384 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.ftrl.Ftrl;
+import org.apache.flink.ml.classification.ftrl.FtrlModel;
+import 
org.apache.flink.ml.classification.logisticregression.LogisticRegression;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Schema;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static junit.framework.TestCase.assertEquals;
+
+/** Tests {@link Ftrl} and {@link FtrlModel}. */
+public class FtrlTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private static final String LABEL_COL = "label";
+    private static final String PREDICT_COL = "prediction";
+    private static final String FEATURE_COL = "features";
+    private static final String MODEL_VERSION_COL = "modelVersion";
+    private Table trainDenseTable;
+    private static final List<Row> TRAIN_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.1, 2.), 0.),
+                    Row.of(Vectors.dense(0.2, 2.), 0.),
+                    Row.of(Vectors.dense(0.3, 2.), 0.),
+                    Row.of(Vectors.dense(0.4, 2.), 0.),
+                    Row.of(Vectors.dense(0.5, 2.), 0.),
+                    Row.of(Vectors.dense(11., 12.), 1.),
+                    Row.of(Vectors.dense(12., 11.), 1.),
+                    Row.of(Vectors.dense(13., 12.), 1.),
+                    Row.of(Vectors.dense(14., 12.), 1.),
+                    Row.of(Vectors.dense(15., 12.), 1.));
+
+    private static final List<Row> TRAIN_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {0, 2, 3}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 3, 4}, new double[] 
{1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 8}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 8, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 8}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    private static final List<Row> PREDICT_DENSE_ROWS =
+            Arrays.asList(
+                    Row.of(Vectors.dense(0.8, 2.7), 0.),
+                    Row.of(Vectors.dense(0.8, 2.4), 0.),
+                    Row.of(Vectors.dense(0.7, 2.3), 0.),
+                    Row.of(Vectors.dense(0.4, 2.7), 0.),
+                    Row.of(Vectors.dense(0.5, 2.8), 0.),
+                    Row.of(Vectors.dense(10.2, 12.1), 1.),
+                    Row.of(Vectors.dense(13.3, 13.1), 1.),
+                    Row.of(Vectors.dense(13.5, 12.2), 1.),
+                    Row.of(Vectors.dense(14.9, 12.5), 1.),
+                    Row.of(Vectors.dense(15.5, 11.2), 1.));
+
+    private static final List<Row> PREDICT_SPARSE_ROWS =
+            Arrays.asList(
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 2, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {2, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(
+                                    10, new int[] {0, 1, 2, 4}, new double[] 
{1.0, 1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {1, 3, 4}, new 
double[] {1.0, 1.0, 1.0}),
+                            0.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {6, 7, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {7, 8, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 7, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 6, 7}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.),
+                    Row.of(
+                            Vectors.sparse(10, new int[] {5, 8, 9}, new 
double[] {1.0, 1.0, 1.0}),
+                            1.));
+
+    @Before
+    public void before() throws Exception {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        Schema schema =
+                Schema.newBuilder()
+                        .column("f0", DataTypes.of(DenseVector.class))
+                        .column("f1", DataTypes.DOUBLE())
+                        .build();
+        DataStream<Row> dataStream = env.fromCollection(TRAIN_DENSE_ROWS);
+        trainDenseTable = tEnv.fromDataStream(dataStream, 
schema).as(FEATURE_COL, LABEL_COL);
+    }
+
+    @Test
+    public void testFtrlWithInitLrModel() throws Exception {
+        Table initModel =
+                new LogisticRegression()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setLabelCol(LABEL_COL)
+                        .fit(trainDenseTable)
+                        .getModelData()[0];
+
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, 
TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            
TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrl() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] 
{0.0, 0.0}), 0L)));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(20, 20, 
TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            
TypeInformation.of(DenseVector.class), Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+
+        Table models =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(10)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlinePredictTable)
+                        .getModelData()[0];
+        tEnv.toDataStream(models).print();
+        env.execute();
+    }
+
+    @Test
+    public void testFtrlModel() throws Exception {
+        Table initModel =
+                tEnv.fromDataStream(
+                        env.fromElements(Row.of(new DenseVector(new double[] 
{0.0, 0.0}), 0L)));
+
+        Table onlineTrainTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, 
TRAIN_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), 
Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, 
PREDICT_DENSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), 
Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+
+        FtrlModel model =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModel)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = 
model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        verifyPredictionResult(resultTable);
+    }
+
+    @Test
+    public void testFtrlModelSparse() throws Exception {
+        Table initModelSparse =
+                tEnv.fromDataStream(
+                        env.fromElements(
+                                Row.of(
+                                        new DenseVector(
+                                                new double[] {
+                                                    0.1, 0.1, 0.1, 0.1, 0.1, 
0.1, 0.1, 0.1, 0.1, 0.1
+                                                }),
+                                        0L)));
+
+        Table onlineTrainTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, 
TRAIN_SPARSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), 
Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+        Table onlinePredictTable =
+                tEnv.fromDataStream(
+                        env.addSource(
+                                new RandomSourceFunction(5, 2000, 
PREDICT_SPARSE_ROWS),
+                                new RowTypeInfo(
+                                        new TypeInformation[] {
+                                            TypeInformation.of(Vector.class), 
Types.DOUBLE
+                                        },
+                                        new String[] {FEATURE_COL, 
LABEL_COL})));
+
+        FtrlModel model =
+                new Ftrl()
+                        .setFeaturesCol(FEATURE_COL)
+                        .setInitialModelData(initModelSparse)
+                        .setGlobalBatchSize(500)
+                        .setLabelCol(LABEL_COL)
+                        .fit(onlineTrainTable);
+        Table resultTable = 
model.setPredictionCol(PREDICT_COL).transform(onlinePredictTable)[0];
+        tEnv.toDataStream(model.getModelData()[0]).print();
+        verifyPredictionResult(resultTable);
+    }
+
+    private static void verifyPredictionResult(Table output) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+        DataStream<Row> stream = tEnv.toDataStream(output);
+        List<Row> result = IteratorUtils.toList(stream.executeAndCollect());
+        Map<Long, Tuple2<Double, Double>> correctRatio = new HashMap<>();
+
+        for (Row row : result) {
+            long modelVersion = row.getFieldAs(MODEL_VERSION_COL);
+            Double pred = row.getFieldAs(PREDICT_COL);
+            Double label = row.getFieldAs(LABEL_COL);
+            if (correctRatio.containsKey(modelVersion)) {
+                Tuple2<Double, Double> t2 = correctRatio.get(modelVersion);
+                if (pred.equals(label)) {
+                    t2.f0 += 1.0;
+                }
+                t2.f1 += 1.0;
+            } else {
+                correctRatio.put(modelVersion, Tuple2.of(pred.equals(label) ? 
1.0 : 0.0, 1.0));
+            }
+        }
+        for (Long id : correctRatio.keySet()) {
+            System.out.println(

Review Comment:
   Let's avoid printing debugging information in test cases.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlParams.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Ftrl}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlParams<T>
+        extends HasLabelCol<T>, HasBatchStrategy<T>, HasGlobalBatchSize<T>, 
HasFeaturesCol<T> {
+
+    Param<Integer> VECTOR_SIZE =
+            new IntParam("vectorSize", "The size of vector.", -1, 
ParamValidators.gt(-2));
+
+    default Integer getVectorSize() {

Review Comment:
   This param seems unused.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlParams.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Ftrl}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlParams<T>
+        extends HasLabelCol<T>, HasBatchStrategy<T>, HasGlobalBatchSize<T>, 
HasFeaturesCol<T> {
+
+    Param<Integer> VECTOR_SIZE =
+            new IntParam("vectorSize", "The size of vector.", -1, 
ParamValidators.gt(-2));
+
+    default Integer getVectorSize() {
+        return get(VECTOR_SIZE);
+    }
+
+    default T setVectorSize(Integer value) {
+        return set(VECTOR_SIZE, value);
+    }
+
+    Param<Double> L_1 =
+            new DoubleParam("l1", "The parameter l1 of ftrl.", 0.1, 
ParamValidators.gt(0.0));
+
+    default Double getL1() {
+        return get(L_1);
+    }
+
+    default T setL1(Double value) {
+        return set(L_1, value);
+    }
+
+    Param<Double> L_2 =
+            new DoubleParam("l2", "The parameter l2 of ftrl.", 0.1, 
ParamValidators.gt(0.0));
+
+    default Double getL2() {

Review Comment:
   It seems that these parameters are also used in other algorithms like soft 
max and multi-layer perception. Let's define them as common parameters.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlParams.java:
##########
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link Ftrl}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FtrlParams<T>
+        extends HasLabelCol<T>, HasBatchStrategy<T>, HasGlobalBatchSize<T>, 
HasFeaturesCol<T> {

Review Comment:
   It seems that some parameters provided by LogisticRegression, like 
`weightCol`, `reg` and `multiClass`, are not supported by FTRL for now. Is FTRL 
supposed to have weaker functionality than LogisticRegression, if both used in 
offline training process? If not, do we have any plan of adding support for 
these parameters in future?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new 
ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new 
GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), 
getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, 
double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] 
{value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], 
LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData 
map(DenseVector[] value) {
+                                                    return new 
LinearModelData(value[0], iter++);
+                                                }
+                                            })
+                                    .setParallelism(1)));
+        }
+    }
+
+    /** Gets vector data. */
+    public static class GetVectorData implements MapFunction<LinearModelData, 
DenseVector[]> {
+        @Override
+        public DenseVector[] map(LinearModelData value) throws Exception {
+            return new DenseVector[] {value.coefficient};
+        }
+    }
+
+    /**
+     * Operator that collects a LogisticRegressionModelData from each upstream 
subtask, and outputs
+     * the weight average of collected model data.
+     */
+    public static class FtrlGlobalReducer implements 
ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] 
newModelData) {
+            for (int i = 0; i < newModelData[0].size(); ++i) {
+                if ((modelData[1].values[i] + newModelData[1].values[i]) > 
0.0) {
+                    newModelData[0].values[i] =
+                            (modelData[0].values[i] * modelData[1].values[i]
+                                            + newModelData[0].values[i] * 
newModelData[1].values[i])
+                                    / (modelData[1].values[i] + 
newModelData[1].values[i]);
+                }
+                newModelData[1].values[i] = modelData[1].values[i] + 
newModelData[1].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    public static class FtrlLocalUpdater extends 
AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<
+                    Tuple2<Vector, Double>[], DenseVector[], DenseVector[]> {
+        private ListState<Tuple2<Vector, Double>[]> localBatchDataState;
+        private ListState<DenseVector[]> modelDataState;
+        private double[] n;
+        private double[] z;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private DenseVector weights;
+
+        public FtrlLocalUpdater(double alpha, double beta, double l1, double 
l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<Tuple2<Vector, Double>[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
DenseVector[].class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Tuple2<Vector, Double>[]> 
pointsRecord)
+                throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> 
modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {

Review Comment:
   Would it be better to add some comments to this method, or divide this 
method further into several smaller methods? That could help improve 
readability.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlModel.java:
##########
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which classifies data using the model data computed by {@link 
Ftrl}. */
+public class FtrlModel implements Model<FtrlModel>, FtrlModelParams<FtrlModel> 
{
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public FtrlModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
Types.DOUBLE, Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(), 
getPredictionCol(), "modelVersion"));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        
.connect(LinearModelData.getModelDataStream(modelDataTable).broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, 
getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends 
AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, LinearModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String 
featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new 
ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            "MODEL_DATA_VERSION_GAUGE_KEY",
+                            (Gauge<String>) () -> 
Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws 
Exception {
+            Row dataPoint = streamRecord.getValue();
+            // todo : predict data
+            Vector features = (Vector) dataPoint.getField(featuresCol);
+            Tuple2<Double, DenseVector> predictionResult = 
predictRaw(features, coefficient);
+            output.collect(
+                    new StreamRecord<>(
+                            Row.join(dataPoint, Row.of(predictionResult.f0, 
modelDataVersion))));
+        }
+
+        @Override
+        public void processElement2(StreamRecord<LinearModelData> 
streamRecord) throws Exception {
+            LinearModelData modelData = streamRecord.getValue();
+
+            coefficient = modelData.coefficient;
+            modelDataVersion = modelData.modelVersion;
+            // System.out.println("update model...");
+            // todo : receive model data.
+            // Preconditions.checkArgument(modelData.centroids.length <= k);
+            // centroids = modelData.centroids;
+            // modelDataVersion++;
+            // for (Row dataPoint : bufferedPointsState.get()) {
+            // processElement1(new StreamRecord<>(dataPoint));
+            // }
+            // bufferedPointsState.clear();

Review Comment:
   Let's reformat the code to remove unused codes.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new 
ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new 
GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), 
getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, 
double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] 
{value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], 
LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData 
map(DenseVector[] value) {
+                                                    return new 
LinearModelData(value[0], iter++);
+                                                }
+                                            })
+                                    .setParallelism(1)));
+        }
+    }
+
+    /** Gets vector data. */
+    public static class GetVectorData implements MapFunction<LinearModelData, 
DenseVector[]> {
+        @Override
+        public DenseVector[] map(LinearModelData value) throws Exception {
+            return new DenseVector[] {value.coefficient};
+        }
+    }
+
+    /**
+     * Operator that collects a LogisticRegressionModelData from each upstream 
subtask, and outputs
+     * the weight average of collected model data.
+     */
+    public static class FtrlGlobalReducer implements 
ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] 
newModelData) {
+            for (int i = 0; i < newModelData[0].size(); ++i) {
+                if ((modelData[1].values[i] + newModelData[1].values[i]) > 
0.0) {
+                    newModelData[0].values[i] =
+                            (modelData[0].values[i] * modelData[1].values[i]
+                                            + newModelData[0].values[i] * 
newModelData[1].values[i])
+                                    / (modelData[1].values[i] + 
newModelData[1].values[i]);
+                }
+                newModelData[1].values[i] = modelData[1].values[i] + 
newModelData[1].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    public static class FtrlLocalUpdater extends 
AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<
+                    Tuple2<Vector, Double>[], DenseVector[], DenseVector[]> {
+        private ListState<Tuple2<Vector, Double>[]> localBatchDataState;
+        private ListState<DenseVector[]> modelDataState;
+        private double[] n;
+        private double[] z;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private DenseVector weights;
+
+        public FtrlLocalUpdater(double alpha, double beta, double l1, double 
l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<Tuple2<Vector, Double>[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
DenseVector[].class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Tuple2<Vector, Double>[]> 
pointsRecord)
+                throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> 
modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, 
"modelData").get();
+            modelDataState.clear();
+
+            List<Tuple2<Vector, Double>[]> pointsList =
+                    IteratorUtils.toList(localBatchDataState.get().iterator());
+            Tuple2<Vector, Double>[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Tuple2<Vector, Double> point : points) {
+                if (n == null) {
+                    n = new double[point.f0.size()];
+                    z = new double[n.length];
+                    weights = new DenseVector(n.length);
+                }
+
+                double p = 0.0;
+                Arrays.fill(weights.values, 0.0);
+                if (point.f0 instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) point.f0;
+                    for (int i = 0; i < denseVector.size(); ++i) {
+                        if (Math.abs(z[i]) <= l1) {
+                            modelData[0].values[i] = 0.0;
+                        } else {
+                            modelData[0].values[i] =
+                                    ((z[i] < 0 ? -1 : 1) * l1 - z[i])
+                                            / ((beta + Math.sqrt(n[i])) / 
alpha + l2);
+                        }
+                        p += modelData[0].values[i] * denseVector.values[i];
+                    }
+                    p = 1 / (1 + Math.exp(-p));
+                    for (int i = 0; i < denseVector.size(); ++i) {
+                        double g = (p - point.f1) * denseVector.values[i];
+                        double sigma = (Math.sqrt(n[i] + g * g) - 
Math.sqrt(n[i])) / alpha;
+                        z[i] += g - sigma * modelData[0].values[i];
+                        n[i] += g * g;
+                        weights.values[i] += 1.0;
+                    }
+                } else {
+                    SparseVector sparseVector = (SparseVector) point.f0;
+                    for (int i = 0; i < sparseVector.indices.length; ++i) {
+                        int idx = sparseVector.indices[i];
+                        if (Math.abs(z[idx]) <= l1) {
+                            modelData[0].values[idx] = 0.0;
+                        } else {
+                            modelData[0].values[idx] =
+                                    ((z[idx] < 0 ? -1 : 1) * l1 - z[idx])
+                                            / ((beta + Math.sqrt(n[idx])) / 
alpha + l2);
+                        }
+                        p += modelData[0].values[idx] * sparseVector.values[i];
+                    }
+                    p = 1 / (1 + Math.exp(-p));
+                    for (int i = 0; i < sparseVector.indices.length; ++i) {
+                        int idx = sparseVector.indices[i];
+                        double g = (p - point.f1) * sparseVector.values[i];
+                        double sigma = (Math.sqrt(n[idx] + g * g) - 
Math.sqrt(n[idx])) / alpha;
+                        z[idx] += g - sigma * modelData[0].values[idx];
+                        n[idx] += g * g;
+                        weights.values[idx] += 1.0;

Review Comment:
   It seems that weights are only updated but not used.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/FtrlModel.java:
##########
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.metrics.Gauge;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/** A Model which classifies data using the model data computed by {@link 
Ftrl}. */
+public class FtrlModel implements Model<FtrlModel>, FtrlModelParams<FtrlModel> 
{
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public FtrlModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
Types.DOUBLE, Types.LONG),
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldNames(), 
getPredictionCol(), "modelVersion"));
+
+        DataStream<Row> predictionResult =
+                tEnv.toDataStream(inputs[0])
+                        
.connect(LinearModelData.getModelDataStream(modelDataTable).broadcast())
+                        .transform(
+                                "PredictLabelOperator",
+                                outputTypeInfo,
+                                new PredictLabelOperator(inputTypeInfo, 
getFeaturesCol()));
+
+        return new Table[] {tEnv.fromDataStream(predictionResult)};
+    }
+
+    /** A utility operator used for prediction. */
+    private static class PredictLabelOperator extends 
AbstractStreamOperator<Row>
+            implements TwoInputStreamOperator<Row, LinearModelData, Row> {
+        private final RowTypeInfo inputTypeInfo;
+
+        private final String featuresCol;
+        private ListState<Row> bufferedPointsState;
+        private DenseVector coefficient;
+        private long modelDataVersion = 0;
+
+        public PredictLabelOperator(RowTypeInfo inputTypeInfo, String 
featuresCol) {
+            this.inputTypeInfo = inputTypeInfo;
+            this.featuresCol = featuresCol;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            bufferedPointsState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new 
ListStateDescriptor<>("bufferedPoints", inputTypeInfo));
+        }
+
+        @Override
+        public void open() throws Exception {
+            super.open();
+
+            getRuntimeContext()
+                    .getMetricGroup()
+                    .gauge(
+                            "MODEL_DATA_VERSION_GAUGE_KEY",
+                            (Gauge<String>) () -> 
Long.toString(modelDataVersion));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Row> streamRecord) throws 
Exception {
+            Row dataPoint = streamRecord.getValue();
+            // todo : predict data
+            Vector features = (Vector) dataPoint.getField(featuresCol);
+            Tuple2<Double, DenseVector> predictionResult = 
predictRaw(features, coefficient);
+            output.collect(
+                    new StreamRecord<>(
+                            Row.join(dataPoint, Row.of(predictionResult.f0, 
modelDataVersion))));
+        }
+
+        @Override
+        public void processElement2(StreamRecord<LinearModelData> 
streamRecord) throws Exception {
+            LinearModelData modelData = streamRecord.getValue();
+
+            coefficient = modelData.coefficient;
+            modelDataVersion = modelData.modelVersion;
+            // System.out.println("update model...");
+            // todo : receive model data.
+            // Preconditions.checkArgument(modelData.centroids.length <= k);
+            // centroids = modelData.centroids;
+            // modelDataVersion++;
+            // for (Row dataPoint : bufferedPointsState.get()) {
+            // processElement1(new StreamRecord<>(dataPoint));
+            // }
+            // bufferedPointsState.clear();
+        }
+    }
+
+    /**
+     * The main logic that predicts one input record.
+     *
+     * @param feature The input feature.
+     * @param coefficient The model parameters.
+     * @return The prediction label and the raw probabilities.
+     */
+    public static Tuple2<Double, DenseVector> predictRaw(Vector feature, 
DenseVector coefficient)

Review Comment:
   Methods like this are almost identical to that in `LogisticRegressionModel`. 
It could be better if we could reuse logics that have already been defined in 
`LogisticRegression` and `LogisticRegressionModel`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new 
ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new 
GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), 
getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, 
double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] 
{value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], 
LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData 
map(DenseVector[] value) {
+                                                    return new 
LinearModelData(value[0], iter++);
+                                                }
+                                            })
+                                    .setParallelism(1)));
+        }
+    }
+
+    /** Gets vector data. */
+    public static class GetVectorData implements MapFunction<LinearModelData, 
DenseVector[]> {
+        @Override
+        public DenseVector[] map(LinearModelData value) throws Exception {
+            return new DenseVector[] {value.coefficient};
+        }
+    }
+
+    /**
+     * Operator that collects a LogisticRegressionModelData from each upstream 
subtask, and outputs
+     * the weight average of collected model data.
+     */
+    public static class FtrlGlobalReducer implements 
ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] 
newModelData) {
+            for (int i = 0; i < newModelData[0].size(); ++i) {
+                if ((modelData[1].values[i] + newModelData[1].values[i]) > 
0.0) {
+                    newModelData[0].values[i] =
+                            (modelData[0].values[i] * modelData[1].values[i]
+                                            + newModelData[0].values[i] * 
newModelData[1].values[i])
+                                    / (modelData[1].values[i] + 
newModelData[1].values[i]);
+                }
+                newModelData[1].values[i] = modelData[1].values[i] + 
newModelData[1].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    public static class FtrlLocalUpdater extends 
AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<
+                    Tuple2<Vector, Double>[], DenseVector[], DenseVector[]> {
+        private ListState<Tuple2<Vector, Double>[]> localBatchDataState;
+        private ListState<DenseVector[]> modelDataState;
+        private double[] n;
+        private double[] z;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private DenseVector weights;
+
+        public FtrlLocalUpdater(double alpha, double beta, double l1, double 
l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<Tuple2<Vector, Double>[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
DenseVector[].class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Tuple2<Vector, Double>[]> 
pointsRecord)
+                throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> 
modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, 
"modelData").get();
+            modelDataState.clear();
+
+            List<Tuple2<Vector, Double>[]> pointsList =
+                    IteratorUtils.toList(localBatchDataState.get().iterator());
+            Tuple2<Vector, Double>[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Tuple2<Vector, Double> point : points) {
+                if (n == null) {
+                    n = new double[point.f0.size()];
+                    z = new double[n.length];
+                    weights = new DenseVector(n.length);
+                }
+
+                double p = 0.0;
+                Arrays.fill(weights.values, 0.0);
+                if (point.f0 instanceof DenseVector) {

Review Comment:
   It seems that some logic in `if` and in `else` are the same. Shall we move 
these codes out of the if-else condition?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/ftrl/Ftrl.java:
##########
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.ftrl;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.classification.logisticregression.LinearModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the Ftrl algorithm. */
+public class Ftrl implements Estimator<Ftrl, FtrlModel>, FtrlParams<Ftrl> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table initModelDataTable;
+
+    public Ftrl() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public FtrlModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+
+        DataStream<Tuple2<Vector, Double>> points =
+                tEnv.toDataStream(inputs[0]).map(new 
ParseSample(getFeaturesCol(), getLabelCol()));
+
+        DataStream<LinearModelData> modelDataStream =
+                LinearModelData.getModelDataStream(initModelDataTable);
+        DataStream<DenseVector[]> initModelData = modelDataStream.map(new 
GetVectorData());
+        initModelData.getTransformation().setParallelism(1);
+
+        IterationBody body =
+                new FtrlIterationBody(
+                        getGlobalBatchSize(), getAlpha(), getBETA(), getL1(), 
getL2());
+
+        DataStream<LinearModelData> onlineModelData =
+                Iterations.iterateUnboundedStreams(
+                                DataStreamList.of(initModelData), 
DataStreamList.of(points), body)
+                        .get(0);
+
+        Table onlineModelDataTable = tEnv.fromDataStream(onlineModelData);
+        FtrlModel model = new FtrlModel().setModelData(onlineModelDataTable);
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    /** Iteration body of ftrl. */
+    public static class FtrlIterationBody implements IterationBody {
+        private final int batchSize;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+
+        public FtrlIterationBody(int batchSize, double alpha, double beta, 
double l1, double l2) {
+            this.batchSize = batchSize;
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public IterationBodyResult process(
+                DataStreamList variableStreams, DataStreamList dataStreams) {
+            DataStream<DenseVector[]> modelData = variableStreams.get(0);
+
+            DataStream<Tuple2<Vector, Double>> points = dataStreams.get(0);
+
+            int parallelism = points.getParallelism();
+            DataStream<DenseVector[]> newModelData =
+                    points.countWindowAll(batchSize)
+                            .apply(new GlobalBatchCreator())
+                            .flatMap(new GlobalBatchSplitter(parallelism))
+                            .rebalance()
+                            .connect(modelData.broadcast())
+                            .transform(
+                                    "ModelDataLocalUpdater",
+                                    TypeInformation.of(DenseVector[].class),
+                                    new FtrlLocalUpdater(alpha, beta, l1, l2))
+                            .setParallelism(parallelism)
+                            .countWindowAll(parallelism)
+                            .reduce(new FtrlGlobalReducer())
+                            .map(
+                                    (MapFunction<DenseVector[], DenseVector[]>)
+                                            value -> new DenseVector[] 
{value[0]})
+                            .setParallelism(1);
+
+            return new IterationBodyResult(
+                    DataStreamList.of(newModelData),
+                    DataStreamList.of(
+                            modelData
+                                    .map(
+                                            new MapFunction<DenseVector[], 
LinearModelData>() {
+                                                long iter = 0L;
+
+                                                @Override
+                                                public LinearModelData 
map(DenseVector[] value) {
+                                                    return new 
LinearModelData(value[0], iter++);
+                                                }
+                                            })
+                                    .setParallelism(1)));
+        }
+    }
+
+    /** Gets vector data. */
+    public static class GetVectorData implements MapFunction<LinearModelData, 
DenseVector[]> {
+        @Override
+        public DenseVector[] map(LinearModelData value) throws Exception {
+            return new DenseVector[] {value.coefficient};
+        }
+    }
+
+    /**
+     * Operator that collects a LogisticRegressionModelData from each upstream 
subtask, and outputs
+     * the weight average of collected model data.
+     */
+    public static class FtrlGlobalReducer implements 
ReduceFunction<DenseVector[]> {
+        @Override
+        public DenseVector[] reduce(DenseVector[] modelData, DenseVector[] 
newModelData) {
+            for (int i = 0; i < newModelData[0].size(); ++i) {
+                if ((modelData[1].values[i] + newModelData[1].values[i]) > 
0.0) {
+                    newModelData[0].values[i] =
+                            (modelData[0].values[i] * modelData[1].values[i]
+                                            + newModelData[0].values[i] * 
newModelData[1].values[i])
+                                    / (modelData[1].values[i] + 
newModelData[1].values[i]);
+                }
+                newModelData[1].values[i] = modelData[1].values[i] + 
newModelData[1].values[i];
+            }
+            return newModelData;
+        }
+    }
+
+    /** Updates local ftrl model. */
+    public static class FtrlLocalUpdater extends 
AbstractStreamOperator<DenseVector[]>
+            implements TwoInputStreamOperator<
+                    Tuple2<Vector, Double>[], DenseVector[], DenseVector[]> {
+        private ListState<Tuple2<Vector, Double>[]> localBatchDataState;
+        private ListState<DenseVector[]> modelDataState;
+        private double[] n;
+        private double[] z;
+        private final double alpha;
+        private final double beta;
+        private final double l1;
+        private final double l2;
+        private DenseVector weights;
+
+        public FtrlLocalUpdater(double alpha, double beta, double l1, double 
l2) {
+            this.alpha = alpha;
+            this.beta = beta;
+            this.l1 = l1;
+            this.l2 = l2;
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws 
Exception {
+            super.initializeState(context);
+
+            TypeInformation<Tuple2<Vector, Double>[]> type =
+                    
ObjectArrayTypeInfo.getInfoFor(DenseVectorTypeInfo.INSTANCE);
+            localBatchDataState =
+                    context.getOperatorStateStore()
+                            .getListState(new 
ListStateDescriptor<>("localBatch", type));
+            modelDataState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>("modelData", 
DenseVector[].class));
+        }
+
+        @Override
+        public void processElement1(StreamRecord<Tuple2<Vector, Double>[]> 
pointsRecord)
+                throws Exception {
+            localBatchDataState.add(pointsRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        @Override
+        public void processElement2(StreamRecord<DenseVector[]> 
modelDataRecord) throws Exception {
+            modelDataState.add(modelDataRecord.getValue());
+            alignAndComputeModelData();
+        }
+
+        private void alignAndComputeModelData() throws Exception {
+            if (!modelDataState.get().iterator().hasNext()
+                    || !localBatchDataState.get().iterator().hasNext()) {
+                return;
+            }
+
+            DenseVector[] modelData =
+                    OperatorStateUtils.getUniqueElement(modelDataState, 
"modelData").get();
+            modelDataState.clear();
+
+            List<Tuple2<Vector, Double>[]> pointsList =
+                    IteratorUtils.toList(localBatchDataState.get().iterator());
+            Tuple2<Vector, Double>[] points = pointsList.remove(0);
+            localBatchDataState.update(pointsList);
+
+            for (Tuple2<Vector, Double> point : points) {
+                if (n == null) {
+                    n = new double[point.f0.size()];
+                    z = new double[n.length];
+                    weights = new DenseVector(n.length);
+                }
+
+                double p = 0.0;
+                Arrays.fill(weights.values, 0.0);
+                if (point.f0 instanceof DenseVector) {
+                    DenseVector denseVector = (DenseVector) point.f0;
+                    for (int i = 0; i < denseVector.size(); ++i) {
+                        if (Math.abs(z[i]) <= l1) {
+                            modelData[0].values[i] = 0.0;
+                        } else {
+                            modelData[0].values[i] =
+                                    ((z[i] < 0 ? -1 : 1) * l1 - z[i])
+                                            / ((beta + Math.sqrt(n[i])) / 
alpha + l2);
+                        }
+                        p += modelData[0].values[i] * denseVector.values[i];
+                    }
+                    p = 1 / (1 + Math.exp(-p));
+                    for (int i = 0; i < denseVector.size(); ++i) {
+                        double g = (p - point.f1) * denseVector.values[i];
+                        double sigma = (Math.sqrt(n[i] + g * g) - 
Math.sqrt(n[i])) / alpha;
+                        z[i] += g - sigma * modelData[0].values[i];
+                        n[i] += g * g;
+                        weights.values[i] += 1.0;
+                    }
+                } else {
+                    SparseVector sparseVector = (SparseVector) point.f0;
+                    for (int i = 0; i < sparseVector.indices.length; ++i) {
+                        int idx = sparseVector.indices[i];
+                        if (Math.abs(z[idx]) <= l1) {
+                            modelData[0].values[idx] = 0.0;
+                        } else {
+                            modelData[0].values[idx] =
+                                    ((z[idx] < 0 ? -1 : 1) * l1 - z[idx])
+                                            / ((beta + Math.sqrt(n[idx])) / 
alpha + l2);
+                        }
+                        p += modelData[0].values[idx] * sparseVector.values[i];
+                    }
+                    p = 1 / (1 + Math.exp(-p));
+                    for (int i = 0; i < sparseVector.indices.length; ++i) {
+                        int idx = sparseVector.indices[i];
+                        double g = (p - point.f1) * sparseVector.values[i];
+                        double sigma = (Math.sqrt(n[idx] + g * g) - 
Math.sqrt(n[idx])) / alpha;
+                        z[idx] += g - sigma * modelData[0].values[idx];
+                        n[idx] += g * g;
+                        weights.values[idx] += 1.0;
+                    }
+                }
+            }
+            output.collect(new StreamRecord<>(new DenseVector[] {modelData[0], 
weights}));
+        }
+    }
+
+    /** Parses samples of input data. */
+    public static class ParseSample extends RichMapFunction<Row, 
Tuple2<Vector, Double>> {
+        private static final long serialVersionUID = 3738888745125082777L;

Review Comment:
   Is variables like this required?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@flink.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to