lindong28 commented on a change in pull request #32: URL: https://github.com/apache/flink-ml/pull/32#discussion_r756947787
########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/StageTestUtils.java ########## @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.util; + +import org.apache.flink.ml.api.core.Stage; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; + +import org.junit.rules.TemporaryFolder; + +import java.io.IOException; +import java.lang.reflect.Method; + +/** Utility methods for testing stages. */ +public class StageTestUtils { + private static final TemporaryFolder tempFolder = new TemporaryFolder(); + + static { + try { + tempFolder.create(); + } catch (IOException e) { + e.printStackTrace(); + } + } + + /** + * Saves a stage to filesystem and reloads it with the static load() method a stage must Review comment: nits: "a stage must implement" seems irrelevant here. Maybe change it to `Saves a stage to filesystem and reloads it by invoking the static load() method of the given stage`. ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.classification.naivebayes.NaiveBayes; +import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; +import org.apache.flink.ml.classification.naivebayes.NaiveBayesModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import org.junit.Before; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests {@link NaiveBayes} and {@link NaiveBayesModel}. */ +public class NaiveBayesTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainTable; + private Table predictTable; + private Map<Vector, Integer> expectedOutput; + private NaiveBayes estimator; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema trainSchema = + Schema.newBuilder() + .column("f0", DataTypes.of(DenseVector.class)) + .column("f1", DataTypes.INT()) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build(); + + trainTable = + tEnv.fromDataStream( + env.fromElements( + Row.of(Vectors.dense(0, 0.), 11), + Row.of(Vectors.dense(1, 0), 10), + Row.of(Vectors.dense(1, 1.), 10)) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + trainSchema) + .as("features", "label"); + + Schema predictSchema = + Schema.newBuilder() + .column("f0", DataTypes.of(DenseVector.class)) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build(); + + predictTable = + tEnv.fromDataStream( + env.fromElements( + Row.of(Vectors.dense(0, 1.)), + Row.of(Vectors.dense(0, 0.)), + Row.of(Vectors.dense(1, 0)), + Row.of(Vectors.dense(1, 1.))) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + predictSchema) + .as("features"); + + expectedOutput = + new HashMap<Vector, Integer>() { + { + put(Vectors.dense(0, 1.), 11); + put(Vectors.dense(0, 0.), 11); + put(Vectors.dense(1, 0.), 10); + put(Vectors.dense(1, 1.), 10); + } + }; + + estimator = + new NaiveBayes() + .setSmoothing(1.0) + .setFeaturesCol("features") + .setLabelCol("label") + .setPredictionCol("predict") + .setModelType("multinomial"); + } + + /** + * Executes a given table and collect its results. Results are returned as a map whose key is + * the feature, value is the prediction result. + * + * @param table A table to be executed and to have its result collected + * @param featuresCol Name of the column in the table that contains the features + * @param predictionCol Name of the column in the table that contains the prediction result + * @return A map containing the collected results + */ + private static Map<Vector, Integer> executeAndCollect( + Table table, String featuresCol, String predictionCol) { + Map<Vector, Integer> map = new HashMap<>(); + for (CloseableIterator<Row> it = table.execute().collect(); it.hasNext(); ) { + Row row = it.next(); + map.put( + (Vector) row.getField(featuresCol), + ((Number) row.getField(predictionCol)).intValue()); + } + + return map; + } + + @Test + public void testParam() { + NaiveBayes estimator = new NaiveBayes(); + + assertEquals("features", estimator.getFeaturesCol()); + assertEquals("label", estimator.getLabelCol()); + assertEquals("multinomial", estimator.getModelType()); + assertEquals("prediction", estimator.getPredictionCol()); + assertEquals(1.0, estimator.getSmoothing(), 1e-5); + + estimator + .setFeaturesCol("test_feature") + .setLabelCol("test_label") + .setPredictionCol("test_prediction") + .setSmoothing(2.0); + + assertEquals("test_feature", estimator.getFeaturesCol()); + assertEquals("test_label", estimator.getLabelCol()); + assertEquals("test_prediction", estimator.getPredictionCol()); + assertEquals(2.0, estimator.getSmoothing(), 1e-5); + + NaiveBayesModel model = new NaiveBayesModel(); + + assertEquals("features", model.getFeaturesCol()); + assertEquals("multinomial", model.getModelType()); + assertEquals("prediction", model.getPredictionCol()); + + model.setFeaturesCol("test_feature").setPredictionCol("test_prediction"); + + assertEquals("test_feature", model.getFeaturesCol()); + assertEquals("test_prediction", model.getPredictionCol()); + } + + @Test + public void testFitAndPredict() throws Exception { + NaiveBayesModel model = estimator.fit(trainTable); + Table outputTable = model.transform(predictTable)[0]; + Map<Vector, Integer> actualOutput = + executeAndCollect(outputTable, model.getFeaturesCol(), model.getPredictionCol()); + assertEquals(expectedOutput, actualOutput); + } + + @Test(expected = Exception.class) + public void testPredictUnseenFeature() throws Exception { Review comment: Do we expect the Exception message to provide information regarding the "unseen feature"? If so, would it be better to also check the message, similar to the existing test such as `StageTest::testParamWithNullDefault`? Same for other tests. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java ########## @@ -0,0 +1,377 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.naivebayes; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.functions.FlatMapFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.functions.ReduceFunction; +import org.apache.flink.api.java.functions.KeySelector; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.api.java.tuple.Tuple3; +import org.apache.flink.api.java.tuple.Tuple4; +import org.apache.flink.ml.api.core.Estimator; +import org.apache.flink.ml.common.datastream.EndOfStreamWindows; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.windowing.AllWindowFunction; +import org.apache.flink.streaming.api.windowing.windows.TimeWindow; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Collector; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; + +/** + * An Estimator which implements the naive bayes classification algorithm. + * + * <p>See https://en.wikipedia.org/wiki/Naive_Bayes_classifier. + */ +public class NaiveBayes + implements Estimator<NaiveBayes, NaiveBayesModel>, NaiveBayesParams<NaiveBayes> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public NaiveBayes() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public NaiveBayesModel fit(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + final String featuresCol = getFeaturesCol(); + final String labelCol = getLabelCol(); + final double smoothing = getSmoothing(); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Tuple2<Vector, Integer>> input = + tEnv.toDataStream(inputs[0]) + .map( + new MapFunction<Row, Tuple2<Vector, Integer>>() { + @Override + public Tuple2<Vector, Integer> map(Row row) throws Exception { + Number number = (Number) row.getField(labelCol); + Preconditions.checkArgument( + number.intValue() == number.doubleValue()); + return new Tuple2<>( + (Vector) row.getField(featuresCol), + number.intValue()); + } + }); + + DataStream<NaiveBayesModelData> modelData = + input.flatMap(new ExtractFeatureFunction()) + .keyBy( + (KeySelector<Tuple4<Integer, Integer, Double, Double>, Object>) + value -> new Tuple3<>(value.f0, value.f1, value.f2)) + .window(EndOfStreamWindows.get()) + .reduce( + (ReduceFunction<Tuple4<Integer, Integer, Double, Double>>) + (t0, t1) -> + new Tuple4<>(t0.f0, t0.f1, t0.f2, t0.f3 + t1.f3)) + .keyBy( + (KeySelector<Tuple4<Integer, Integer, Double, Double>, Object>) + value -> new Tuple2<>(value.f0, value.f1)) + .window(EndOfStreamWindows.get()) + .aggregate(new GenerateFeatureWeightMapFunction()) + .keyBy( + (KeySelector< + Tuple4< + Integer, + Integer, + Map<Double, Double>, + Double>, + Object>) + value -> value.f0) + .window(EndOfStreamWindows.get()) + .aggregate(new AggregateIntoArrayFunction()) + .windowAll(EndOfStreamWindows.get()) + .apply(new GenerateModelFunction(smoothing)); + + NaiveBayesModel model = + new NaiveBayesModel() + .setModelData(NaiveBayesModelData.getModelDataTable(tEnv, modelData)); + ReadWriteUtils.updateExistingParams(model, paramMap); + return model; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static NaiveBayes load(StreamExecutionEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + /** + * Function to extract feature values from input rows. + * + * <p>Output records are tuples with the following fields in order: + * + * <ul> + * <li>label value + * <li>feature column index + * <li>feature value + * <li>weight + * </ul> + */ + private static class ExtractFeatureFunction + implements FlatMapFunction< + Tuple2<Vector, Integer>, Tuple4<Integer, Integer, Double, Double>> { + @Override + public void flatMap( + Tuple2<Vector, Integer> value, + Collector<Tuple4<Integer, Integer, Double, Double>> collector) { + Preconditions.checkNotNull(value.f1); + for (int i = 0; i < value.f0.size(); i++) { + collector.collect(new Tuple4<>(value.f1, i, value.f0.get(i), 1.0)); + } + } + } + + /** + * Function that aggregates entries of feature value and weight into maps. + * + * <p>Input records should have the same label value and feature column index. + * + * <p>Input records are tuples with the following fields in order: + * + * <ul> + * <li>label value + * <li>feature column index + * <li>feature value + * <li>weight + * </ul> + * + * <p>Output records are tuples with the following fields in order: + * + * <ul> + * <li>label value + * <li>feature column index + * <li>map of (feature value, weight) + * </ul> + */ + private static class GenerateFeatureWeightMapFunction + implements AggregateFunction< + Tuple4<Integer, Integer, Double, Double>, + Tuple3<Integer, Integer, Map<Double, Double>>, + Tuple4<Integer, Integer, Map<Double, Double>, Double>> { + + @Override + public Tuple3<Integer, Integer, Map<Double, Double>> createAccumulator() { + return new Tuple3<>(0, -1, new HashMap<>()); + } + + @Override + public Tuple3<Integer, Integer, Map<Double, Double>> add( + Tuple4<Integer, Integer, Double, Double> value, + Tuple3<Integer, Integer, Map<Double, Double>> acc) { + acc.f0 = value.f0; + acc.f1 = value.f1; + acc.f2.put(value.f2, value.f3); + return acc; + } + + @Override + public Tuple4<Integer, Integer, Map<Double, Double>, Double> getResult( + Tuple3<Integer, Integer, Map<Double, Double>> acc) { + double weightSum = acc.f2.values().stream().mapToDouble(Double::doubleValue).sum(); + return new Tuple4<>(acc.f0, acc.f1, acc.f2, weightSum); + } + + @Override + public Tuple3<Integer, Integer, Map<Double, Double>> merge( + Tuple3<Integer, Integer, Map<Double, Double>> acc0, + Tuple3<Integer, Integer, Map<Double, Double>> acc1) { + Preconditions.checkArgument(acc0.f1 != -1); + acc0.f2.putAll(acc1.f2); + return acc0; + } + } + + /** + * Function that aggregates maps under the same label into arrays. + * + * <p>Length of the generated array equals to the number of feature columns. + * + * <p>Input records are tuples with the following fields in order: + * + * <ul> + * <li>label value + * <li>feature column index + * <li>map of (feature value, weight) Review comment: It seems that the 4th field of the input is missing? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java ########## @@ -0,0 +1,327 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.classification.naivebayes.NaiveBayes; +import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel; +import org.apache.flink.ml.classification.naivebayes.NaiveBayesModelData; +import org.apache.flink.ml.linalg.DenseVector; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.ml.util.StageTestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import org.junit.Before; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; + +/** Tests {@link NaiveBayes} and {@link NaiveBayesModel}. */ +public class NaiveBayesTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainTable; + private Table predictTable; + private Map<Vector, Integer> expectedOutput; + private NaiveBayes estimator; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + Schema trainSchema = + Schema.newBuilder() + .column("f0", DataTypes.of(DenseVector.class)) + .column("f1", DataTypes.INT()) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build(); + + trainTable = + tEnv.fromDataStream( + env.fromElements( + Row.of(Vectors.dense(0, 0.), 11), + Row.of(Vectors.dense(1, 0), 10), + Row.of(Vectors.dense(1, 1.), 10)) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + trainSchema) + .as("features", "label"); + + Schema predictSchema = + Schema.newBuilder() + .column("f0", DataTypes.of(DenseVector.class)) + .columnByMetadata("rowtime", "TIMESTAMP_LTZ(3)") + .watermark("rowtime", "SOURCE_WATERMARK()") + .build(); + + predictTable = + tEnv.fromDataStream( + env.fromElements( + Row.of(Vectors.dense(0, 1.)), + Row.of(Vectors.dense(0, 0.)), + Row.of(Vectors.dense(1, 0)), + Row.of(Vectors.dense(1, 1.))) + .assignTimestampsAndWatermarks( + WatermarkStrategy.noWatermarks()), + predictSchema) + .as("features"); + + expectedOutput = + new HashMap<Vector, Integer>() { + { + put(Vectors.dense(0, 1.), 11); + put(Vectors.dense(0, 0.), 11); + put(Vectors.dense(1, 0.), 10); + put(Vectors.dense(1, 1.), 10); + } + }; + + estimator = + new NaiveBayes() + .setSmoothing(1.0) + .setFeaturesCol("features") + .setLabelCol("label") + .setPredictionCol("predict") + .setModelType("multinomial"); + } + + /** + * Executes a given table and collect its results. Results are returned as a map whose key is + * the feature, value is the prediction result. + * + * @param table A table to be executed and to have its result collected + * @param featuresCol Name of the column in the table that contains the features + * @param predictionCol Name of the column in the table that contains the prediction result + * @return A map containing the collected results + */ + private static Map<Vector, Integer> executeAndCollect( + Table table, String featuresCol, String predictionCol) { + Map<Vector, Integer> map = new HashMap<>(); + for (CloseableIterator<Row> it = table.execute().collect(); it.hasNext(); ) { + Row row = it.next(); + map.put( + (Vector) row.getField(featuresCol), + ((Number) row.getField(predictionCol)).intValue()); + } + + return map; + } + + @Test + public void testParam() { + NaiveBayes estimator = new NaiveBayes(); + + assertEquals("features", estimator.getFeaturesCol()); + assertEquals("label", estimator.getLabelCol()); + assertEquals("multinomial", estimator.getModelType()); + assertEquals("prediction", estimator.getPredictionCol()); + assertEquals(1.0, estimator.getSmoothing(), 1e-5); + + estimator + .setFeaturesCol("test_feature") + .setLabelCol("test_label") + .setPredictionCol("test_prediction") + .setSmoothing(2.0); + + assertEquals("test_feature", estimator.getFeaturesCol()); + assertEquals("test_label", estimator.getLabelCol()); + assertEquals("test_prediction", estimator.getPredictionCol()); + assertEquals(2.0, estimator.getSmoothing(), 1e-5); + + NaiveBayesModel model = new NaiveBayesModel(); + + assertEquals("features", model.getFeaturesCol()); + assertEquals("multinomial", model.getModelType()); + assertEquals("prediction", model.getPredictionCol()); + + model.setFeaturesCol("test_feature").setPredictionCol("test_prediction"); + + assertEquals("test_feature", model.getFeaturesCol()); + assertEquals("test_prediction", model.getPredictionCol()); + } + + @Test + public void testFitAndPredict() throws Exception { Review comment: Could you help update this test to use non-default column names for features, labels and predictions? I hope to do this test for every algorithm to make sure the user-specified columns names are enforced. This could be similar to `KMeansTest::testFeaturePredictionParam`. ########## File path: flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java ########## @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.classification.naivebayes; + +import org.apache.flink.api.common.eventtime.WatermarkStrategy; +import org.apache.flink.api.common.functions.AbstractRichFunction; +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.connector.source.Source; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.connector.file.sink.FileSink; +import org.apache.flink.connector.file.src.FileSource; +import org.apache.flink.core.fs.Path; +import org.apache.flink.ml.api.core.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.BLAS; +import org.apache.flink.ml.linalg.Vector; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner; +import org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy; +import org.apache.flink.streaming.api.operators.AbstractUdfStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +/** A Model which classifies data using the model data computed by {@link NaiveBayes}. */ +public class NaiveBayesModel + implements Model<NaiveBayesModel>, NaiveBayesModelParams<NaiveBayesModel> { + private static final long serialVersionUID = -4673084154965905629L; Review comment: Do you know what could go wrong if we don't specify `serialVersionUID` for `NaiveBayesModel` and `NaiveBayesModelData`? Spark does not seem to explicitly specify UID for its Estimator/Model subclasses. If there is no known issue yet, maybe we should remove this UID for simplicity. If there is known impact, we will need to add this UID for `NaiveBayes`, `KMeans` and `KMeansModel` etc. ########## File path: flink-ml-api/src/main/java/org/apache/flink/ml/linalg/BLAS.java ########## @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.linalg; + +import org.apache.flink.util.Preconditions; + +/** A utility class that provides BLAS routines over matrices and vectors. */ +public class BLAS { + /** For level-1 routines, we use Java implementation. */ + private static final dev.ludovic.netlib.BLAS F2J_BLAS = + dev.ludovic.netlib.JavaBLAS.getInstance(); + + /** y += a * x . */ + public static void axpy(double a, double[] x, double[] y) { Review comment: In the future will we need to support `axpy(a: Double, x: SparseVector, y: DenseVector)`? If so, should we make this `axpy(a: Double, x: DenseVector, y: DenseVector)`? ########## File path: flink-ml-lib/src/test/java/org/apache/flink/ml/util/StageTestUtils.java ########## @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.util; + +import org.apache.flink.ml.api.core.Stage; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; + +import org.junit.rules.TemporaryFolder; + +import java.io.IOException; +import java.lang.reflect.Method; + +/** Utility methods for testing stages. */ +public class StageTestUtils { + private static final TemporaryFolder tempFolder = new TemporaryFolder(); + + static { + try { + tempFolder.create(); Review comment: Why do we need to call `tempFolder.create`? The Java doc of `tempFolder.create` says: `for testing purposes only. Do not use` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
