lindong28 commented on a change in pull request #32: URL: https://github.com/apache/flink-ml/pull/32#discussion_r750878843
########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasFeatureSize.java ########## @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** Interface for the shared featureLength param. */ +public interface HasFeatureSize<T> extends WithParams<T> { Review comment: Can we remove this param and derive this information from the input data? ########## File path: flink-ml-api/pom.xml ########## @@ -65,5 +65,19 @@ under the License. <artifactId>flink-shaded-jackson</artifactId> <scope>provided</scope> </dependency> + + <dependency> + <groupId>com.github.fommil.netlib</groupId> Review comment: Do we still need `com.github.fommil.netlib`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelParams.java ########## @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.naivebayes; + +import org.apache.flink.ml.common.param.HasFeatureSize; +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; + +/** Parameters of naive bayes training process. */ Review comment: Maybe avoid introducing concepts such as `process` here and follow the same comments patterns as the existing algorithm? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java ########## @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.classification.naivebayes.NaiveBayes; +import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.table.api.Expressions.$; + +public class NaiveBayesTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Schema schema; + private Row[] trainData; + private Row[] predictData; + private Row[] expectedOutput; + private String featuresCol; + private int featureSize; + private String labelCol; + private String predictCol; + private String modelType; + private double smoothing; + private boolean isSaveLoad; + private String errorMessage; + + @Before + public void Setup() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + schema = + Schema.newBuilder() + .column("f0", DataTypes.DOUBLE()) + .column("f1", DataTypes.of(DenseVector.class)) + .column("f2", DataTypes.DOUBLE()) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build(); + + trainData = + new Row[] { + Row.of(1., Vectors.dense(1, 1., 1., 1., 2.), 11.0), + Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0) + }; + + predictData = trainData; + + expectedOutput = + new Row[] { + Row.of(1., Vectors.dense(1, 1., 1., 1., 2.), 11.0, 11.0), + Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), 11.0, 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0, 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0, 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0, 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0, 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0, 10.0) + }; + + featuresCol = "features"; + featureSize = 5; + labelCol = "label"; + predictCol = "predict"; + modelType = "categorical"; + smoothing = 1.0; + isSaveLoad = false; + } + + @Test + public void testNaiveBayes() throws Exception { + errorMessage = "normal test for Naive Bayes"; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testInputNotContainFeature() throws Exception { + errorMessage = + "Naive Bayes should throw exception if some feature columns are missing from train data"; + featuresCol = "feature2"; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testPredictUnseenFeature() throws Exception { + errorMessage = + "Naive Bayes should throw exception if unseen feature values are met in prediction " + + "and the model type is categorical."; + predictData = + new Row[] { + Row.of(1., Vectors.dense(5, 1., 1., 1., 2.), 11.0), + Row.of(1., Vectors.dense(5, 1., 0., 1., 2.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0) + }; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testEmptyLabel() throws Exception { Review comment: It seems a bit verbose to have a dedicated test for each param. We probably don't want to do this for every parameter of every Estimator/Transformer. And the purpose of this test seems to have already been covered by `StageTest::testValidators`. We can add unit tests there if needed. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLabelCol.java ########## @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.param; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Param of the name of the label column in the input table. + * + * @param <T> + */ +public interface HasLabelCol<T> extends WithParams<T> { + Param<String> LABEL_COL = + new StringParam( + "labelCol", + "Name of the label column in the input table.", + null, Review comment: How about changing the doc to `Label column name.` to stay consistent with `HasFeaturesCol`? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelParams.java ########## @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.naivebayes; + +import org.apache.flink.ml.common.param.HasFeatureSize; +import org.apache.flink.ml.common.param.HasFeaturesCol; +import org.apache.flink.ml.common.param.HasPredictionCol; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringParam; + +/** Parameters of naive bayes training process. */ +public interface NaiveBayesModelParams<T> + extends HasFeaturesCol<T>, HasFeatureSize<T>, HasPredictionCol<T> { + Param<String> MODEL_TYPE = + new StringParam( + "modelType", + "Type of the Naive Bayes model.", Review comment: It may be useful to specify the supported options here, as is done in `HasDistanceMeasure` and in Spark. For example, `The model type. Supported options: 'categorical'.` ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java ########## @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.classification.naivebayes.NaiveBayes; +import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.table.api.Expressions.$; + +public class NaiveBayesTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Schema schema; + private Row[] trainData; + private Row[] predictData; + private Row[] expectedOutput; + private String featuresCol; + private int featureSize; + private String labelCol; + private String predictCol; + private String modelType; + private double smoothing; + private boolean isSaveLoad; + private String errorMessage; + + @Before + public void Setup() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + schema = + Schema.newBuilder() + .column("f0", DataTypes.DOUBLE()) + .column("f1", DataTypes.of(DenseVector.class)) + .column("f2", DataTypes.DOUBLE()) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build(); + + trainData = + new Row[] { + Row.of(1., Vectors.dense(1, 1., 1., 1., 2.), 11.0), + Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0) + }; + + predictData = trainData; + + expectedOutput = + new Row[] { + Row.of(1., Vectors.dense(1, 1., 1., 1., 2.), 11.0, 11.0), + Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), 11.0, 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0, 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0, 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0, 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0, 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0, 10.0) + }; + + featuresCol = "features"; + featureSize = 5; + labelCol = "label"; + predictCol = "predict"; + modelType = "categorical"; + smoothing = 1.0; + isSaveLoad = false; + } + + @Test + public void testNaiveBayes() throws Exception { + errorMessage = "normal test for Naive Bayes"; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testInputNotContainFeature() throws Exception { + errorMessage = + "Naive Bayes should throw exception if some feature columns are missing from train data"; + featuresCol = "feature2"; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testPredictUnseenFeature() throws Exception { + errorMessage = + "Naive Bayes should throw exception if unseen feature values are met in prediction " + + "and the model type is categorical."; + predictData = + new Row[] { + Row.of(1., Vectors.dense(5, 1., 1., 1., 2.), 11.0), + Row.of(1., Vectors.dense(5, 1., 0., 1., 2.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0) + }; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testEmptyLabel() throws Exception { Review comment: It seems a bit verbose to have a dedicated test for each param. We probably don't want to do this for every parameter of every Estimator/Transformer. And the purpose of this test seems to have already been covered by `StageTest::testValidators`. We can add unit tests there if needed. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java ########## @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.classification.naivebayes.NaiveBayes; +import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.table.api.Expressions.$; + +public class NaiveBayesTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Schema schema; + private Row[] trainData; + private Row[] predictData; + private Row[] expectedOutput; + private String featuresCol; + private int featureSize; + private String labelCol; + private String predictCol; + private String modelType; + private double smoothing; + private boolean isSaveLoad; + private String errorMessage; + + @Before + public void Setup() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + schema = + Schema.newBuilder() + .column("f0", DataTypes.DOUBLE()) + .column("f1", DataTypes.of(DenseVector.class)) + .column("f2", DataTypes.DOUBLE()) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build(); + + trainData = + new Row[] { + Row.of(1., Vectors.dense(1, 1., 1., 1., 2.), 11.0), + Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0) + }; + + predictData = trainData; + + expectedOutput = + new Row[] { + Row.of(1., Vectors.dense(1, 1., 1., 1., 2.), 11.0, 11.0), + Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), 11.0, 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0, 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0, 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0, 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0, 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0, 10.0) + }; + + featuresCol = "features"; + featureSize = 5; + labelCol = "label"; + predictCol = "predict"; + modelType = "categorical"; + smoothing = 1.0; + isSaveLoad = false; + } + + @Test + public void testNaiveBayes() throws Exception { + errorMessage = "normal test for Naive Bayes"; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testInputNotContainFeature() throws Exception { + errorMessage = + "Naive Bayes should throw exception if some feature columns are missing from train data"; + featuresCol = "feature2"; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testPredictUnseenFeature() throws Exception { + errorMessage = + "Naive Bayes should throw exception if unseen feature values are met in prediction " + + "and the model type is categorical."; + predictData = + new Row[] { + Row.of(1., Vectors.dense(5, 1., 1., 1., 2.), 11.0), + Row.of(1., Vectors.dense(5, 1., 0., 1., 2.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0) + }; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testEmptyLabel() throws Exception { + errorMessage = "Naive Bayes should throw exception if label is empty string"; + labelCol = ""; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testNullLabel() throws Exception { + errorMessage = "Naive Bayes should throw exception if label is null"; + labelCol = null; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testNullLabelValue() throws Exception { + errorMessage = + "Naive Bayes should throw exception of train or predict data contains null label."; + trainData = + new Row[] { + Row.of(1., Vectors.dense(1, 1., 1., 1., 2.), null), + Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), null), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), null), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0) + }; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testInputNotContainLabel() throws Exception { + errorMessage = + "Naive Bayes should throw exception if input table schema does not contain label column."; + labelCol = "non-label"; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testNullPredict() throws Exception { + errorMessage = "Naive Bayes should throw exception if predict col is not set."; + predictCol = null; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testNullModelType() throws Exception { + errorMessage = "Naive Bayes should throw exception if model type is not set."; + modelType = null; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testVectorWithDiffLen() throws Exception { + errorMessage = + "Naive Bayes should throw exception if length of feature vectors are not equal."; + trainData = + new Row[] { + Row.of(1., Vectors.dense(1, 1., 1., 1.), 11.0), + Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0) + }; + runAndCheck(); + } + + @Test + public void testSaveLoad() throws Exception { + errorMessage = "Naive Bayes should be able to save Model to filesystem and load correctly."; + isSaveLoad = true; + runAndCheck(); + } + + private void runAndCheck() throws Exception { + Table trainTable = + tEnv.fromDataStream( + env.fromElements(trainData) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + schema) + .as("weight", "features", "label"); + Table predictTable = + tEnv.fromDataStream( + env.fromElements(predictData) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + schema) + .as("weight", "features", "label"); + + NaiveBayes estimator = new NaiveBayes(); + estimator.setSmoothing(smoothing); + if (featuresCol != null) estimator.setFeaturesCol(featuresCol); Review comment: Can we just remove the `if` and set the param directly? ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java ########## @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.naivebayes; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.core.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.stream.Collectors; + +/** Naive Bayes Predictor. */ Review comment: Can you rephrase the comments here following KMeans or Spark's convention? It is preferred not to introduce `Predictor` as a new concept here. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java ########## @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.classification.naivebayes.NaiveBayes; +import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.table.api.Expressions.$; + +public class NaiveBayesTest { Review comment: Can we add test to cover `setModelData`, `getModelData` and default param values? Maybe use KMeansTest for example. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java ########## @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.classification.naivebayes.NaiveBayes; +import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.flink.table.api.Expressions.$; + +public class NaiveBayesTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Schema schema; + private Row[] trainData; + private Row[] predictData; + private Row[] expectedOutput; + private String featuresCol; + private int featureSize; + private String labelCol; + private String predictCol; + private String modelType; + private double smoothing; + private boolean isSaveLoad; + private String errorMessage; + + @Before + public void Setup() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + schema = + Schema.newBuilder() + .column("f0", DataTypes.DOUBLE()) + .column("f1", DataTypes.of(DenseVector.class)) + .column("f2", DataTypes.DOUBLE()) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build(); + + trainData = + new Row[] { + Row.of(1., Vectors.dense(1, 1., 1., 1., 2.), 11.0), + Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0) + }; + + predictData = trainData; + + expectedOutput = + new Row[] { + Row.of(1., Vectors.dense(1, 1., 1., 1., 2.), 11.0, 11.0), + Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), 11.0, 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0, 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0, 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0, 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0, 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0, 10.0) + }; + + featuresCol = "features"; + featureSize = 5; + labelCol = "label"; + predictCol = "predict"; + modelType = "categorical"; + smoothing = 1.0; + isSaveLoad = false; + } + + @Test + public void testNaiveBayes() throws Exception { + errorMessage = "normal test for Naive Bayes"; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testInputNotContainFeature() throws Exception { + errorMessage = + "Naive Bayes should throw exception if some feature columns are missing from train data"; + featuresCol = "feature2"; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testPredictUnseenFeature() throws Exception { + errorMessage = + "Naive Bayes should throw exception if unseen feature values are met in prediction " + + "and the model type is categorical."; + predictData = + new Row[] { + Row.of(1., Vectors.dense(5, 1., 1., 1., 2.), 11.0), + Row.of(1., Vectors.dense(5, 1., 0., 1., 2.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0) + }; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testEmptyLabel() throws Exception { + errorMessage = "Naive Bayes should throw exception if label is empty string"; + labelCol = ""; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testNullLabel() throws Exception { + errorMessage = "Naive Bayes should throw exception if label is null"; + labelCol = null; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testNullLabelValue() throws Exception { + errorMessage = + "Naive Bayes should throw exception of train or predict data contains null label."; + trainData = + new Row[] { + Row.of(1., Vectors.dense(1, 1., 1., 1., 2.), null), + Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), null), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), null), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0) + }; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testInputNotContainLabel() throws Exception { + errorMessage = + "Naive Bayes should throw exception if input table schema does not contain label column."; + labelCol = "non-label"; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testNullPredict() throws Exception { + errorMessage = "Naive Bayes should throw exception if predict col is not set."; + predictCol = null; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testNullModelType() throws Exception { + errorMessage = "Naive Bayes should throw exception if model type is not set."; + modelType = null; + runAndCheck(); + } + + @Test(expected = Exception.class) + public void testVectorWithDiffLen() throws Exception { + errorMessage = + "Naive Bayes should throw exception if length of feature vectors are not equal."; + trainData = + new Row[] { + Row.of(1., Vectors.dense(1, 1., 1., 1.), 11.0), + Row.of(1., Vectors.dense(1, 1., 0., 1., 2.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1., 3.), 11.0), + Row.of(1., Vectors.dense(2, 0., 1., 1.5, 2.), 11.0), + Row.of(2., Vectors.dense(3, 1.5, 1., 0.5, 3.), 10.0), + Row.of(1., Vectors.dense(1, 1., 1.5, 0., 1.), 10.0), + Row.of(2., Vectors.dense(4, 1., 1., 0., 1.), 10.0) + }; + runAndCheck(); + } + + @Test + public void testSaveLoad() throws Exception { + errorMessage = "Naive Bayes should be able to save Model to filesystem and load correctly."; + isSaveLoad = true; + runAndCheck(); + } + + private void runAndCheck() throws Exception { + Table trainTable = + tEnv.fromDataStream( + env.fromElements(trainData) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + schema) + .as("weight", "features", "label"); + Table predictTable = + tEnv.fromDataStream( + env.fromElements(predictData) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + schema) + .as("weight", "features", "label"); + + NaiveBayes estimator = new NaiveBayes(); Review comment: It is preferred to follow best practice when writing test code, since it is what users will read and probably copy. We can use `NaiveBayes estimator = new NaiveBayes().setSmoothing(smoothing)` here. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java ########## @@ -0,0 +1,338 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.naivebayes; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.core.Estimator; +import org.apache.flink.ml.common.datastream.EndOfStreamWindows; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; +import org.apache.flink.streaming.api.windowing.windows.TimeWindow; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; + +/** + * Naive Bayes classifier is a simple probability classification algorithm using Bayes theorem based + * on independent assumption. It is an independent feature model. The input feature can be continual + * or categorical. + */ +public class NaiveBayes + implements Estimator<NaiveBayes, NaiveBayesModel>, NaiveBayesParams<NaiveBayes> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public NaiveBayes() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public NaiveBayesModel fit(Table... inputs) { + final String featuresCol = getFeaturesCol(); + final int featureSize = getFeatureSize(); + final String labelCol = getLabelCol(); + final String predictionCol = getPredictionCol(); Review comment: Should we remove `predictionCol` since it is not used? Same for other variables that are not used. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
