This is an automated email from the ASF dual-hosted git repository.
gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new e4f7eed [FLINK-24817] Add Estimator and Transformer for Naive Bayes
e4f7eed is described below
commit e4f7eedb958e29817525f54ee6944d819d0de47c
Author: Yunfeng Zhou <[email protected]>
AuthorDate: Wed Nov 17 10:59:03 2021 +0800
[FLINK-24817] Add Estimator and Transformer for Naive Bayes
This closes #32.
---
flink-ml-core/pom.xml | 11 +
.../main/java/org/apache/flink/ml/linalg/BLAS.java | 34 ++
.../org/apache/flink/ml/util/ReadWriteUtils.java | 46 +++
.../ml/classification/naivebayes/NaiveBayes.java | 360 +++++++++++++++++++++
.../classification/naivebayes/NaiveBayesModel.java | 187 +++++++++++
.../naivebayes/NaiveBayesModelData.java | 164 ++++++++++
.../naivebayes/NaiveBayesModelParams.java} | 33 +-
.../naivebayes/NaiveBayesParams.java} | 33 +-
.../apache/flink/ml/clustering/kmeans/KMeans.java | 5 +-
.../flink/ml/clustering/kmeans/KMeansModel.java | 44 +--
.../ml/clustering/kmeans/KMeansModelData.java | 80 +++--
.../flink/ml/common/datastream/TableUtils.java | 3 +-
.../flink/ml/common/param/HasDistanceMeasure.java | 2 +-
.../{HasDistanceMeasure.java => HasLabelCol.java} | 22 +-
.../flink/ml/classification/NaiveBayesTest.java | 314 ++++++++++++++++++
.../org/apache/flink/ml/clustering/KMeansTest.java | 122 ++++---
.../org/apache/flink/ml/util/StageTestUtils.java | 48 +++
17 files changed, 1331 insertions(+), 177 deletions(-)
diff --git a/flink-ml-core/pom.xml b/flink-ml-core/pom.xml
index 6ed25b7..79a9654 100644
--- a/flink-ml-core/pom.xml
+++ b/flink-ml-core/pom.xml
@@ -65,5 +65,16 @@ under the License.
<artifactId>flink-shaded-jackson</artifactId>
<scope>provided</scope>
</dependency>
+
+ <dependency>
+ <groupId>dev.ludovic.netlib</groupId>
+ <artifactId>blas</artifactId>
+ <version>2.2.0</version>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.flink</groupId>
+ <artifactId>flink-table-api-java-bridge_2.11</artifactId>
+ <version>1.14.0</version>
+ </dependency>
</dependencies>
</project>
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
new file mode 100644
index 0000000..3b91cdb
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import org.apache.flink.util.Preconditions;
+
+/** A utility class that provides BLAS routines over matrices and vectors. */
+public class BLAS {
+ /** For level-1 function dspmv, use javaBLAS for better performance. */
+ private static final dev.ludovic.netlib.BLAS JAVA_BLAS =
+ dev.ludovic.netlib.JavaBLAS.getInstance();
+
+ /** y += a * x . */
+ public static void axpy(double a, DenseVector x, DenseVector y) {
+ Preconditions.checkArgument(x.size() == y.size(), "Vector size
mismatched.");
+ JAVA_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1);
+ }
+}
diff --git
a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
index 457cac3..800a33d 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
@@ -18,9 +18,19 @@
package org.apache.flink.ml.util;
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
import org.apache.flink.ml.api.Stage;
import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.util.InstantiationUtil;
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
@@ -312,4 +322,40 @@ public class ReadWriteUtils {
throw new RuntimeException("Failed to load stage.", e);
}
}
+
+ /**
+ * Saves the model data stream to the given path using the model encoder.
+ *
+ * @param model The model data stream.
+ * @param path The parent directory of the model data file.
+ * @param modelEncoder The encoder to encode the model data.
+ * @param <T> The class type of the model data.
+ */
+ public static <T> void saveModelData(
+ DataStream<T> model, String path, Encoder<T> modelEncoder) {
+ FileSink<T> sink =
+ FileSink.forRowFormat(
+ new
org.apache.flink.core.fs.Path(getDataPath(path)), modelEncoder)
+ .withRollingPolicy(OnCheckpointRollingPolicy.build())
+ .withBucketAssigner(new BasePathBucketAssigner<>())
+ .build();
+ model.sinkTo(sink);
+ }
+
+ /**
+ * Loads the model data from the given path using the model decoder.
+ *
+ * @param env A StreamExecutionEnvironment instance.
+ * @param path The parent directory of the model data file.
+ * @param modelDecoder The decoder used to decode the model data.
+ * @param <T> The class type of the model data.
+ * @return The loaded model data.
+ */
+ public static <T> DataStream<T> loadModelData(
+ StreamExecutionEnvironment env, String path, SimpleStreamFormat<T>
modelDecoder) {
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+ Source<T, ?, ?> source =
+ FileSource.forRecordStreamFormat(modelDecoder,
getDataPaths(path)).build();
+ return env.fromSource(source, WatermarkStrategy.noWatermarks(),
"modelData");
+ }
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
new file mode 100644
index 0000000..104c908
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
@@ -0,0 +1,360 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the naive bayes classification algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Naive_Bayes_classifier.
+ */
+public class NaiveBayes
+ implements Estimator<NaiveBayes, NaiveBayesModel>,
NaiveBayesParams<NaiveBayes> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public NaiveBayes() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public NaiveBayesModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+
+ final String featuresCol = getFeaturesCol();
+ final String labelCol = getLabelCol();
+ final double smoothing = getSmoothing();
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+ DataStream<Tuple2<Vector, Double>> input =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new MapFunction<Row, Tuple2<Vector, Double>>()
{
+ @Override
+ public Tuple2<Vector, Double> map(Row row)
throws Exception {
+ Number number = (Number)
row.getField(labelCol);
+ Preconditions.checkNotNull(
+ number, "Input data should
contain label value.");
+ Preconditions.checkArgument(
+ number.intValue() ==
number.doubleValue(),
+ "Label value should be indexed
number.");
+ return new Tuple2<>(
+ (Vector)
row.getField(featuresCol),
+ number.doubleValue());
+ }
+ });
+
+ DataStream<NaiveBayesModelData> modelData =
+ input.flatMap(new ExtractFeatureFunction())
+ .keyBy(value -> new Tuple2<>(value.f0,
value.f1).hashCode())
+ .transform(
+ "GenerateFeatureWeightMapFunction",
+ Types.TUPLE(
+ Types.DOUBLE,
+ Types.INT,
+ Types.MAP(Types.DOUBLE, Types.DOUBLE),
+ Types.INT),
+ new MapPartitionFunctionWrapper<>(
+ new
GenerateFeatureWeightMapFunction()))
+ .keyBy(value -> value.f0)
+ .transform(
+ "AggregateIntoArrayFunction",
+ Types.TUPLE(
+ Types.DOUBLE,
+ Types.INT,
+
Types.OBJECT_ARRAY(Types.MAP(Types.DOUBLE, Types.DOUBLE))),
+ new MapPartitionFunctionWrapper<>(new
AggregateIntoArrayFunction()))
+ .transform(
+ "GenerateModelFunction",
+ TypeInformation.of(NaiveBayesModelData.class),
+ new MapPartitionFunctionWrapper<>(
+ new GenerateModelFunction(smoothing)))
+ .setParallelism(1);
+
+ NaiveBayesModel model =
+ new NaiveBayesModel()
+
.setModelData(NaiveBayesModelData.getModelDataTable(modelData));
+ ReadWriteUtils.updateExistingParams(model, paramMap);
+ return model;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static NaiveBayes load(StreamExecutionEnvironment env, String path)
throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ /**
+ * Function to extract feature values from input rows.
+ *
+ * <p>Output records are tuples with the following fields in order:
+ *
+ * <ul>
+ * <li>label value
+ * <li>feature column index
+ * <li>feature value
+ * </ul>
+ */
+ private static class ExtractFeatureFunction
+ implements FlatMapFunction<Tuple2<Vector, Double>, Tuple3<Double,
Integer, Double>> {
+ @Override
+ public void flatMap(
+ Tuple2<Vector, Double> value,
+ Collector<Tuple3<Double, Integer, Double>> collector) {
+ Preconditions.checkNotNull(value.f1);
+ for (int i = 0; i < value.f0.size(); i++) {
+ collector.collect(new Tuple3<>(value.f1, i, value.f0.get(i)));
+ }
+ }
+ }
+
+ /**
+ * Function that aggregates entries of feature value and weight into maps.
+ *
+ * <p>Input records should have the same label value and feature column
index.
+ *
+ * <p>Input records are tuples with the following fields in order:
+ *
+ * <ul>
+ * <li>label value
+ * <li>feature column index
+ * <li>feature value
+ * </ul>
+ *
+ * <p>Output records are tuples with the following fields in order:
+ *
+ * <ul>
+ * <li>label value
+ * <li>feature column index
+ * <li>map of (feature value, weight)
+ * <li>number of records
+ * </ul>
+ */
+ private static class GenerateFeatureWeightMapFunction
+ implements MapPartitionFunction<
+ Tuple3<Double, Integer, Double>,
+ Tuple4<Double, Integer, Map<Double, Double>, Integer>> {
+
+ @Override
+ public void mapPartition(
+ Iterable<Tuple3<Double, Integer, Double>> iterable,
+ Collector<Tuple4<Double, Integer, Map<Double, Double>,
Integer>> collector) {
+ List<Tuple3<Double, Integer, Double>> list = new ArrayList<>();
+ iterable.iterator().forEachRemaining(list::add);
+
+ Map<Tuple2<Double, Integer>, Map<Double, Double>> accMap = new
HashMap<>();
+ Map<Tuple2<Double, Integer>, Integer> numMap = new HashMap<>();
+ for (Tuple3<Double, Integer, Double> value : list) {
+ Tuple2<Double, Integer> key = new Tuple2<>(value.f0, value.f1);
+ Map<Double, Double> acc = accMap.computeIfAbsent(key, x -> new
HashMap<>());
+ acc.put(value.f2, acc.getOrDefault(value.f2, 0.) + 1.0);
+ numMap.put(key, numMap.getOrDefault(key, 0) + 1);
+ }
+ for (Map.Entry<Tuple2<Double, Integer>, Map<Double, Double>> entry
:
+ accMap.entrySet()) {
+ collector.collect(
+ new Tuple4<>(
+ entry.getKey().f0,
+ entry.getKey().f1,
+ entry.getValue(),
+ numMap.get(entry.getKey())));
+ }
+ }
+ }
+
+ /**
+ * Function that aggregates maps under the same label into arrays.
+ *
+ * <p>Length of the generated array equals to the number of feature
columns.
+ *
+ * <p>Input records are tuples with the following fields in order:
+ *
+ * <ul>
+ * <li>label value
+ * <li>feature column index
+ * <li>map of (feature value, weight)
+ * <li>number of records
+ * </ul>
+ *
+ * <p>Output records are tuples with the following fields in order:
+ *
+ * <ul>
+ * <li>label value
+ * <li>number of records
+ * <li>array of featureValue-weight maps of each feature
+ * </ul>
+ */
+ private static class AggregateIntoArrayFunction
+ implements MapPartitionFunction<
+ Tuple4<Double, Integer, Map<Double, Double>, Integer>,
+ Tuple3<Double, Integer, Map<Double, Double>[]>> {
+
+ @Override
+ public void mapPartition(
+ Iterable<Tuple4<Double, Integer, Map<Double, Double>,
Integer>> iterable,
+ Collector<Tuple3<Double, Integer, Map<Double, Double>[]>>
collector) {
+ Map<Double, List<Tuple4<Double, Integer, Map<Double, Double>,
Integer>>> map =
+ new HashMap<>();
+ for (Tuple4<Double, Integer, Map<Double, Double>, Integer> value :
iterable) {
+ map.computeIfAbsent(value.f0, x -> new
ArrayList<>()).add(value);
+ }
+
+ for (List<Tuple4<Double, Integer, Map<Double, Double>, Integer>>
list : map.values()) {
+ final int featureSize =
+ list.stream().map(x ->
x.f1).max(Integer::compareTo).orElse(-1) + 1;
+
+ int minDocNum =
+ list.stream()
+ .map(x -> x.f3)
+ .min(Integer::compareTo)
+ .orElse(Integer.MAX_VALUE);
+ int maxDocNum =
+ list.stream()
+ .map(x -> x.f3)
+ .max(Integer::compareTo)
+ .orElse(Integer.MIN_VALUE);
+ Preconditions.checkArgument(
+ minDocNum == maxDocNum, "Feature vectors should be of
equal length.");
+
+ Map<Double, Integer> numMap = new HashMap<>();
+ Map<Double, Map<Double, Double>[]> featureWeightMap = new
HashMap<>();
+ for (Tuple4<Double, Integer, Map<Double, Double>, Integer>
value : list) {
+ Map<Double, Double>[] featureWeight =
+ featureWeightMap.computeIfAbsent(
+ value.f0, x -> new HashMap[featureSize]);
+ numMap.put(value.f0, value.f3);
+ featureWeight[value.f1] = value.f2;
+ }
+
+ for (double key : featureWeightMap.keySet()) {
+ collector.collect(
+ new Tuple3<>(key, numMap.get(key),
featureWeightMap.get(key)));
+ }
+ }
+ }
+ }
+
+ /** Function to generate Naive Bayes model data. */
+ private static class GenerateModelFunction
+ implements MapPartitionFunction<
+ Tuple3<Double, Integer, Map<Double, Double>[]>,
NaiveBayesModelData> {
+ private final double smoothing;
+
+ private GenerateModelFunction(double smoothing) {
+ this.smoothing = smoothing;
+ }
+
+ @Override
+ public void mapPartition(
+ Iterable<Tuple3<Double, Integer, Map<Double, Double>[]>>
iterable,
+ Collector<NaiveBayesModelData> collector) {
+ ArrayList<Tuple3<Double, Integer, Map<Double, Double>[]>> list =
new ArrayList<>();
+ iterable.iterator().forEachRemaining(list::add);
+ final int featureSize = list.get(0).f2.length;
+ for (Tuple3<Double, Integer, Map<Double, Double>[]> tup : list) {
+ Preconditions.checkArgument(
+ featureSize == tup.f2.length, "Feature vectors should
be of equal length.");
+ }
+
+ double[] numDocs = new double[featureSize];
+ HashSet<Double>[] categoryNumbers = new HashSet[featureSize];
+ for (int i = 0; i < featureSize; i++) {
+ categoryNumbers[i] = new HashSet<>();
+ }
+ for (Tuple3<Double, Integer, Map<Double, Double>[]> tup : list) {
+ for (int i = 0; i < featureSize; i++) {
+ numDocs[i] += tup.f1;
+ categoryNumbers[i].addAll(tup.f2[i].keySet());
+ }
+ }
+
+ int[] categoryNumber = new int[featureSize];
+ double piLog = 0;
+ int numLabels = list.size();
+ for (int i = 0; i < featureSize; i++) {
+ categoryNumber[i] = categoryNumbers[i].size();
+ piLog += numDocs[i];
+ }
+ piLog = Math.log(piLog + numLabels * smoothing);
+
+ Map<Double, Double>[][] theta = new
HashMap[numLabels][featureSize];
+ double[] piArray = new double[numLabels];
+ double[] labels = new double[numLabels];
+
+ // Consider smoothing.
+ for (int i = 0; i < numLabels; i++) {
+ Map<Double, Double>[] param = list.get(i).f2;
+ for (int j = 0; j < featureSize; j++) {
+ Map<Double, Double> squareData = new HashMap<>();
+ double thetaLog =
+ Math.log(list.get(i).f1 * 1.0 + smoothing *
categoryNumber[j]);
+ for (Double cate : categoryNumbers[j]) {
+ double value = param[j].getOrDefault(cate, 0.0);
+ squareData.put(cate, Math.log(value + smoothing) -
thetaLog);
+ }
+ theta[i][j] = squareData;
+ }
+
+ labels[i] = list.get(i).f0;
+ double weightSum = list.get(i).f1 * featureSize;
+ piArray[i] = Math.log(weightSum + smoothing) - piLog;
+ }
+
+ NaiveBayesModelData modelData = new NaiveBayesModelData(theta,
piArray, labels);
+ collector.collect(modelData);
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java
new file mode 100644
index 0000000..50ef1b1
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+/** A Model which classifies data using the model data computed by {@link
NaiveBayes}. */
+public class NaiveBayesModel
+ implements Model<NaiveBayesModel>,
NaiveBayesModelParams<NaiveBayesModel> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private Table modelDataTable;
+
+ public NaiveBayesModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+
+ final String predictionCol = getPredictionCol();
+ final String featuresCol = getFeaturesCol();
+ final String broadcastModelKey = "NaiveBayesModelStream";
+
+ RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(
+ inputTypeInfo.getFieldTypes(),
TypeInformation.of(Double.class)),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
predictionCol));
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
modelDataTable).getTableEnvironment();
+ DataStream<NaiveBayesModelData> modelDataStream =
+ NaiveBayesModelData.getModelDataStream(modelDataTable);
+ DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+
+ Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+ broadcastMap.put(broadcastModelKey, modelDataStream);
+
+ Function<List<DataStream<?>>, DataStream<Row>> function =
+ dataStreams -> {
+ DataStream stream = dataStreams.get(0);
+ return stream.map(
+ new PredictLabelFunction(featuresCol,
broadcastModelKey),
+ outputTypeInfo);
+ };
+ DataStream<Row> output =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(input),
+ Collections.singletonMap(broadcastModelKey,
modelDataStream),
+ function);
+
+ Table outputTable = tEnv.fromDataStream(output);
+
+ return new Table[] {outputTable};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ NaiveBayesModelData.getModelDataStream(modelDataTable),
+ path,
+ new NaiveBayesModelData.ModelDataEncoder());
+ }
+
+ public static NaiveBayesModel load(StreamExecutionEnvironment env, String
path)
+ throws IOException {
+ NaiveBayesModel model = ReadWriteUtils.loadStageParam(path);
+ DataStream<NaiveBayesModelData> modelData =
+ ReadWriteUtils.loadModelData(
+ env, path, new
NaiveBayesModelData.ModelDataStreamFormat());
+ return
model.setModelData(NaiveBayesModelData.getModelDataTable(modelData));
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public NaiveBayesModel setModelData(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ modelDataTable = inputs[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ private static class PredictLabelFunction extends RichMapFunction<Row,
Row> {
+ private final String featuresCol;
+ private final String broadcastModelKey;
+ private NaiveBayesModelData modelData = null;
+
+ public PredictLabelFunction(String featuresCol, String
broadcastModelKey) {
+ this.featuresCol = featuresCol;
+ this.broadcastModelKey = broadcastModelKey;
+ }
+
+ @Override
+ public Row map(Row row) {
+ if (modelData == null) {
+ modelData =
+ (NaiveBayesModelData)
+
getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
+ }
+ Vector vector = (Vector) row.getField(featuresCol);
+ double label = findMaxProbLabel(calculateProb(modelData, vector),
modelData.labels);
+ return Row.join(row, Row.of(label));
+ }
+ }
+
+ private static double findMaxProbLabel(DenseVector prob, Vector label) {
+ double result = 0.;
+ int probSize = prob.size();
+ double maxVal = Double.NEGATIVE_INFINITY;
+ for (int i = 0; i < probSize; ++i) {
+ if (maxVal < prob.values[i]) {
+ maxVal = prob.values[i];
+ result = label.get(i);
+ }
+ }
+ Preconditions.checkArgument(maxVal > Double.NEGATIVE_INFINITY);
+ return result;
+ }
+
+ /** Calculate probability of the input data. */
+ private static DenseVector calculateProb(NaiveBayesModelData modelData,
Vector data) {
+ int labelSize = modelData.labels.size();
+ DenseVector probs = new DenseVector(new double[labelSize]);
+ for (int i = 0; i < labelSize; i++) {
+ Map<Double, Double>[] labelData = modelData.theta[i];
+ for (int j = 0; j < data.size(); j++) {
+ probs.values[i] += labelData[j].get(data.get(j));
+ }
+ }
+ BLAS.axpy(1, modelData.piArray, probs);
+ return probs;
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
new file mode 100644
index 0000000..8505cbc
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleSerializer;
+import org.apache.flink.api.common.typeutils.base.MapSerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * The model data of {@link NaiveBayesModel}.
+ *
+ * <p>This class also provides methods to convert model data between Table and
Datastream, and
+ * classes to save/load model data.
+ */
+public class NaiveBayesModelData {
+ /**
+ * Log of class conditional probabilities, whose dimension is C (number of
classes) by D (number
+ * of features).
+ */
+ public final Map<Double, Double>[][] theta;
+
+ /** Log of class priors, whose dimension is C (number of classes). */
+ public final DenseVector piArray;
+
+ /** Value of labels. */
+ public final DenseVector labels;
+
+ public NaiveBayesModelData(Map<Double, Double>[][] theta, double[]
piArray, double[] labels) {
+ this(theta, Vectors.dense(piArray), Vectors.dense(labels));
+ }
+
+ public NaiveBayesModelData(
+ Map<Double, Double>[][] theta, DenseVector piArray, DenseVector
labels) {
+ this.theta = theta;
+ this.piArray = piArray;
+ this.labels = labels;
+ }
+
+ /** Converts the provided modelData Datastream into corresponding Table. */
+ public static Table getModelDataTable(DataStream<NaiveBayesModelData>
stream) {
+ StreamTableEnvironment tEnv =
+
StreamTableEnvironment.create(stream.getExecutionEnvironment());
+ return tEnv.fromDataStream(stream);
+ }
+
+ /** Converts the provided modelData Table into corresponding DataStream. */
+ public static DataStream<NaiveBayesModelData> getModelDataStream(Table
table) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
table).getTableEnvironment();
+ return tEnv.toDataStream(table)
+ .map(
+ (MapFunction<Row, NaiveBayesModelData>)
+ row -> (NaiveBayesModelData)
row.getField("f0"));
+ }
+
+ /** Encoder for the {@link NaiveBayesModelData}. */
+ public static class ModelDataEncoder implements
Encoder<NaiveBayesModelData> {
+ @Override
+ public void encode(NaiveBayesModelData modelData, OutputStream
outputStream)
+ throws IOException {
+ DataOutputViewStreamWrapper output = new
DataOutputViewStreamWrapper(outputStream);
+
+ DenseVectorSerializer denseVectorSerializer = new
DenseVectorSerializer();
+ MapSerializer<Double, Double> mapSerializer =
+ new MapSerializer<>(new DoubleSerializer(), new
DoubleSerializer());
+
+ denseVectorSerializer.serialize(modelData.labels, output);
+
+ denseVectorSerializer.serialize(modelData.piArray, output);
+
+ output.writeInt(modelData.theta.length);
+ output.writeInt(modelData.theta[0].length);
+ for (Map<Double, Double>[] maps : modelData.theta) {
+ for (Map<Double, Double> map : maps) {
+ mapSerializer.serialize(map, output);
+ }
+ }
+ }
+ }
+
+ /** Decoder for the {@link NaiveBayesModelData}. */
+ public static class ModelDataStreamFormat extends
SimpleStreamFormat<NaiveBayesModelData> {
+ @Override
+ public Reader<NaiveBayesModelData> createReader(
+ Configuration config, FSDataInputStream inputStream) {
+ return new Reader<NaiveBayesModelData>() {
+ private final DataInputViewStreamWrapper input =
+ new DataInputViewStreamWrapper(inputStream);
+
+ @Override
+ public NaiveBayesModelData read() throws IOException {
+ try {
+ DenseVectorSerializer denseVectorSerializer = new
DenseVectorSerializer();
+ MapSerializer<Double, Double> mapSerializer =
+ new MapSerializer<>(new DoubleSerializer(),
new DoubleSerializer());
+
+ DenseVector labels =
denseVectorSerializer.deserialize(input);
+
+ DenseVector piArray =
denseVectorSerializer.deserialize(input);
+
+ int featureSize = input.readInt();
+ int numLabels = input.readInt();
+ Map<Double, Double>[][] theta = new
HashMap[numLabels][featureSize];
+ for (int i = 0; i < featureSize; i++) {
+ for (int j = 0; j < numLabels; j++) {
+ theta[i][j] = mapSerializer.deserialize(input);
+ }
+ }
+ return new NaiveBayesModelData(theta, piArray, labels);
+ } catch (EOFException e) {
+ return null;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ inputStream.close();
+ }
+ };
+ }
+
+ @Override
+ public TypeInformation<NaiveBayesModelData> getProducedType() {
+ return TypeInformation.of(NaiveBayesModelData.class);
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelParams.java
similarity index 57%
copy from
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
copy to
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelParams.java
index c4c5953..8001df0 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelParams.java
@@ -16,29 +16,32 @@
* limitations under the License.
*/
-package org.apache.flink.ml.common.param;
+package org.apache.flink.ml.classification.naivebayes;
-import org.apache.flink.ml.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.ParamValidators;
import org.apache.flink.ml.param.StringParam;
-import org.apache.flink.ml.param.WithParams;
-/** Interface for the shared distanceMeasure param. */
-public interface HasDistanceMeasure<T> extends WithParams<T> {
- Param<String> DISTANCE_MEASURE =
+/**
+ * Params of {@link NaiveBayesModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface NaiveBayesModelParams<T> extends HasFeaturesCol<T>,
HasPredictionCol<T> {
+ Param<String> MODEL_TYPE =
new StringParam(
- "distanceMeasure",
- "The distance measure. Supported options: 'euclidean'.",
- EuclideanDistanceMeasure.NAME,
- ParamValidators.inArray(EuclideanDistanceMeasure.NAME));
+ "modelType",
+ "The model type.",
+ "multinomial",
+ ParamValidators.inArray("multinomial"));
- default String getDistanceMeasure() {
- return get(DISTANCE_MEASURE);
+ default String getModelType() {
+ return get(MODEL_TYPE);
}
- default T setDistanceMeasure(String value) {
- set(DISTANCE_MEASURE, value);
- return (T) this;
+ default T setModelType(String value) {
+ return set(MODEL_TYPE, value);
}
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesParams.java
similarity index 52%
copy from
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
copy to
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesParams.java
index c4c5953..8e7e089 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesParams.java
@@ -16,29 +16,28 @@
* limitations under the License.
*/
-package org.apache.flink.ml.common.param;
+package org.apache.flink.ml.classification.naivebayes;
-import org.apache.flink.ml.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.ParamValidators;
-import org.apache.flink.ml.param.StringParam;
-import org.apache.flink.ml.param.WithParams;
-/** Interface for the shared distanceMeasure param. */
-public interface HasDistanceMeasure<T> extends WithParams<T> {
- Param<String> DISTANCE_MEASURE =
- new StringParam(
- "distanceMeasure",
- "The distance measure. Supported options: 'euclidean'.",
- EuclideanDistanceMeasure.NAME,
- ParamValidators.inArray(EuclideanDistanceMeasure.NAME));
+/**
+ * Params of {@link NaiveBayes}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface NaiveBayesParams<T> extends NaiveBayesModelParams<T>,
HasLabelCol<T> {
+ Param<Double> SMOOTHING =
+ new DoubleParam(
+ "smoothing", "The smoothing parameter.", 1.0,
ParamValidators.gtEq(0.0));
- default String getDistanceMeasure() {
- return get(DISTANCE_MEASURE);
+ default Double getSmoothing() {
+ return get(SMOOTHING);
}
- default T setDistanceMeasure(String value) {
- set(DISTANCE_MEASURE, value);
- return (T) this;
+ default T setSmoothing(Double value) {
+ return set(SMOOTHING, value);
}
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
index 1b08b66..3966626 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
@@ -60,6 +60,7 @@ import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
import org.apache.commons.collections.IteratorUtils;
@@ -85,6 +86,8 @@ public class KMeans implements Estimator<KMeans,
KMeansModel>, KMeansParams<KMea
@Override
public KMeansModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
DataStream<DenseVector> points =
@@ -110,7 +113,7 @@ public class KMeans implements Estimator<KMeans,
KMeansModel>, KMeansParams<KMea
body)
.get(0);
- Table finalCentroidsTable = tEnv.fromDataStream(finalCentroids,
KMeansModelData.SCHEMA);
+ Table finalCentroidsTable =
KMeansModelData.getModelDataTable(finalCentroids);
KMeansModel model = new
KMeansModel().setModelData(finalCentroidsTable);
ReadWriteUtils.updateExistingParams(model, paramMap);
return model;
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
index 0d2351c..d5d94e1 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
@@ -18,17 +18,12 @@
package org.apache.flink.ml.clustering.kmeans;
-import org.apache.flink.api.common.eventtime.WatermarkStrategy;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
-import org.apache.flink.api.connector.source.Source;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.connector.file.sink.FileSink;
-import org.apache.flink.connector.file.src.FileSource;
-import org.apache.flink.core.fs.Path;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.common.datastream.TableUtils;
import org.apache.flink.ml.distance.DistanceMeasure;
@@ -40,8 +35,6 @@ import org.apache.flink.ml.util.ReadWriteUtils;
import org.apache.flink.runtime.state.StateInitializationContext;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
-import
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -49,6 +42,7 @@ import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.lang3.ArrayUtils;
@@ -69,6 +63,7 @@ public class KMeansModel implements Model<KMeansModel>,
KMeansModelParams<KMeans
@Override
public KMeansModel setModelData(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
centroidsTable = inputs[0];
return this;
}
@@ -80,10 +75,11 @@ public class KMeansModel implements Model<KMeansModel>,
KMeansModelParams<KMeans
@Override
public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
- DataStream<DenseVector[]> centroids =
- tEnv.toDataStream(centroidsTable).map(row -> (DenseVector[])
row.getField("f0"));
+ DataStream<DenseVector[]> centroids =
KMeansModelData.getModelDataStream(centroidsTable);
String featureCol = getFeaturesCol();
String predictionCol = getPredictionCol();
@@ -182,33 +178,19 @@ public class KMeansModel implements Model<KMeansModel>,
KMeansModelParams<KMeans
@Override
public void save(String path) throws IOException {
- StreamTableEnvironment tEnv =
- (StreamTableEnvironment) ((TableImpl)
centroidsTable).getTableEnvironment();
-
- String dataPath = ReadWriteUtils.getDataPath(path);
- FileSink<DenseVector[]> sink =
- FileSink.forRowFormat(new Path(dataPath), new
KMeansModelData.ModelDataEncoder())
- .withRollingPolicy(OnCheckpointRollingPolicy.build())
- .withBucketAssigner(new BasePathBucketAssigner<>())
- .build();
- tEnv.toDataStream(centroidsTable)
- .map(row -> (DenseVector[]) row.getField("f0"))
- .sinkTo(sink);
-
ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ KMeansModelData.getModelDataStream(centroidsTable),
+ path,
+ new KMeansModelData.ModelDataEncoder());
}
// TODO: Add INFO level logging.
public static KMeansModel load(StreamExecutionEnvironment env, String
path) throws IOException {
- StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
- Source<DenseVector[], ?, ?> source =
- FileSource.forRecordStreamFormat(
- new KMeansModelData.ModelDataStreamFormat(),
- ReadWriteUtils.getDataPaths(path))
- .build();
KMeansModel model = ReadWriteUtils.loadStageParam(path);
- DataStream<DenseVector[]> modelData =
- env.fromSource(source, WatermarkStrategy.noWatermarks(),
"modelData");
- return model.setModelData(tEnv.fromDataStream(modelData,
KMeansModelData.SCHEMA));
+ DataStream<DenseVector[]> centroids =
+ ReadWriteUtils.loadModelData(
+ env, path, new
KMeansModelData.ModelDataStreamFormat());
+ return
model.setModelData(KMeansModelData.getModelDataTable(centroids));
}
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
index 779873b..7101a18 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
@@ -20,71 +20,83 @@ package org.apache.flink.ml.clustering.kmeans;
import org.apache.flink.api.common.serialization.Encoder;
import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.table.api.DataTypes;
-import org.apache.flink.table.api.Schema;
-
-import com.esotericsoftware.kryo.Kryo;
-import com.esotericsoftware.kryo.io.Input;
-import com.esotericsoftware.kryo.io.Output;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import java.io.EOFException;
import java.io.IOException;
import java.io.OutputStream;
-import java.util.ArrayList;
-import java.util.List;
/** Provides classes to save/load model data. */
public class KMeansModelData {
+ /** Converts the provided modelData Datastream into corresponding Table. */
+ public static Table getModelDataTable(DataStream<DenseVector[]> modelData)
{
+ StreamTableEnvironment tEnv =
+
StreamTableEnvironment.create(modelData.getExecutionEnvironment());
+ return tEnv.fromDataStream(modelData);
+ }
- public static final Schema SCHEMA =
- Schema.newBuilder()
- .column("f0",
DataTypes.ARRAY(DataTypes.of(DenseVector.class)))
- .build();
+ /** Converts the provided modelData Table into corresponding Datastream. */
+ public static DataStream<DenseVector[]> getModelDataStream(Table table) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
table).getTableEnvironment();
+ return tEnv.toDataStream(table).map(row -> (DenseVector[])
row.getField("f0"));
+ }
/** Encoder for the KMeans model data. */
public static class ModelDataEncoder implements Encoder<DenseVector[]> {
@Override
- public void encode(DenseVector[] modelData, OutputStream outputStream)
{
- Kryo kryo = new Kryo();
- Output output = new Output(outputStream);
- List<double[]> convertedData = new ArrayList<>();
- for (int i = 0; i < modelData.length; i++) {
- convertedData.add(modelData[i].values);
+ public void encode(DenseVector[] modelData, OutputStream outputStream)
throws IOException {
+ IntSerializer intSerializer = new IntSerializer();
+ DenseVectorSerializer denseVectorSerializer = new
DenseVectorSerializer();
+ DataOutputViewStreamWrapper outputViewStreamWrapper =
+ new DataOutputViewStreamWrapper(outputStream);
+ intSerializer.serialize(modelData.length, outputViewStreamWrapper);
+ for (DenseVector denseVector : modelData) {
+ denseVectorSerializer.serialize(
+ denseVector, new
DataOutputViewStreamWrapper(outputStream));
}
- kryo.writeObject(output, convertedData);
- output.flush();
}
}
/** Decoder for the KMeans model data. */
public static class ModelDataStreamFormat extends
SimpleStreamFormat<DenseVector[]> {
@Override
- public Reader<DenseVector[]> createReader(Configuration config,
FSDataInputStream stream) {
+ public Reader<DenseVector[]> createReader(
+ Configuration config, FSDataInputStream inputStream) {
return new Reader<DenseVector[]>() {
- private final Kryo kryo = new Kryo();
- private final Input input = new Input(stream);
-
@Override
- public DenseVector[] read() {
- if (input.eof()) {
+ public DenseVector[] read() throws IOException {
+ try {
+ IntSerializer intSerializer = new IntSerializer();
+ DenseVectorSerializer denseVectorSerializer = new
DenseVectorSerializer();
+ DataInputViewStreamWrapper inputViewStreamWrapper =
+ new DataInputViewStreamWrapper(inputStream);
+ int numDenseVectors =
intSerializer.deserialize(inputViewStreamWrapper);
+ DenseVector[] result = new
DenseVector[numDenseVectors];
+ for (int i = 0; i < numDenseVectors; i++) {
+ result[i] =
denseVectorSerializer.deserialize(inputViewStreamWrapper);
+ }
+ return result;
+ } catch (EOFException e) {
return null;
}
- ArrayList<double[]> row = kryo.readObject(input,
ArrayList.class);
-
- DenseVector[] result = new DenseVector[row.size()];
- for (int i = 0; i < result.length; i++) {
- result[i] = new DenseVector(row.get(i));
- }
- return result;
}
@Override
public void close() throws IOException {
- stream.close();
+ inputStream.close();
}
};
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
index db38451..51aa18c 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
@@ -22,7 +22,6 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.table.catalog.Column;
import org.apache.flink.table.catalog.ResolvedSchema;
-import org.apache.flink.table.runtime.typeutils.ExternalTypeInfo;
/** Utility class for table-related operations. */
public class TableUtils {
@@ -33,7 +32,7 @@ public class TableUtils {
for (int i = 0; i < schema.getColumnCount(); i++) {
Column column = schema.getColumn(i).get();
- types[i] = ExternalTypeInfo.of(column.getDataType());
+ types[i] =
TypeInformation.of(column.getDataType().getConversionClass());
names[i] = column.getName();
}
return new RowTypeInfo(types, names);
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
index c4c5953..f58d08a 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
@@ -29,7 +29,7 @@ public interface HasDistanceMeasure<T> extends WithParams<T> {
Param<String> DISTANCE_MEASURE =
new StringParam(
"distanceMeasure",
- "The distance measure. Supported options: 'euclidean'.",
+ "The distance measure.",
EuclideanDistanceMeasure.NAME,
ParamValidators.inArray(EuclideanDistanceMeasure.NAME));
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLabelCol.java
similarity index 60%
copy from
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
copy to
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLabelCol.java
index c4c5953..badd3b3 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLabelCol.java
@@ -18,27 +18,21 @@
package org.apache.flink.ml.common.param;
-import org.apache.flink.ml.distance.EuclideanDistanceMeasure;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.ParamValidators;
import org.apache.flink.ml.param.StringParam;
import org.apache.flink.ml.param.WithParams;
-/** Interface for the shared distanceMeasure param. */
-public interface HasDistanceMeasure<T> extends WithParams<T> {
- Param<String> DISTANCE_MEASURE =
- new StringParam(
- "distanceMeasure",
- "The distance measure. Supported options: 'euclidean'.",
- EuclideanDistanceMeasure.NAME,
- ParamValidators.inArray(EuclideanDistanceMeasure.NAME));
+/** Interface for the shared labelCol param. */
+public interface HasLabelCol<T> extends WithParams<T> {
+ Param<String> LABEL_COL =
+ new StringParam("labelCol", "Label column name.", "label",
ParamValidators.notNull());
- default String getDistanceMeasure() {
- return get(DISTANCE_MEASURE);
+ default String getLabelCol() {
+ return get(LABEL_COL);
}
- default T setDistanceMeasure(String value) {
- set(DISTANCE_MEASURE, value);
- return (T) this;
+ default T setLabelCol(String colName) {
+ return set(LABEL_COL, colName);
}
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
new file mode 100644
index 0000000..581242f
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.naivebayes.NaiveBayes;
+import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel;
+import org.apache.flink.ml.classification.naivebayes.NaiveBayesModelData;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link NaiveBayes} and {@link NaiveBayesModel}. */
+public class NaiveBayesTest {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table trainTable;
+ private Table predictTable;
+ private Map<Vector, Integer> expectedOutput;
+ private NaiveBayes estimator;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ List<Row> trainData =
+ Arrays.asList(
+ Row.of(Vectors.dense(0, 0.), 11),
+ Row.of(Vectors.dense(1, 0), 10),
+ Row.of(Vectors.dense(1, 1.), 10));
+
+ trainTable =
tEnv.fromDataStream(env.fromCollection(trainData)).as("features", "label");
+
+ List<Row> predictData =
+ Arrays.asList(
+ Row.of(Vectors.dense(0, 1.)),
+ Row.of(Vectors.dense(0, 0.)),
+ Row.of(Vectors.dense(1, 0)),
+ Row.of(Vectors.dense(1, 1.)));
+
+ predictTable =
tEnv.fromDataStream(env.fromCollection(predictData)).as("features");
+
+ expectedOutput =
+ new HashMap<Vector, Integer>() {
+ {
+ put(Vectors.dense(0, 1.), 11);
+ put(Vectors.dense(0, 0.), 11);
+ put(Vectors.dense(1, 0.), 10);
+ put(Vectors.dense(1, 1.), 10);
+ }
+ };
+
+ estimator =
+ new NaiveBayes()
+ .setSmoothing(1.0)
+ .setFeaturesCol("features")
+ .setLabelCol("label")
+ .setPredictionCol("prediction")
+ .setModelType("multinomial");
+ }
+
+ /**
+ * Executes a given table and collect its results. Results are returned as
a map whose key is
+ * the feature, value is the prediction result.
+ *
+ * @param table A table to be executed and to have its result collected
+ * @param featuresCol Name of the column in the table that contains the
features
+ * @param predictionCol Name of the column in the table that contains the
prediction result
+ * @return A map containing the collected results
+ */
+ private static Map<Vector, Integer> executeAndCollect(
+ Table table, String featuresCol, String predictionCol) {
+ Map<Vector, Integer> map = new HashMap<>();
+ for (CloseableIterator<Row> it = table.execute().collect();
it.hasNext(); ) {
+ Row row = it.next();
+ map.put(
+ (Vector) row.getField(featuresCol),
+ ((Number) row.getField(predictionCol)).intValue());
+ }
+
+ return map;
+ }
+
+ @Test
+ public void testParam() {
+ NaiveBayes estimator = new NaiveBayes();
+
+ assertEquals("features", estimator.getFeaturesCol());
+ assertEquals("label", estimator.getLabelCol());
+ assertEquals("multinomial", estimator.getModelType());
+ assertEquals("prediction", estimator.getPredictionCol());
+ assertEquals(1.0, estimator.getSmoothing(), 1e-5);
+
+ estimator
+ .setFeaturesCol("test_feature")
+ .setLabelCol("test_label")
+ .setPredictionCol("test_prediction")
+ .setSmoothing(2.0);
+
+ assertEquals("test_feature", estimator.getFeaturesCol());
+ assertEquals("test_label", estimator.getLabelCol());
+ assertEquals("test_prediction", estimator.getPredictionCol());
+ assertEquals(2.0, estimator.getSmoothing(), 1e-5);
+
+ NaiveBayesModel model = new NaiveBayesModel();
+
+ assertEquals("features", model.getFeaturesCol());
+ assertEquals("multinomial", model.getModelType());
+ assertEquals("prediction", model.getPredictionCol());
+
+
model.setFeaturesCol("test_feature").setPredictionCol("test_prediction");
+
+ assertEquals("test_feature", model.getFeaturesCol());
+ assertEquals("test_prediction", model.getPredictionCol());
+ }
+
+ @Test
+ public void testFitAndPredict() {
+ NaiveBayesModel model = estimator.fit(trainTable);
+ Table outputTable = model.transform(predictTable)[0];
+ Map<Vector, Integer> actualOutput =
+ executeAndCollect(outputTable, model.getFeaturesCol(),
model.getPredictionCol());
+ assertEquals(expectedOutput, actualOutput);
+ }
+
+ @Test
+ public void testFeaturePredictionParam() {
+ trainTable = trainTable.as("test_features", "test_label");
+ predictTable = predictTable.as("test_features");
+
+ estimator
+ .setFeaturesCol("test_features")
+ .setLabelCol("test_label")
+ .setPredictionCol("test_prediction");
+
+ NaiveBayesModel model = estimator.fit(trainTable);
+ Table outputTable = model.transform(predictTable)[0];
+ Map<Vector, Integer> actualOutput =
+ executeAndCollect(outputTable, model.getFeaturesCol(),
model.getPredictionCol());
+ assertEquals(expectedOutput, actualOutput);
+ }
+
+ @Test
+ public void testPredictUnseenFeature() {
+ predictTable =
+ tEnv.fromDataStream(env.fromElements(Row.of(Vectors.dense(2,
1.)))).as("features");
+
+ NaiveBayesModel model = estimator.fit(trainTable);
+ Table outputTable = model.transform(predictTable)[0];
+
+ try {
+ outputTable.execute().collect().next();
+ Assert.fail("Expected NullPointerException");
+ } catch (Exception e) {
+ Throwable exception = e;
+ while (exception.getCause() != null) {
+ exception = exception.getCause();
+ }
+ assertEquals(
+ NaiveBayesModel.class.getName(),
exception.getStackTrace()[0].getClassName());
+ assertEquals("calculateProb",
exception.getStackTrace()[0].getMethodName());
+ assertEquals(NullPointerException.class, exception.getClass());
+ }
+ }
+
+ @Test
+ public void testVectorWithDiffLen() {
+ List<Row> trainData =
+ Arrays.asList(
+ Row.of(Vectors.dense(0, 0.), 11.0),
+ Row.of(Vectors.dense(1, 0), 10.0),
+ Row.of(Vectors.dense(1), 10.0));
+
+ trainTable =
tEnv.fromDataStream(env.fromCollection(trainData)).as("features", "label");
+
+ NaiveBayesModel model = estimator.fit(trainTable);
+ Table outputTable = model.transform(trainTable)[0];
+
+ try {
+ outputTable.execute().collect().next();
+ Assert.fail("Expected IllegalArgumentException");
+ } catch (Exception e) {
+ Throwable exception = e;
+ while (exception.getCause() != null) {
+ exception = exception.getCause();
+ }
+ assertEquals(IllegalArgumentException.class, exception.getClass());
+ assertEquals("Feature vectors should be of equal length.",
exception.getMessage());
+ }
+ }
+
+ @Test
+ public void testVectorWithDiffLen2() {
+ List<Row> trainData =
+ Arrays.asList(Row.of(Vectors.dense(0, 0.), 11.0),
Row.of(Vectors.dense(1), 10.0));
+
+ trainTable =
tEnv.fromDataStream(env.fromCollection(trainData)).as("features", "label");
+
+ NaiveBayesModel model = estimator.fit(trainTable);
+ Table outputTable = model.transform(trainTable)[0];
+
+ try {
+ outputTable.execute().collect().next();
+ Assert.fail("Expected IllegalArgumentException");
+ } catch (Exception e) {
+ Throwable exception = e;
+ while (exception.getCause() != null) {
+ exception = exception.getCause();
+ }
+ assertEquals(IllegalArgumentException.class, exception.getClass());
+ assertEquals("Feature vectors should be of equal length.",
exception.getMessage());
+ }
+ }
+
+ @Test
+ public void testSaveLoad() throws Exception {
+ estimator =
+ StageTestUtils.saveAndReload(
+ env, estimator,
tempFolder.newFolder().getAbsolutePath());
+
+ NaiveBayesModel model = estimator.fit(trainTable);
+
+ model = StageTestUtils.saveAndReload(env, model,
tempFolder.newFolder().getAbsolutePath());
+
+ Table outputTable = model.transform(predictTable)[0];
+
+ Map<Vector, Integer> actualOutput =
+ executeAndCollect(outputTable, model.getFeaturesCol(),
model.getPredictionCol());
+ assertEquals(expectedOutput, actualOutput);
+ }
+
+ @Test
+ public void testGetModelData() throws Exception {
+ List<Row> trainData =
+ Arrays.asList(
+ Row.of(Vectors.dense(1, 1.), 11.0),
Row.of(Vectors.dense(2, 1.), 11.0));
+
+ trainTable =
tEnv.fromDataStream(env.fromCollection(trainData)).as("features", "label");
+
+ NaiveBayesModel model = estimator.fit(trainTable);
+
+ NaiveBayesModelData actual =
+ NaiveBayesModelData.getModelDataStream(model.getModelData()[0])
+ .executeAndCollect()
+ .next();
+
+ assertArrayEquals(new double[] {11.}, actual.labels.toArray(), 1e-5);
+ assertArrayEquals(new double[] {0.0}, actual.piArray.toArray(), 1e-5);
+ assertEquals(-0.6931471805599453, actual.theta[0][0].get(1.0), 1e-5);
+ assertEquals(-0.6931471805599453, actual.theta[0][0].get(2.0), 1e-5);
+ assertEquals(0.0, actual.theta[0][1].get(1.0), 1e-5);
+ }
+
+ @Test
+ public void testSetModelData() {
+ NaiveBayesModel modelA = estimator.fit(trainTable);
+
+ Table modelData = modelA.getModelData()[0];
+ NaiveBayesModel modelB = new NaiveBayesModel().setModelData(modelData);
+ ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+
+ Table outputTable = modelB.transform(predictTable)[0];
+
+ Map<Vector, Integer> actualOutput =
+ executeAndCollect(outputTable, modelB.getFeaturesCol(),
modelB.getPredictionCol());
+ assertEquals(expectedOutput, actualOutput);
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
index 27a839c..c2331a3 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
@@ -18,9 +18,7 @@
package org.apache.flink.ml.clustering;
-import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
-import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.ml.clustering.kmeans.KMeans;
import org.apache.flink.ml.clustering.kmeans.KMeansModel;
@@ -28,6 +26,7 @@ import org.apache.flink.ml.distance.EuclideanDistanceMeasure;
import org.apache.flink.ml.linalg.DenseVector;
import org.apache.flink.ml.linalg.Vectors;
import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -35,15 +34,18 @@ import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-import org.apache.flink.table.api.internal.TableImpl;
import org.apache.flink.test.util.AbstractTestBase;
import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.IteratorUtils;
import org.junit.Before;
+import org.junit.Rule;
import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
-import java.nio.file.Files;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
@@ -51,12 +53,16 @@ import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
+import java.util.Set;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
/** Tests KMeans and KMeansModel. */
public class KMeansTest extends AbstractTestBase {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
private static final List<DenseVector> DATA =
Arrays.asList(
Vectors.dense(0.0, 0.0),
@@ -67,6 +73,18 @@ public class KMeansTest extends AbstractTestBase {
Vectors.dense(9.6, 0.0));
private StreamExecutionEnvironment env;
private StreamTableEnvironment tEnv;
+ private static final List<Set<DenseVector>> expectedGroups =
+ Arrays.asList(
+ new HashSet<>(
+ Arrays.asList(
+ Vectors.dense(0.0, 0.0),
+ Vectors.dense(0.0, 0.3),
+ Vectors.dense(0.3, 0.0))),
+ new HashSet<>(
+ Arrays.asList(
+ Vectors.dense(9.0, 0.0),
+ Vectors.dense(9.0, 0.6),
+ Vectors.dense(9.6, 0.0))));
private Table dataTable;
@Before
@@ -83,43 +101,26 @@ public class KMeansTest extends AbstractTestBase {
dataTable = tEnv.fromDataStream(env.fromCollection(DATA),
schema).as("features");
}
- // Executes the graph and returns a map which maps points to clusterId.
- private static Map<DenseVector, Integer> executeAndCollect(
- Table output, String featureCol, String predictionCol) throws
Exception {
- StreamTableEnvironment tEnv =
- (StreamTableEnvironment) ((TableImpl)
output).getTableEnvironment();
-
- DataStream<Tuple2<DenseVector, Integer>> stream =
- tEnv.toDataStream(output)
- .map(
- new MapFunction<Row, Tuple2<DenseVector,
Integer>>() {
- @Override
- public Tuple2<DenseVector, Integer>
map(Row row) {
- return Tuple2.of(
- (DenseVector)
row.getField(featureCol),
- (Integer)
row.getField(predictionCol));
- }
- });
-
- List<Tuple2<DenseVector, Integer>> pointsWithClusterId =
- IteratorUtils.toList(stream.executeAndCollect());
-
- Map<DenseVector, Integer> clusterIdByPoints = new HashMap<>();
- for (Tuple2<DenseVector, Integer> entry : pointsWithClusterId) {
- clusterIdByPoints.put(entry.f0, entry.f1);
- }
- return clusterIdByPoints;
- }
-
- private static void verifyClusteringResult(
- Map<DenseVector, Integer> clusterIdByPoints, List<List<Integer>>
groups) {
- for (List<Integer> group : groups) {
- for (int i = 1; i < group.size(); i++) {
- assertEquals(
- clusterIdByPoints.get(DATA.get(group.get(0))),
- clusterIdByPoints.get(DATA.get(group.get(i))));
- }
+ /**
+ * Executes a table and collects its results. Results are returned as a
list of sets, where
+ * elements in the same set are features whose prediction results are the
same.
+ *
+ * @param table A table to be executed and to have its result collected
+ * @param featureCol Name of the column in the table that contains the
features
+ * @param predictionCol Name of the column in the table that contains the
prediction result
+ * @return A map containing the collected results
+ */
+ private static List<Set<DenseVector>> executeAndCollect(
+ Table table, String featureCol, String predictionCol) {
+ Map<Integer, Set<DenseVector>> map = new HashMap<>();
+ for (CloseableIterator<Row> it = table.execute().collect();
it.hasNext(); ) {
+ Row row = it.next();
+ DenseVector vector = (DenseVector) row.getField(featureCol);
+ int predict = (Integer) row.getField(predictionCol);
+ map.putIfAbsent(predict, new HashSet<>());
+ map.get(predict).add(vector);
}
+ return new ArrayList<>(map.values());
}
@Test
@@ -158,10 +159,9 @@ public class KMeansTest extends AbstractTestBase {
assertEquals(
Arrays.asList("test_feature", "test_prediction"),
output.getResolvedSchema().getColumnNames());
- Map<DenseVector, Integer> clusterIdByPoints =
+ List<Set<DenseVector>> actualGroups =
executeAndCollect(output, kmeans.getFeaturesCol(),
kmeans.getPredictionCol());
- verifyClusteringResult(
- clusterIdByPoints, Arrays.asList(Arrays.asList(0, 1, 2),
Arrays.asList(3, 4, 5)));
+ assertTrue(CollectionUtils.isEqualCollection(expectedGroups,
actualGroups));
}
@Test
@@ -176,10 +176,11 @@ public class KMeansTest extends AbstractTestBase {
KMeans kmeans = new KMeans().setK(2);
KMeansModel model = kmeans.fit(input);
Table output = model.transform(input)[0];
-
- Map<DenseVector, Integer> clusterIdByPoints =
+ List<Set<DenseVector>> expectedGroups =
+
Collections.singletonList(Collections.singleton(Vectors.dense(0.0, 0.1)));
+ List<Set<DenseVector>> actualGroups =
executeAndCollect(output, kmeans.getFeaturesCol(),
kmeans.getPredictionCol());
- assertEquals(Collections.singleton(0), new
HashSet<>(clusterIdByPoints.values()));
+ assertTrue(CollectionUtils.isEqualCollection(expectedGroups,
actualGroups));
}
@Test
@@ -191,23 +192,22 @@ public class KMeansTest extends AbstractTestBase {
assertEquals(
Arrays.asList("features", "prediction"),
output.getResolvedSchema().getColumnNames());
- Map<DenseVector, Integer> clusterIdByPoints =
+ List<Set<DenseVector>> actualGroups =
executeAndCollect(output, kmeans.getFeaturesCol(),
kmeans.getPredictionCol());
- verifyClusteringResult(
- clusterIdByPoints, Arrays.asList(Arrays.asList(0, 1, 2),
Arrays.asList(3, 4, 5)));
+ assertTrue(CollectionUtils.isEqualCollection(expectedGroups,
actualGroups));
}
@Test
public void testSaveLoadAndPredict() throws Exception {
- String path = Files.createTempDirectory("").toString();
-
KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
- KMeansModel model = kmeans.fit(dataTable);
- model.save(path);
- env.execute();
+ KMeans loadedKmeans =
+ StageTestUtils.saveAndReload(env, kmeans,
tempFolder.newFolder().getAbsolutePath());
+
+ KMeansModel model = loadedKmeans.fit(dataTable);
- KMeansModel loadedModel = KMeansModel.load(env, path);
+ KMeansModel loadedModel =
+ StageTestUtils.saveAndReload(env, model,
tempFolder.newFolder().getAbsolutePath());
Table output = loadedModel.transform(dataTable)[0];
assertEquals(
@@ -216,10 +216,10 @@ public class KMeansTest extends AbstractTestBase {
assertEquals(
Arrays.asList("features", "prediction"),
output.getResolvedSchema().getColumnNames());
- Map<DenseVector, Integer> clusterIdByPoints =
+
+ List<Set<DenseVector>> actualGroups =
executeAndCollect(output, kmeans.getFeaturesCol(),
kmeans.getPredictionCol());
- verifyClusteringResult(
- clusterIdByPoints, Arrays.asList(Arrays.asList(0, 1, 2),
Arrays.asList(3, 4, 5)));
+ assertTrue(CollectionUtils.isEqualCollection(expectedGroups,
actualGroups));
}
@Test
@@ -250,10 +250,8 @@ public class KMeansTest extends AbstractTestBase {
ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
Table output = modelB.transform(dataTable)[0];
- Map<DenseVector, Integer> clusterIdByPoints =
+ List<Set<DenseVector>> actualGroups =
executeAndCollect(output, kmeans.getFeaturesCol(),
kmeans.getPredictionCol());
-
- verifyClusteringResult(
- clusterIdByPoints, Arrays.asList(Arrays.asList(0, 1, 2),
Arrays.asList(3, 4, 5)));
+ assertTrue(CollectionUtils.isEqualCollection(expectedGroups,
actualGroups));
}
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/util/StageTestUtils.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/StageTestUtils.java
new file mode 100644
index 0000000..27283e8
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/StageTestUtils.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+
+import java.lang.reflect.Method;
+
+/** Utility methods for testing stages. */
+public class StageTestUtils {
+ /**
+ * Saves a stage to filesystem and reloads it by invoking the static
load() method of the given
+ * stage.
+ */
+ public static <T extends Stage<T>> T saveAndReload(
+ StreamExecutionEnvironment env, T stage, String path) throws
Exception {
+ stage.save(path);
+ try {
+ env.execute();
+ } catch (RuntimeException e) {
+ if (!e.getMessage()
+ .equals("No operators defined in streaming topology.
Cannot execute.")) {
+ throw e;
+ }
+ }
+
+ Method method =
+ stage.getClass().getMethod("load",
StreamExecutionEnvironment.class, String.class);
+ return (T) method.invoke(null, env, path);
+ }
+}