yunfengzhou-hub commented on code in PR #174:
URL: https://github.com/apache/flink-ml/pull/174#discussion_r1025077674


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java:
##########
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.countvectorizer;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.IntStream;
+
+/** A Model which transforms data using the model data computed by {@link 
CountVectorizer}. */
+public class CountVectorizerModel
+        implements Model<CountVectorizerModel>, 
CountVectorizerModelParams<CountVectorizerModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public CountVectorizerModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public CountVectorizerModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                CountVectorizerModelData.getModelDataStream(modelDataTable),
+                path,
+                new CountVectorizerModelData.ModelDataEncoder());
+    }
+
+    public static CountVectorizerModel load(StreamTableEnvironment tEnv, 
String path)
+            throws IOException {
+        CountVectorizerModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(
+                        tEnv, path, new 
CountVectorizerModelData.ModelDataDecoder());
+        return model.setModelData(modelDataTable);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+        DataStream<CountVectorizerModelData> countVectorizerModel =

Review Comment:
   nit: this variable is a "model data" or "model data stream", not a model. 
Let's change its name to avoid ambiguity.



##########
docs/content/docs/operators/feature/countvectorizer.md:
##########
@@ -0,0 +1,182 @@
+---
+title: "Count Vectorizer"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/countvectorizer.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+  http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions dand limitations
+under the License.
+-->
+
+## Count Vectorizer
+
+CountVectorizer aims to help convert a collection of text documents to
+vectors of token counts. When an a-priori dictionary is not available,
+CountVectorizer can be used as an estimator to extract the vocabulary,
+and generates a CountVectorizerModel. The model produces sparse
+representations for the documents over the vocabulary, which can then
+be passed to other algorithms like LDA.
+
+### Input Columns
+
+| Param name | Type     | Default   | Description         |
+|:-----------|:---------|:----------|:--------------------|
+| inputCol   | String[] | `"input"` | Input string array. |
+
+### Output Columns
+
+| Param name | Type         | Default    | Description             |
+|:-----------|:-------------|:-----------|:------------------------|
+| outputCol  | SparseVector | `"output"` | Vector of token counts. |
+
+### Parameters
+
+Below are the parameters required by `CountVectorizerModel`.
+
+| Key        | Default    | Type    | Required | Description                   
                                                                                
                                                                                
                                                                                
                                                                  |
+|------------|------------|---------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| inputCol   | `"input"`  | String  | no       | Input column name.            
                                                                                
                                                                                
                                                                                
                                                                  |
+| outputCol  | `"output"` | String  | no       | Output column name.           
                                                                                
                                                                                
                                                                                
                                                                  |
+| minTF      | `1.0`      | Double  | no       | Filter to ignore rare words 
in a document. For each document, terms with frequency/count less than the 
given threshold are ignored. If this is an integer >= 1, then this specifies a 
count (of times the term must appear in the document); if this is a double in 
[0,1), then this specifies a fraction (out of the document's token count).  |
+| binary     | `false`    | Boolean | no       | Binary toggle to control the 
output vector values. If True, all nonzero counts (after minTF filter applied) 
are set to 1.0.                                                                 
                                                                                
                                                                    |
+
+`CountVectorizer` needs parameters above and also below.
+
+| Key            | Default    | Type     | Required | Description              
                                                                                
                                                                                
                                                                                
                                                                                
                                    |
+|:---------------|:-----------|:---------|:---------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| vocabularySize | `2^18`     | Integer  | no       | Max size of the 
vocabulary. CountVectorizer will build a vocabulary that only considers the top 
vocabulary size terms ordered by term frequency across the corpus.              
                                                                                
                                                                                
                                             |
+| minDF          | `1.0`      | Double   | no       | Specifies the minimum 
number of different documents a term must appear in to be included in the 
vocabulary. If this is an integer >= 1, this specifies the number of documents 
the term must appear in; if this is a double in [0,1), then this specifies the 
fraction of documents.                                                          
                                               |
+| maxDF          | `2^63 - 1` | Double   | no       | Specifies the maximum 
number of different documents a term could appear in to be included in the 
vocabulary. A term that appears more than the threshold will be ignored. If 
this is an integer >= 1, this specifies the maximum number of documents the 
term could appear in; if this is a double in [0,1), then this specifies the 
maximum fraction of documents the term could appear in. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizer;
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+
+/**
+ * Simple program that trains a {@link CountVectorizer} model and uses it for 
feature engineering.
+ */
+public class CountVectorizerExample {
+
+    public static void main(String[] args) {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        // Generates input training and prediction data.
+        DataStream<Row> trainStream =
+                env.fromElements(
+                        Row.of((Object) new String[] {"a", "c", "b", "c"}),
+                        Row.of((Object) new String[] {"c", "d", "e"}),
+                        Row.of((Object) new String[] {"a", "b", "c"}),
+                        Row.of((Object) new String[] {"e", "f"}),
+                        Row.of((Object) new String[] {"a", "c", "a"}));
+        Table trainTable = tEnv.fromDataStream(trainStream).as("input");

Review Comment:
   This variable contains both training and prediction data. It might be better 
to rename this variable as `inputTable`. Same for that in the python example.



##########
docs/content/docs/operators/feature/countvectorizer.md:
##########
@@ -0,0 +1,182 @@
+---
+title: "Count Vectorizer"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/countvectorizer.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+  http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions dand limitations
+under the License.
+-->
+
+## Count Vectorizer
+
+CountVectorizer aims to help convert a collection of text documents to
+vectors of token counts. When an a-priori dictionary is not available,
+CountVectorizer can be used as an estimator to extract the vocabulary,
+and generates a CountVectorizerModel. The model produces sparse
+representations for the documents over the vocabulary, which can then
+be passed to other algorithms like LDA.
+
+### Input Columns
+
+| Param name | Type     | Default   | Description         |
+|:-----------|:---------|:----------|:--------------------|
+| inputCol   | String[] | `"input"` | Input string array. |
+
+### Output Columns
+
+| Param name | Type         | Default    | Description             |
+|:-----------|:-------------|:-----------|:------------------------|
+| outputCol  | SparseVector | `"output"` | Vector of token counts. |
+
+### Parameters
+
+Below are the parameters required by `CountVectorizerModel`.
+
+| Key        | Default    | Type    | Required | Description                   
                                                                                
                                                                                
                                                                                
                                                                  |
+|------------|------------|---------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| inputCol   | `"input"`  | String  | no       | Input column name.            
                                                                                
                                                                                
                                                                                
                                                                  |
+| outputCol  | `"output"` | String  | no       | Output column name.           
                                                                                
                                                                                
                                                                                
                                                                  |
+| minTF      | `1.0`      | Double  | no       | Filter to ignore rare words 
in a document. For each document, terms with frequency/count less than the 
given threshold are ignored. If this is an integer >= 1, then this specifies a 
count (of times the term must appear in the document); if this is a double in 
[0,1), then this specifies a fraction (out of the document's token count).  |

Review Comment:
   In fact, users can set the value of `minTF` to any double >= 1, in which 
case the specified count is `Math.floor(minTF)`. Do you think it would be 
better if we clarify that if `minTF >= 1`, it does not have to be an integer? 
Save for `minDF` and `maxDF`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizer.java:
##########
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.countvectorizer;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * {@link CountVectorizer} aims to help convert a collection of text documents 
to vectors of token
+ * counts. When an a-priori dictionary is not available, {@link 
CountVectorizer} can be used as an
+ * estimator to extract the vocabulary, and generates a {@link 
CountVectorizerModel}. The model
+ * produces sparse representations for the documents over the vocabulary, 
which can then be passed
+ * to other algorithms like LDA.

Review Comment:
   Let's modify the description so that it aligns with the JavaDoc of other 
algorithms in style. 
   - The document should provide a description of the algorithm, not describe 
the Estimator and the Model separately.
   - It is better to start the document with a noun or definition, like 
"CountVectorizer is an algorithm that ..."
   - Let's not use `@link` when referring to the class itself.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java:
##########
@@ -0,0 +1,390 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizer;
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.DoubleStream;
+import java.util.stream.IntStream;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/** Tests {@link CountVectorizer} and {@link CountVectorizerModel}. */
+public class CountVectorizerTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainDataTable;
+
+    private static final double EPS = 1.0e-5;
+    private static final List<Row> TRAIN_DATA =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of((Object) new String[] {"a", "c", "b", "c"}),
+                            Row.of((Object) new String[] {"c", "d", "e"}),
+                            Row.of((Object) new String[] {"a", "b", "c"}),
+                            Row.of((Object) new String[] {"e", "f"}),
+                            Row.of((Object) new String[] {"a", "c", "a"})));
+
+    private static final List<SparseVector> EXPECTED_OUTPUT =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Vectors.sparse(
+                                    6,
+                                    IntStream.of(0, 1, 2).toArray(),
+                                    DoubleStream.of(2.0, 1.0, 1.0).toArray()),
+                            Vectors.sparse(
+                                    6,
+                                    IntStream.of(0, 3, 4).toArray(),
+                                    DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+                            Vectors.sparse(
+                                    6,
+                                    IntStream.of(0, 1, 2).toArray(),
+                                    DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+                            Vectors.sparse(
+                                    6,
+                                    IntStream.of(3, 5).toArray(),
+                                    DoubleStream.of(1.0, 1.0).toArray()),
+                            Vectors.sparse(
+                                    6,
+                                    IntStream.of(0, 1).toArray(),
+                                    DoubleStream.of(1.0, 2.0).toArray())));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        trainDataTable = 
tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("input");
+    }
+
+    private static void verifyPredictionResult(
+            Table output, String outputCol, List<SparseVector> expected) 
throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
output).getTableEnvironment();
+        DataStream<SparseVector> stream =
+                tEnv.toDataStream(output)
+                        .map(
+                                (MapFunction<Row, SparseVector>)
+                                        row -> (SparseVector) 
row.getField(outputCol));
+        List<SparseVector> result = 
IteratorUtils.toList(stream.executeAndCollect());
+        compareResultCollections(expected, result, TestUtils::compare);
+    }
+
+    @Test
+    public void testParam() {
+        CountVectorizer countVectorizer = new CountVectorizer();
+        assertEquals("input", countVectorizer.getInputCol());
+        assertEquals("output", countVectorizer.getOutputCol());
+        assertEquals((double) Long.MAX_VALUE, countVectorizer.getMaxDF(), EPS);
+        assertEquals(1.0, countVectorizer.getMinDF(), EPS);
+        assertEquals(1.0, countVectorizer.getMinTF(), EPS);
+        assertEquals(1 << 18, countVectorizer.getVocabularySize());
+        assertFalse(countVectorizer.getBinary());
+
+        countVectorizer
+                .setInputCol("test_input")
+                .setOutputCol("test_output")
+                .setMinDF(0.1)
+                .setMaxDF(0.9)
+                .setMinTF(10)
+                .setVocabularySize(1000)
+                .setBinary(true);
+        assertEquals("test_input", countVectorizer.getInputCol());
+        assertEquals("test_output", countVectorizer.getOutputCol());
+        assertEquals(0.9, countVectorizer.getMaxDF(), EPS);
+        assertEquals(0.1, countVectorizer.getMinDF(), EPS);
+        assertEquals(10, countVectorizer.getMinTF(), EPS);
+        assertEquals(1000, countVectorizer.getVocabularySize());
+        assertTrue(countVectorizer.getBinary());
+    }
+
+    @Test
+    public void testInvalidMinMaxDF() {
+        String errMessage = "maxDF must be >= minDF.";
+        CountVectorizer countVectorizer = new CountVectorizer();
+        countVectorizer.setMaxDF(0.1);
+        countVectorizer.setMinDF(0.2);
+        try {
+            countVectorizer.fit(trainDataTable);
+            fail();
+        } catch (Throwable e) {
+            assertEquals(errMessage, e.getMessage());
+        }
+        countVectorizer.setMaxDF(1);
+        countVectorizer.setMinDF(2);
+        try {
+            countVectorizer.fit(trainDataTable);
+            fail();
+        } catch (Throwable e) {
+            assertEquals(errMessage, e.getMessage());
+        }
+        countVectorizer.setMaxDF(1);
+        countVectorizer.setMinDF(0.9);
+        try {
+            CountVectorizerModel model = countVectorizer.fit(trainDataTable);
+            Table output = model.transform(trainDataTable)[0];
+            output.execute().print();
+            fail();
+        } catch (Throwable e) {
+            assertEquals(errMessage, 
ExceptionUtils.getRootCause(e).getMessage());
+        }
+        countVectorizer.setMaxDF(0.1);
+        countVectorizer.setMinDF(10);
+        try {
+            CountVectorizerModel model = countVectorizer.fit(trainDataTable);
+            Table output = model.transform(trainDataTable)[0];
+            output.execute().print();
+            fail();
+        } catch (Throwable e) {
+            assertEquals(errMessage, 
ExceptionUtils.getRootCause(e).getMessage());
+        }
+    }
+
+    @Test
+    public void testOutputSchema() {
+        CountVectorizer countVectorizer =
+                new 
CountVectorizer().setInputCol("test_input").setOutputCol("test_output");
+        CountVectorizerModel model = countVectorizer.fit(trainDataTable);
+        Table output = model.transform(trainDataTable.as("test_input"))[0];
+        assertEquals(
+                Arrays.asList("test_input", "test_output"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        CountVectorizer countVectorizer = new CountVectorizer();
+        CountVectorizerModel model = countVectorizer.fit(trainDataTable);
+        Table output = model.transform(trainDataTable)[0];
+
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), 
EXPECTED_OUTPUT);
+    }
+
+    @Test
+    public void testSaveLoadAndPredict() throws Exception {
+        CountVectorizer countVectorizer = new CountVectorizer();
+        CountVectorizer loadedCountVectorizer =
+                TestUtils.saveAndReload(
+                        tEnv, countVectorizer, 
tempFolder.newFolder().getAbsolutePath());
+        CountVectorizerModel model = loadedCountVectorizer.fit(trainDataTable);
+        CountVectorizerModel loadedModel =
+                TestUtils.saveAndReload(tEnv, model, 
tempFolder.newFolder().getAbsolutePath());
+        assertEquals(
+                Arrays.asList("vocabulary"),
+                
loadedModel.getModelData()[0].getResolvedSchema().getColumnNames());
+        Table output = loadedModel.transform(trainDataTable)[0];
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), 
EXPECTED_OUTPUT);
+    }
+
+    @Test
+    public void testFitOnEmptyData() {
+        Table emptyTable =
+                tEnv.fromDataStream(env.fromCollection(TRAIN_DATA).filter(x -> 
x.getArity() == 0))
+                        .as("input");
+        CountVectorizer countVectorizer = new CountVectorizer();
+        CountVectorizerModel model = countVectorizer.fit(emptyTable);
+        Table modelDataTable = model.getModelData()[0];
+        try {
+            modelDataTable.execute().print();
+            fail();
+        } catch (Throwable e) {
+            assertEquals("The training set is empty.", 
ExceptionUtils.getRootCause(e).getMessage());
+        }
+    }
+
+    @Test
+    public void testMinMaxDF() throws Exception {
+        List<SparseVector> expectedOutput =
+                new ArrayList<>(
+                        Arrays.asList(
+                                Vectors.sparse(
+                                        4,
+                                        IntStream.of(0, 1, 2).toArray(),
+                                        DoubleStream.of(2.0, 1.0, 
1.0).toArray()),
+                                Vectors.sparse(
+                                        4,
+                                        IntStream.of(0, 3).toArray(),
+                                        DoubleStream.of(1.0, 1.0).toArray()),
+                                Vectors.sparse(
+                                        4,
+                                        IntStream.of(0, 1, 2).toArray(),
+                                        DoubleStream.of(1.0, 1.0, 
1.0).toArray()),
+                                Vectors.sparse(
+                                        4,
+                                        IntStream.of(3).toArray(),
+                                        DoubleStream.of(1.0).toArray()),
+                                Vectors.sparse(
+                                        4,
+                                        IntStream.of(0, 1).toArray(),
+                                        DoubleStream.of(1.0, 2.0).toArray())));
+        CountVectorizer countVectorizer = new 
CountVectorizer().setMinDF(2).setMaxDF(4);
+        CountVectorizerModel model = countVectorizer.fit(trainDataTable);
+        Table output = model.transform(trainDataTable)[0];
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), 
expectedOutput);
+
+        countVectorizer.setMinDF(0.4).setMaxDF(0.8);
+        model = countVectorizer.fit(trainDataTable);
+        output = model.transform(trainDataTable)[0];
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), 
expectedOutput);
+    }
+
+    @Test
+    public void testMinTF() throws Exception {
+        List<SparseVector> expectedOutput =
+                new ArrayList<>(
+                        Arrays.asList(
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(0).toArray(),
+                                        DoubleStream.of(2.0).toArray()),
+                                Vectors.sparse(6, new int[0], new double[0]),
+                                Vectors.sparse(6, new int[0], new double[0]),
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(3, 5).toArray(),
+                                        DoubleStream.of(1.0, 1.0).toArray()),
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(1).toArray(),
+                                        DoubleStream.of(2.0).toArray())));
+        CountVectorizer countVectorizer = new CountVectorizer().setMinTF(0.5);
+        CountVectorizerModel model = countVectorizer.fit(trainDataTable);
+        Table output = model.transform(trainDataTable)[0];
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), 
expectedOutput);
+    }
+
+    @Test
+    public void testBinary() throws Exception {
+        List<SparseVector> expectedOutput =
+                new ArrayList<>(
+                        Arrays.asList(
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(0, 1, 2).toArray(),
+                                        DoubleStream.of(1.0, 1.0, 
1.0).toArray()),
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(0, 3, 4).toArray(),
+                                        DoubleStream.of(1.0, 1.0, 
1.0).toArray()),
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(0, 1, 2).toArray(),
+                                        DoubleStream.of(1.0, 1.0, 
1.0).toArray()),
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(3, 5).toArray(),
+                                        DoubleStream.of(1.0, 1.0).toArray()),
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(0, 1).toArray(),
+                                        DoubleStream.of(1.0, 1.0).toArray())));
+        CountVectorizer countVectorizer = new 
CountVectorizer().setBinary(true);
+        CountVectorizerModel model = countVectorizer.fit(trainDataTable);
+        Table output = model.transform(trainDataTable)[0];
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), 
expectedOutput);
+    }
+
+    @Test
+    public void testVocabularySize() throws Exception {
+        List<SparseVector> expectedOutput =
+                new ArrayList<>(
+                        Arrays.asList(
+                                Vectors.sparse(
+                                        2,
+                                        IntStream.of(0, 1).toArray(),
+                                        DoubleStream.of(2.0, 1.0).toArray()),
+                                Vectors.sparse(
+                                        2,
+                                        IntStream.of(0).toArray(),
+                                        DoubleStream.of(1.0).toArray()),
+                                Vectors.sparse(
+                                        2,
+                                        IntStream.of(0, 1).toArray(),
+                                        DoubleStream.of(1.0, 1.0).toArray()),
+                                Vectors.sparse(2, new int[0], new double[0]),
+                                Vectors.sparse(
+                                        2,
+                                        IntStream.of(0, 1).toArray(),
+                                        DoubleStream.of(1.0, 2.0).toArray())));
+        CountVectorizer countVectorizer = new 
CountVectorizer().setVocabularySize(2);
+        CountVectorizerModel model = countVectorizer.fit(trainDataTable);
+        Table output = model.transform(trainDataTable)[0];
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), 
expectedOutput);
+    }
+
+    @Test
+    public void testGetModelData() throws Exception {
+        CountVectorizer countVectorizer = new CountVectorizer();
+        CountVectorizerModel model = countVectorizer.fit(trainDataTable);
+        Table modelData = model.getModelData()[0];
+        assertEquals(Arrays.asList("vocabulary"), 
modelData.getResolvedSchema().getColumnNames());
+
+        DataStream<Row> output = tEnv.toDataStream(modelData);
+        List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect());
+        String[] vocabulary = (String[]) modelRows.get(0).getField(0);
+        assert vocabulary != null;

Review Comment:
   This line seems redundant.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to