lindong28 commented on a change in pull request #28: URL: https://github.com/apache/flink-ml/pull/28#discussion_r767506178
########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModelData.java ########## @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.serialization.Encoder; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.connector.file.src.reader.SimpleStreamFormat; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.memory.DataInputViewStreamWrapper; +import org.apache.flink.core.memory.DataOutputViewStreamWrapper; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; + +import java.io.EOFException; +import java.io.IOException; +import java.io.OutputStream; + +/** + * Model data of {@link LogisticRegressionModel}. + * + * <p>This class also provides methods to convert model data from Table to Datastream, and classes + * to save/load model data. + */ +public class LogisticRegressionModelData { + + public final DenseVector coefficient; + + public LogisticRegressionModelData(DenseVector coefficient) { + this.coefficient = coefficient; + } + + /** + * Converts the table model to a data stream. + * + * @param modelData The table model data. + * @return The data stream model data. + */ + public static DataStream<LogisticRegressionModelData> getModelDataStream(Table modelData) { Review comment: @gaoyunhaii is this OK to use `LogisticRegressionModelData` as the DataStream element type when `LogisticRegressionModelData` is not serializable? If we use `LogisticRegressionModelData` as DataStream element type, would Flink automatically use `DenseVectorSerializer` to serialize/de-serialize `coefficient`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionModel.java ########## @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** A Model which classifies data using the model data computed by {@link LogisticRegression}. */ +public class LogisticRegressionModel + implements Model<LogisticRegressionModel>, + LogisticRegressionModelParams<LogisticRegressionModel> { + + private Map<Param<?>, Object> paramMap = new HashMap<>(); + + private Table modelDataTable; + + public LogisticRegressionModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + LogisticRegressionModelData.getModelDataStream(modelDataTable), + path, + new LogisticRegressionModelData.ModelDataEncoder()); + } + + public static LogisticRegressionModel load(StreamExecutionEnvironment env, String path) + throws IOException { + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + LogisticRegressionModel model = ReadWriteUtils.loadStageParam(path); + DataStream<LogisticRegressionModelData> modelData = + ReadWriteUtils.loadModelData( + env, path, new LogisticRegressionModelData.ModelDataDecoder()); + return model.setModelData(tEnv.fromDataStream(modelData)); + } + + @Override + public LogisticRegressionModel setModelData(Table... inputs) { + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + @SuppressWarnings("unchecked") + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> inputStream = tEnv.toDataStream(inputs[0]); + final String broadcastModelKey = "broadcastModelKey"; + DataStream<LogisticRegressionModelData> modelDataStream = + LogisticRegressionModelData.getModelDataStream(modelDataTable); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll( + inputTypeInfo.getFieldTypes(), + BasicTypeInfo.DOUBLE_TYPE_INFO, + TypeInformation.of(DenseVector.class)), + ArrayUtils.addAll( + inputTypeInfo.getFieldNames(), + getPredictionCol(), + getRawPredictionCol())); + DataStream<Row> predictionResult = + BroadcastUtils.withBroadcastStream( + Collections.singletonList(inputStream), + Collections.singletonMap(broadcastModelKey, modelDataStream), + inputList -> { + DataStream inputData = inputList.get(0); + return inputData.map( + new PredictLabelFunction(broadcastModelKey, getFeaturesCol()), + outputTypeInfo); + }); + return new Table[] {tEnv.fromDataStream(predictionResult)}; + } + + /** A utility function used for prediction. */ + private static class PredictLabelFunction extends RichMapFunction<Row, Row> { + + private final String broadcastModelKey; + + private final String featuresCol; + + private DenseVector coefficient; + + public PredictLabelFunction(String broadcastModelKey, String featuresCol) { + this.broadcastModelKey = broadcastModelKey; + this.featuresCol = featuresCol; + } + + @Override + public Row map(Row dataPoint) { + if (coefficient == null) { + LogisticRegressionModelData modelData = + (LogisticRegressionModelData) + getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0); + coefficient = modelData.coefficient; + } + DenseVector features = (DenseVector) dataPoint.getField(featuresCol); + Tuple2<Double, DenseVector> predictionResult = predictRaw(features, coefficient); + return Row.join(dataPoint, Row.of(predictionResult.f0), Row.of(predictionResult.f1)); Review comment: nits: it seems simpler to use `Row.join(dataPoint, Row.of(predictionResult.f0, predictionResult.f1))`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/logisticregression/LogisticRegressionTest.java ########## @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.logisticregression; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** Tests {@link LogisticRegression} and {@link LogisticRegressionModel}. */ +public class LogisticRegressionTest { + + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + + private StreamExecutionEnvironment env; + + private StreamTableEnvironment tEnv; + + private static final List<Row> binomialTrainData = + Arrays.asList( + Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.), + Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.), + Row.of(Vectors.dense(3, 2, 3, 4), 0., 3.), + Row.of(Vectors.dense(4, 2, 3, 4), 0., 4.), + Row.of(Vectors.dense(5, 2, 3, 4), 0., 5.), + Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.), + Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.), + Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.), + Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.), + Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); + + private static final List<Row> multinomialTrainData = + Arrays.asList( + Row.of(Vectors.dense(1, 2, 3, 4), 0., 1.), + Row.of(Vectors.dense(2, 2, 3, 4), 0., 2.), + Row.of(Vectors.dense(3, 2, 3, 4), 2., 3.), + Row.of(Vectors.dense(4, 2, 3, 4), 2., 4.), + Row.of(Vectors.dense(5, 2, 3, 4), 2., 5.), + Row.of(Vectors.dense(11, 2, 3, 4), 1., 1.), + Row.of(Vectors.dense(12, 2, 3, 4), 1., 2.), + Row.of(Vectors.dense(13, 2, 3, 4), 1., 3.), + Row.of(Vectors.dense(14, 2, 3, 4), 1., 4.), + Row.of(Vectors.dense(15, 2, 3, 4), 1., 5.)); + + private static final double[] expectedCoefficient = + new double[] {0.528, -0.286, -0.429, -0.572}; + + private static final double TOLERANCE = 1e-7; + + private Table binomialDataTable; + + private Table multinomialDataTable; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + Collections.shuffle(binomialTrainData); + binomialDataTable = + tEnv.fromDataStream( + env.fromCollection( + binomialTrainData, + new RowTypeInfo( + new TypeInformation[] { + TypeInformation.of(DenseVector.class), + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + multinomialDataTable = + tEnv.fromDataStream( + env.fromCollection( + multinomialTrainData, + new RowTypeInfo( + new TypeInformation[] { + TypeInformation.of(DenseVector.class), + Types.DOUBLE, + Types.DOUBLE + }, + new String[] {"features", "label", "weight"}))); + } + + @SuppressWarnings("ConstantConditions") + private void verifyPredictionResult( + Table output, String featuresCol, String predictionCol, String rawPredictionCol) + throws Exception { + List<Row> predResult = IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row predictionRow : predResult) { + DenseVector feature = (DenseVector) predictionRow.getField(featuresCol); + double prediction = (double) predictionRow.getField(predictionCol); + DenseVector rawPrediction = (DenseVector) predictionRow.getField(rawPredictionCol); + if (feature.get(0) <= 5) { + assertEquals(0, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) > 0.5); + } else { + assertEquals(1, prediction, TOLERANCE); + assertTrue(rawPrediction.get(0) < 0.5); + } + } + } + + @Test + public void testParam() { + LogisticRegression logisticRegression = new LogisticRegression(); + assertEquals(logisticRegression.getLabelCol(), "label"); + assertNull(logisticRegression.getWeightCol()); + assertEquals(logisticRegression.getMaxIter(), 20); + assertEquals(logisticRegression.getReg(), 0, TOLERANCE); + assertEquals(logisticRegression.getLearningRate(), 0.1, TOLERANCE); + assertEquals(logisticRegression.getGlobalBatchSize(), 32); + assertEquals(logisticRegression.getTol(), 1e-6, TOLERANCE); + assertEquals(logisticRegression.getMultiClass(), "auto"); + assertEquals(logisticRegression.getFeaturesCol(), "features"); + assertEquals(logisticRegression.getPredictionCol(), "prediction"); + assertEquals(logisticRegression.getRawPredictionCol(), "rawPrediction"); + + logisticRegression + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setMaxIter(1000) + .setTol(0.001) + .setLearningRate(0.5) + .setGlobalBatchSize(1000) + .setReg(0.1) + .setMultiClass("binomial") + .setPredictionCol("test_predictionCol") + .setRawPredictionCol("test_rawPredictionCol"); + assertEquals(logisticRegression.getFeaturesCol(), "test_features"); + assertEquals(logisticRegression.getLabelCol(), "test_label"); + assertEquals(logisticRegression.getWeightCol(), "test_weight"); + assertEquals(logisticRegression.getMaxIter(), 1000); + assertEquals(logisticRegression.getTol(), 0.001, TOLERANCE); + assertEquals(logisticRegression.getLearningRate(), 0.5, TOLERANCE); + assertEquals(logisticRegression.getGlobalBatchSize(), 1000); + assertEquals(logisticRegression.getReg(), 0.1, TOLERANCE); + assertEquals(logisticRegression.getMultiClass(), "binomial"); + assertEquals(logisticRegression.getPredictionCol(), "test_predictionCol"); + assertEquals(logisticRegression.getRawPredictionCol(), "test_rawPredictionCol"); + } + + @Test + public void testFeaturePredictionParam() { + Table tempTable = binomialDataTable.as("test_features", "test_label", "test_weight"); + LogisticRegression logisticRegression = + new LogisticRegression() + .setFeaturesCol("test_features") + .setLabelCol("test_label") + .setWeightCol("test_weight") + .setPredictionCol("test_predictionCol") + .setRawPredictionCol("test_rawPredictionCol"); + Table output = logisticRegression.fit(binomialDataTable).transform(tempTable)[0]; + assertEquals( + Arrays.asList( + "test_features", + "test_label", + "test_weight", + "test_predictionCol", + "test_rawPredictionCol"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFitAndPredict() throws Exception { + LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); + Table output = logisticRegression.fit(binomialDataTable).transform(binomialDataTable)[0]; + verifyPredictionResult( + output, + logisticRegression.getFeaturesCol(), + logisticRegression.getPredictionCol(), + logisticRegression.getRawPredictionCol()); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + LogisticRegression logisticRegression = new LogisticRegression().setWeightCol("weight"); Review comment: Could we also save/load the estimator in this test, similar to `NaiveBayesTest::testSaveLoad`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
