lindong28 commented on a change in pull request #28:
URL: https://github.com/apache/flink-ml/pull/28#discussion_r757008784
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasMaxIter.java
##########
@@ -26,7 +26,7 @@
/** Interface for the shared maxIter param. */
public interface HasMaxIter<T> extends WithParams<T> {
Param<Integer> MAX_ITER =
- new IntParam("maxIter", "Maximum number of iterations.", 20,
ParamValidators.gtEq(0));
+ new IntParam("maxIter", "Maximum number of iterations.", 20,
ParamValidators.gt(0));
Review comment:
Could we just update all algorithms to support `maxIter=0` instead of
disallowing it?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/SortPartitionImpl.java
##########
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.datastream;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeutils.TypeComparator;
+import org.apache.flink.ml.common.utils.ComparatorAdapter;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.util.List;
+
+/** Applies sortPartition to a bounded data stream. */
+class SortPartitionImpl {
Review comment:
Having one dedicated file for a package private static class seems a bit
overkill. This is a rare pattern in Flink. The typical approach is to put those
classes in the same file as the public static method that uses them (e.g.
DataStreamUtil).
It might make sense to have dedicated file if the number of static
class/method is more than 4 (e.g. AllReduce).
Maybe it is simpler to move
`SortPartitionImpl/DistinctImpl/MapPartitionImpl` to DataStreamUtils.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
##########
@@ -146,7 +145,7 @@ public IterationBodyResult process(
DataStream<DenseVector> points = dataStreams.get(0);
DataStream<Integer> terminationCriteria =
- centroids.flatMap(new
TerminateOnMaxIterationNum<>(maxIterationNum));
+ centroids.map(x -> 0.).flatMap(new
TerminationCriteria(maxIterationNum));
Review comment:
Would it be better to still allow `TerminationCriteria` to accept
arbitrary input type so that the caller code does not have to explicitly
convert input type to Integer?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/EndOfStreamWindows.java
##########
@@ -37,6 +37,8 @@
private static final EndOfStreamWindows INSTANCE = new
EndOfStreamWindows();
+ private static final TimeWindow FOREVER_WINDOW = new
TimeWindow(Long.MIN_VALUE, Long.MAX_VALUE);
Review comment:
The word `forever` seems a bit `rich` and it is rare to use this word as
variable name.
Can we use a more neural word such as `TIME_WINDOW_INSTANCE`? The variable
name does not need to be very descriptive here since its role is obvious from
its constructor parameter values.
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticGradient.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.ml.common.linalg.BLAS;
+
+import java.io.Serializable;
+
+/** Utility class to compute gradient and loss for logistic loss. */
Review comment:
Spark has pretty detailed comments for `LogisticGradient`. It also
provides reference link to read detailed mathematical derivation. Do we also
need to provide more information here?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModelData.java
##########
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.ArrayList;
+import java.util.List;
+
+/** Model data of {@link LogisticRegressionModel}. */
+public class LogisticRegressionModelData {
+
+ public final double[] coefficient;
+
Review comment:
nits: could we remove the empty line between those variables?
##########
File path:
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/linear/LogisticRegressionModel.java
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.linear;
+
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.ml.api.core.Model;
+import
org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataEncoder;
+import
org.apache.flink.ml.classification.linear.LogisticRegressionModelData.LogisticRegressionModelDataStreamFormat;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.linalg.BLAS;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/** This class implements {@link Model} for {@link LogisticRegression}. */
+public class LogisticRegressionModel
+ implements Model<LogisticRegressionModel>,
+ LogisticRegressionModelParams<LogisticRegressionModel> {
+
+ private Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ private Table model;
+
+ public LogisticRegressionModel() {
+ ParamUtils.initializeMapWithDefaultValues(this.paramMap, this);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
model).getTableEnvironment();
+ String dataPath = ReadWriteUtils.getDataPath(path);
+ FileSink<LogisticRegressionModelData> sink =
+ FileSink.forRowFormat(new Path(dataPath), new
LogisticRegressionModelDataEncoder())
+ .withRollingPolicy(OnCheckpointRollingPolicy.build())
+ .withBucketAssigner(new BasePathBucketAssigner<>())
+ .build();
+ ReadWriteUtils.saveMetadata(this, path);
+ tEnv.toDataStream(model)
+ .map(x -> (LogisticRegressionModelData) x.getField(0))
+ .sinkTo(sink)
+ .setParallelism(1);
+ }
+
+ public static LogisticRegressionModel load(StreamExecutionEnvironment env,
String path)
+ throws IOException {
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ Source<LogisticRegressionModelData, ?, ?> source =
+ FileSource.forRecordStreamFormat(
+ new LogisticRegressionModelDataStreamFormat(),
+ ReadWriteUtils.getDataPaths(path))
+ .build();
+ LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path);
+ DataStream<LogisticRegressionModelData> modelData =
+ env.fromSource(source, WatermarkStrategy.noWatermarks(),
"modelData");
+ model.setModelData(tEnv.fromDataStream(modelData));
+ return model;
+ }
+
+ @Override
+ public LogisticRegressionModel setModelData(Table... inputs) {
+ model = inputs[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {model};
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+
+ DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+ final String broadcastModelName = "broadcastModel";
+ DataStream<LogisticRegressionModelData> modelData =
+ tEnv.toDataStream(model).map(x ->
(LogisticRegressionModelData) x.getField(0));
Review comment:
Could you add the static
`LogisticRegressionModelData::getModelDataStream` to do this conversion? The
NaiveBayes PR does this and I find this to be useful. The reason is that this
approach centralizes the knowledge the the model data schema in one class. And
in the future when we provide util methods for users to manually construct and
set model data, those methods can be easily added in e.g.
`LogisticRegressionModelData`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]