yunfengzhou-hub commented on code in PR #174: URL: https://github.com/apache/flink-ml/pull/174#discussion_r1025077674
########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java: ########## @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.countvectorizer; + +import org.apache.flink.api.common.functions.RichMapFunction; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Model; +import org.apache.flink.ml.common.broadcast.BroadcastUtils; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import org.apache.commons.lang3.ArrayUtils; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; + +/** A Model which transforms data using the model data computed by {@link CountVectorizer}. */ +public class CountVectorizerModel + implements Model<CountVectorizerModel>, CountVectorizerModelParams<CountVectorizerModel> { + + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private Table modelDataTable; + + public CountVectorizerModel() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public CountVectorizerModel setModelData(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + modelDataTable = inputs[0]; + return this; + } + + @Override + public Table[] getModelData() { + return new Table[] {modelDataTable}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + ReadWriteUtils.saveModelData( + CountVectorizerModelData.getModelDataStream(modelDataTable), + path, + new CountVectorizerModelData.ModelDataEncoder()); + } + + public static CountVectorizerModel load(StreamTableEnvironment tEnv, String path) + throws IOException { + CountVectorizerModel model = ReadWriteUtils.loadStageParam(path); + Table modelDataTable = + ReadWriteUtils.loadModelData( + tEnv, path, new CountVectorizerModelData.ModelDataDecoder()); + return model.setModelData(modelDataTable); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]); + DataStream<CountVectorizerModelData> countVectorizerModel = Review Comment: nit: this variable is a "model data" or "model data stream", not a model. Let's change its name to avoid ambiguity. ########## docs/content/docs/operators/feature/countvectorizer.md: ########## @@ -0,0 +1,182 @@ +--- +title: "Count Vectorizer" +weight: 1 +type: docs +aliases: +- /operators/feature/countvectorizer.html +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions dand limitations +under the License. +--> + +## Count Vectorizer + +CountVectorizer aims to help convert a collection of text documents to +vectors of token counts. When an a-priori dictionary is not available, +CountVectorizer can be used as an estimator to extract the vocabulary, +and generates a CountVectorizerModel. The model produces sparse +representations for the documents over the vocabulary, which can then +be passed to other algorithms like LDA. + +### Input Columns + +| Param name | Type | Default | Description | +|:-----------|:---------|:----------|:--------------------| +| inputCol | String[] | `"input"` | Input string array. | + +### Output Columns + +| Param name | Type | Default | Description | +|:-----------|:-------------|:-----------|:------------------------| +| outputCol | SparseVector | `"output"` | Vector of token counts. | + +### Parameters + +Below are the parameters required by `CountVectorizerModel`. + +| Key | Default | Type | Required | Description | +|------------|------------|---------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| inputCol | `"input"` | String | no | Input column name. | +| outputCol | `"output"` | String | no | Output column name. | +| minTF | `1.0` | Double | no | Filter to ignore rare words in a document. For each document, terms with frequency/count less than the given threshold are ignored. If this is an integer >= 1, then this specifies a count (of times the term must appear in the document); if this is a double in [0,1), then this specifies a fraction (out of the document's token count). | +| binary | `false` | Boolean | no | Binary toggle to control the output vector values. If True, all nonzero counts (after minTF filter applied) are set to 1.0. | + +`CountVectorizer` needs parameters above and also below. + +| Key | Default | Type | Required | Description | +|:---------------|:-----------|:---------|:---------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| vocabularySize | `2^18` | Integer | no | Max size of the vocabulary. CountVectorizer will build a vocabulary that only considers the top vocabulary size terms ordered by term frequency across the corpus. | +| minDF | `1.0` | Double | no | Specifies the minimum number of different documents a term must appear in to be included in the vocabulary. If this is an integer >= 1, this specifies the number of documents the term must appear in; if this is a double in [0,1), then this specifies the fraction of documents. | +| maxDF | `2^63 - 1` | Double | no | Specifies the maximum number of different documents a term could appear in to be included in the vocabulary. A term that appears more than the threshold will be ignored. If this is an integer >= 1, this specifies the maximum number of documents the term could appear in; if this is a double in [0,1), then this specifies the maximum fraction of documents the term could appear in. | + +### Examples + +{{< tabs examples >}} + +{{< tab "Java">}} + +```java +import org.apache.flink.ml.feature.countvectorizer.CountVectorizer; +import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +import java.util.Arrays; + +/** + * Simple program that trains a {@link CountVectorizer} model and uses it for feature engineering. + */ +public class CountVectorizerExample { + + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input training and prediction data. + DataStream<Row> trainStream = + env.fromElements( + Row.of((Object) new String[] {"a", "c", "b", "c"}), + Row.of((Object) new String[] {"c", "d", "e"}), + Row.of((Object) new String[] {"a", "b", "c"}), + Row.of((Object) new String[] {"e", "f"}), + Row.of((Object) new String[] {"a", "c", "a"})); + Table trainTable = tEnv.fromDataStream(trainStream).as("input"); Review Comment: This variable contains both training and prediction data. It might be better to rename this variable as `inputTable`. Same for that in the python example. ########## docs/content/docs/operators/feature/countvectorizer.md: ########## @@ -0,0 +1,182 @@ +--- +title: "Count Vectorizer" +weight: 1 +type: docs +aliases: +- /operators/feature/countvectorizer.html +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions dand limitations +under the License. +--> + +## Count Vectorizer + +CountVectorizer aims to help convert a collection of text documents to +vectors of token counts. When an a-priori dictionary is not available, +CountVectorizer can be used as an estimator to extract the vocabulary, +and generates a CountVectorizerModel. The model produces sparse +representations for the documents over the vocabulary, which can then +be passed to other algorithms like LDA. + +### Input Columns + +| Param name | Type | Default | Description | +|:-----------|:---------|:----------|:--------------------| +| inputCol | String[] | `"input"` | Input string array. | + +### Output Columns + +| Param name | Type | Default | Description | +|:-----------|:-------------|:-----------|:------------------------| +| outputCol | SparseVector | `"output"` | Vector of token counts. | + +### Parameters + +Below are the parameters required by `CountVectorizerModel`. + +| Key | Default | Type | Required | Description | +|------------|------------|---------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| inputCol | `"input"` | String | no | Input column name. | +| outputCol | `"output"` | String | no | Output column name. | +| minTF | `1.0` | Double | no | Filter to ignore rare words in a document. For each document, terms with frequency/count less than the given threshold are ignored. If this is an integer >= 1, then this specifies a count (of times the term must appear in the document); if this is a double in [0,1), then this specifies a fraction (out of the document's token count). | Review Comment: In fact, users can set the value of `minTF` to any double >= 1, in which case the specified count is `Math.floor(minTF)`. Do you think it would be better if we clarify that if `minTF >= 1`, it does not have to be an integer? Save for `minDF` and `maxDF`. ########## flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizer.java: ########## @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.countvectorizer; + +import org.apache.flink.api.common.functions.AggregateFunction; +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.tuple.Tuple2; +import org.apache.flink.ml.api.Estimator; +import org.apache.flink.ml.common.datastream.DataStreamUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** + * {@link CountVectorizer} aims to help convert a collection of text documents to vectors of token + * counts. When an a-priori dictionary is not available, {@link CountVectorizer} can be used as an + * estimator to extract the vocabulary, and generates a {@link CountVectorizerModel}. The model + * produces sparse representations for the documents over the vocabulary, which can then be passed + * to other algorithms like LDA. Review Comment: Let's modify the description so that it aligns with the JavaDoc of other algorithms in style. - The document should provide a description of the algorithm, not describe the Estimator and the Model separately. - It is better to start the document with a noun or definition, like "CountVectorizer is an algorithm that ..." - Let's not use `@link` when referring to the class itself. ########## flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java: ########## @@ -0,0 +1,390 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.countvectorizer.CountVectorizer; +import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel; +import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vectors; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.apache.commons.lang3.exception.ExceptionUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.stream.DoubleStream; +import java.util.stream.IntStream; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +/** Tests {@link CountVectorizer} and {@link CountVectorizerModel}. */ +public class CountVectorizerTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table trainDataTable; + + private static final double EPS = 1.0e-5; + private static final List<Row> TRAIN_DATA = + new ArrayList<>( + Arrays.asList( + Row.of((Object) new String[] {"a", "c", "b", "c"}), + Row.of((Object) new String[] {"c", "d", "e"}), + Row.of((Object) new String[] {"a", "b", "c"}), + Row.of((Object) new String[] {"e", "f"}), + Row.of((Object) new String[] {"a", "c", "a"}))); + + private static final List<SparseVector> EXPECTED_OUTPUT = + new ArrayList<>( + Arrays.asList( + Vectors.sparse( + 6, + IntStream.of(0, 1, 2).toArray(), + DoubleStream.of(2.0, 1.0, 1.0).toArray()), + Vectors.sparse( + 6, + IntStream.of(0, 3, 4).toArray(), + DoubleStream.of(1.0, 1.0, 1.0).toArray()), + Vectors.sparse( + 6, + IntStream.of(0, 1, 2).toArray(), + DoubleStream.of(1.0, 1.0, 1.0).toArray()), + Vectors.sparse( + 6, + IntStream.of(3, 5).toArray(), + DoubleStream.of(1.0, 1.0).toArray()), + Vectors.sparse( + 6, + IntStream.of(0, 1).toArray(), + DoubleStream.of(1.0, 2.0).toArray()))); + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + + trainDataTable = tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("input"); + } + + private static void verifyPredictionResult( + Table output, String outputCol, List<SparseVector> expected) throws Exception { + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment(); + DataStream<SparseVector> stream = + tEnv.toDataStream(output) + .map( + (MapFunction<Row, SparseVector>) + row -> (SparseVector) row.getField(outputCol)); + List<SparseVector> result = IteratorUtils.toList(stream.executeAndCollect()); + compareResultCollections(expected, result, TestUtils::compare); + } + + @Test + public void testParam() { + CountVectorizer countVectorizer = new CountVectorizer(); + assertEquals("input", countVectorizer.getInputCol()); + assertEquals("output", countVectorizer.getOutputCol()); + assertEquals((double) Long.MAX_VALUE, countVectorizer.getMaxDF(), EPS); + assertEquals(1.0, countVectorizer.getMinDF(), EPS); + assertEquals(1.0, countVectorizer.getMinTF(), EPS); + assertEquals(1 << 18, countVectorizer.getVocabularySize()); + assertFalse(countVectorizer.getBinary()); + + countVectorizer + .setInputCol("test_input") + .setOutputCol("test_output") + .setMinDF(0.1) + .setMaxDF(0.9) + .setMinTF(10) + .setVocabularySize(1000) + .setBinary(true); + assertEquals("test_input", countVectorizer.getInputCol()); + assertEquals("test_output", countVectorizer.getOutputCol()); + assertEquals(0.9, countVectorizer.getMaxDF(), EPS); + assertEquals(0.1, countVectorizer.getMinDF(), EPS); + assertEquals(10, countVectorizer.getMinTF(), EPS); + assertEquals(1000, countVectorizer.getVocabularySize()); + assertTrue(countVectorizer.getBinary()); + } + + @Test + public void testInvalidMinMaxDF() { + String errMessage = "maxDF must be >= minDF."; + CountVectorizer countVectorizer = new CountVectorizer(); + countVectorizer.setMaxDF(0.1); + countVectorizer.setMinDF(0.2); + try { + countVectorizer.fit(trainDataTable); + fail(); + } catch (Throwable e) { + assertEquals(errMessage, e.getMessage()); + } + countVectorizer.setMaxDF(1); + countVectorizer.setMinDF(2); + try { + countVectorizer.fit(trainDataTable); + fail(); + } catch (Throwable e) { + assertEquals(errMessage, e.getMessage()); + } + countVectorizer.setMaxDF(1); + countVectorizer.setMinDF(0.9); + try { + CountVectorizerModel model = countVectorizer.fit(trainDataTable); + Table output = model.transform(trainDataTable)[0]; + output.execute().print(); + fail(); + } catch (Throwable e) { + assertEquals(errMessage, ExceptionUtils.getRootCause(e).getMessage()); + } + countVectorizer.setMaxDF(0.1); + countVectorizer.setMinDF(10); + try { + CountVectorizerModel model = countVectorizer.fit(trainDataTable); + Table output = model.transform(trainDataTable)[0]; + output.execute().print(); + fail(); + } catch (Throwable e) { + assertEquals(errMessage, ExceptionUtils.getRootCause(e).getMessage()); + } + } + + @Test + public void testOutputSchema() { + CountVectorizer countVectorizer = + new CountVectorizer().setInputCol("test_input").setOutputCol("test_output"); + CountVectorizerModel model = countVectorizer.fit(trainDataTable); + Table output = model.transform(trainDataTable.as("test_input"))[0]; + assertEquals( + Arrays.asList("test_input", "test_output"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFitAndPredict() throws Exception { + CountVectorizer countVectorizer = new CountVectorizer(); + CountVectorizerModel model = countVectorizer.fit(trainDataTable); + Table output = model.transform(trainDataTable)[0]; + + verifyPredictionResult(output, countVectorizer.getOutputCol(), EXPECTED_OUTPUT); + } + + @Test + public void testSaveLoadAndPredict() throws Exception { + CountVectorizer countVectorizer = new CountVectorizer(); + CountVectorizer loadedCountVectorizer = + TestUtils.saveAndReload( + tEnv, countVectorizer, tempFolder.newFolder().getAbsolutePath()); + CountVectorizerModel model = loadedCountVectorizer.fit(trainDataTable); + CountVectorizerModel loadedModel = + TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath()); + assertEquals( + Arrays.asList("vocabulary"), + loadedModel.getModelData()[0].getResolvedSchema().getColumnNames()); + Table output = loadedModel.transform(trainDataTable)[0]; + verifyPredictionResult(output, countVectorizer.getOutputCol(), EXPECTED_OUTPUT); + } + + @Test + public void testFitOnEmptyData() { + Table emptyTable = + tEnv.fromDataStream(env.fromCollection(TRAIN_DATA).filter(x -> x.getArity() == 0)) + .as("input"); + CountVectorizer countVectorizer = new CountVectorizer(); + CountVectorizerModel model = countVectorizer.fit(emptyTable); + Table modelDataTable = model.getModelData()[0]; + try { + modelDataTable.execute().print(); + fail(); + } catch (Throwable e) { + assertEquals("The training set is empty.", ExceptionUtils.getRootCause(e).getMessage()); + } + } + + @Test + public void testMinMaxDF() throws Exception { + List<SparseVector> expectedOutput = + new ArrayList<>( + Arrays.asList( + Vectors.sparse( + 4, + IntStream.of(0, 1, 2).toArray(), + DoubleStream.of(2.0, 1.0, 1.0).toArray()), + Vectors.sparse( + 4, + IntStream.of(0, 3).toArray(), + DoubleStream.of(1.0, 1.0).toArray()), + Vectors.sparse( + 4, + IntStream.of(0, 1, 2).toArray(), + DoubleStream.of(1.0, 1.0, 1.0).toArray()), + Vectors.sparse( + 4, + IntStream.of(3).toArray(), + DoubleStream.of(1.0).toArray()), + Vectors.sparse( + 4, + IntStream.of(0, 1).toArray(), + DoubleStream.of(1.0, 2.0).toArray()))); + CountVectorizer countVectorizer = new CountVectorizer().setMinDF(2).setMaxDF(4); + CountVectorizerModel model = countVectorizer.fit(trainDataTable); + Table output = model.transform(trainDataTable)[0]; + verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput); + + countVectorizer.setMinDF(0.4).setMaxDF(0.8); + model = countVectorizer.fit(trainDataTable); + output = model.transform(trainDataTable)[0]; + verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput); + } + + @Test + public void testMinTF() throws Exception { + List<SparseVector> expectedOutput = + new ArrayList<>( + Arrays.asList( + Vectors.sparse( + 6, + IntStream.of(0).toArray(), + DoubleStream.of(2.0).toArray()), + Vectors.sparse(6, new int[0], new double[0]), + Vectors.sparse(6, new int[0], new double[0]), + Vectors.sparse( + 6, + IntStream.of(3, 5).toArray(), + DoubleStream.of(1.0, 1.0).toArray()), + Vectors.sparse( + 6, + IntStream.of(1).toArray(), + DoubleStream.of(2.0).toArray()))); + CountVectorizer countVectorizer = new CountVectorizer().setMinTF(0.5); + CountVectorizerModel model = countVectorizer.fit(trainDataTable); + Table output = model.transform(trainDataTable)[0]; + verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput); + } + + @Test + public void testBinary() throws Exception { + List<SparseVector> expectedOutput = + new ArrayList<>( + Arrays.asList( + Vectors.sparse( + 6, + IntStream.of(0, 1, 2).toArray(), + DoubleStream.of(1.0, 1.0, 1.0).toArray()), + Vectors.sparse( + 6, + IntStream.of(0, 3, 4).toArray(), + DoubleStream.of(1.0, 1.0, 1.0).toArray()), + Vectors.sparse( + 6, + IntStream.of(0, 1, 2).toArray(), + DoubleStream.of(1.0, 1.0, 1.0).toArray()), + Vectors.sparse( + 6, + IntStream.of(3, 5).toArray(), + DoubleStream.of(1.0, 1.0).toArray()), + Vectors.sparse( + 6, + IntStream.of(0, 1).toArray(), + DoubleStream.of(1.0, 1.0).toArray()))); + CountVectorizer countVectorizer = new CountVectorizer().setBinary(true); + CountVectorizerModel model = countVectorizer.fit(trainDataTable); + Table output = model.transform(trainDataTable)[0]; + verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput); + } + + @Test + public void testVocabularySize() throws Exception { + List<SparseVector> expectedOutput = + new ArrayList<>( + Arrays.asList( + Vectors.sparse( + 2, + IntStream.of(0, 1).toArray(), + DoubleStream.of(2.0, 1.0).toArray()), + Vectors.sparse( + 2, + IntStream.of(0).toArray(), + DoubleStream.of(1.0).toArray()), + Vectors.sparse( + 2, + IntStream.of(0, 1).toArray(), + DoubleStream.of(1.0, 1.0).toArray()), + Vectors.sparse(2, new int[0], new double[0]), + Vectors.sparse( + 2, + IntStream.of(0, 1).toArray(), + DoubleStream.of(1.0, 2.0).toArray()))); + CountVectorizer countVectorizer = new CountVectorizer().setVocabularySize(2); + CountVectorizerModel model = countVectorizer.fit(trainDataTable); + Table output = model.transform(trainDataTable)[0]; + verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput); + } + + @Test + public void testGetModelData() throws Exception { + CountVectorizer countVectorizer = new CountVectorizer(); + CountVectorizerModel model = countVectorizer.fit(trainDataTable); + Table modelData = model.getModelData()[0]; + assertEquals(Arrays.asList("vocabulary"), modelData.getResolvedSchema().getColumnNames()); + + DataStream<Row> output = tEnv.toDataStream(modelData); + List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect()); + String[] vocabulary = (String[]) modelRows.get(0).getField(0); + assert vocabulary != null; Review Comment: This line seems redundant. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
