lindong28 commented on code in PR #83:
URL: https://github.com/apache/flink-ml/pull/83#discussion_r883628874


##########
flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java:
##########
@@ -182,4 +193,63 @@ public void snapshotState(StateSnapshotContext context) 
throws Exception {
             }
         }
     }
+
+    /**
+     * A function that generate several data batches and distribute them to 
downstream operator.
+     *
+     * @param <T> Data type of batch data.
+     */
+    public static <T> DataStream<T[]> generateBatchData(

Review Comment:
   The current function name and the Java doc do not seem to capture the key 
functionality of this method, e.g. split the input data into global batches of 
`batchSize`, where each global batch is further split into 
`downStreamParallelism` local batches for downstream operators.
   
   And previous code that uses `GlobalBatchSplitter` seems a bit more readable 
than the current version, which puts everything into one method with deeper 
indentation.
   
   Could you split the preserve the classes `GlobalBatchSplitter`, 
`GlobalBatchCreator`, and updates the method name and its Java doc to more it a 
bit more self-explanatory?



##########
flink-ml-core/src/test/java/org/apache/flink/ml/linalg/BLASTest.java:
##########
@@ -70,7 +70,18 @@ public void testAxpyK() {
     @Test
     public void testDot() {
         DenseVector anotherDenseVec = Vectors.dense(1, 2, 3, 4, 5);
+        SparseVector sparseVector1 =
+                Vectors.sparse(5, new int[] {1, 2, 4}, new double[] {1., 1., 
4.});
+        SparseVector sparseVector2 =
+                Vectors.sparse(5, new int[] {1, 3, 4}, new double[] {1., 2., 
1.});
+        // Tests Dot(dense, dense).

Review Comment:
   nits: Since the method name is `dot(...)`, would it be more intuitive to use 
`dot(dense, dense)` here?
   
   Same for the lines below.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java:
##########
@@ -161,7 +162,8 @@ public Row map(Row dataPoint) {
      * @param coefficient The model parameters.
      * @return The prediction label and the raw probabilities.
      */
-    private static Row predictOneDataPoint(DenseVector feature, DenseVector 
coefficient) {
+    protected static Row predictOneDataPoint(Vector feature, DenseVector 
coefficient) {
+

Review Comment:
   Is this empty line necessary?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasWeightCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasReg<T>,
+                HasElasticNet<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =
+            new DoubleParam("alpha", "The parameter alpha of ftrl.", 0.1, 
ParamValidators.gt(0.0));
+
+    default Double getAlpha() {
+        return get(ALPHA);
+    }
+
+    default T setAlpha(Double value) {
+        return set(ALPHA, value);
+    }
+
+    Param<Double> BETA =
+            new DoubleParam("alpha", "The parameter beta of ftrl.", 0.1, 
ParamValidators.gt(0.0));

Review Comment:
   Hmm.. how would users know what is FTRL when they read this doc?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -18,40 +18,86 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.serialization.Encoder;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
+import java.util.Random;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link 
OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to 
Datastream, and classes
  * to save/load model data.
  */
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;
 
-    public LogisticRegressionModelData(DenseVector coefficient) {
+    /* todo : modelVersion is a new idea to manage model, other algorithm if 
needed should add this parameter later. */
+    public LogisticRegressionModelData(DenseVector coefficient, long 
modelVersion) {
         this.coefficient = coefficient;
+        this.modelVersion = modelVersion;
     }
 
     public LogisticRegressionModelData() {}
 
+    public LogisticRegressionModelData(DenseVector coefficient) {

Review Comment:
   There is only one place that calls this constructor.
   
   Instead of adding this constructor for the specific case where 
modelVersion=0, would it be simpler to update the caller code to the following 
code, so that this class is simpler and more consistent with other model 
classes?
   
   ```
   DataStream<LogisticRegressionModelData> modelData =
           rawModelData.map(vector -> new LogisticRegressionModelData(vector, 
0));
   ```



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -18,40 +18,86 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.serialization.Encoder;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
+import java.util.Random;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link 
OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to 
Datastream, and classes
  * to save/load model data.
  */
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;
 
-    public LogisticRegressionModelData(DenseVector coefficient) {
+    /* todo : modelVersion is a new idea to manage model, other algorithm if 
needed should add this parameter later. */

Review Comment:
   It is not clear what is the actionable item for this TODO.
   
   Since `LogisticRegressionModelData` already has `modelVersion`, is this TODO 
still needed?



##########
flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java:
##########
@@ -65,12 +65,54 @@ public static void hDot(Vector x, Vector y) {
         }
     }
 
-    /** x \cdot y . */
-    public static double dot(DenseVector x, DenseVector y) {
+    /** Computes the dot of the two vectors (y = y \dot x). */

Review Comment:
   The Java doc seems incorrect since this method actually does not update `y`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegression.java:
##########
@@ -0,0 +1,418 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
+import org.apache.flink.iteration.DataStreamList;
+import org.apache.flink.iteration.IterationBody;
+import org.apache.flink.iteration.IterationBodyResult;
+import org.apache.flink.iteration.Iterations;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+/** An Estimator which implements the online logistic regression algorithm. */

Review Comment:
   Should we provide reference/link to the original paper, so that users could 
know what is the algorithm and why it is useful?
   
   Feel free to see KMeans.scala in Spark ML for example doc.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java:
##########
@@ -18,40 +18,86 @@
 
 package org.apache.flink.ml.classification.logisticregression;
 
+import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.serialization.Encoder;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.core.fs.FSDataInputStream;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
 import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
+import java.util.Random;
 
 /**
- * Model data of {@link LogisticRegressionModel}.
+ * Model data of {@link LogisticRegressionModel} and {@link 
OnlineLogisticRegressionModel}.
  *
  * <p>This class also provides methods to convert model data from Table to 
Datastream, and classes
  * to save/load model data.
  */
 public class LogisticRegressionModelData {
 
     public DenseVector coefficient;
+    public long modelVersion;
 
-    public LogisticRegressionModelData(DenseVector coefficient) {
+    /* todo : modelVersion is a new idea to manage model, other algorithm if 
needed should add this parameter later. */
+    public LogisticRegressionModelData(DenseVector coefficient, long 
modelVersion) {
         this.coefficient = coefficient;
+        this.modelVersion = modelVersion;
     }
 
     public LogisticRegressionModelData() {}
 
+    public LogisticRegressionModelData(DenseVector coefficient) {
+        this.coefficient = coefficient;
+        this.modelVersion = 0L;
+    }
+
+    /**
+     * Generates a Table containing a {@link LogisticRegressionModelData} 
instance with randomly
+     * generated coefficient.
+     *
+     * @param tEnv The environment where to create the table.
+     * @param dim The size of generated coefficient.
+     * @param seed Random seed.
+     */
+    public static Table generateRandomModelData(StreamTableEnvironment tEnv, 
int dim, int seed) {
+        StreamExecutionEnvironment env = 
TableUtils.getExecutionEnvironment(tEnv);
+        return tEnv.fromDataStream(env.fromElements(1).map(new 
GenerateRandomModel(dim, seed)));
+    }
+
+    private static class GenerateRandomModel

Review Comment:
   We typically use noun instead of verb as the class name. How about 
`RandomModelDataGenerator`?



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/OnlineLogisticRegressionParams.java:
##########
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.logisticregression;
+
+import org.apache.flink.ml.common.param.HasBatchStrategy;
+import org.apache.flink.ml.common.param.HasElasticNet;
+import org.apache.flink.ml.common.param.HasGlobalBatchSize;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.common.param.HasReg;
+import org.apache.flink.ml.common.param.HasWeightCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link OnlineLogisticRegression}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface OnlineLogisticRegressionParams<T>
+        extends HasLabelCol<T>,
+                HasWeightCol<T>,
+                HasBatchStrategy<T>,
+                HasGlobalBatchSize<T>,
+                HasReg<T>,
+                HasElasticNet<T>,
+                OnlineLogisticRegressionModelParams<T> {
+
+    Param<Double> ALPHA =
+            new DoubleParam("alpha", "The parameter alpha of ftrl.", 0.1, 
ParamValidators.gt(0.0));
+
+    default Double getAlpha() {
+        return get(ALPHA);
+    }
+
+    default T setAlpha(Double value) {
+        return set(ALPHA, value);
+    }
+
+    Param<Double> BETA =

Review Comment:
   Since we typically declare variables before methods, could this variable be 
moved above `getAlpha()`?
   
   Feel free to see `StandardScalerParams` for example.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to