This is an automated email from the ASF dual-hosted git repository.

gaoyunhaii pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new e4f7eed  [FLINK-24817] Add Estimator and Transformer for Naive Bayes
e4f7eed is described below

commit e4f7eedb958e29817525f54ee6944d819d0de47c
Author: Yunfeng Zhou <[email protected]>
AuthorDate: Wed Nov 17 10:59:03 2021 +0800

    [FLINK-24817] Add Estimator and Transformer for Naive Bayes
    
    This closes #32.
---
 flink-ml-core/pom.xml                              |  11 +
 .../main/java/org/apache/flink/ml/linalg/BLAS.java |  34 ++
 .../org/apache/flink/ml/util/ReadWriteUtils.java   |  46 +++
 .../ml/classification/naivebayes/NaiveBayes.java   | 360 +++++++++++++++++++++
 .../classification/naivebayes/NaiveBayesModel.java | 187 +++++++++++
 .../naivebayes/NaiveBayesModelData.java            | 164 ++++++++++
 .../naivebayes/NaiveBayesModelParams.java}         |  33 +-
 .../naivebayes/NaiveBayesParams.java}              |  33 +-
 .../apache/flink/ml/clustering/kmeans/KMeans.java  |   5 +-
 .../flink/ml/clustering/kmeans/KMeansModel.java    |  44 +--
 .../ml/clustering/kmeans/KMeansModelData.java      |  80 +++--
 .../flink/ml/common/datastream/TableUtils.java     |   3 +-
 .../flink/ml/common/param/HasDistanceMeasure.java  |   2 +-
 .../{HasDistanceMeasure.java => HasLabelCol.java}  |  22 +-
 .../flink/ml/classification/NaiveBayesTest.java    | 314 ++++++++++++++++++
 .../org/apache/flink/ml/clustering/KMeansTest.java | 122 ++++---
 .../org/apache/flink/ml/util/StageTestUtils.java   |  48 +++
 17 files changed, 1331 insertions(+), 177 deletions(-)

diff --git a/flink-ml-core/pom.xml b/flink-ml-core/pom.xml
index 6ed25b7..79a9654 100644
--- a/flink-ml-core/pom.xml
+++ b/flink-ml-core/pom.xml
@@ -65,5 +65,16 @@ under the License.
       <artifactId>flink-shaded-jackson</artifactId>
       <scope>provided</scope>
     </dependency>
+
+    <dependency>
+      <groupId>dev.ludovic.netlib</groupId>
+      <artifactId>blas</artifactId>
+      <version>2.2.0</version>
+    </dependency>
+      <dependency>
+          <groupId>org.apache.flink</groupId>
+          <artifactId>flink-table-api-java-bridge_2.11</artifactId>
+          <version>1.14.0</version>
+      </dependency>
   </dependencies>
 </project>
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
new file mode 100644
index 0000000..3b91cdb
--- /dev/null
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/BLAS.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.linalg;
+
+import org.apache.flink.util.Preconditions;
+
+/** A utility class that provides BLAS routines over matrices and vectors. */
+public class BLAS {
+    /** For level-1 function dspmv, use javaBLAS for better performance. */
+    private static final dev.ludovic.netlib.BLAS JAVA_BLAS =
+            dev.ludovic.netlib.JavaBLAS.getInstance();
+
+    /** y += a * x . */
+    public static void axpy(double a, DenseVector x, DenseVector y) {
+        Preconditions.checkArgument(x.size() == y.size(), "Vector size 
mismatched.");
+        JAVA_BLAS.daxpy(x.size(), a, x.values, 1, y.values, 1);
+    }
+}
diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
index 457cac3..800a33d 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/util/ReadWriteUtils.java
@@ -18,9 +18,19 @@
 
 package org.apache.flink.ml.util;
 
+import org.apache.flink.api.common.eventtime.WatermarkStrategy;
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.connector.source.Source;
+import org.apache.flink.connector.file.sink.FileSink;
+import org.apache.flink.connector.file.src.FileSource;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.ml.api.Stage;
 import org.apache.flink.ml.param.Param;
+import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
+import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.util.InstantiationUtil;
 
 import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper;
@@ -312,4 +322,40 @@ public class ReadWriteUtils {
             throw new RuntimeException("Failed to load stage.", e);
         }
     }
+
+    /**
+     * Saves the model data stream to the given path using the model encoder.
+     *
+     * @param model The model data stream.
+     * @param path The parent directory of the model data file.
+     * @param modelEncoder The encoder to encode the model data.
+     * @param <T> The class type of the model data.
+     */
+    public static <T> void saveModelData(
+            DataStream<T> model, String path, Encoder<T> modelEncoder) {
+        FileSink<T> sink =
+                FileSink.forRowFormat(
+                                new 
org.apache.flink.core.fs.Path(getDataPath(path)), modelEncoder)
+                        .withRollingPolicy(OnCheckpointRollingPolicy.build())
+                        .withBucketAssigner(new BasePathBucketAssigner<>())
+                        .build();
+        model.sinkTo(sink);
+    }
+
+    /**
+     * Loads the model data from the given path using the model decoder.
+     *
+     * @param env A StreamExecutionEnvironment instance.
+     * @param path The parent directory of the model data file.
+     * @param modelDecoder The decoder used to decode the model data.
+     * @param <T> The class type of the model data.
+     * @return The loaded model data.
+     */
+    public static <T> DataStream<T> loadModelData(
+            StreamExecutionEnvironment env, String path, SimpleStreamFormat<T> 
modelDecoder) {
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+        Source<T, ?, ?> source =
+                FileSource.forRecordStreamFormat(modelDecoder, 
getDataPaths(path)).build();
+        return env.fromSource(source, WatermarkStrategy.noWatermarks(), 
"modelData");
+    }
 }
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
new file mode 100644
index 0000000..104c908
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayes.java
@@ -0,0 +1,360 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.MapPartitionFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.api.java.tuple.Tuple4;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.MapPartitionFunctionWrapper;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the naive bayes classification algorithm.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Naive_Bayes_classifier.
+ */
+public class NaiveBayes
+        implements Estimator<NaiveBayes, NaiveBayesModel>, 
NaiveBayesParams<NaiveBayes> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public NaiveBayes() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public NaiveBayesModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String featuresCol = getFeaturesCol();
+        final String labelCol = getLabelCol();
+        final double smoothing = getSmoothing();
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Tuple2<Vector, Double>> input =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                new MapFunction<Row, Tuple2<Vector, Double>>() 
{
+                                    @Override
+                                    public Tuple2<Vector, Double> map(Row row) 
throws Exception {
+                                        Number number = (Number) 
row.getField(labelCol);
+                                        Preconditions.checkNotNull(
+                                                number, "Input data should 
contain label value.");
+                                        Preconditions.checkArgument(
+                                                number.intValue() == 
number.doubleValue(),
+                                                "Label value should be indexed 
number.");
+                                        return new Tuple2<>(
+                                                (Vector) 
row.getField(featuresCol),
+                                                number.doubleValue());
+                                    }
+                                });
+
+        DataStream<NaiveBayesModelData> modelData =
+                input.flatMap(new ExtractFeatureFunction())
+                        .keyBy(value -> new Tuple2<>(value.f0, 
value.f1).hashCode())
+                        .transform(
+                                "GenerateFeatureWeightMapFunction",
+                                Types.TUPLE(
+                                        Types.DOUBLE,
+                                        Types.INT,
+                                        Types.MAP(Types.DOUBLE, Types.DOUBLE),
+                                        Types.INT),
+                                new MapPartitionFunctionWrapper<>(
+                                        new 
GenerateFeatureWeightMapFunction()))
+                        .keyBy(value -> value.f0)
+                        .transform(
+                                "AggregateIntoArrayFunction",
+                                Types.TUPLE(
+                                        Types.DOUBLE,
+                                        Types.INT,
+                                        
Types.OBJECT_ARRAY(Types.MAP(Types.DOUBLE, Types.DOUBLE))),
+                                new MapPartitionFunctionWrapper<>(new 
AggregateIntoArrayFunction()))
+                        .transform(
+                                "GenerateModelFunction",
+                                TypeInformation.of(NaiveBayesModelData.class),
+                                new MapPartitionFunctionWrapper<>(
+                                        new GenerateModelFunction(smoothing)))
+                        .setParallelism(1);
+
+        NaiveBayesModel model =
+                new NaiveBayesModel()
+                        
.setModelData(NaiveBayesModelData.getModelDataTable(modelData));
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static NaiveBayes load(StreamExecutionEnvironment env, String path) 
throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * Function to extract feature values from input rows.
+     *
+     * <p>Output records are tuples with the following fields in order:
+     *
+     * <ul>
+     *   <li>label value
+     *   <li>feature column index
+     *   <li>feature value
+     * </ul>
+     */
+    private static class ExtractFeatureFunction
+            implements FlatMapFunction<Tuple2<Vector, Double>, Tuple3<Double, 
Integer, Double>> {
+        @Override
+        public void flatMap(
+                Tuple2<Vector, Double> value,
+                Collector<Tuple3<Double, Integer, Double>> collector) {
+            Preconditions.checkNotNull(value.f1);
+            for (int i = 0; i < value.f0.size(); i++) {
+                collector.collect(new Tuple3<>(value.f1, i, value.f0.get(i)));
+            }
+        }
+    }
+
+    /**
+     * Function that aggregates entries of feature value and weight into maps.
+     *
+     * <p>Input records should have the same label value and feature column 
index.
+     *
+     * <p>Input records are tuples with the following fields in order:
+     *
+     * <ul>
+     *   <li>label value
+     *   <li>feature column index
+     *   <li>feature value
+     * </ul>
+     *
+     * <p>Output records are tuples with the following fields in order:
+     *
+     * <ul>
+     *   <li>label value
+     *   <li>feature column index
+     *   <li>map of (feature value, weight)
+     *   <li>number of records
+     * </ul>
+     */
+    private static class GenerateFeatureWeightMapFunction
+            implements MapPartitionFunction<
+                    Tuple3<Double, Integer, Double>,
+                    Tuple4<Double, Integer, Map<Double, Double>, Integer>> {
+
+        @Override
+        public void mapPartition(
+                Iterable<Tuple3<Double, Integer, Double>> iterable,
+                Collector<Tuple4<Double, Integer, Map<Double, Double>, 
Integer>> collector) {
+            List<Tuple3<Double, Integer, Double>> list = new ArrayList<>();
+            iterable.iterator().forEachRemaining(list::add);
+
+            Map<Tuple2<Double, Integer>, Map<Double, Double>> accMap = new 
HashMap<>();
+            Map<Tuple2<Double, Integer>, Integer> numMap = new HashMap<>();
+            for (Tuple3<Double, Integer, Double> value : list) {
+                Tuple2<Double, Integer> key = new Tuple2<>(value.f0, value.f1);
+                Map<Double, Double> acc = accMap.computeIfAbsent(key, x -> new 
HashMap<>());
+                acc.put(value.f2, acc.getOrDefault(value.f2, 0.) + 1.0);
+                numMap.put(key, numMap.getOrDefault(key, 0) + 1);
+            }
+            for (Map.Entry<Tuple2<Double, Integer>, Map<Double, Double>> entry 
:
+                    accMap.entrySet()) {
+                collector.collect(
+                        new Tuple4<>(
+                                entry.getKey().f0,
+                                entry.getKey().f1,
+                                entry.getValue(),
+                                numMap.get(entry.getKey())));
+            }
+        }
+    }
+
+    /**
+     * Function that aggregates maps under the same label into arrays.
+     *
+     * <p>Length of the generated array equals to the number of feature 
columns.
+     *
+     * <p>Input records are tuples with the following fields in order:
+     *
+     * <ul>
+     *   <li>label value
+     *   <li>feature column index
+     *   <li>map of (feature value, weight)
+     *   <li>number of records
+     * </ul>
+     *
+     * <p>Output records are tuples with the following fields in order:
+     *
+     * <ul>
+     *   <li>label value
+     *   <li>number of records
+     *   <li>array of featureValue-weight maps of each feature
+     * </ul>
+     */
+    private static class AggregateIntoArrayFunction
+            implements MapPartitionFunction<
+                    Tuple4<Double, Integer, Map<Double, Double>, Integer>,
+                    Tuple3<Double, Integer, Map<Double, Double>[]>> {
+
+        @Override
+        public void mapPartition(
+                Iterable<Tuple4<Double, Integer, Map<Double, Double>, 
Integer>> iterable,
+                Collector<Tuple3<Double, Integer, Map<Double, Double>[]>> 
collector) {
+            Map<Double, List<Tuple4<Double, Integer, Map<Double, Double>, 
Integer>>> map =
+                    new HashMap<>();
+            for (Tuple4<Double, Integer, Map<Double, Double>, Integer> value : 
iterable) {
+                map.computeIfAbsent(value.f0, x -> new 
ArrayList<>()).add(value);
+            }
+
+            for (List<Tuple4<Double, Integer, Map<Double, Double>, Integer>> 
list : map.values()) {
+                final int featureSize =
+                        list.stream().map(x -> 
x.f1).max(Integer::compareTo).orElse(-1) + 1;
+
+                int minDocNum =
+                        list.stream()
+                                .map(x -> x.f3)
+                                .min(Integer::compareTo)
+                                .orElse(Integer.MAX_VALUE);
+                int maxDocNum =
+                        list.stream()
+                                .map(x -> x.f3)
+                                .max(Integer::compareTo)
+                                .orElse(Integer.MIN_VALUE);
+                Preconditions.checkArgument(
+                        minDocNum == maxDocNum, "Feature vectors should be of 
equal length.");
+
+                Map<Double, Integer> numMap = new HashMap<>();
+                Map<Double, Map<Double, Double>[]> featureWeightMap = new 
HashMap<>();
+                for (Tuple4<Double, Integer, Map<Double, Double>, Integer> 
value : list) {
+                    Map<Double, Double>[] featureWeight =
+                            featureWeightMap.computeIfAbsent(
+                                    value.f0, x -> new HashMap[featureSize]);
+                    numMap.put(value.f0, value.f3);
+                    featureWeight[value.f1] = value.f2;
+                }
+
+                for (double key : featureWeightMap.keySet()) {
+                    collector.collect(
+                            new Tuple3<>(key, numMap.get(key), 
featureWeightMap.get(key)));
+                }
+            }
+        }
+    }
+
+    /** Function to generate Naive Bayes model data. */
+    private static class GenerateModelFunction
+            implements MapPartitionFunction<
+                    Tuple3<Double, Integer, Map<Double, Double>[]>, 
NaiveBayesModelData> {
+        private final double smoothing;
+
+        private GenerateModelFunction(double smoothing) {
+            this.smoothing = smoothing;
+        }
+
+        @Override
+        public void mapPartition(
+                Iterable<Tuple3<Double, Integer, Map<Double, Double>[]>> 
iterable,
+                Collector<NaiveBayesModelData> collector) {
+            ArrayList<Tuple3<Double, Integer, Map<Double, Double>[]>> list = 
new ArrayList<>();
+            iterable.iterator().forEachRemaining(list::add);
+            final int featureSize = list.get(0).f2.length;
+            for (Tuple3<Double, Integer, Map<Double, Double>[]> tup : list) {
+                Preconditions.checkArgument(
+                        featureSize == tup.f2.length, "Feature vectors should 
be of equal length.");
+            }
+
+            double[] numDocs = new double[featureSize];
+            HashSet<Double>[] categoryNumbers = new HashSet[featureSize];
+            for (int i = 0; i < featureSize; i++) {
+                categoryNumbers[i] = new HashSet<>();
+            }
+            for (Tuple3<Double, Integer, Map<Double, Double>[]> tup : list) {
+                for (int i = 0; i < featureSize; i++) {
+                    numDocs[i] += tup.f1;
+                    categoryNumbers[i].addAll(tup.f2[i].keySet());
+                }
+            }
+
+            int[] categoryNumber = new int[featureSize];
+            double piLog = 0;
+            int numLabels = list.size();
+            for (int i = 0; i < featureSize; i++) {
+                categoryNumber[i] = categoryNumbers[i].size();
+                piLog += numDocs[i];
+            }
+            piLog = Math.log(piLog + numLabels * smoothing);
+
+            Map<Double, Double>[][] theta = new 
HashMap[numLabels][featureSize];
+            double[] piArray = new double[numLabels];
+            double[] labels = new double[numLabels];
+
+            // Consider smoothing.
+            for (int i = 0; i < numLabels; i++) {
+                Map<Double, Double>[] param = list.get(i).f2;
+                for (int j = 0; j < featureSize; j++) {
+                    Map<Double, Double> squareData = new HashMap<>();
+                    double thetaLog =
+                            Math.log(list.get(i).f1 * 1.0 + smoothing * 
categoryNumber[j]);
+                    for (Double cate : categoryNumbers[j]) {
+                        double value = param[j].getOrDefault(cate, 0.0);
+                        squareData.put(cate, Math.log(value + smoothing) - 
thetaLog);
+                    }
+                    theta[i][j] = squareData;
+                }
+
+                labels[i] = list.get(i).f0;
+                double weightSum = list.get(i).f1 * featureSize;
+                piArray[i] = Math.log(weightSum + smoothing) - piLog;
+            }
+
+            NaiveBayesModelData modelData = new NaiveBayesModelData(theta, 
piArray, labels);
+            collector.collect(modelData);
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java
new file mode 100644
index 0000000..50ef1b1
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModel.java
@@ -0,0 +1,187 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+/** A Model which classifies data using the model data computed by {@link 
NaiveBayes}. */
+public class NaiveBayesModel
+        implements Model<NaiveBayesModel>, 
NaiveBayesModelParams<NaiveBayesModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public NaiveBayesModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        final String predictionCol = getPredictionCol();
+        final String featuresCol = getFeaturesCol();
+        final String broadcastModelKey = "NaiveBayesModelStream";
+
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), 
TypeInformation.of(Double.class)),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
predictionCol));
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
modelDataTable).getTableEnvironment();
+        DataStream<NaiveBayesModelData> modelDataStream =
+                NaiveBayesModelData.getModelDataStream(modelDataTable);
+        DataStream<Row> input = tEnv.toDataStream(inputs[0]);
+
+        Map<String, DataStream<?>> broadcastMap = new HashMap<>();
+        broadcastMap.put(broadcastModelKey, modelDataStream);
+
+        Function<List<DataStream<?>>, DataStream<Row>> function =
+                dataStreams -> {
+                    DataStream stream = dataStreams.get(0);
+                    return stream.map(
+                            new PredictLabelFunction(featuresCol, 
broadcastModelKey),
+                            outputTypeInfo);
+                };
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(input),
+                        Collections.singletonMap(broadcastModelKey, 
modelDataStream),
+                        function);
+
+        Table outputTable = tEnv.fromDataStream(output);
+
+        return new Table[] {outputTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                NaiveBayesModelData.getModelDataStream(modelDataTable),
+                path,
+                new NaiveBayesModelData.ModelDataEncoder());
+    }
+
+    public static NaiveBayesModel load(StreamExecutionEnvironment env, String 
path)
+            throws IOException {
+        NaiveBayesModel model = ReadWriteUtils.loadStageParam(path);
+        DataStream<NaiveBayesModelData> modelData =
+                ReadWriteUtils.loadModelData(
+                        env, path, new 
NaiveBayesModelData.ModelDataStreamFormat());
+        return 
model.setModelData(NaiveBayesModelData.getModelDataTable(modelData));
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public NaiveBayesModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    private static class PredictLabelFunction extends RichMapFunction<Row, 
Row> {
+        private final String featuresCol;
+        private final String broadcastModelKey;
+        private NaiveBayesModelData modelData = null;
+
+        public PredictLabelFunction(String featuresCol, String 
broadcastModelKey) {
+            this.featuresCol = featuresCol;
+            this.broadcastModelKey = broadcastModelKey;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (modelData == null) {
+                modelData =
+                        (NaiveBayesModelData)
+                                
getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
+            }
+            Vector vector = (Vector) row.getField(featuresCol);
+            double label = findMaxProbLabel(calculateProb(modelData, vector), 
modelData.labels);
+            return Row.join(row, Row.of(label));
+        }
+    }
+
+    private static double findMaxProbLabel(DenseVector prob, Vector label) {
+        double result = 0.;
+        int probSize = prob.size();
+        double maxVal = Double.NEGATIVE_INFINITY;
+        for (int i = 0; i < probSize; ++i) {
+            if (maxVal < prob.values[i]) {
+                maxVal = prob.values[i];
+                result = label.get(i);
+            }
+        }
+        Preconditions.checkArgument(maxVal > Double.NEGATIVE_INFINITY);
+        return result;
+    }
+
+    /** Calculate probability of the input data. */
+    private static DenseVector calculateProb(NaiveBayesModelData modelData, 
Vector data) {
+        int labelSize = modelData.labels.size();
+        DenseVector probs = new DenseVector(new double[labelSize]);
+        for (int i = 0; i < labelSize; i++) {
+            Map<Double, Double>[] labelData = modelData.theta[i];
+            for (int j = 0; j < data.size(); j++) {
+                probs.values[i] += labelData[j].get(data.get(j));
+            }
+        }
+        BLAS.axpy(1, modelData.piArray, probs);
+        return probs;
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
new file mode 100644
index 0000000..8505cbc
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelData.java
@@ -0,0 +1,164 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification.naivebayes;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleSerializer;
+import org.apache.flink.api.common.typeutils.base.MapSerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * The model data of {@link NaiveBayesModel}.
+ *
+ * <p>This class also provides methods to convert model data between Table and 
Datastream, and
+ * classes to save/load model data.
+ */
+public class NaiveBayesModelData {
+    /**
+     * Log of class conditional probabilities, whose dimension is C (number of 
classes) by D (number
+     * of features).
+     */
+    public final Map<Double, Double>[][] theta;
+
+    /** Log of class priors, whose dimension is C (number of classes). */
+    public final DenseVector piArray;
+
+    /** Value of labels. */
+    public final DenseVector labels;
+
+    public NaiveBayesModelData(Map<Double, Double>[][] theta, double[] 
piArray, double[] labels) {
+        this(theta, Vectors.dense(piArray), Vectors.dense(labels));
+    }
+
+    public NaiveBayesModelData(
+            Map<Double, Double>[][] theta, DenseVector piArray, DenseVector 
labels) {
+        this.theta = theta;
+        this.piArray = piArray;
+        this.labels = labels;
+    }
+
+    /** Converts the provided modelData Datastream into corresponding Table. */
+    public static Table getModelDataTable(DataStream<NaiveBayesModelData> 
stream) {
+        StreamTableEnvironment tEnv =
+                
StreamTableEnvironment.create(stream.getExecutionEnvironment());
+        return tEnv.fromDataStream(stream);
+    }
+
+    /** Converts the provided modelData Table into corresponding DataStream. */
+    public static DataStream<NaiveBayesModelData> getModelDataStream(Table 
table) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
table).getTableEnvironment();
+        return tEnv.toDataStream(table)
+                .map(
+                        (MapFunction<Row, NaiveBayesModelData>)
+                                row -> (NaiveBayesModelData) 
row.getField("f0"));
+    }
+
+    /** Encoder for the {@link NaiveBayesModelData}. */
+    public static class ModelDataEncoder implements 
Encoder<NaiveBayesModelData> {
+        @Override
+        public void encode(NaiveBayesModelData modelData, OutputStream 
outputStream)
+                throws IOException {
+            DataOutputViewStreamWrapper output = new 
DataOutputViewStreamWrapper(outputStream);
+
+            DenseVectorSerializer denseVectorSerializer = new 
DenseVectorSerializer();
+            MapSerializer<Double, Double> mapSerializer =
+                    new MapSerializer<>(new DoubleSerializer(), new 
DoubleSerializer());
+
+            denseVectorSerializer.serialize(modelData.labels, output);
+
+            denseVectorSerializer.serialize(modelData.piArray, output);
+
+            output.writeInt(modelData.theta.length);
+            output.writeInt(modelData.theta[0].length);
+            for (Map<Double, Double>[] maps : modelData.theta) {
+                for (Map<Double, Double> map : maps) {
+                    mapSerializer.serialize(map, output);
+                }
+            }
+        }
+    }
+
+    /** Decoder for the {@link NaiveBayesModelData}. */
+    public static class ModelDataStreamFormat extends 
SimpleStreamFormat<NaiveBayesModelData> {
+        @Override
+        public Reader<NaiveBayesModelData> createReader(
+                Configuration config, FSDataInputStream inputStream) {
+            return new Reader<NaiveBayesModelData>() {
+                private final DataInputViewStreamWrapper input =
+                        new DataInputViewStreamWrapper(inputStream);
+
+                @Override
+                public NaiveBayesModelData read() throws IOException {
+                    try {
+                        DenseVectorSerializer denseVectorSerializer = new 
DenseVectorSerializer();
+                        MapSerializer<Double, Double> mapSerializer =
+                                new MapSerializer<>(new DoubleSerializer(), 
new DoubleSerializer());
+
+                        DenseVector labels = 
denseVectorSerializer.deserialize(input);
+
+                        DenseVector piArray = 
denseVectorSerializer.deserialize(input);
+
+                        int featureSize = input.readInt();
+                        int numLabels = input.readInt();
+                        Map<Double, Double>[][] theta = new 
HashMap[numLabels][featureSize];
+                        for (int i = 0; i < featureSize; i++) {
+                            for (int j = 0; j < numLabels; j++) {
+                                theta[i][j] = mapSerializer.deserialize(input);
+                            }
+                        }
+                        return new NaiveBayesModelData(theta, piArray, labels);
+                    } catch (EOFException e) {
+                        return null;
+                    }
+                }
+
+                @Override
+                public void close() throws IOException {
+                    inputStream.close();
+                }
+            };
+        }
+
+        @Override
+        public TypeInformation<NaiveBayesModelData> getProducedType() {
+            return TypeInformation.of(NaiveBayesModelData.class);
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelParams.java
similarity index 57%
copy from 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
copy to 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelParams.java
index c4c5953..8001df0 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesModelParams.java
@@ -16,29 +16,32 @@
  * limitations under the License.
  */
 
-package org.apache.flink.ml.common.param;
+package org.apache.flink.ml.classification.naivebayes;
 
-import org.apache.flink.ml.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasPredictionCol;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.param.ParamValidators;
 import org.apache.flink.ml.param.StringParam;
-import org.apache.flink.ml.param.WithParams;
 
-/** Interface for the shared distanceMeasure param. */
-public interface HasDistanceMeasure<T> extends WithParams<T> {
-    Param<String> DISTANCE_MEASURE =
+/**
+ * Params of {@link NaiveBayesModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface NaiveBayesModelParams<T> extends HasFeaturesCol<T>, 
HasPredictionCol<T> {
+    Param<String> MODEL_TYPE =
             new StringParam(
-                    "distanceMeasure",
-                    "The distance measure. Supported options: 'euclidean'.",
-                    EuclideanDistanceMeasure.NAME,
-                    ParamValidators.inArray(EuclideanDistanceMeasure.NAME));
+                    "modelType",
+                    "The model type.",
+                    "multinomial",
+                    ParamValidators.inArray("multinomial"));
 
-    default String getDistanceMeasure() {
-        return get(DISTANCE_MEASURE);
+    default String getModelType() {
+        return get(MODEL_TYPE);
     }
 
-    default T setDistanceMeasure(String value) {
-        set(DISTANCE_MEASURE, value);
-        return (T) this;
+    default T setModelType(String value) {
+        return set(MODEL_TYPE, value);
     }
 }
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesParams.java
similarity index 52%
copy from 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
copy to 
flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesParams.java
index c4c5953..8e7e089 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/classification/naivebayes/NaiveBayesParams.java
@@ -16,29 +16,28 @@
  * limitations under the License.
  */
 
-package org.apache.flink.ml.common.param;
+package org.apache.flink.ml.classification.naivebayes;
 
-import org.apache.flink.ml.distance.EuclideanDistanceMeasure;
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.param.ParamValidators;
-import org.apache.flink.ml.param.StringParam;
-import org.apache.flink.ml.param.WithParams;
 
-/** Interface for the shared distanceMeasure param. */
-public interface HasDistanceMeasure<T> extends WithParams<T> {
-    Param<String> DISTANCE_MEASURE =
-            new StringParam(
-                    "distanceMeasure",
-                    "The distance measure. Supported options: 'euclidean'.",
-                    EuclideanDistanceMeasure.NAME,
-                    ParamValidators.inArray(EuclideanDistanceMeasure.NAME));
+/**
+ * Params of {@link NaiveBayes}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface NaiveBayesParams<T> extends NaiveBayesModelParams<T>, 
HasLabelCol<T> {
+    Param<Double> SMOOTHING =
+            new DoubleParam(
+                    "smoothing", "The smoothing parameter.", 1.0, 
ParamValidators.gtEq(0.0));
 
-    default String getDistanceMeasure() {
-        return get(DISTANCE_MEASURE);
+    default Double getSmoothing() {
+        return get(SMOOTHING);
     }
 
-    default T setDistanceMeasure(String value) {
-        set(DISTANCE_MEASURE, value);
-        return (T) this;
+    default T setSmoothing(Double value) {
+        return set(SMOOTHING, value);
     }
 }
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
index 1b08b66..3966626 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeans.java
@@ -60,6 +60,7 @@ import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
 
 import org.apache.commons.collections.IteratorUtils;
 
@@ -85,6 +86,8 @@ public class KMeans implements Estimator<KMeans, 
KMeansModel>, KMeansParams<KMea
 
     @Override
     public KMeansModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
         DataStream<DenseVector> points =
@@ -110,7 +113,7 @@ public class KMeans implements Estimator<KMeans, 
KMeansModel>, KMeansParams<KMea
                                 body)
                         .get(0);
 
-        Table finalCentroidsTable = tEnv.fromDataStream(finalCentroids, 
KMeansModelData.SCHEMA);
+        Table finalCentroidsTable = 
KMeansModelData.getModelDataTable(finalCentroids);
         KMeansModel model = new 
KMeansModel().setModelData(finalCentroidsTable);
         ReadWriteUtils.updateExistingParams(model, paramMap);
         return model;
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
index 0d2351c..d5d94e1 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModel.java
@@ -18,17 +18,12 @@
 
 package org.apache.flink.ml.clustering.kmeans;
 
-import org.apache.flink.api.common.eventtime.WatermarkStrategy;
 import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ListStateDescriptor;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.common.typeinfo.Types;
-import org.apache.flink.api.connector.source.Source;
 import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.connector.file.sink.FileSink;
-import org.apache.flink.connector.file.src.FileSource;
-import org.apache.flink.core.fs.Path;
 import org.apache.flink.ml.api.Model;
 import org.apache.flink.ml.common.datastream.TableUtils;
 import org.apache.flink.ml.distance.DistanceMeasure;
@@ -40,8 +35,6 @@ import org.apache.flink.ml.util.ReadWriteUtils;
 import org.apache.flink.runtime.state.StateInitializationContext;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import 
org.apache.flink.streaming.api.functions.sink.filesystem.bucketassigners.BasePathBucketAssigner;
-import 
org.apache.flink.streaming.api.functions.sink.filesystem.rollingpolicies.OnCheckpointRollingPolicy;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -49,6 +42,7 @@ import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.internal.TableImpl;
 import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
 
 import org.apache.commons.collections.IteratorUtils;
 import org.apache.commons.lang3.ArrayUtils;
@@ -69,6 +63,7 @@ public class KMeansModel implements Model<KMeansModel>, 
KMeansModelParams<KMeans
 
     @Override
     public KMeansModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
         centroidsTable = inputs[0];
         return this;
     }
@@ -80,10 +75,11 @@ public class KMeansModel implements Model<KMeansModel>, 
KMeansModelParams<KMeans
 
     @Override
     public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
-        DataStream<DenseVector[]> centroids =
-                tEnv.toDataStream(centroidsTable).map(row -> (DenseVector[]) 
row.getField("f0"));
+        DataStream<DenseVector[]> centroids = 
KMeansModelData.getModelDataStream(centroidsTable);
 
         String featureCol = getFeaturesCol();
         String predictionCol = getPredictionCol();
@@ -182,33 +178,19 @@ public class KMeansModel implements Model<KMeansModel>, 
KMeansModelParams<KMeans
 
     @Override
     public void save(String path) throws IOException {
-        StreamTableEnvironment tEnv =
-                (StreamTableEnvironment) ((TableImpl) 
centroidsTable).getTableEnvironment();
-
-        String dataPath = ReadWriteUtils.getDataPath(path);
-        FileSink<DenseVector[]> sink =
-                FileSink.forRowFormat(new Path(dataPath), new 
KMeansModelData.ModelDataEncoder())
-                        .withRollingPolicy(OnCheckpointRollingPolicy.build())
-                        .withBucketAssigner(new BasePathBucketAssigner<>())
-                        .build();
-        tEnv.toDataStream(centroidsTable)
-                .map(row -> (DenseVector[]) row.getField("f0"))
-                .sinkTo(sink);
-
         ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                KMeansModelData.getModelDataStream(centroidsTable),
+                path,
+                new KMeansModelData.ModelDataEncoder());
     }
 
     // TODO: Add INFO level logging.
     public static KMeansModel load(StreamExecutionEnvironment env, String 
path) throws IOException {
-        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
-        Source<DenseVector[], ?, ?> source =
-                FileSource.forRecordStreamFormat(
-                                new KMeansModelData.ModelDataStreamFormat(),
-                                ReadWriteUtils.getDataPaths(path))
-                        .build();
         KMeansModel model = ReadWriteUtils.loadStageParam(path);
-        DataStream<DenseVector[]> modelData =
-                env.fromSource(source, WatermarkStrategy.noWatermarks(), 
"modelData");
-        return model.setModelData(tEnv.fromDataStream(modelData, 
KMeansModelData.SCHEMA));
+        DataStream<DenseVector[]> centroids =
+                ReadWriteUtils.loadModelData(
+                        env, path, new 
KMeansModelData.ModelDataStreamFormat());
+        return 
model.setModelData(KMeansModelData.getModelDataTable(centroids));
     }
 }
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
index 779873b..7101a18 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/clustering/kmeans/KMeansModelData.java
@@ -20,71 +20,83 @@ package org.apache.flink.ml.clustering.kmeans;
 
 import org.apache.flink.api.common.serialization.Encoder;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
 import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
 import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.table.api.DataTypes;
-import org.apache.flink.table.api.Schema;
-
-import com.esotericsoftware.kryo.Kryo;
-import com.esotericsoftware.kryo.io.Input;
-import com.esotericsoftware.kryo.io.Output;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
 
+import java.io.EOFException;
 import java.io.IOException;
 import java.io.OutputStream;
-import java.util.ArrayList;
-import java.util.List;
 
 /** Provides classes to save/load model data. */
 public class KMeansModelData {
+    /** Converts the provided modelData Datastream into corresponding Table. */
+    public static Table getModelDataTable(DataStream<DenseVector[]> modelData) 
{
+        StreamTableEnvironment tEnv =
+                
StreamTableEnvironment.create(modelData.getExecutionEnvironment());
+        return tEnv.fromDataStream(modelData);
+    }
 
-    public static final Schema SCHEMA =
-            Schema.newBuilder()
-                    .column("f0", 
DataTypes.ARRAY(DataTypes.of(DenseVector.class)))
-                    .build();
+    /** Converts the provided modelData Table into corresponding Datastream. */
+    public static DataStream<DenseVector[]> getModelDataStream(Table table) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
table).getTableEnvironment();
+        return tEnv.toDataStream(table).map(row -> (DenseVector[]) 
row.getField("f0"));
+    }
 
     /** Encoder for the KMeans model data. */
     public static class ModelDataEncoder implements Encoder<DenseVector[]> {
         @Override
-        public void encode(DenseVector[] modelData, OutputStream outputStream) 
{
-            Kryo kryo = new Kryo();
-            Output output = new Output(outputStream);
-            List<double[]> convertedData = new ArrayList<>();
-            for (int i = 0; i < modelData.length; i++) {
-                convertedData.add(modelData[i].values);
+        public void encode(DenseVector[] modelData, OutputStream outputStream) 
throws IOException {
+            IntSerializer intSerializer = new IntSerializer();
+            DenseVectorSerializer denseVectorSerializer = new 
DenseVectorSerializer();
+            DataOutputViewStreamWrapper outputViewStreamWrapper =
+                    new DataOutputViewStreamWrapper(outputStream);
+            intSerializer.serialize(modelData.length, outputViewStreamWrapper);
+            for (DenseVector denseVector : modelData) {
+                denseVectorSerializer.serialize(
+                        denseVector, new 
DataOutputViewStreamWrapper(outputStream));
             }
-            kryo.writeObject(output, convertedData);
-            output.flush();
         }
     }
 
     /** Decoder for the KMeans model data. */
     public static class ModelDataStreamFormat extends 
SimpleStreamFormat<DenseVector[]> {
         @Override
-        public Reader<DenseVector[]> createReader(Configuration config, 
FSDataInputStream stream) {
+        public Reader<DenseVector[]> createReader(
+                Configuration config, FSDataInputStream inputStream) {
             return new Reader<DenseVector[]>() {
-                private final Kryo kryo = new Kryo();
-                private final Input input = new Input(stream);
-
                 @Override
-                public DenseVector[] read() {
-                    if (input.eof()) {
+                public DenseVector[] read() throws IOException {
+                    try {
+                        IntSerializer intSerializer = new IntSerializer();
+                        DenseVectorSerializer denseVectorSerializer = new 
DenseVectorSerializer();
+                        DataInputViewStreamWrapper inputViewStreamWrapper =
+                                new DataInputViewStreamWrapper(inputStream);
+                        int numDenseVectors = 
intSerializer.deserialize(inputViewStreamWrapper);
+                        DenseVector[] result = new 
DenseVector[numDenseVectors];
+                        for (int i = 0; i < numDenseVectors; i++) {
+                            result[i] = 
denseVectorSerializer.deserialize(inputViewStreamWrapper);
+                        }
+                        return result;
+                    } catch (EOFException e) {
                         return null;
                     }
-                    ArrayList<double[]> row = kryo.readObject(input, 
ArrayList.class);
-
-                    DenseVector[] result = new DenseVector[row.size()];
-                    for (int i = 0; i < result.length; i++) {
-                        result[i] = new DenseVector(row.get(i));
-                    }
-                    return result;
                 }
 
                 @Override
                 public void close() throws IOException {
-                    stream.close();
+                    inputStream.close();
                 }
             };
         }
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
index db38451..51aa18c 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
@@ -22,7 +22,6 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.table.catalog.Column;
 import org.apache.flink.table.catalog.ResolvedSchema;
-import org.apache.flink.table.runtime.typeutils.ExternalTypeInfo;
 
 /** Utility class for table-related operations. */
 public class TableUtils {
@@ -33,7 +32,7 @@ public class TableUtils {
 
         for (int i = 0; i < schema.getColumnCount(); i++) {
             Column column = schema.getColumn(i).get();
-            types[i] = ExternalTypeInfo.of(column.getDataType());
+            types[i] = 
TypeInformation.of(column.getDataType().getConversionClass());
             names[i] = column.getName();
         }
         return new RowTypeInfo(types, names);
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
index c4c5953..f58d08a 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
@@ -29,7 +29,7 @@ public interface HasDistanceMeasure<T> extends WithParams<T> {
     Param<String> DISTANCE_MEASURE =
             new StringParam(
                     "distanceMeasure",
-                    "The distance measure. Supported options: 'euclidean'.",
+                    "The distance measure.",
                     EuclideanDistanceMeasure.NAME,
                     ParamValidators.inArray(EuclideanDistanceMeasure.NAME));
 
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
 b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLabelCol.java
similarity index 60%
copy from 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
copy to 
flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLabelCol.java
index c4c5953..badd3b3 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasDistanceMeasure.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasLabelCol.java
@@ -18,27 +18,21 @@
 
 package org.apache.flink.ml.common.param;
 
-import org.apache.flink.ml.distance.EuclideanDistanceMeasure;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.param.ParamValidators;
 import org.apache.flink.ml.param.StringParam;
 import org.apache.flink.ml.param.WithParams;
 
-/** Interface for the shared distanceMeasure param. */
-public interface HasDistanceMeasure<T> extends WithParams<T> {
-    Param<String> DISTANCE_MEASURE =
-            new StringParam(
-                    "distanceMeasure",
-                    "The distance measure. Supported options: 'euclidean'.",
-                    EuclideanDistanceMeasure.NAME,
-                    ParamValidators.inArray(EuclideanDistanceMeasure.NAME));
+/** Interface for the shared labelCol param. */
+public interface HasLabelCol<T> extends WithParams<T> {
+    Param<String> LABEL_COL =
+            new StringParam("labelCol", "Label column name.", "label", 
ParamValidators.notNull());
 
-    default String getDistanceMeasure() {
-        return get(DISTANCE_MEASURE);
+    default String getLabelCol() {
+        return get(LABEL_COL);
     }
 
-    default T setDistanceMeasure(String value) {
-        set(DISTANCE_MEASURE, value);
-        return (T) this;
+    default T setLabelCol(String colName) {
+        return set(LABEL_COL, colName);
     }
 }
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
new file mode 100644
index 0000000..581242f
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/classification/NaiveBayesTest.java
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.classification;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.classification.naivebayes.NaiveBayes;
+import org.apache.flink.ml.classification.naivebayes.NaiveBayesModel;
+import org.apache.flink.ml.classification.naivebayes.NaiveBayesModelData;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link NaiveBayes} and {@link NaiveBayesModel}. */
+public class NaiveBayesTest {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainTable;
+    private Table predictTable;
+    private Map<Vector, Integer> expectedOutput;
+    private NaiveBayes estimator;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        List<Row> trainData =
+                Arrays.asList(
+                        Row.of(Vectors.dense(0, 0.), 11),
+                        Row.of(Vectors.dense(1, 0), 10),
+                        Row.of(Vectors.dense(1, 1.), 10));
+
+        trainTable = 
tEnv.fromDataStream(env.fromCollection(trainData)).as("features", "label");
+
+        List<Row> predictData =
+                Arrays.asList(
+                        Row.of(Vectors.dense(0, 1.)),
+                        Row.of(Vectors.dense(0, 0.)),
+                        Row.of(Vectors.dense(1, 0)),
+                        Row.of(Vectors.dense(1, 1.)));
+
+        predictTable = 
tEnv.fromDataStream(env.fromCollection(predictData)).as("features");
+
+        expectedOutput =
+                new HashMap<Vector, Integer>() {
+                    {
+                        put(Vectors.dense(0, 1.), 11);
+                        put(Vectors.dense(0, 0.), 11);
+                        put(Vectors.dense(1, 0.), 10);
+                        put(Vectors.dense(1, 1.), 10);
+                    }
+                };
+
+        estimator =
+                new NaiveBayes()
+                        .setSmoothing(1.0)
+                        .setFeaturesCol("features")
+                        .setLabelCol("label")
+                        .setPredictionCol("prediction")
+                        .setModelType("multinomial");
+    }
+
+    /**
+     * Executes a given table and collect its results. Results are returned as 
a map whose key is
+     * the feature, value is the prediction result.
+     *
+     * @param table A table to be executed and to have its result collected
+     * @param featuresCol Name of the column in the table that contains the 
features
+     * @param predictionCol Name of the column in the table that contains the 
prediction result
+     * @return A map containing the collected results
+     */
+    private static Map<Vector, Integer> executeAndCollect(
+            Table table, String featuresCol, String predictionCol) {
+        Map<Vector, Integer> map = new HashMap<>();
+        for (CloseableIterator<Row> it = table.execute().collect(); 
it.hasNext(); ) {
+            Row row = it.next();
+            map.put(
+                    (Vector) row.getField(featuresCol),
+                    ((Number) row.getField(predictionCol)).intValue());
+        }
+
+        return map;
+    }
+
+    @Test
+    public void testParam() {
+        NaiveBayes estimator = new NaiveBayes();
+
+        assertEquals("features", estimator.getFeaturesCol());
+        assertEquals("label", estimator.getLabelCol());
+        assertEquals("multinomial", estimator.getModelType());
+        assertEquals("prediction", estimator.getPredictionCol());
+        assertEquals(1.0, estimator.getSmoothing(), 1e-5);
+
+        estimator
+                .setFeaturesCol("test_feature")
+                .setLabelCol("test_label")
+                .setPredictionCol("test_prediction")
+                .setSmoothing(2.0);
+
+        assertEquals("test_feature", estimator.getFeaturesCol());
+        assertEquals("test_label", estimator.getLabelCol());
+        assertEquals("test_prediction", estimator.getPredictionCol());
+        assertEquals(2.0, estimator.getSmoothing(), 1e-5);
+
+        NaiveBayesModel model = new NaiveBayesModel();
+
+        assertEquals("features", model.getFeaturesCol());
+        assertEquals("multinomial", model.getModelType());
+        assertEquals("prediction", model.getPredictionCol());
+
+        
model.setFeaturesCol("test_feature").setPredictionCol("test_prediction");
+
+        assertEquals("test_feature", model.getFeaturesCol());
+        assertEquals("test_prediction", model.getPredictionCol());
+    }
+
+    @Test
+    public void testFitAndPredict() {
+        NaiveBayesModel model = estimator.fit(trainTable);
+        Table outputTable = model.transform(predictTable)[0];
+        Map<Vector, Integer> actualOutput =
+                executeAndCollect(outputTable, model.getFeaturesCol(), 
model.getPredictionCol());
+        assertEquals(expectedOutput, actualOutput);
+    }
+
+    @Test
+    public void testFeaturePredictionParam() {
+        trainTable = trainTable.as("test_features", "test_label");
+        predictTable = predictTable.as("test_features");
+
+        estimator
+                .setFeaturesCol("test_features")
+                .setLabelCol("test_label")
+                .setPredictionCol("test_prediction");
+
+        NaiveBayesModel model = estimator.fit(trainTable);
+        Table outputTable = model.transform(predictTable)[0];
+        Map<Vector, Integer> actualOutput =
+                executeAndCollect(outputTable, model.getFeaturesCol(), 
model.getPredictionCol());
+        assertEquals(expectedOutput, actualOutput);
+    }
+
+    @Test
+    public void testPredictUnseenFeature() {
+        predictTable =
+                tEnv.fromDataStream(env.fromElements(Row.of(Vectors.dense(2, 
1.)))).as("features");
+
+        NaiveBayesModel model = estimator.fit(trainTable);
+        Table outputTable = model.transform(predictTable)[0];
+
+        try {
+            outputTable.execute().collect().next();
+            Assert.fail("Expected NullPointerException");
+        } catch (Exception e) {
+            Throwable exception = e;
+            while (exception.getCause() != null) {
+                exception = exception.getCause();
+            }
+            assertEquals(
+                    NaiveBayesModel.class.getName(), 
exception.getStackTrace()[0].getClassName());
+            assertEquals("calculateProb", 
exception.getStackTrace()[0].getMethodName());
+            assertEquals(NullPointerException.class, exception.getClass());
+        }
+    }
+
+    @Test
+    public void testVectorWithDiffLen() {
+        List<Row> trainData =
+                Arrays.asList(
+                        Row.of(Vectors.dense(0, 0.), 11.0),
+                        Row.of(Vectors.dense(1, 0), 10.0),
+                        Row.of(Vectors.dense(1), 10.0));
+
+        trainTable = 
tEnv.fromDataStream(env.fromCollection(trainData)).as("features", "label");
+
+        NaiveBayesModel model = estimator.fit(trainTable);
+        Table outputTable = model.transform(trainTable)[0];
+
+        try {
+            outputTable.execute().collect().next();
+            Assert.fail("Expected IllegalArgumentException");
+        } catch (Exception e) {
+            Throwable exception = e;
+            while (exception.getCause() != null) {
+                exception = exception.getCause();
+            }
+            assertEquals(IllegalArgumentException.class, exception.getClass());
+            assertEquals("Feature vectors should be of equal length.", 
exception.getMessage());
+        }
+    }
+
+    @Test
+    public void testVectorWithDiffLen2() {
+        List<Row> trainData =
+                Arrays.asList(Row.of(Vectors.dense(0, 0.), 11.0), 
Row.of(Vectors.dense(1), 10.0));
+
+        trainTable = 
tEnv.fromDataStream(env.fromCollection(trainData)).as("features", "label");
+
+        NaiveBayesModel model = estimator.fit(trainTable);
+        Table outputTable = model.transform(trainTable)[0];
+
+        try {
+            outputTable.execute().collect().next();
+            Assert.fail("Expected IllegalArgumentException");
+        } catch (Exception e) {
+            Throwable exception = e;
+            while (exception.getCause() != null) {
+                exception = exception.getCause();
+            }
+            assertEquals(IllegalArgumentException.class, exception.getClass());
+            assertEquals("Feature vectors should be of equal length.", 
exception.getMessage());
+        }
+    }
+
+    @Test
+    public void testSaveLoad() throws Exception {
+        estimator =
+                StageTestUtils.saveAndReload(
+                        env, estimator, 
tempFolder.newFolder().getAbsolutePath());
+
+        NaiveBayesModel model = estimator.fit(trainTable);
+
+        model = StageTestUtils.saveAndReload(env, model, 
tempFolder.newFolder().getAbsolutePath());
+
+        Table outputTable = model.transform(predictTable)[0];
+
+        Map<Vector, Integer> actualOutput =
+                executeAndCollect(outputTable, model.getFeaturesCol(), 
model.getPredictionCol());
+        assertEquals(expectedOutput, actualOutput);
+    }
+
+    @Test
+    public void testGetModelData() throws Exception {
+        List<Row> trainData =
+                Arrays.asList(
+                        Row.of(Vectors.dense(1, 1.), 11.0), 
Row.of(Vectors.dense(2, 1.), 11.0));
+
+        trainTable = 
tEnv.fromDataStream(env.fromCollection(trainData)).as("features", "label");
+
+        NaiveBayesModel model = estimator.fit(trainTable);
+
+        NaiveBayesModelData actual =
+                NaiveBayesModelData.getModelDataStream(model.getModelData()[0])
+                        .executeAndCollect()
+                        .next();
+
+        assertArrayEquals(new double[] {11.}, actual.labels.toArray(), 1e-5);
+        assertArrayEquals(new double[] {0.0}, actual.piArray.toArray(), 1e-5);
+        assertEquals(-0.6931471805599453, actual.theta[0][0].get(1.0), 1e-5);
+        assertEquals(-0.6931471805599453, actual.theta[0][0].get(2.0), 1e-5);
+        assertEquals(0.0, actual.theta[0][1].get(1.0), 1e-5);
+    }
+
+    @Test
+    public void testSetModelData() {
+        NaiveBayesModel modelA = estimator.fit(trainTable);
+
+        Table modelData = modelA.getModelData()[0];
+        NaiveBayesModel modelB = new NaiveBayesModel().setModelData(modelData);
+        ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
+
+        Table outputTable = modelB.transform(predictTable)[0];
+
+        Map<Vector, Integer> actualOutput =
+                executeAndCollect(outputTable, modelB.getFeaturesCol(), 
modelB.getPredictionCol());
+        assertEquals(expectedOutput, actualOutput);
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
index 27a839c..c2331a3 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java
@@ -18,9 +18,7 @@
 
 package org.apache.flink.ml.clustering;
 
-import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
-import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.ml.clustering.kmeans.KMeans;
 import org.apache.flink.ml.clustering.kmeans.KMeansModel;
@@ -28,6 +26,7 @@ import org.apache.flink.ml.distance.EuclideanDistanceMeasure;
 import org.apache.flink.ml.linalg.DenseVector;
 import org.apache.flink.ml.linalg.Vectors;
 import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.StageTestUtils;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
@@ -35,15 +34,18 @@ import org.apache.flink.table.api.DataTypes;
 import org.apache.flink.table.api.Schema;
 import org.apache.flink.table.api.Table;
 import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
-import org.apache.flink.table.api.internal.TableImpl;
 import org.apache.flink.test.util.AbstractTestBase;
 import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
 
+import org.apache.commons.collections.CollectionUtils;
 import org.apache.commons.collections.IteratorUtils;
 import org.junit.Before;
+import org.junit.Rule;
 import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
 
-import java.nio.file.Files;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.Comparator;
@@ -51,12 +53,16 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 
 /** Tests KMeans and KMeansModel. */
 public class KMeansTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+
     private static final List<DenseVector> DATA =
             Arrays.asList(
                     Vectors.dense(0.0, 0.0),
@@ -67,6 +73,18 @@ public class KMeansTest extends AbstractTestBase {
                     Vectors.dense(9.6, 0.0));
     private StreamExecutionEnvironment env;
     private StreamTableEnvironment tEnv;
+    private static final List<Set<DenseVector>> expectedGroups =
+            Arrays.asList(
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(0.0, 0.0),
+                                    Vectors.dense(0.0, 0.3),
+                                    Vectors.dense(0.3, 0.0))),
+                    new HashSet<>(
+                            Arrays.asList(
+                                    Vectors.dense(9.0, 0.0),
+                                    Vectors.dense(9.0, 0.6),
+                                    Vectors.dense(9.6, 0.0))));
     private Table dataTable;
 
     @Before
@@ -83,43 +101,26 @@ public class KMeansTest extends AbstractTestBase {
         dataTable = tEnv.fromDataStream(env.fromCollection(DATA), 
schema).as("features");
     }
 
-    // Executes the graph and returns a map which maps points to clusterId.
-    private static Map<DenseVector, Integer> executeAndCollect(
-            Table output, String featureCol, String predictionCol) throws 
Exception {
-        StreamTableEnvironment tEnv =
-                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
-
-        DataStream<Tuple2<DenseVector, Integer>> stream =
-                tEnv.toDataStream(output)
-                        .map(
-                                new MapFunction<Row, Tuple2<DenseVector, 
Integer>>() {
-                                    @Override
-                                    public Tuple2<DenseVector, Integer> 
map(Row row) {
-                                        return Tuple2.of(
-                                                (DenseVector) 
row.getField(featureCol),
-                                                (Integer) 
row.getField(predictionCol));
-                                    }
-                                });
-
-        List<Tuple2<DenseVector, Integer>> pointsWithClusterId =
-                IteratorUtils.toList(stream.executeAndCollect());
-
-        Map<DenseVector, Integer> clusterIdByPoints = new HashMap<>();
-        for (Tuple2<DenseVector, Integer> entry : pointsWithClusterId) {
-            clusterIdByPoints.put(entry.f0, entry.f1);
-        }
-        return clusterIdByPoints;
-    }
-
-    private static void verifyClusteringResult(
-            Map<DenseVector, Integer> clusterIdByPoints, List<List<Integer>> 
groups) {
-        for (List<Integer> group : groups) {
-            for (int i = 1; i < group.size(); i++) {
-                assertEquals(
-                        clusterIdByPoints.get(DATA.get(group.get(0))),
-                        clusterIdByPoints.get(DATA.get(group.get(i))));
-            }
+    /**
+     * Executes a table and collects its results. Results are returned as a 
list of sets, where
+     * elements in the same set are features whose prediction results are the 
same.
+     *
+     * @param table A table to be executed and to have its result collected
+     * @param featureCol Name of the column in the table that contains the 
features
+     * @param predictionCol Name of the column in the table that contains the 
prediction result
+     * @return A map containing the collected results
+     */
+    private static List<Set<DenseVector>> executeAndCollect(
+            Table table, String featureCol, String predictionCol) {
+        Map<Integer, Set<DenseVector>> map = new HashMap<>();
+        for (CloseableIterator<Row> it = table.execute().collect(); 
it.hasNext(); ) {
+            Row row = it.next();
+            DenseVector vector = (DenseVector) row.getField(featureCol);
+            int predict = (Integer) row.getField(predictionCol);
+            map.putIfAbsent(predict, new HashSet<>());
+            map.get(predict).add(vector);
         }
+        return new ArrayList<>(map.values());
     }
 
     @Test
@@ -158,10 +159,9 @@ public class KMeansTest extends AbstractTestBase {
         assertEquals(
                 Arrays.asList("test_feature", "test_prediction"),
                 output.getResolvedSchema().getColumnNames());
-        Map<DenseVector, Integer> clusterIdByPoints =
+        List<Set<DenseVector>> actualGroups =
                 executeAndCollect(output, kmeans.getFeaturesCol(), 
kmeans.getPredictionCol());
-        verifyClusteringResult(
-                clusterIdByPoints, Arrays.asList(Arrays.asList(0, 1, 2), 
Arrays.asList(3, 4, 5)));
+        assertTrue(CollectionUtils.isEqualCollection(expectedGroups, 
actualGroups));
     }
 
     @Test
@@ -176,10 +176,11 @@ public class KMeansTest extends AbstractTestBase {
         KMeans kmeans = new KMeans().setK(2);
         KMeansModel model = kmeans.fit(input);
         Table output = model.transform(input)[0];
-
-        Map<DenseVector, Integer> clusterIdByPoints =
+        List<Set<DenseVector>> expectedGroups =
+                
Collections.singletonList(Collections.singleton(Vectors.dense(0.0, 0.1)));
+        List<Set<DenseVector>> actualGroups =
                 executeAndCollect(output, kmeans.getFeaturesCol(), 
kmeans.getPredictionCol());
-        assertEquals(Collections.singleton(0), new 
HashSet<>(clusterIdByPoints.values()));
+        assertTrue(CollectionUtils.isEqualCollection(expectedGroups, 
actualGroups));
     }
 
     @Test
@@ -191,23 +192,22 @@ public class KMeansTest extends AbstractTestBase {
         assertEquals(
                 Arrays.asList("features", "prediction"),
                 output.getResolvedSchema().getColumnNames());
-        Map<DenseVector, Integer> clusterIdByPoints =
+        List<Set<DenseVector>> actualGroups =
                 executeAndCollect(output, kmeans.getFeaturesCol(), 
kmeans.getPredictionCol());
-        verifyClusteringResult(
-                clusterIdByPoints, Arrays.asList(Arrays.asList(0, 1, 2), 
Arrays.asList(3, 4, 5)));
+        assertTrue(CollectionUtils.isEqualCollection(expectedGroups, 
actualGroups));
     }
 
     @Test
     public void testSaveLoadAndPredict() throws Exception {
-        String path = Files.createTempDirectory("").toString();
-
         KMeans kmeans = new KMeans().setMaxIter(2).setK(2);
-        KMeansModel model = kmeans.fit(dataTable);
 
-        model.save(path);
-        env.execute();
+        KMeans loadedKmeans =
+                StageTestUtils.saveAndReload(env, kmeans, 
tempFolder.newFolder().getAbsolutePath());
+
+        KMeansModel model = loadedKmeans.fit(dataTable);
 
-        KMeansModel loadedModel = KMeansModel.load(env, path);
+        KMeansModel loadedModel =
+                StageTestUtils.saveAndReload(env, model, 
tempFolder.newFolder().getAbsolutePath());
         Table output = loadedModel.transform(dataTable)[0];
 
         assertEquals(
@@ -216,10 +216,10 @@ public class KMeansTest extends AbstractTestBase {
         assertEquals(
                 Arrays.asList("features", "prediction"),
                 output.getResolvedSchema().getColumnNames());
-        Map<DenseVector, Integer> clusterIdByPoints =
+
+        List<Set<DenseVector>> actualGroups =
                 executeAndCollect(output, kmeans.getFeaturesCol(), 
kmeans.getPredictionCol());
-        verifyClusteringResult(
-                clusterIdByPoints, Arrays.asList(Arrays.asList(0, 1, 2), 
Arrays.asList(3, 4, 5)));
+        assertTrue(CollectionUtils.isEqualCollection(expectedGroups, 
actualGroups));
     }
 
     @Test
@@ -250,10 +250,8 @@ public class KMeansTest extends AbstractTestBase {
         ReadWriteUtils.updateExistingParams(modelB, modelA.getParamMap());
 
         Table output = modelB.transform(dataTable)[0];
-        Map<DenseVector, Integer> clusterIdByPoints =
+        List<Set<DenseVector>> actualGroups =
                 executeAndCollect(output, kmeans.getFeaturesCol(), 
kmeans.getPredictionCol());
-
-        verifyClusteringResult(
-                clusterIdByPoints, Arrays.asList(Arrays.asList(0, 1, 2), 
Arrays.asList(3, 4, 5)));
+        assertTrue(CollectionUtils.isEqualCollection(expectedGroups, 
actualGroups));
     }
 }
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/util/StageTestUtils.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/StageTestUtils.java
new file mode 100644
index 0000000..27283e8
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/StageTestUtils.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.util;
+
+import org.apache.flink.ml.api.Stage;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+
+import java.lang.reflect.Method;
+
+/** Utility methods for testing stages. */
+public class StageTestUtils {
+    /**
+     * Saves a stage to filesystem and reloads it by invoking the static 
load() method of the given
+     * stage.
+     */
+    public static <T extends Stage<T>> T saveAndReload(
+            StreamExecutionEnvironment env, T stage, String path) throws 
Exception {
+        stage.save(path);
+        try {
+            env.execute();
+        } catch (RuntimeException e) {
+            if (!e.getMessage()
+                    .equals("No operators defined in streaming topology. 
Cannot execute.")) {
+                throw e;
+            }
+        }
+
+        Method method =
+                stage.getClass().getMethod("load", 
StreamExecutionEnvironment.class, String.class);
+        return (T) method.invoke(null, env, path);
+    }
+}

Reply via email to