lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r766361133
########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/linear/LogisticRegressionTest.java ########## @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** Tests {@link LogisticRegression} and {@link LogisticRegressionModel}. */ +public class LogisticRegressionTest { + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private StreamExecutionEnvironment env; + + private StreamTableEnvironment tEnv; + + private static final List<Row> binomialTrainData = + Arrays.asList( + Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.), + Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.), + Row.of(Vectors.dense(3, 2, 3, 4), 0., 3.), + Row.of(Vectors.dense(4, 2, 3, 4), 0., 4.), + Row.of(Vectors.dense(5, 2, 3, 4), 0., 5.), + Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.), + Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.), + Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.), + Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.), + Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); + + private static final List<Row> multinomialTrainData = + Arrays.asList( + Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.), + Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.), + Row.of(Vectors.dense(3, 2, 3, 4), 2., 3.), + Row.of(Vectors.dense(4, 2, 3, 4), 2., 4.), + Row.of(Vectors.dense(5, 2, 3, 4), 2., 5.), + Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.), + Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.), + Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.), + Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.), + Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); + + private static final double[] expectedCoefficient = + new double[] {0.528, -0.286, -0.429, -0.572}; + + private static final double TOLERANCE = 1e-7; + + private Table binomialDataTable; + + private Table multinomialDataTable; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + Collections.shuffle(binomialTrainData); + binomialDataTable = + tEnv.fromDataStream( + env.fromCollection( + binomialTrainData, + new RowTypeInfo( + new TypeInformation[] { + TypeInformation.of(DenseVector.class), + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + multinomialDataTable = + tEnv.fromDataStream( + env.fromCollection( + multinomialTrainData, + new RowTypeInfo( + new TypeInformation[] { + TypeInformation.of(DenseVector.class), + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + } + + @SuppressWarnings("ConstantConditions") + private void verifyPredictionResult( + Table output, String featuresCol, String predictionCol, String rawPredictionCol) + throws Exception { + List<Row> predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row predictionRow : predResult) { + DenseVector feature = (DenseVector) predictionRow.getField(featuresCol); + double prediction = (double) predictionRow.getField(predictionCol); + DenseVector rawPrediction = (DenseVector) predictionRow.getField(rawPredictionCol); + if (feature.get(0) <= 5) { + assertEquals(0, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) > 0.5); + } else { + assertEquals(1, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) < 0.5); + } + } + } + + @Test + public void testParam() { + LogisticRegression logisticRegression = new LogisticRegression(); + assertEquals(logisticRegression.getLabelCol(), "label"); + assertNull(logisticRegression.getWeightCol()); + assertEquals(logisticRegression.getMaxIter(), 20); + assertEquals(logisticRegression.getReg(), 0, TOLERANCE); + assertEquals(logisticRegression.getLearningRate(), 0.1, TOLERANCE); + assertEquals(logisticRegression.getGlobalBatchSize(), 32); + assertEquals(logisticRegression.getTol(), 1e-6, TOLERANCE); + assertEquals(logisticRegression.getMultiClass(), "auto"); + assertEquals(logisticRegression.getFeaturesCol(), "features"); + assertEquals(logisticRegression.getPredictionCol(), "prediction"); + assertEquals(logisticRegression.getRawPredictionCol(), "rawPrediction"); + + logisticRegression + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setMaxIter(1000) + .setTol(0.001) + .setLearningRate(0.5) + .setGlobalBatchSize(1000) + .setReg(0.1) + .setMultiClass("binomial") + .setPredictionCol("test_predictionCol") + .setRawPredictionCol("test_rawPredictionCol"); + assertEquals(logisticRegression.getFeaturesCol(), "test_features"); + assertEquals(logisticRegression.getLabelCol(), "test_label"); + assertEquals(logisticRegression.getWeightCol(), "test_weight"); + assertEquals(logisticRegression.getMaxIter(), 1000); + assertEquals(logisticRegression.getTol(), 0.001, TOLERANCE); + assertEquals(logisticRegression.getLearningRate(), 0.5, TOLERANCE); + assertEquals(logisticRegression.getGlobalBatchSize(), 1000); + assertEquals(logisticRegression.getReg(), 0.1, TOLERANCE); + assertEquals(logisticRegression.getMultiClass(), "binomial"); + assertEquals(logisticRegression.getPredictionCol(), "test_predictionCol"); + assertEquals(logisticRegression.getRawPredictionCol(), "test_rawPredictionCol"); + } + + @Test + public void testFeaturePredictionParam() { + Table tempTable = binomialDataTable.as("test_features", "test_label", "test_weight"); + LogisticRegression logisticRegression = + new LogisticRegression() + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setPredictionCol("test_predictionCol") + .setRawPredictionCol("test_rawPredictionCol"); + Table output = logisticRegression.fit(binomialDataTable).transform(tempTable)[0]; + assertEquals( + Arrays.asList( + "test_features", + "test_label", + "test_weight", + "test_predictionCol", + "test_rawPredictionCol"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFitAndPredict() throws Exception { + LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); + Table output = logisticRegression.fit(binomialDataTable).transform(binomialDataTable)[0]; + verifyPredictionResult( + output, + logisticRegression.getFeaturesCol(), + logisticRegression.getPredictionCol(), + logisticRegression.getRawPredictionCol()); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + String path = tempFolder.newFolder().getAbsolutePath(); + LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); + LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); + model.save(path); + env.execute(); + LogisticRegressionModel loadedModel = LogisticRegressionModel.load(env, path); + assertEquals( + Collections.singletonList("f0"), + loadedModel.getModelData()[0].getResolvedSchema().getColumnNames()); + Table output = loadedModel.transform(binomialDataTable)[0]; + verifyPredictionResult( + output, + logisticRegression.getFeaturesCol(), + logisticRegression.getPredictionCol(), + logisticRegression.getRawPredictionCol()); + } + + @Test + public void testGetModelData() throws Exception { + LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); + LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); + List<Row> collectedModelData = + IteratorUtils.toList( + tEnv.toDataStream(model.getModelData()[0]).executeAndCollect()); + LogisticRegressionModelData modelData = + (LogisticRegressionModelData) collectedModelData.get(0).getField(0); + assert modelData != null; Review comment: nits: Could we use `Assert.assertNotNull(...)` for consistency. with `Assert.assertEquals(...)` used below? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java ########## @@ -26,9 +26,57 @@ private static final dev.ludovic.netlib.BLAS JAVA_BLAS = dev.ludovic.netlib.JavaBLAS.getInstance(); - /** y += a * x . */ + /** + * \sum_i |x_i| . + * + * @param x x + * @return \sum_i |x_i| + */ + public static double asum(DenseVector x) { + return JAVA_BLAS.dasum(x.size(), x.values, 0, 1); + } + + /** + * y += a * x . + * + * @param a a Review comment: This `@param a a` seems to redundant as it does not provide any useful information. Could we remove these and just keep `y += a * x`? This could be similar to Spark's BLAS Java doc. Same for other methods' Java doc. ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java ########## @@ -118,14 +119,26 @@ public static void saveMetadata(Stage<?> stage, String path) throws IOException saveMetadata(stage, path, new HashMap<>()); } - /** Returns a subdirectory of the given path for saving/loading model data. */ - public static String getDataPath(String path) { - return Paths.get(path, "data").toString(); + /** + * Returns a subdirectory of the given path for saving/loading table `tableIndex`. + * + * @param path The parent directory to save the table. + * @param tableIndex The index of the table to save. + * @return A subdirectory of the given path for saving and loading table `tableIndex`. + */ + public static String getDataPath(String path, int tableIndex) { Review comment: Its better to make this method `private` if it is only used in this class. This is similar to `private static String getPathForPipelineStage(...)`. ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java ########## @@ -329,33 +342,42 @@ public static void updateExistingParams(Stage<?> stage, Map<Param<?>, Object> pa * @param model The model data stream. * @param path The parent directory of the model data file. * @param modelEncoder The encoder to encode the model data. + * @param modelIndex The index of the table to save. * @param <T> The class type of the model data. */ public static <T> void saveModelData( - DataStream<T> model, String path, Encoder<T> modelEncoder) { + DataStream<T> model, String path, Encoder<T> modelEncoder, int modelIndex) { FileSink<T> sink = FileSink.forRowFormat( - new org.apache.flink.core.fs.Path(getDataPath(path)), modelEncoder) + new org.apache.flink.core.fs.Path(getDataPath(path, modelIndex)), + modelEncoder) .withRollingPolicy(OnCheckpointRollingPolicy.build()) .withBucketAssigner(new BasePathBucketAssigner<>()) .build(); model.sinkTo(sink); } /** - * Loads the model data from the given path using the model decoder. + * Loads the model table with index `modelIndex` from the given path using the model decoder. * * @param env A StreamExecutionEnvironment instance. * @param path The parent directory of the model data file. * @param modelDecoder The decoder used to decode the model data. + * @param modelIndex The index of the table to load. * @param <T> The class type of the model data. * @return The loaded model data. */ - public static <T> DataStream<T> loadModelData( - StreamExecutionEnvironment env, String path, SimpleStreamFormat<T> modelDecoder) { + public static <T> Table loadModelData( Review comment: Why do we need `modelIndex`? It seems that modelIndex is always 0 as of now. Since `saveModelData(...)` takes a `DataStream`, could we let `loadModelData(...)` return a DataStream so that they are symmetric? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/iteration/TerminateOnMaxIterOrTol.java ########## @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.iteration; + +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.iteration.IterationListener; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +/** + * A FlatMapFunction that emits values iff the iteration's epochWatermark does not exceed a certain + * threshold and the loss exceeds a certain tolerance. + * + * <p>When the output of this FlatMapFunction is used as the termination criteria of an iteration + * body, the iteration will be executed for at most the given `maxIter` iterations. And the + * iteration will terminate once any input value is smaller than or equal to the given `tol`. + */ +public class TerminateOnMaxIterOrTol + implements IterationListener<Integer>, FlatMapFunction<Double, Integer> { + + private final int maxIter; + + private final double tol; + + private double loss = Double.MAX_VALUE; + + public TerminateOnMaxIterOrTol(Integer maxIter, Double tol) { + this.maxIter = maxIter; + this.tol = tol; + } + + public TerminateOnMaxIterOrTol(Double tol) { + this.maxIter = Integer.MAX_VALUE; + this.tol = tol; + } + + @Override + public void flatMap(Double value, Collector<Integer> out) { + Preconditions.checkArgument( + Double.compare(this.loss, Double.MAX_VALUE) == 0, + "Each epoch should contain only one loss value."); + this.loss = value; + } + + @Override + public void onEpochWatermarkIncremented( + int epochWatermark, Context context, Collector<Integer> collector) { + if ((epochWatermark + 1) < maxIter && this.loss > tol) { Review comment: nits: can we replace `this.loss` with `loss` for better code consistency? In general we only use `this.variable` if the class member variable name collide with the function parameter name. Same for other usages of `this.xxx`. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/linear/LogisticRegressionTest.java ########## @@ -0,0 +1,280 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.linear; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** Tests {@link LogisticRegression} and {@link LogisticRegressionModel}. */ +public class LogisticRegressionTest { + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private StreamExecutionEnvironment env; + + private StreamTableEnvironment tEnv; + + private static final List<Row> binomialTrainData = + Arrays.asList( + Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.), + Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.), + Row.of(Vectors.dense(3, 2, 3, 4), 0., 3.), + Row.of(Vectors.dense(4, 2, 3, 4), 0., 4.), + Row.of(Vectors.dense(5, 2, 3, 4), 0., 5.), + Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.), + Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.), + Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.), + Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.), + Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); + + private static final List<Row> multinomialTrainData = + Arrays.asList( + Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.), + Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.), + Row.of(Vectors.dense(3, 2, 3, 4), 2., 3.), + Row.of(Vectors.dense(4, 2, 3, 4), 2., 4.), + Row.of(Vectors.dense(5, 2, 3, 4), 2., 5.), + Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.), + Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.), + Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.), + Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.), + Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); + + private static final double[] expectedCoefficient = + new double[] {0.528, -0.286, -0.429, -0.572}; + + private static final double TOLERANCE = 1e-7; + + private Table binomialDataTable; + + private Table multinomialDataTable; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + Collections.shuffle(binomialTrainData); + binomialDataTable = + tEnv.fromDataStream( + env.fromCollection( + binomialTrainData, + new RowTypeInfo( + new TypeInformation[] { + TypeInformation.of(DenseVector.class), + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + multinomialDataTable = + tEnv.fromDataStream( + env.fromCollection( + multinomialTrainData, + new RowTypeInfo( + new TypeInformation[] { + TypeInformation.of(DenseVector.class), + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + } + + @SuppressWarnings("ConstantConditions") + private void verifyPredictionResult( + Table output, String featuresCol, String predictionCol, String rawPredictionCol) + throws Exception { + List<Row> predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row predictionRow : predResult) { + DenseVector feature = (DenseVector) predictionRow.getField(featuresCol); + double prediction = (double) predictionRow.getField(predictionCol); + DenseVector rawPrediction = (DenseVector) predictionRow.getField(rawPredictionCol); + if (feature.get(0) <= 5) { + assertEquals(0, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) > 0.5); + } else { + assertEquals(1, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) < 0.5); + } + } + } + + @Test + public void testParam() { + LogisticRegression logisticRegression = new LogisticRegression(); + assertEquals(logisticRegression.getLabelCol(), "label"); + assertNull(logisticRegression.getWeightCol()); + assertEquals(logisticRegression.getMaxIter(), 20); + assertEquals(logisticRegression.getReg(), 0, TOLERANCE); + assertEquals(logisticRegression.getLearningRate(), 0.1, TOLERANCE); + assertEquals(logisticRegression.getGlobalBatchSize(), 32); + assertEquals(logisticRegression.getTol(), 1e-6, TOLERANCE); + assertEquals(logisticRegression.getMultiClass(), "auto"); + assertEquals(logisticRegression.getFeaturesCol(), "features"); + assertEquals(logisticRegression.getPredictionCol(), "prediction"); + assertEquals(logisticRegression.getRawPredictionCol(), "rawPrediction"); + + logisticRegression + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setMaxIter(1000) + .setTol(0.001) + .setLearningRate(0.5) + .setGlobalBatchSize(1000) + .setReg(0.1) + .setMultiClass("binomial") + .setPredictionCol("test_predictionCol") + .setRawPredictionCol("test_rawPredictionCol"); + assertEquals(logisticRegression.getFeaturesCol(), "test_features"); + assertEquals(logisticRegression.getLabelCol(), "test_label"); + assertEquals(logisticRegression.getWeightCol(), "test_weight"); + assertEquals(logisticRegression.getMaxIter(), 1000); + assertEquals(logisticRegression.getTol(), 0.001, TOLERANCE); + assertEquals(logisticRegression.getLearningRate(), 0.5, TOLERANCE); + assertEquals(logisticRegression.getGlobalBatchSize(), 1000); + assertEquals(logisticRegression.getReg(), 0.1, TOLERANCE); + assertEquals(logisticRegression.getMultiClass(), "binomial"); + assertEquals(logisticRegression.getPredictionCol(), "test_predictionCol"); + assertEquals(logisticRegression.getRawPredictionCol(), "test_rawPredictionCol"); + } + + @Test + public void testFeaturePredictionParam() { + Table tempTable = binomialDataTable.as("test_features", "test_label", "test_weight"); + LogisticRegression logisticRegression = + new LogisticRegression() + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setPredictionCol("test_predictionCol") + .setRawPredictionCol("test_rawPredictionCol"); + Table output = logisticRegression.fit(binomialDataTable).transform(tempTable)[0]; + assertEquals( + Arrays.asList( + "test_features", + "test_label", + "test_weight", + "test_predictionCol", + "test_rawPredictionCol"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFitAndPredict() throws Exception { + LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); + Table output = logisticRegression.fit(binomialDataTable).transform(binomialDataTable)[0]; + verifyPredictionResult( + output, + logisticRegression.getFeaturesCol(), + logisticRegression.getPredictionCol(), + logisticRegression.getRawPredictionCol()); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + String path = tempFolder.newFolder().getAbsolutePath(); + LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); + LogisticRegressionModel model = logisticRegression.fit(binomialDataTable); + model.save(path); + env.execute(); + LogisticRegressionModel loadedModel = LogisticRegressionModel.load(env, path); Review comment: Could we update the code to use `StageTestUtils.saveAndReload(...)`? This util method is provided in the NaiveBayes PR. And could we also test the save/load of the `logisticRegression` in this test, similar to `NaiveBayesTest::testSaveLoad`? ########## File path: flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java ########## @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapPartitionFunction; +import org.apache.flink.api.common.state.ListState; +import org.apache.flink.api.common.state.ListStateDescriptor; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.typeutils.TypeExtractor; +import org.apache.flink.runtime.state.StateInitializationContext; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.BoundedOneInput; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.api.operators.TimestampedCollector; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; + +/** Provides utility functions for {@link DataStream}. */ +public class DataStreamUtils { + /** + * Applies allReduceSum on the input data stream. The input data stream is supposed to contain + * one double array in each partition. The result data stream has the same parallelism as the + * input, where each partition contains one double array that sums all of the double arrays in + * the input data stream. + * + * <p>Note that we throw exception when one of the following two cases happen: + * <li>There exists one partition that contains more than one double array. + * <li>The length of the double array is not consistent among all partitions. + * + * @param input The input data stream. + * @return The result data stream. + */ + public static DataStream<double[]> allReduceSum(DataStream<double[]> input) { + return AllReduceImpl.allReduceSum(input); + } + + /** + * Applies a {@link MapPartitionFunction} on a bounded data stream. + * + * @param input The input data stream. + * @param func The user defined mapPartition function. + * @param <IN> The class type of the input element. + * @param <OUT> The class type of output element. + * @return The result data stream. + */ + public static <IN, OUT> DataStream<OUT> mapPartition( + DataStream<IN> input, MapPartitionFunction<IN, OUT> func) { + TypeInformation<OUT> resultType = + TypeExtractor.getMapPartitionReturnTypes(func, input.getType(), null, true); + return input.transform("mapPartition", resultType, new MapPartitionOperator<>(func)) + .setParallelism(input.getParallelism()); + } + + /** + * Applies a {@link MapPartitionFunction} on a bounded data stream. + * + * @param input The input data stream. + * @param func The user defined mapPartition function. + * @param outputType The type information of the output element. + * @param <IN> The class type of the input element. + * @param <OUT> The class type of output element. + * @return The result data stream. + */ + public static <IN, OUT> DataStream<OUT> mapPartition( Review comment: Why do we need this method? Could we just use the `mapPartition(...)` defined above which does not require user to explicitly provide the `outputType`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
