yunfengzhou-hub commented on code in PR #90: URL: https://github.com/apache/flink-ml/pull/90#discussion_r860433041
########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/optimizer/Optimizer.java: ########## @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.optimizer; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.common.lossfunc.LossFunc; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.streaming.api.datastream.DataStream; + +/** + * An optimizer is a function to modify the weight of a machine learning model, which aims to find + * the optimal parameter configuration for a machine learning model. Examples of optimizers could be + * stochastic gradient descent (SGD), L-BFGS, etc. + * + * @param <ParamType> Type of the optimizer-related parameter. + */ +@Internal +public abstract class Optimizer<ParamType> { + /** + * Optimize the given loss function using the init model and the training data. + * + * @param bcInitModel The broadcasted init model. Note that each task contains one DenseVector Review Comment: nit: broadcast ########## flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java: ########## @@ -34,10 +34,16 @@ public static double asum(DenseVector x) { /** y += a * x . */ public static void axpy(double a, Vector x, DenseVector y) { Preconditions.checkArgument(x.size() == y.size(), "Vector size mismatched."); + axpy(a, x, y, x.size()); + } + + /** y += a * x for the first k dimensions, with the other dimensions unchanged. */ + public static void axpy(double a, Vector x, DenseVector y, int k) { Review Comment: In which cases do we need to do computation on only the first `k` dimensions? ########## flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java: ########## @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.regression.linearregression.LinearRegression; +import org.apache.flink.ml.regression.linearregression.LinearRegressionModel; +import org.apache.flink.ml.regression.linearregression.LinearRegressionModelData; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Tests {@link LinearRegression} and {@link LinearRegressionModel}. */ +public class LinearRegressionTest { + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private StreamExecutionEnvironment env; + + private StreamTableEnvironment tEnv; + + private static final List<Row> trainData = + Arrays.asList( + Row.of(Vectors.dense(2, 1), 4.0, 1.0), + Row.of(Vectors.dense(3, 2), 7.0, 1.0), + Row.of(Vectors.dense(4, 3), 10.0, 1.0), + Row.of(Vectors.dense(2, 4), 10.0, 1.0), + Row.of(Vectors.dense(2, 2), 6.0, 1.0), + Row.of(Vectors.dense(4, 3), 10.0, 1.0), + Row.of(Vectors.dense(1, 2), 5.0, 1.0), + Row.of(Vectors.dense(5, 3), 11.0, 1.0)); + + private static final double[] expectedCoefficient = new double[] {1.0, 2.0}; + + private static final double TOLERANCE = 1e-7; + + private Table trainDataTable; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + Collections.shuffle(trainData); + trainDataTable = + tEnv.fromDataStream( + env.fromCollection( + trainData, + new RowTypeInfo( + new TypeInformation[] { + TypeInformation.of(DenseVector.class), Review Comment: nit: `DenseVectorTypeInfo.INSTANCE` ########## flink-ml-lib/src/test/java/org/apache/flink/ml/regression/LinearRegressionTest.java: ########## @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.regression; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.regression.linearregression.LinearRegression; +import org.apache.flink.ml.regression.linearregression.LinearRegressionModel; +import org.apache.flink.ml.regression.linearregression.LinearRegressionModelData; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +/** Tests {@link LinearRegression} and {@link LinearRegressionModel}. */ +public class LinearRegressionTest { + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private StreamExecutionEnvironment env; + + private StreamTableEnvironment tEnv; + + private static final List<Row> trainData = + Arrays.asList( + Row.of(Vectors.dense(2, 1), 4.0, 1.0), + Row.of(Vectors.dense(3, 2), 7.0, 1.0), + Row.of(Vectors.dense(4, 3), 10.0, 1.0), + Row.of(Vectors.dense(2, 4), 10.0, 1.0), + Row.of(Vectors.dense(2, 2), 6.0, 1.0), + Row.of(Vectors.dense(4, 3), 10.0, 1.0), + Row.of(Vectors.dense(1, 2), 5.0, 1.0), + Row.of(Vectors.dense(5, 3), 11.0, 1.0)); + + private static final double[] expectedCoefficient = new double[] {1.0, 2.0}; + + private static final double TOLERANCE = 1e-7; + + private Table trainDataTable; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + Collections.shuffle(trainData); + trainDataTable = + tEnv.fromDataStream( + env.fromCollection( + trainData, + new RowTypeInfo( + new TypeInformation[] { + TypeInformation.of(DenseVector.class), + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + } + + @SuppressWarnings("unchecked") + private void verifyPredictionResult( + Table output, String labelCol, String weightCol, String predictionCol) + throws Exception { + List<Row> predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + double lossSum = 0; + for (Row predictionRow : predResult) { + double label = (double) predictionRow.getField(labelCol); + double prediction = (double) predictionRow.getField(predictionCol); + double weight = (double) predictionRow.getField(weightCol); + lossSum += weight * Math.pow(label - prediction, 2); + } + assertTrue(lossSum < 1.0); Review Comment: I'm not sure whether this criteria can guarantee the correctness of this algorithm. Is it possible to provide an estimation of the expected prediction result, and tests that the actual prediction should be close enough to the expected value, for example using `expectedCoefficient`? ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasElasticNetParam.java: ########## @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** + * Interface for the shared elasticNet param, which specifies the mixing of L1 and L2 penalty: + * <li>If the value is zero, it is L2 penalty. + * <li>If the value is one, it is L1 penalty. + * <li>For value in (0,1), it is a combination of L1 and L2 penalty. Review Comment: nit: add `<ul></ul>` around the `<li>` tags could make the JavaDoc's rendered results better. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LossFunc.java: ########## @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.DenseVector; + +import java.io.Serializable; + +/** + * A loss function is to compute the loss and gradient with the given coefficient and training data. + */ +@Internal +public interface LossFunc extends Serializable { Review Comment: Would it be better to merge `LossFunc` and its implementations into `optimizer` package? Would `LossFunc` be used in places other than `Optimizer` subclasses? ########## flink-ml-lib/src/test/java/org/apache/flink/ml/common/optimizer/RegularizationUtilsTest.java: ########## @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.optimizer; + +import org.apache.flink.ml.linalg.DenseVector; + +import org.apache.commons.lang3.RandomUtils; +import org.junit.Test; + +import static org.junit.Assert.assertArrayEquals; + +/** Tests {@link RegularizationUtils}. */ +public class RegularizationUtilsTest { + + private static final double learningRate = 0.1; + private static final double TOLERANCE = 1e-7; + private static final DenseVector coefficient = new DenseVector(new double[] {1.0, -2.0, 0}); + + @Test + public void testReg0() { Review Comment: The test cases in this class seems to be only different in their input values. Moving them to a single test case might be of better readability. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasElasticNetParam.java: ########## @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** + * Interface for the shared elasticNet param, which specifies the mixing of L1 and L2 penalty: + * <li>If the value is zero, it is L2 penalty. + * <li>If the value is one, it is L1 penalty. + * <li>For value in (0,1), it is a combination of L1 and L2 penalty. + */ +public interface HasElasticNetParam<T> extends WithParams<T> { Review Comment: nit: `HasElasticNet` seems better in accordance with other class names. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasElasticNetParam.java: ########## @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** + * Interface for the shared elasticNet param, which specifies the mixing of L1 and L2 penalty: + * <li>If the value is zero, it is L2 penalty. + * <li>If the value is one, it is L1 penalty. + * <li>For value in (0,1), it is a combination of L1 and L2 penalty. + */ +public interface HasElasticNetParam<T> extends WithParams<T> { + Param<Double> ELASTICNET = Review Comment: nit: `ELASTIC_NET` seems better in accordance with others. ########## flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java: ########## @@ -67,6 +72,28 @@ public static <IN, OUT> DataStream<OUT> mapPartition( .setParallelism(input.getParallelism()); } + /** + * Applies a {@link ReduceFunction} on a bounded data stream. The output stream contains at most + * one stream record and its parallelism is one. + * + * @param input The input data stream. + * @param func The user defined reduce function. + * @param <T> The class type of the input. + * @return The result data stream. + */ + public static <T> DataStream<T> reduce(DataStream<T> input, ReduceFunction<T> func) { + DataStream<T> partialReducedStream = + input.transform("partialReduce", input.getType(), new ReduceOperator<>(func)) + .setParallelism(input.getParallelism()); Review Comment: It might be better to use `forward()` and `colocate()` here so that the reduce is of better performance. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/common/lossfunc/LeastSquareLoss.java: ########## @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.lossfunc; + +import org.apache.flink.annotation.Internal; +import org.apache.flink.ml.common.feature.LabeledPointWithWeight; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.regression.linearregression.LinearRegression; + +/** The loss function for linear regression. See {@link LinearRegression} */ +@Internal +public class LeastSquareLoss implements LossFunc { Review Comment: It seems that this class can be used in any algorithm that uses the least square method. Maybe we can refractor its Javadoc to reflect this. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
