This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new f793ef1 [FLINK-29601] Add Estimator and Transformer for
UnivariateFeatureSelector
f793ef1 is described below
commit f793ef13aeefc5a65f43ec5cbd3a90eca743bae4
Author: JiangXin <[email protected]>
AuthorDate: Tue Dec 13 17:57:28 2022 +0800
[FLINK-29601] Add Estimator and Transformer for UnivariateFeatureSelector
This closes #187.
---
.../operators/feature/univariatefeatureselector.md | 222 ++++++
.../feature/UnivariateFeatureSelectorExample.java | 77 ++
.../apache/flink/ml/common/util/VectorUtils.java | 62 ++
.../UnivariateFeatureSelector.java | 305 ++++++++
.../UnivariateFeatureSelectorModel.java} | 111 ++-
.../UnivariateFeatureSelectorModelData.java | 112 +++
.../UnivariateFeatureSelectorModelParams.java | 30 +
.../UnivariateFeatureSelectorParams.java | 139 ++++
.../VarianceThresholdSelector.java | 12 +-
.../VarianceThresholdSelectorModel.java | 20 +-
.../flink/ml/common/util/VectorUtilsTest.java | 49 ++
.../ml/feature/UnivariateFeatureSelectorTest.java | 782 +++++++++++++++++++++
.../ml/feature/VarianceThresholdSelectorTest.java | 89 ++-
.../feature/univariatefeatureselector_example.py | 68 ++
.../tests/test_univariatefeatureselector.py | 203 ++++++
.../ml/lib/feature/univariatefeatureselector.py | 208 ++++++
16 files changed, 2366 insertions(+), 123 deletions(-)
diff --git a/docs/content/docs/operators/feature/univariatefeatureselector.md
b/docs/content/docs/operators/feature/univariatefeatureselector.md
new file mode 100644
index 0000000..e119d6a
--- /dev/null
+++ b/docs/content/docs/operators/feature/univariatefeatureselector.md
@@ -0,0 +1,222 @@
+---
+title: "Univariate Feature Selector"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/univariatefeatureselector.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## Univariate Feature Selector
+Univariate Feature Selector is an algorithm that selects features based on
+univariate statistical tests against labels.
+
+Currently, Flink supports three Univariate Feature Selectors: chi-squared,
+ANOVA F-test and F-value. User can choose Univariate Feature Selector by
+setting `featureType` and `labelType`, and Flink will pick the score function
+based on the specified `featureType` and `labelType`.
+
+The following combination of `featureType` and `labelType` are supported:
+
+<ul>
+ <li>`featureType` `categorical` and `labelType` `categorical`: Flink uses
+ chi-squared, i.e. chi2 in sklearn.
+ <li>`featureType` `continuous` and `labelType` `categorical`: Flink uses
+ ANOVA F-test, i.e. f_classif in sklearn.
+ <li>`featureType` `continuous` and `labelType` `continuous`: Flink uses
+ F-value, i.e. f_regression in sklearn.
+</ul>
+
+Univariate Feature Selector supports different selection modes:
+
+<ul>
+ <li>numTopFeatures: chooses a fixed number of top features according to a
+ hypothesis.
+ <li>percentile: similar to numTopFeatures but chooses a fraction of all
+ features instead of a fixed number.
+ <li>fpr: chooses all features whose p-value are below a threshold, thus
+ controlling the false positive rate of selection.
+ <li>fdr: uses the <a
href="https://en.wikipedia.org/wiki/False_discovery_rate#
+ Benjamini.E2.80.93Hochberg_procedure">Benjamini-Hochberg procedure</a>
to
+ choose all features whose false discovery rate is below a threshold.
+ <li>fwe: chooses all features whose p-values are below a threshold. The
+ threshold is scaled by 1/numFeatures, thus controlling the family-wise
+ error rate of selection.
+</ul>
+
+By default, the selection mode is `numTopFeatures`.
+
+### Input Columns
+
+| Param name | Type | Default | Description |
+|:------------|:-------|:-------------|:-----------------------|
+| featuresCol | Vector | `"features"` | Feature vector. |
+| labelCol | Number | `"label"` | Label of the features. |
+
+### Output Columns
+
+| Param name | Type | Default | Description |
+|:-----------|:-------|:-----------|:-------------------|
+| outputCol | Vector | `"output"` | Selected features. |
+
+### Parameters
+
+Below are the parameters required by `UnivariateFeatureSelectorModel`.
+
+| Key | Default | Type | Required | Description |
+|-------------|--------------|--------|----------|-------------------------|
+| featuresCol | `"features"` | String | no | Features column name. |
+| outputCol | `"output"` | String | no | Output column name. |
+
+`UnivariateFeatureSelector` needs parameters above and also below.
+
+| Key | Default | Type | Required | Description
|
+| ------------------ | ------------------ | ------- | -------- |
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
+| labelCol | `"label"` | String | no | Label column
name.
|
+| featureType | `null` | String | yes | The feature
type. Supported values: 'categorical', 'continuous'.
|
+| labelType | `null` | String | yes | The label
type. Supported values: 'categorical', 'continuous'.
|
+| selectionMode | `"numTopFeatures"` | String | no | The feature
selection mode. Supported values: 'numTopFeatures', 'percentile', 'fpr', 'fdr',
'fwe'.
|
+| selectionThreshold | `null` | Number | no | The upper
bound of the features that selector will select. If not set, it will be
replaced with a meaningful value according to different selection modes at
runtime. When the mode is numTopFeatures, it will be replaced with 50; when the
mode is percentile, it will be replaced with 0.1; otherwise, it will be
replaced with 0.05. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import
org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector;
+import
org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/**
+ * Simple program that trains a {@link UnivariateFeatureSelector} model and
uses it for feature
+ * selection.
+ */
+public class UnivariateFeatureSelectorExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input training and prediction data.
+ DataStream<Row> trainStream =
+ env.fromElements(
+ Row.of(Vectors.dense(1.7, 4.4, 7.6, 5.8, 9.6, 2.3),
3.0),
+ Row.of(Vectors.dense(8.8, 7.3, 5.7, 7.3, 2.2, 4.1),
2.0),
+ Row.of(Vectors.dense(1.2, 9.5, 2.5, 3.1, 8.7, 2.5),
1.0),
+ Row.of(Vectors.dense(3.7, 9.2, 6.1, 4.1, 7.5, 3.8),
2.0),
+ Row.of(Vectors.dense(8.9, 5.2, 7.8, 8.3, 5.2, 3.0),
4.0),
+ Row.of(Vectors.dense(7.9, 8.5, 9.2, 4.0, 9.4, 2.1),
4.0));
+ Table trainTable = tEnv.fromDataStream(trainStream).as("features",
"label");
+
+ // Creates a UnivariateFeatureSelector object and initializes its
parameters.
+ UnivariateFeatureSelector univariateFeatureSelector =
+ new UnivariateFeatureSelector()
+ .setFeaturesCol("features")
+ .setLabelCol("label")
+ .setFeatureType("continuous")
+ .setLabelType("categorical")
+ .setSelectionThreshold(1);
+
+ // Trains the UnivariateFeatureSelector model.
+ UnivariateFeatureSelectorModel model =
univariateFeatureSelector.fit(trainTable);
+
+ // Uses the UnivariateFeatureSelector model for predictions.
+ Table outputTable = model.transform(trainTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect();
it.hasNext(); ) {
+ Row row = it.next();
+ DenseVector inputValue =
+ (DenseVector)
row.getField(univariateFeatureSelector.getFeaturesCol());
+ DenseVector outputValue =
+ (DenseVector)
row.getField(univariateFeatureSelector.getOutputCol());
+ System.out.printf("Input Value: %-15s\tOutput Value: %s\n",
inputValue, outputValue);
+ }
+ }
+}
+
+```
+
+{{< /tab>}}
+
+{{< tab "Python">}}
+
+```python
+# Simple program that creates a UnivariateFeatureSelector instance and uses it
for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.univariatefeatureselector import
UnivariateFeatureSelector
+from pyflink.table import StreamTableEnvironment
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+
+env = StreamExecutionEnvironment.get_execution_environment()
+
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input training and prediction data.
+input_table = t_env.from_data_stream(
+ env.from_collection([
+ (Vectors.dense(1.7, 4.4, 7.6, 5.8, 9.6, 2.3), 3.0,),
+ (Vectors.dense(8.8, 7.3, 5.7, 7.3, 2.2, 4.1), 2.0,),
+ (Vectors.dense(1.2, 9.5, 2.5, 3.1, 8.7, 2.5), 1.0,),
+ (Vectors.dense(3.7, 9.2, 6.1, 4.1, 7.5, 3.8), 2.0,),
+ (Vectors.dense(8.9, 5.2, 7.8, 8.3, 5.2, 3.0), 4.0,),
+ (Vectors.dense(7.9, 8.5, 9.2, 4.0, 9.4, 2.1), 4.0,),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['features', 'label'],
+ [DenseVectorTypeInfo(), Types.FLOAT()])
+ ))
+
+# Creates an UnivariateFeatureSelector object and initializes its parameters.
+univariate_feature_selector = UnivariateFeatureSelector() \
+ .set_features_col('features') \
+ .set_label_col('label') \
+ .set_feature_type('continuous') \
+ .set_label_type('categorical') \
+ .set_selection_threshold(1)
+
+# Trains the UnivariateFeatureSelector Model.
+model = univariate_feature_selector.fit(input_table)
+
+# Uses the UnivariateFeatureSelector Model for predictions.
+output = model.transform(input_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ input_index =
field_names.index(univariate_feature_selector.get_features_col())
+ output_index =
field_names.index(univariate_feature_selector.get_output_col())
+ print('Input Value: ' + str(result[input_index]) +
+ '\tOutput Value: ' + str(result[output_index]))
+
+```
\ No newline at end of file
diff --git
a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/UnivariateFeatureSelectorExample.java
b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/UnivariateFeatureSelectorExample.java
new file mode 100644
index 0000000..4d4c07f
--- /dev/null
+++
b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/UnivariateFeatureSelectorExample.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import
org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector;
+import
org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/**
+ * Simple program that trains a {@link UnivariateFeatureSelector} model and
uses it for feature
+ * selection.
+ */
+public class UnivariateFeatureSelectorExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env =
StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input training and prediction data.
+ DataStream<Row> trainStream =
+ env.fromElements(
+ Row.of(Vectors.dense(1.7, 4.4, 7.6, 5.8, 9.6, 2.3),
3.0),
+ Row.of(Vectors.dense(8.8, 7.3, 5.7, 7.3, 2.2, 4.1),
2.0),
+ Row.of(Vectors.dense(1.2, 9.5, 2.5, 3.1, 8.7, 2.5),
1.0),
+ Row.of(Vectors.dense(3.7, 9.2, 6.1, 4.1, 7.5, 3.8),
2.0),
+ Row.of(Vectors.dense(8.9, 5.2, 7.8, 8.3, 5.2, 3.0),
4.0),
+ Row.of(Vectors.dense(7.9, 8.5, 9.2, 4.0, 9.4, 2.1),
4.0));
+ Table trainTable = tEnv.fromDataStream(trainStream).as("features",
"label");
+
+ // Creates a UnivariateFeatureSelector object and initializes its
parameters.
+ UnivariateFeatureSelector univariateFeatureSelector =
+ new UnivariateFeatureSelector()
+ .setFeaturesCol("features")
+ .setLabelCol("label")
+ .setFeatureType("continuous")
+ .setLabelType("categorical")
+ .setSelectionThreshold(1);
+
+ // Trains the UnivariateFeatureSelector model.
+ UnivariateFeatureSelectorModel model =
univariateFeatureSelector.fit(trainTable);
+
+ // Uses the UnivariateFeatureSelector model for predictions.
+ Table outputTable = model.transform(trainTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect();
it.hasNext(); ) {
+ Row row = it.next();
+ DenseVector inputValue =
+ (DenseVector)
row.getField(univariateFeatureSelector.getFeaturesCol());
+ DenseVector outputValue =
+ (DenseVector)
row.getField(univariateFeatureSelector.getOutputCol());
+ System.out.printf("Input Value: %-15s\tOutput Value: %s\n",
inputValue, outputValue);
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/VectorUtils.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/VectorUtils.java
new file mode 100644
index 0000000..3d37c1f
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/util/VectorUtils.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.util;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/** Provides utility functions for {@link Vector}. */
+public class VectorUtils {
+ /**
+ * Selects a subset of the vector base on the indices. Note that the input
indices must be
+ * sorted in ascending order.
+ */
+ public static Vector selectByIndices(Vector vector, int[] sortedIndices) {
+ if (vector instanceof DenseVector) {
+ DenseVector resultVec = new DenseVector(sortedIndices.length);
+ for (int i = 0; i < sortedIndices.length; i++) {
+ resultVec.set(i, vector.get(sortedIndices[i]));
+ }
+ return resultVec;
+ } else {
+ List<Integer> resultIndices = new ArrayList<>();
+ List<Double> resultValues = new ArrayList<>();
+
+ int[] indices = ((SparseVector) vector).indices;
+ for (int i = 0, j = 0; i < indices.length && j <
sortedIndices.length; ) {
+ if (indices[i] == sortedIndices[j]) {
+ resultIndices.add(j++);
+ resultValues.add(((SparseVector) vector).values[i++]);
+ } else if (indices[i] > sortedIndices[j]) {
+ j++;
+ } else {
+ i++;
+ }
+ }
+ return new SparseVector(
+ sortedIndices.length,
+
resultIndices.stream().mapToInt(Integer::intValue).toArray(),
+
resultValues.stream().mapToDouble(Double::doubleValue).toArray());
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelector.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelector.java
new file mode 100644
index 0000000..78810d3
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelector.java
@@ -0,0 +1,305 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.stats.anovatest.ANOVATest;
+import org.apache.flink.ml.stats.chisqtest.ChiSqTest;
+import org.apache.flink.ml.stats.fvaluetest.FValueTest;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.collections.IteratorUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.IntStream;
+
+/**
+ * An Estimator which selects features based on univariate statistical tests
against labels.
+ *
+ * <p>Currently, Flink supports three Univariate Feature Selectors:
chi-squared, ANOVA F-test and
+ * F-value. User can choose Univariate Feature Selector by setting
`featureType` and `labelType`,
+ * and Flink will pick the score function based on the specified `featureType`
and `labelType`.
+ *
+ * <p>The following combination of `featureType` and `labelType` are supported:
+ *
+ * <ul>
+ * <li>`featureType` `categorical` and `labelType` `categorical`: Flink uses
chi-squared, i.e.
+ * chi2 in sklearn.
+ * <li>`featureType` `continuous` and `labelType` `categorical`: Flink uses
ANOVA F-test, i.e.
+ * f_classif in sklearn.
+ * <li>`featureType` `continuous` and `labelType` `continuous`: Flink uses
F-value, i.e.
+ * f_regression in sklearn.
+ * </ul>
+ *
+ * <p>The `UnivariateFeatureSelector` supports different selection modes:
+ *
+ * <ul>
+ * <li>numTopFeatures: chooses a fixed number of top features according to a
hypothesis.
+ * <li>percentile: similar to numTopFeatures but chooses a fraction of all
features instead of a
+ * fixed number.
+ * <li>fpr: chooses all features whose p-value are below a threshold, thus
controlling the false
+ * positive rate of selection.
+ * <li>fdr: uses the <a
+ *
href="https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure">
+ * Benjamini-Hochberg procedure</a> to choose all features whose false
discovery rate is below
+ * a threshold.
+ * <li>fwe: chooses all features whose p-values are below a threshold. The
threshold is scaled by
+ * 1/numFeatures, thus controlling the family-wise error rate of
selection.
+ * </ul>
+ *
+ * <p>By default, the selection mode is `numTopFeatures`.
+ */
+public class UnivariateFeatureSelector
+ implements Estimator<UnivariateFeatureSelector,
UnivariateFeatureSelectorModel>,
+ UnivariateFeatureSelectorParams<UnivariateFeatureSelector> {
+
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public UnivariateFeatureSelector() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public UnivariateFeatureSelectorModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+
+ final String featuresCol = getFeaturesCol();
+ final String labelCol = getLabelCol();
+ final String featureType = getFeatureType();
+ final String labelType = getLabelType();
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
+
+ Table output;
+ if (CATEGORICAL.equals(featureType) && CATEGORICAL.equals(labelType)) {
+ output =
+ new ChiSqTest()
+ .setFeaturesCol(featuresCol)
+ .setLabelCol(labelCol)
+ .setFlatten(true)
+ .transform(inputs[0])[0];
+ } else if (CONTINUOUS.equals(featureType) &&
CATEGORICAL.equals(labelType)) {
+ output =
+ new ANOVATest()
+ .setFeaturesCol(featuresCol)
+ .setLabelCol(labelCol)
+ .setFlatten(true)
+ .transform(inputs[0])[0];
+ } else if (CONTINUOUS.equals(featureType) &&
CONTINUOUS.equals(labelType)) {
+ output =
+ new FValueTest()
+ .setFeaturesCol(featuresCol)
+ .setLabelCol(labelCol)
+ .setFlatten(true)
+ .transform(inputs[0])[0];
+ } else {
+ throw new IllegalArgumentException(
+ String.format(
+ "Unsupported combination: featureType=%s,
labelType=%s.",
+ featureType, labelType));
+ }
+ DataStream<UnivariateFeatureSelectorModelData> modelData =
+ tEnv.toDataStream(output)
+ .transform(
+ "selectIndicesFromPValues",
+
TypeInformation.of(UnivariateFeatureSelectorModelData.class),
+ new SelectIndicesFromPValuesOperator(
+ getSelectionMode(),
getActualSelectionThreshold()))
+ .setParallelism(1);
+ UnivariateFeatureSelectorModel model =
+ new
UnivariateFeatureSelectorModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ private double getActualSelectionThreshold() {
+ Double threshold = getSelectionThreshold();
+ if (threshold == null) {
+ String selectionMode = getSelectionMode();
+ if (NUM_TOP_FEATURES.equals(selectionMode)) {
+ threshold = 50.0;
+ } else if (PERCENTILE.equals(selectionMode)) {
+ threshold = 0.1;
+ } else {
+ threshold = 0.05;
+ }
+ } else {
+ if (NUM_TOP_FEATURES.equals(getSelectionMode())) {
+ Preconditions.checkArgument(
+ threshold >= 1 && threshold.intValue() == threshold,
+ "SelectionThreshold needs to be a positive Integer "
+ + "for selection mode numTopFeatures, but got
%s.",
+ threshold);
+ } else {
+ Preconditions.checkArgument(
+ threshold >= 0 && threshold <= 1,
+ "SelectionThreshold needs to be in the range [0, 1] "
+ + "for selection mode %s, but got %s.",
+ getSelectionMode(),
+ threshold);
+ }
+ }
+ return threshold;
+ }
+
+ private static class SelectIndicesFromPValuesOperator
+ extends AbstractStreamOperator<UnivariateFeatureSelectorModelData>
+ implements OneInputStreamOperator<Row,
UnivariateFeatureSelectorModelData>,
+ BoundedOneInput {
+ private final String selectionMode;
+ private final double threshold;
+
+ private List<Tuple2<Double, Integer>> pValuesAndIndices;
+ private ListState<Tuple2<Double, Integer>> pValuesAndIndicesState;
+
+ public SelectIndicesFromPValuesOperator(String selectionMode, double
threshold) {
+ this.selectionMode = selectionMode;
+ this.threshold = threshold;
+ }
+
+ @Override
+ public void endInput() {
+ List<Integer> indices = new ArrayList<>();
+
+ switch (selectionMode) {
+ case NUM_TOP_FEATURES:
+ pValuesAndIndices.sort(
+ Comparator.comparingDouble((Tuple2<Double,
Integer> t) -> t.f0)
+ .thenComparingInt(t -> t.f1));
+ IntStream.range(0, Math.min(pValuesAndIndices.size(),
(int) threshold))
+ .forEach(i ->
indices.add(pValuesAndIndices.get(i).f1));
+ break;
+ case PERCENTILE:
+ pValuesAndIndices.sort(
+ Comparator.comparingDouble((Tuple2<Double,
Integer> t) -> t.f0)
+ .thenComparingInt(t -> t.f1));
+ IntStream.range(
+ 0,
+ Math.min(
+ pValuesAndIndices.size(),
+ (int) (pValuesAndIndices.size() *
threshold)))
+ .forEach(i ->
indices.add(pValuesAndIndices.get(i).f1));
+ break;
+ case FPR:
+ pValuesAndIndices.stream()
+ .filter(x -> x.f0 < threshold)
+ .forEach(x -> indices.add(x.f1));
+ break;
+ case FDR:
+ pValuesAndIndices.sort(
+ Comparator.comparingDouble((Tuple2<Double,
Integer> t) -> t.f0)
+ .thenComparingInt(t -> t.f1));
+
+ int maxIndex = -1;
+ for (int i = 0; i < pValuesAndIndices.size(); i++) {
+ if (pValuesAndIndices.get(i).f0
+ < (threshold / pValuesAndIndices.size()) * (i
+ 1)) {
+ maxIndex = Math.max(maxIndex, i);
+ }
+ }
+ if (maxIndex >= 0) {
+ pValuesAndIndices.sort(
+ Comparator.comparingDouble((Tuple2<Double,
Integer> t) -> t.f0)
+ .thenComparingInt(t -> t.f1));
+ IntStream.range(0, maxIndex + 1)
+ .forEach(i ->
indices.add(pValuesAndIndices.get(i).f1));
+ }
+ break;
+ case FWE:
+ pValuesAndIndices.stream()
+ .filter(x -> x.f0 < threshold /
pValuesAndIndices.size())
+ .forEach(x -> indices.add(x.f1));
+ break;
+ default:
+ throw new RuntimeException("Unknown Selection Mode: " +
selectionMode);
+ }
+
+ UnivariateFeatureSelectorModelData modelData =
+ new UnivariateFeatureSelectorModelData(
+
indices.stream().mapToInt(Integer::intValue).toArray());
+ output.collect(new StreamRecord<>(modelData));
+ }
+
+ @Override
+ public void processElement(StreamRecord<Row> record) {
+ Row row = record.getValue();
+ double pValue = (double) row.getField("pValue");
+ int featureIndex = (int) row.getField("featureIndex");
+ pValuesAndIndices.add(Tuple2.of(pValue, featureIndex));
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws
Exception {
+ super.initializeState(context);
+ pValuesAndIndicesState =
+ context.getOperatorStateStore()
+ .getListState(
+ new ListStateDescriptor<>(
+ "pValuesAndIndices",
+ Types.TUPLE(Types.DOUBLE,
Types.INT)));
+ pValuesAndIndices =
IteratorUtils.toList(pValuesAndIndicesState.get().iterator());
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws
Exception {
+ super.snapshotState(context);
+ pValuesAndIndicesState.update(pValuesAndIndices);
+ }
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static UnivariateFeatureSelector load(StreamTableEnvironment tEnv,
String path)
+ throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java
similarity index 66%
copy from
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java
copy to
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java
index a13dd06..f5acf7a 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModel.java
@@ -16,17 +16,17 @@
* limitations under the License.
*/
-package org.apache.flink.ml.feature.variancethresholdselector;
+package org.apache.flink.ml.feature.univariatefeatureselector;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.common.broadcast.BroadcastUtils;
import org.apache.flink.ml.common.datastream.TableUtils;
-import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.common.util.VectorUtils;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
@@ -40,27 +40,27 @@ import org.apache.flink.util.Preconditions;
import org.apache.commons.lang3.ArrayUtils;
import java.io.IOException;
+import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
/**
- * A Model which removes low-variance data using the model data computed by
{@link
- * VarianceThresholdSelector}.
+ * A Model which transforms data using the model data computed by {@link
UnivariateFeatureSelector}.
*/
-public class VarianceThresholdSelectorModel
- implements Model<VarianceThresholdSelectorModel>,
-
VarianceThresholdSelectorModelParams<VarianceThresholdSelectorModel> {
+public class UnivariateFeatureSelectorModel
+ implements Model<UnivariateFeatureSelectorModel>,
+
UnivariateFeatureSelectorModelParams<UnivariateFeatureSelectorModel> {
private final Map<Param<?>, Object> paramMap = new HashMap<>();
private Table modelDataTable;
- public VarianceThresholdSelectorModel() {
+ public UnivariateFeatureSelectorModel() {
ParamUtils.initializeMapWithDefaultValues(paramMap, this);
}
@Override
- public VarianceThresholdSelectorModel setModelData(Table... inputs) {
+ public UnivariateFeatureSelectorModel setModelData(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
modelDataTable = inputs[0];
return this;
@@ -71,29 +71,7 @@ public class VarianceThresholdSelectorModel
return new Table[] {modelDataTable};
}
- @Override
- public Map<Param<?>, Object> getParamMap() {
- return paramMap;
- }
-
- @Override
- public void save(String path) throws IOException {
- ReadWriteUtils.saveMetadata(this, path);
- ReadWriteUtils.saveModelData(
-
VarianceThresholdSelectorModelData.getModelDataStream(modelDataTable),
- path,
- new VarianceThresholdSelectorModelData.ModelDataEncoder());
- }
-
- public static VarianceThresholdSelectorModel load(StreamTableEnvironment
tEnv, String path)
- throws IOException {
- VarianceThresholdSelectorModel model =
ReadWriteUtils.loadStageParam(path);
- Table modelDataTable =
- ReadWriteUtils.loadModelData(
- tEnv, path, new
VarianceThresholdSelectorModelData.ModelDataDecoder());
- return model.setModelData(modelDataTable);
- }
-
+ @SuppressWarnings({"unchecked", "rawtypes"})
@Override
public Table[] transform(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
@@ -101,29 +79,28 @@ public class VarianceThresholdSelectorModel
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
DataStream<Row> data = tEnv.toDataStream(inputs[0]);
- DataStream<VarianceThresholdSelectorModelData>
varianceThresholdSelectorModel =
-
VarianceThresholdSelectorModelData.getModelDataStream(modelDataTable);
+ DataStream<UnivariateFeatureSelectorModelData> modelData =
+
UnivariateFeatureSelectorModelData.getModelDataStream(modelDataTable);
final String broadcastModelKey = "broadcastModelKey";
RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
RowTypeInfo outputTypeInfo =
new RowTypeInfo(
- ArrayUtils.addAll(
- inputTypeInfo.getFieldTypes(),
DenseVectorTypeInfo.INSTANCE),
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(),
VectorTypeInfo.INSTANCE),
ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
getOutputCol()));
- DataStream<Row> output =
+ DataStream<Row> outputStream =
BroadcastUtils.withBroadcastStream(
Collections.singletonList(data),
- Collections.singletonMap(broadcastModelKey,
varianceThresholdSelectorModel),
+ Collections.singletonMap(broadcastModelKey, modelData),
inputList -> {
DataStream input = inputList.get(0);
return input.map(
- new PredictOutputFunction(getInputCol(),
broadcastModelKey),
+ new
PredictOutputFunction(getFeaturesCol(), broadcastModelKey),
outputTypeInfo);
});
- return new Table[] {tEnv.fromDataStream(output)};
+ return new Table[] {tEnv.fromDataStream(outputStream)};
}
/** This operator loads model data and predicts result. */
@@ -131,8 +108,7 @@ public class VarianceThresholdSelectorModel
private final String inputCol;
private final String broadcastKey;
- private int expectedNumOfFeatures;
- private int[] indices = null;
+ private int[] indices;
public PredictOutputFunction(String inputCol, String broadcastKey) {
this.inputCol = inputCol;
@@ -142,29 +118,48 @@ public class VarianceThresholdSelectorModel
@Override
public Row map(Row row) {
if (indices == null) {
- VarianceThresholdSelectorModelData
varianceThresholdSelectorModelData =
- (VarianceThresholdSelectorModelData)
+ UnivariateFeatureSelectorModelData modelData =
+ (UnivariateFeatureSelectorModelData)
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
- expectedNumOfFeatures =
varianceThresholdSelectorModelData.numOfFeatures;
- indices = varianceThresholdSelectorModelData.indices;
+ indices = Arrays.stream(modelData.indices).sorted().toArray();
}
- DenseVector inputVec = ((Vector) row.getField(inputCol)).toDense();
- Preconditions.checkArgument(
- inputVec.size() == expectedNumOfFeatures,
- "%s has %s features, but VarianceThresholdSelector is
expecting %s features as input.",
- inputCol,
- inputVec.size(),
- expectedNumOfFeatures);
if (indices.length == 0) {
return Row.join(row, Row.of(Vectors.dense()));
} else {
- DenseVector outputVec = new DenseVector(indices.length);
- for (int i = 0; i < indices.length; i++) {
- outputVec.values[i] = inputVec.get(indices[i]);
- }
+ Vector inputVec = ((Vector) row.getField(inputCol));
+ Preconditions.checkArgument(
+ inputVec.size() > indices[indices.length - 1],
+ "Input %s features, but UnivariateFeatureSelector is "
+ + "expecting at least %s features as input.",
+ inputVec.size(),
+ indices[indices.length - 1] + 1);
+ Vector outputVec = VectorUtils.selectByIndices(inputVec,
indices);
return Row.join(row, Row.of(outputVec));
}
}
}
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+
UnivariateFeatureSelectorModelData.getModelDataStream(modelDataTable),
+ path,
+ new UnivariateFeatureSelectorModelData.ModelDataEncoder());
+ }
+
+ public static UnivariateFeatureSelectorModel load(StreamTableEnvironment
tEnv, String path)
+ throws IOException {
+ UnivariateFeatureSelectorModel model =
ReadWriteUtils.loadStageParam(path);
+ Table modelDataTable =
+ ReadWriteUtils.loadModelData(
+ tEnv, path, new
UnivariateFeatureSelectorModelData.ModelDataDecoder());
+ return model.setModelData(modelDataTable);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModelData.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModelData.java
new file mode 100644
index 0000000..7ad848a
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModelData.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import
org.apache.flink.api.common.typeutils.base.array.IntPrimitiveArraySerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link UnivariateFeatureSelectorModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to a
data stream, and
+ * classes to save/load model data.
+ */
+public class UnivariateFeatureSelectorModelData {
+
+ /** Indices of the input features that are selected. */
+ public int[] indices;
+
+ public UnivariateFeatureSelectorModelData() {}
+
+ public UnivariateFeatureSelectorModelData(int[] indices) {
+ this.indices = indices;
+ }
+
+ /**
+ * Converts the table model to a data stream.
+ *
+ * @param modelDataTable The table model data.
+ * @return The data stream model data.
+ */
+ public static DataStream<UnivariateFeatureSelectorModelData>
getModelDataStream(
+ Table modelDataTable) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
modelDataTable).getTableEnvironment();
+ return tEnv.toDataStream(modelDataTable)
+ .map(x -> new UnivariateFeatureSelectorModelData((int[])
x.getField(0)));
+ }
+
+ /** Encoder for {@link UnivariateFeatureSelectorModelData}. */
+ public static class ModelDataEncoder implements
Encoder<UnivariateFeatureSelectorModelData> {
+ @Override
+ public void encode(UnivariateFeatureSelectorModelData modelData,
OutputStream outputStream)
+ throws IOException {
+ DataOutputView dataOutputView = new
DataOutputViewStreamWrapper(outputStream);
+ IntPrimitiveArraySerializer.INSTANCE.serialize(modelData.indices,
dataOutputView);
+ }
+ }
+
+ /** Decoder for {@link UnivariateFeatureSelectorModelData}. */
+ public static class ModelDataDecoder
+ extends SimpleStreamFormat<UnivariateFeatureSelectorModelData> {
+ @Override
+ public Reader<UnivariateFeatureSelectorModelData> createReader(
+ Configuration config, FSDataInputStream stream) {
+ return new Reader<UnivariateFeatureSelectorModelData>() {
+
+ @Override
+ public UnivariateFeatureSelectorModelData read() throws
IOException {
+ DataInputView source = new
DataInputViewStreamWrapper(stream);
+ try {
+ int[] indices =
IntPrimitiveArraySerializer.INSTANCE.deserialize(source);
+ return new UnivariateFeatureSelectorModelData(indices);
+ } catch (EOFException e) {
+ return null;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ stream.close();
+ }
+ };
+ }
+
+ @Override
+ public TypeInformation<UnivariateFeatureSelectorModelData>
getProducedType() {
+ return
TypeInformation.of(UnivariateFeatureSelectorModelData.class);
+ }
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModelParams.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModelParams.java
new file mode 100644
index 0000000..15b4bab
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorModelParams.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.ml.common.param.HasFeaturesCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
+
+/**
+ * Params for {@link UnivariateFeatureSelectorModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface UnivariateFeatureSelectorModelParams<T>
+ extends HasFeaturesCol<T>, HasOutputCol<T> {}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorParams.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorParams.java
new file mode 100644
index 0000000..87f05fb
--- /dev/null
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/univariatefeatureselector/UnivariateFeatureSelectorParams.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.univariatefeatureselector;
+
+import org.apache.flink.ml.common.param.HasLabelCol;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringParam;
+
+/**
+ * Params for {@link UnivariateFeatureSelector}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface UnivariateFeatureSelectorParams<T>
+ extends HasLabelCol<T>, UnivariateFeatureSelectorModelParams<T> {
+
+ String CATEGORICAL = "categorical";
+ String CONTINUOUS = "continuous";
+
+ String NUM_TOP_FEATURES = "numTopFeatures";
+ String PERCENTILE = "percentile";
+ String FPR = "fpr";
+ String FDR = "fdr";
+ String FWE = "fwe";
+
+ /**
+ * Supported options of the feature type.
+ *
+ * <ul>
+ * <li>categorical: the features are categorical data.
+ * <li>continuous: the features are continuous data.
+ * </ul>
+ */
+ Param<String> FEATURE_TYPE =
+ new StringParam(
+ "featureType",
+ "The feature type.",
+ null,
+ ParamValidators.inArray(CATEGORICAL, CONTINUOUS));
+
+ /**
+ * Supported options of the label type.
+ *
+ * <ul>
+ * <li>categorical: the label is categorical data.
+ * <li>continuous: the label is continuous data.
+ * </ul>
+ */
+ Param<String> LABEL_TYPE =
+ new StringParam(
+ "labelType",
+ "The label type.",
+ null,
+ ParamValidators.inArray(CATEGORICAL, CONTINUOUS));
+
+ /**
+ * Supported options of the feature selection mode.
+ *
+ * <ul>
+ * <li>numTopFeatures: chooses a fixed number of top features according
to a hypothesis.
+ * <li>percentile: similar to numTopFeatures but chooses a fraction of
all features instead of
+ * a fixed number.
+ * <li>fpr: chooses all features whose p-value are below a threshold,
thus controlling the
+ * false positive rate of selection.
+ * <li>fdr: uses the <a
+ *
href="https://en.wikipedia.org/wiki/False_discovery_rate#Benjamini.E2.80.93Hochberg_procedure">
+ * Benjamini-Hochberg procedure</a> to choose all features whose
false discovery rate is
+ * below a threshold.
+ * <li>fwe: chooses all features whose p-values are below a threshold.
The threshold is scaled
+ * by 1/numFeatures, thus controlling the family-wise error rate of
selection.
+ * </ul>
+ */
+ Param<String> SELECTION_MODE =
+ new StringParam(
+ "selectionMode",
+ "The feature selection mode.",
+ NUM_TOP_FEATURES,
+ ParamValidators.inArray(NUM_TOP_FEATURES, PERCENTILE, FPR,
FDR, FWE));
+
+ Param<Double> SELECTION_THRESHOLD =
+ new DoubleParam(
+ "selectionThreshold",
+ "The upper bound of the features that selector will
select. If not set, "
+ + "it will be replaced with a meaningful value
according to different "
+ + "selection modes at runtime. When the mode is
numTopFeatures, it will be "
+ + "replaced with 50; when the mode is percentile,
it will be replaced "
+ + "with 0.1; otherwise, it will be replaced with
0.05.",
+ null);
+
+ default String getFeatureType() {
+ return get(FEATURE_TYPE);
+ }
+
+ default T setFeatureType(String value) {
+ return set(FEATURE_TYPE, value);
+ }
+
+ default String getLabelType() {
+ return get(LABEL_TYPE);
+ }
+
+ default T setLabelType(String value) {
+ return set(LABEL_TYPE, value);
+ }
+
+ default String getSelectionMode() {
+ return get(SELECTION_MODE);
+ }
+
+ default T setSelectionMode(String value) {
+ return set(SELECTION_MODE, value);
+ }
+
+ default Double getSelectionThreshold() {
+ return get(SELECTION_THRESHOLD);
+ }
+
+ default T setSelectionThreshold(double value) {
+ return set(SELECTION_THRESHOLD, value);
+ }
+}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java
index 53944b6..37d4635 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelector.java
@@ -63,11 +63,11 @@ public class VarianceThresholdSelector
final String inputCol = getInputCol();
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
- DataStream<DenseVector> inputData =
+ DataStream<Vector> inputData =
tEnv.toDataStream(inputs[0])
.map(
- (MapFunction<Row, DenseVector>)
- value -> ((Vector)
value.getField(inputCol)).toDense());
+ (MapFunction<Row, Vector>)
+ value -> ((Vector)
value.getField(inputCol)));
DataStream<VarianceThresholdSelectorModelData> modelData =
DataStreamUtils.aggregate(
@@ -85,7 +85,7 @@ public class VarianceThresholdSelector
*/
private static class VarianceThresholdSelectorAggregator
implements AggregateFunction<
- DenseVector,
+ Vector,
Tuple3<Long, DenseVector, DenseVector>,
VarianceThresholdSelectorModelData> {
@@ -102,7 +102,7 @@ public class VarianceThresholdSelector
@Override
public Tuple3<Long, DenseVector, DenseVector> add(
- DenseVector vector, Tuple3<Long, DenseVector, DenseVector>
numAndSums) {
+ Vector vector, Tuple3<Long, DenseVector, DenseVector>
numAndSums) {
if (numAndSums.f0 == 0) {
numAndSums.f1 = new DenseVector(vector.size());
numAndSums.f2 = new DenseVector(vector.size());
@@ -110,7 +110,7 @@ public class VarianceThresholdSelector
numAndSums.f0 += 1L;
BLAS.axpy(1.0, vector, numAndSums.f1);
for (int i = 0; i < vector.size(); i++) {
- numAndSums.f2.values[i] += vector.values[i] * vector.values[i];
+ numAndSums.f2.values[i] += vector.get(i) * vector.get(i);
}
return numAndSums;
}
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java
index a13dd06..f042c9f 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/variancethresholdselector/VarianceThresholdSelectorModel.java
@@ -23,10 +23,10 @@ import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Model;
import org.apache.flink.ml.common.broadcast.BroadcastUtils;
import org.apache.flink.ml.common.datastream.TableUtils;
-import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.common.util.VectorUtils;
import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
-import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.util.ParamUtils;
import org.apache.flink.ml.util.ReadWriteUtils;
@@ -40,6 +40,7 @@ import org.apache.flink.util.Preconditions;
import org.apache.commons.lang3.ArrayUtils;
import java.io.IOException;
+import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
@@ -108,8 +109,7 @@ public class VarianceThresholdSelectorModel
RowTypeInfo inputTypeInfo =
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
RowTypeInfo outputTypeInfo =
new RowTypeInfo(
- ArrayUtils.addAll(
- inputTypeInfo.getFieldTypes(),
DenseVectorTypeInfo.INSTANCE),
+ ArrayUtils.addAll(inputTypeInfo.getFieldTypes(),
VectorTypeInfo.INSTANCE),
ArrayUtils.addAll(inputTypeInfo.getFieldNames(),
getOutputCol()));
DataStream<Row> output =
@@ -146,10 +146,13 @@ public class VarianceThresholdSelectorModel
(VarianceThresholdSelectorModelData)
getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
expectedNumOfFeatures =
varianceThresholdSelectorModelData.numOfFeatures;
- indices = varianceThresholdSelectorModelData.indices;
+ indices =
+
Arrays.stream(varianceThresholdSelectorModelData.indices)
+ .sorted()
+ .toArray();
}
- DenseVector inputVec = ((Vector) row.getField(inputCol)).toDense();
+ Vector inputVec = ((Vector) row.getField(inputCol));
Preconditions.checkArgument(
inputVec.size() == expectedNumOfFeatures,
"%s has %s features, but VarianceThresholdSelector is
expecting %s features as input.",
@@ -159,10 +162,7 @@ public class VarianceThresholdSelectorModel
if (indices.length == 0) {
return Row.join(row, Row.of(Vectors.dense()));
} else {
- DenseVector outputVec = new DenseVector(indices.length);
- for (int i = 0; i < indices.length; i++) {
- outputVec.values[i] = inputVec.get(indices[i]);
- }
+ Vector outputVec = VectorUtils.selectByIndices(inputVec,
indices);
return Row.join(row, Row.of(outputVec));
}
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/VectorUtilsTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/VectorUtilsTest.java
new file mode 100644
index 0000000..027be97
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/common/util/VectorUtilsTest.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.util;
+
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertArrayEquals;
+
+/** Tests {@link VectorUtils}. */
+public class VectorUtilsTest {
+
+ private static final double EPS = 1.0e-5;
+
+ @Test
+ public void testSelectByIndices() {
+ DenseVector denseVector = Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0);
+ assertArrayEquals(
+ Vectors.dense(2.0, 4.0).toArray(),
+ VectorUtils.selectByIndices(denseVector, new int[] {1,
3}).toArray(),
+ EPS);
+
+ SparseVector sparseVector =
+ Vectors.sparse(5, new int[] {1, 2, 3}, new double[] {2.0, 3.0,
4.0});
+ assertArrayEquals(
+ Vectors.sparse(3, new int[] {1, 2}, new double[] {2.0,
4.0}).toArray(),
+ VectorUtils.selectByIndices(sparseVector, new int[] {0, 1,
3}).toArray(),
+ EPS);
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java
new file mode 100644
index 0000000..76076d9
--- /dev/null
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/UnivariateFeatureSelectorTest.java
@@ -0,0 +1,782 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.Types;
+import org.apache.flink.configuration.Configuration;
+import
org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelector;
+import
org.apache.flink.ml.feature.univariatefeatureselector.UnivariateFeatureSelectorModel;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+
+/** Tests {@link UnivariateFeatureSelector} and {@link
UnivariateFeatureSelectorModel}. */
+public class UnivariateFeatureSelectorTest extends AbstractTestBase {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table inputChiSqTable;
+ private Table inputANOVATable;
+ private Table inputFValueTable;
+
+ private static final double EPS = 1.0e-5;
+
+ private UnivariateFeatureSelector selectorWithChiSqTest;
+ private UnivariateFeatureSelector selectorWithANOVATest;
+ private UnivariateFeatureSelector selectorWithFValueTest;
+
+ private static final List<Row> INPUT_CHISQ_DATA =
+ Arrays.asList(
+ Row.of(0.0, Vectors.dense(6.0, 7.0, 0.0, 7.0, 6.0, 0.0)),
+ Row.of(1.0, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0, 9.0)),
+ Row.of(1.0, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0)),
+ Row.of(1.0, Vectors.dense(0.0, 9.0, 8.0, 5.0, 6.0,
4.0).toSparse()),
+ Row.of(2.0, Vectors.dense(8.0, 9.0, 6.0, 5.0, 4.0,
4.0).toSparse()),
+ Row.of(2.0, Vectors.dense(8.0, 9.0, 6.0, 4.0, 0.0,
0.0).toSparse()));
+
+ private static final List<Row> INPUT_ANOVA_DATA =
+ Arrays.asList(
+ Row.of(
+ 1,
+ Vectors.dense(
+ 4.65415496e-03,
+ 1.03550567e-01,
+ -1.17358140e+00,
+ 1.61408773e-01,
+ 3.92492111e-01,
+ 7.31240882e-01)),
+ Row.of(
+ 1,
+ Vectors.dense(
+ -9.01651741e-01,
+ -5.28905302e-01,
+ 1.27636785e+00,
+ 7.02154563e-01,
+ 6.21348351e-01,
+ 1.88397353e-01)),
+ Row.of(
+ 1,
+ Vectors.dense(
+ 3.85692159e-01,
+ -9.04639637e-01,
+ 5.09782604e-02,
+ 8.40043971e-01,
+ 7.45977857e-01,
+ 8.78402288e-01)),
+ Row.of(
+ 1,
+ Vectors.dense(
+ 1.36264353e+00,
+ 2.62454094e-01,
+ 7.96306202e-01,
+ 6.14948000e-01,
+ 7.44948187e-01,
+ 9.74034830e-01)),
+ Row.of(
+ 1,
+ Vectors.dense(
+ 9.65874070e-01,
+ 2.52773665e+00,
+ -2.19380094e+00,
+ 2.33408080e-01,
+ 1.86340919e-01,
+ 8.23390433e-01)),
+ Row.of(
+ 2,
+ Vectors.dense(
+ 1.12324305e+01,
+ -2.77121515e-01,
+ 1.12740513e-01,
+ 2.35184013e-01,
+ 3.46668895e-01,
+ 9.38500782e-02)),
+ Row.of(
+ 2,
+ Vectors.dense(
+ 1.06195839e+01,
+ -1.82891238e+00,
+ 2.25085601e-01,
+ 9.09979851e-01,
+ 6.80257535e-02,
+ 8.24017480e-01)),
+ Row.of(
+ 2,
+ Vectors.dense(
+ 1.12806837e+01,
+ 1.30686889e+00,
+ 9.32839108e-02,
+ 3.49784755e-01,
+ 1.71322408e-02,
+ 7.48465194e-02)),
+ Row.of(
+ 2,
+ Vectors.dense(
+ 9.98689462e+00,
+ 9.50808938e-01,
+ -2.90786359e-01,
+ 2.31253009e-01,
+ 7.46270968e-01,
+ 1.60308169e-01)),
+ Row.of(
+ 2,
+ Vectors.dense(
+ 1.08428551e+01,
+ -1.02749936e+00,
+ 1.73951508e-01,
+ 8.92482744e-02,
+ 1.42651730e-01,
+ 7.66751625e-01)),
+ Row.of(
+ 3,
+ Vectors.dense(
+ -1.98641448e+00,
+ 1.12811990e+01,
+ -2.35246756e-01,
+ 8.22809049e-01,
+ 3.26739456e-01,
+ 7.88268404e-01)
+ .toSparse()),
+ Row.of(
+ 3,
+ Vectors.dense(
+ -6.09864090e-01,
+ 1.07346276e+01,
+ -2.18805509e-01,
+ 7.33931213e-01,
+ 1.42554396e-01,
+ 7.11225605e-01)
+ .toSparse()),
+ Row.of(
+ 3,
+ Vectors.dense(
+ -1.58481268e+00,
+ 9.19364039e+00,
+ -5.87490459e-02,
+ 2.51532056e-01,
+ 2.82729807e-01,
+ 7.16245686e-01)
+ .toSparse()),
+ Row.of(
+ 3,
+ Vectors.dense(
+ -2.50949277e-01,
+ 1.12815254e+01,
+ -6.94806734e-01,
+ 5.93898886e-01,
+ 5.68425656e-01,
+ 8.49762330e-01)
+ .toSparse()),
+ Row.of(
+ 3,
+ Vectors.dense(
+ 7.63485129e-01,
+ 1.02605138e+01,
+ 1.32617719e+00,
+ 5.49682879e-01,
+ 8.59931442e-01,
+ 4.88677978e-02)
+ .toSparse()),
+ Row.of(
+ 4,
+ Vectors.dense(
+ 9.34900015e-01,
+ 4.11379043e-01,
+ 8.65010205e+00,
+ 9.23509168e-01,
+ 1.16995043e-01,
+ 5.91894106e-03)
+ .toSparse()),
+ Row.of(
+ 4,
+ Vectors.dense(
+ 4.73734933e-01,
+ -1.48321181e+00,
+ 9.73349621e+00,
+ 4.09421563e-01,
+ 5.09375719e-01,
+ 5.93157850e-01)
+ .toSparse()),
+ Row.of(
+ 4,
+ Vectors.dense(
+ 3.41470679e-01,
+ -6.88972582e-01,
+ 9.60347938e+00,
+ 3.62654055e-01,
+ 2.43437468e-01,
+ 7.13052838e-01)
+ .toSparse()),
+ Row.of(
+ 4,
+ Vectors.dense(
+ -5.29614251e-01,
+ -1.39262856e+00,
+ 1.01354144e+01,
+ 8.24123861e-01,
+ 5.84074506e-01,
+ 6.54461558e-01)
+ .toSparse()),
+ Row.of(
+ 4,
+ Vectors.dense(
+ -2.99454508e-01,
+ 2.20457263e+00,
+ 1.14586015e+01,
+ 5.16336729e-01,
+ 9.99776159e-01,
+ 3.15769738e-01)
+ .toSparse()));
+
+ private static final List<Row> INPUT_FVALUE_DATA =
+ Arrays.asList(
+ Row.of(
+ 0.52516321,
+ Vectors.dense(
+ 0.19151945,
+ 0.62210877,
+ 0.43772774,
+ 0.78535858,
+ 0.77997581,
+ 0.27259261)),
+ Row.of(
+ 0.88275782,
+ Vectors.dense(
+ 0.27646426,
+ 0.80187218,
+ 0.95813935,
+ 0.87593263,
+ 0.35781727,
+ 0.50099513)),
+ Row.of(
+ 0.67524507,
+ Vectors.dense(
+ 0.68346294,
+ 0.71270203,
+ 0.37025075,
+ 0.56119619,
+ 0.50308317,
+ 0.01376845)),
+ Row.of(
+ 0.76734745,
+ Vectors.dense(
+ 0.77282662,
+ 0.88264119,
+ 0.36488598,
+ 0.61539618,
+ 0.07538124,
+ 0.36882401)),
+ Row.of(
+ 0.73909458,
+ Vectors.dense(
+ 0.9331401,
+ 0.65137814,
+ 0.39720258,
+ 0.78873014,
+ 0.31683612,
+ 0.56809865)),
+ Row.of(
+ 0.83628141,
+ Vectors.dense(
+ 0.86912739,
+ 0.43617342,
+ 0.80214764,
+ 0.14376682,
+ 0.70426097,
+ 0.70458131)),
+ Row.of(
+ 0.65665506,
+ Vectors.dense(
+ 0.21879211,
+ 0.92486763,
+ 0.44214076,
+ 0.90931596,
+ 0.05980922,
+ 0.18428708)),
+ Row.of(
+ 0.58147135,
+ Vectors.dense(
+ 0.04735528,
+ 0.67488094,
+ 0.59462478,
+ 0.53331016,
+ 0.04332406,
+ 0.56143308)),
+ Row.of(
+ 0.35603443,
+ Vectors.dense(
+ 0.32966845,
+ 0.50296683,
+ 0.11189432,
+ 0.60719371,
+ 0.56594464,
+ 0.00676406)),
+ Row.of(
+ 0.94534373,
+ Vectors.dense(
+ 0.61744171,
+ 0.91212289,
+ 0.79052413,
+ 0.99208147,
+ 0.95880176,
+ 0.79196414)),
+ Row.of(
+ 0.57458887,
+ Vectors.dense(
+ 0.28525096,
+ 0.62491671,
+ 0.4780938,
+ 0.19567518,
+ 0.38231745,
+ 0.05387369)
+ .toSparse()),
+ Row.of(
+ 0.59026777,
+ Vectors.dense(
+ 0.45164841,
+ 0.98200474,
+ 0.1239427,
+ 0.1193809,
+ 0.73852306,
+ 0.58730363)
+ .toSparse()),
+ Row.of(
+ 0.29894977,
+ Vectors.dense(
+ 0.47163253,
+ 0.10712682,
+ 0.22921857,
+ 0.89996519,
+ 0.41675354,
+ 0.53585166)
+ .toSparse()),
+ Row.of(
+ 0.34056582,
+ Vectors.dense(
+ 0.00620852,
+ 0.30064171,
+ 0.43689317,
+ 0.612149,
+ 0.91819808,
+ 0.62573667)
+ .toSparse()),
+ Row.of(
+ 0.64476446,
+ Vectors.dense(
+ 0.70599757,
+ 0.14983372,
+ 0.74606341,
+ 0.83100699,
+ 0.63372577,
+ 0.43830988)
+ .toSparse()),
+ Row.of(
+ 0.53724782,
+ Vectors.dense(
+ 0.15257277,
+ 0.56840962,
+ 0.52822428,
+ 0.95142876,
+ 0.48035918,
+ 0.50255956)
+ .toSparse()),
+ Row.of(
+ 0.5173021,
+ Vectors.dense(
+ 0.53687819,
+ 0.81920207,
+ 0.05711564,
+ 0.66942174,
+ 0.76711663,
+ 0.70811536)
+ .toSparse()),
+ Row.of(
+ 0.94508275,
+ Vectors.dense(
+ 0.79686718,
+ 0.55776083,
+ 0.96583653,
+ 0.1471569,
+ 0.029647,
+ 0.59389349)
+ .toSparse()),
+ Row.of(
+ 0.57739736,
+ Vectors.dense(
+ 0.1140657,
+ 0.95080985,
+ 0.96583653,
+ 0.19361869,
+ 0.45781165,
+ 0.92040257)
+ .toSparse()),
+ Row.of(
+ 0.53877145,
+ Vectors.dense(
+ 0.87906916,
+ 0.25261576,
+ 0.34800879,
+ 0.18258873,
+ 0.90179605,
+ 0.70652816)
+ .toSparse()));
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH,
true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ env.getConfig().enableObjectReuse();
+ tEnv = StreamTableEnvironment.create(env);
+
+ selectorWithChiSqTest =
+ new UnivariateFeatureSelector()
+ .setFeatureType("categorical")
+ .setLabelType("categorical");
+ selectorWithANOVATest =
+ new UnivariateFeatureSelector()
+ .setFeatureType("continuous")
+ .setLabelType("categorical");
+ selectorWithFValueTest =
+ new UnivariateFeatureSelector()
+ .setFeatureType("continuous")
+ .setLabelType("continuous");
+ inputChiSqTable =
+ tEnv.fromDataStream(
+ env.fromCollection(
+ INPUT_CHISQ_DATA,
+ Types.ROW(Types.DOUBLE,
VectorTypeInfo.INSTANCE)))
+ .as("label", "features");
+ inputANOVATable =
+ tEnv.fromDataStream(
+ env.fromCollection(
+ INPUT_ANOVA_DATA,
+ Types.ROW(Types.INT,
VectorTypeInfo.INSTANCE)))
+ .as("label", "features");
+ inputFValueTable =
+ tEnv.fromDataStream(
+ env.fromCollection(
+ INPUT_FVALUE_DATA,
+ Types.ROW(Types.DOUBLE,
VectorTypeInfo.INSTANCE)))
+ .as("label", "features");
+ }
+
+ private void transformAndVerify(
+ UnivariateFeatureSelector selector, Table table, int...
expectedIndices)
+ throws Exception {
+ UnivariateFeatureSelectorModel model = selector.fit(table);
+ Table output = model.transform(table)[0];
+ verifyOutputResult(output, expectedIndices);
+ }
+
+ private void verifyOutputResult(Table table, int... expectedIndices)
throws Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl)
table).getTableEnvironment();
+ CloseableIterator<Row> rowIterator =
tEnv.toDataStream(table).executeAndCollect();
+ while (rowIterator.hasNext()) {
+ Row row = rowIterator.next();
+ assertEquals(expectedIndices.length, ((Vector)
row.getField("output")).size());
+ for (int i = 0; i < expectedIndices.length; i++) {
+ assertEquals(
+ ((Vector)
row.getField("features")).get(expectedIndices[i]),
+ ((Vector) row.getField("output")).get(i),
+ EPS);
+ }
+ }
+ }
+
+ @Test
+ public void testParam() {
+ UnivariateFeatureSelector selector = new UnivariateFeatureSelector();
+ assertEquals("features", selector.getFeaturesCol());
+ assertEquals("label", selector.getLabelCol());
+ assertEquals("output", selector.getOutputCol());
+ assertEquals("numTopFeatures", selector.getSelectionMode());
+ assertNull(selector.getSelectionThreshold());
+
+ selector.setFeaturesCol("test_features")
+ .setLabelCol("test_label")
+ .setOutputCol("test_output")
+ .setFeatureType("continuous")
+ .setLabelType("categorical")
+ .setSelectionMode("fpr")
+ .setSelectionThreshold(0.01);
+
+ assertEquals("test_features", selector.getFeaturesCol());
+ assertEquals("test_label", selector.getLabelCol());
+ assertEquals("test_output", selector.getOutputCol());
+ assertEquals("continuous", selector.getFeatureType());
+ assertEquals("categorical", selector.getLabelType());
+ assertEquals("fpr", selector.getSelectionMode());
+ assertEquals(0.01, selector.getSelectionThreshold(), EPS);
+ }
+
+ @Test
+ public void testIncompatibleSelectionModeAndThreshold() {
+ UnivariateFeatureSelector selector =
+ new UnivariateFeatureSelector()
+ .setFeatureType("continuous")
+ .setLabelType("categorical")
+ .setSelectionThreshold(50.1);
+
+ try {
+ selector.fit(inputANOVATable);
+ fail();
+ } catch (Throwable e) {
+ assertEquals(
+ "SelectionThreshold needs to be a positive Integer "
+ + "for selection mode numTopFeatures, but got
50.1.",
+ e.getMessage());
+ }
+ try {
+
selector.setSelectionMode("fpr").setSelectionThreshold(1.1).fit(inputANOVATable);
+ fail();
+ } catch (Throwable e) {
+ assertEquals(
+ "SelectionThreshold needs to be in the range [0, 1] "
+ + "for selection mode fpr, but got 1.1.",
+ e.getMessage());
+ }
+ }
+
+ @Test
+ public void testOutputSchema() {
+ Table tempTable = inputANOVATable.as("test_label", "test_features");
+ UnivariateFeatureSelector selector =
+ new UnivariateFeatureSelector()
+ .setLabelCol("test_label")
+ .setFeaturesCol("test_features")
+ .setOutputCol("test_output")
+ .setFeatureType("continuous")
+ .setLabelType("categorical");
+
+ UnivariateFeatureSelectorModel model = selector.fit(tempTable);
+ Table output = model.transform(tempTable)[0];
+ assertEquals(
+ Arrays.asList("test_label", "test_features", "test_output"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testFitTransformWithNumTopFeatures() throws Exception {
+ transformAndVerify(selectorWithChiSqTest.setSelectionThreshold(2),
inputChiSqTable, 0, 1);
+ transformAndVerify(selectorWithANOVATest.setSelectionThreshold(2),
inputANOVATable, 0, 2);
+ transformAndVerify(selectorWithFValueTest.setSelectionThreshold(2),
inputFValueTable, 0, 2);
+ }
+
+ @Test
+ public void testFitTransformWithPercentile() throws Exception {
+ transformAndVerify(
+
selectorWithChiSqTest.setSelectionMode("percentile").setSelectionThreshold(0.17),
+ inputChiSqTable,
+ 0);
+ transformAndVerify(
+
selectorWithANOVATest.setSelectionMode("percentile").setSelectionThreshold(0.17),
+ inputANOVATable,
+ 0);
+ transformAndVerify(
+
selectorWithFValueTest.setSelectionMode("percentile").setSelectionThreshold(0.17),
+ inputFValueTable,
+ 2);
+ }
+
+ @Test
+ public void testFitTransformWithFPR() throws Exception {
+ transformAndVerify(
+
selectorWithChiSqTest.setSelectionMode("fpr").setSelectionThreshold(0.02),
+ inputChiSqTable,
+ 0);
+ transformAndVerify(
+
selectorWithANOVATest.setSelectionMode("fpr").setSelectionThreshold(1.0E-12),
+ inputANOVATable,
+ 0);
+ transformAndVerify(
+
selectorWithFValueTest.setSelectionMode("fpr").setSelectionThreshold(0.01),
+ inputFValueTable,
+ 2);
+ }
+
+ @Test
+ public void testFitTransformWithFDR() throws Exception {
+ transformAndVerify(
+
selectorWithChiSqTest.setSelectionMode("fdr").setSelectionThreshold(0.12),
+ inputChiSqTable,
+ 0);
+ transformAndVerify(
+
selectorWithANOVATest.setSelectionMode("fdr").setSelectionThreshold(6.0E-12),
+ inputANOVATable,
+ 0);
+ transformAndVerify(
+
selectorWithFValueTest.setSelectionMode("fdr").setSelectionThreshold(0.03),
+ inputFValueTable,
+ 2);
+ }
+
+ @Test
+ public void testFitTransformWithFWE() throws Exception {
+ transformAndVerify(
+
selectorWithChiSqTest.setSelectionMode("fwe").setSelectionThreshold(0.12),
+ inputChiSqTable,
+ 0);
+ transformAndVerify(
+
selectorWithANOVATest.setSelectionMode("fwe").setSelectionThreshold(6.0E-12),
+ inputANOVATable,
+ 0);
+ transformAndVerify(
+
selectorWithFValueTest.setSelectionMode("fwe").setSelectionThreshold(0.03),
+ inputFValueTable,
+ 2);
+ }
+
+ @Test
+ public void testSaveLoadAndPredict() throws Exception {
+ UnivariateFeatureSelector selector =
+ new UnivariateFeatureSelector()
+ .setFeatureType("continuous")
+ .setLabelType("categorical")
+ .setSelectionThreshold(1);
+
+ UnivariateFeatureSelector loadSelector =
+ TestUtils.saveAndReload(tEnv, selector,
tempFolder.newFolder().getAbsolutePath());
+ UnivariateFeatureSelectorModel model =
loadSelector.fit(inputANOVATable);
+ UnivariateFeatureSelectorModel loadedModel =
+ TestUtils.saveAndReload(tEnv, model,
tempFolder.newFolder().getAbsolutePath());
+ assertEquals(
+ Collections.singletonList("indices"),
+ model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+ Table output = loadedModel.transform(inputANOVATable)[0];
+ verifyOutputResult(output, 0);
+ }
+
+ @Test
+ public void testIncompatibleNumOfFeatures() {
+ UnivariateFeatureSelector selector =
+ new UnivariateFeatureSelector()
+ .setFeatureType("continuous")
+ .setLabelType("continuous")
+ .setSelectionThreshold(1);
+ UnivariateFeatureSelectorModel model = selector.fit(inputFValueTable);
+
+ List<Row> predictData =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of(1, Vectors.dense(1.0, 2.0)),
+ Row.of(-1, Vectors.dense(-1.0, -2.0))));
+ Table predictTable =
+
tEnv.fromDataStream(env.fromCollection(predictData)).as("label", "features");
+ Table output = model.transform(predictTable)[0];
+ try {
+ output.execute().print();
+ fail();
+ } catch (Throwable e) {
+ assertEquals(
+ "Input 2 features, but UnivariateFeatureSelector is "
+ + "expecting at least 3 features as input.",
+ ExceptionUtils.getRootCause(e).getMessage());
+ }
+ }
+
+ @Test
+ public void testEquivalentPValues() throws Exception {
+ List<Row> inputData =
+ Arrays.asList(
+ Row.of(0.0, Vectors.dense(6.0, 7.0, 0.0, 6.0, 6.0,
6.0)),
+ Row.of(1.0, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0,
0.0)),
+ Row.of(1.0, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0,
0.0)),
+ Row.of(1.0, Vectors.dense(0.0, 9.0, 8.0, 0.0, 6.0,
0.0)),
+ Row.of(2.0, Vectors.dense(8.0, 9.0, 6.0, 8.0, 4.0,
8.0)),
+ Row.of(2.0, Vectors.dense(8.0, 9.0, 6.0, 8.0, 0.0,
8.0)));
+ Table inputTable =
+ tEnv.fromDataStream(
+ env.fromCollection(
+ inputData,
+ Types.ROW(Types.DOUBLE,
VectorTypeInfo.INSTANCE)))
+ .as("label", "features");
+ UnivariateFeatureSelectorModel model =
+ selectorWithChiSqTest.setSelectionThreshold(4).fit(inputTable);
+ Table modelData = model.getModelData()[0];
+ DataStream<Row> output = tEnv.toDataStream(modelData);
+ List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect());
+ int[] expectedIndices = {0, 3, 5, 1};
+ assertArrayEquals(expectedIndices, (int[])
modelRows.get(0).getField(0));
+ }
+
+ @Test
+ public void testGetModelData() throws Exception {
+ UnivariateFeatureSelector selector =
+ new UnivariateFeatureSelector()
+ .setFeatureType("continuous")
+ .setLabelType("categorical")
+ .setSelectionThreshold(3);
+ UnivariateFeatureSelectorModel model = selector.fit(inputANOVATable);
+ Table modelData = model.getModelData()[0];
+ assertEquals(
+ Collections.singletonList("indices"),
+ modelData.getResolvedSchema().getColumnNames());
+
+ DataStream<Row> output = tEnv.toDataStream(modelData);
+ List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect());
+ int[] expectedIndices = {0, 2, 1};
+ assertArrayEquals(expectedIndices, (int[])
modelRows.get(0).getField(0));
+ }
+
+ @Test
+ public void testSetModelData() throws Exception {
+ UnivariateFeatureSelector selector =
+
selectorWithANOVATest.setSelectionMode("fpr").setSelectionThreshold(1.0E-12);
+ UnivariateFeatureSelectorModel modelA = selector.fit(inputANOVATable);
+ Table modelData = modelA.getModelData()[0];
+
+ UnivariateFeatureSelectorModel modelB =
+ new UnivariateFeatureSelectorModel().setModelData(modelData);
+ Table output = modelB.transform(inputANOVATable)[0];
+ verifyOutputResult(output, 0);
+ }
+}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java
index 95f9779..3e2716b 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VarianceThresholdSelectorTest.java
@@ -20,17 +20,17 @@ package org.apache.flink.ml.feature;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.configuration.Configuration;
import
org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelector;
import
org.apache.flink.ml.feature.variancethresholdselector.VarianceThresholdSelectorModel;
-import org.apache.flink.ml.linalg.DenseVector;
-import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
import org.apache.flink.ml.util.TestUtils;
import org.apache.flink.streaming.api.datastream.DataStream;
import
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.internal.TableImpl;
@@ -48,7 +48,6 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
-import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
@@ -64,24 +63,25 @@ public class VarianceThresholdSelectorTest extends
AbstractTestBase {
private static final double EPS = 1.0e-5;
private static final List<Row> TRAIN_DATA =
- new ArrayList<>(
- Arrays.asList(
- Row.of(1, Vectors.dense(5.0, 7.0, 0.0, 7.0, 6.0,
0.0)),
- Row.of(2, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0,
9.0)),
- Row.of(3, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0,
5.0)),
- Row.of(4, Vectors.dense(1.0, 9.0, 8.0, 5.0, 7.0,
4.0)),
- Row.of(5, Vectors.dense(9.0, 8.0, 6.0, 5.0, 4.0,
4.0)),
- Row.of(6, Vectors.dense(6.0, 9.0, 7.0, 0.0, 2.0,
0.0))));
+ Arrays.asList(
+ Row.of(1, Vectors.dense(5.0, 7.0, 0.0, 7.0, 6.0, 0.0)),
+ Row.of(2, Vectors.dense(0.0, 9.0, 6.0, 0.0, 5.0,
9.0).toSparse()),
+ Row.of(3, Vectors.dense(0.0, 9.0, 3.0, 0.0, 5.0, 5.0)),
+ Row.of(4, Vectors.dense(1.0, 9.0, 8.0, 5.0, 7.0,
4.0).toSparse()),
+ Row.of(5, Vectors.dense(9.0, 8.0, 6.0, 5.0, 4.0, 4.0)),
+ Row.of(6, Vectors.dense(6.0, 9.0, 7.0, 0.0, 2.0,
0.0).toSparse()));
private static final List<Row> PREDICT_DATA =
- new ArrayList<>(
- Arrays.asList(
- Row.of(Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0,
6.0)),
- Row.of(Vectors.dense(0.1, 0.2, 0.3, 0.4, 0.5,
0.6))));
+ Arrays.asList(
+ Row.of(Vectors.dense(1.0, 2.0, 3.0, 4.0, 5.0, 6.0)),
+ Row.of(Vectors.dense(0.1, 0.2, 0.3, 0.4, 0.5, 0.6)),
+ Row.of(Vectors.sparse(6, new int[] {0, 3, 4}, new double[]
{0.1, 0.3, 0.5})));
- private static final List<DenseVector> EXPECTED_OUTPUT =
- new ArrayList<>(
- Arrays.asList(Vectors.dense(1.0, 4.0, 6.0),
Vectors.dense(0.1, 0.4, 0.6)));
+ private static final List<Vector> EXPECTED_OUTPUT =
+ Arrays.asList(
+ Vectors.dense(1.0, 4.0, 6.0),
+ Vectors.dense(0.1, 0.4, 0.6),
+ Vectors.sparse(3, new int[] {0, 1}, new double[] {0.1,
0.3}));
@Before
public void before() {
@@ -94,20 +94,26 @@ public class VarianceThresholdSelectorTest extends
AbstractTestBase {
env.setRestartStrategy(RestartStrategies.noRestart());
tEnv = StreamTableEnvironment.create(env);
- trainDataTable =
tEnv.fromDataStream(env.fromCollection(TRAIN_DATA)).as("id", "input");
- predictDataTable =
tEnv.fromDataStream(env.fromCollection(PREDICT_DATA)).as("input");
+ trainDataTable =
+ tEnv.fromDataStream(
+ env.fromCollection(
+ TRAIN_DATA, Types.ROW(Types.INT,
VectorTypeInfo.INSTANCE)))
+ .as("id", "input");
+ predictDataTable =
+ tEnv.fromDataStream(
+ env.fromCollection(
+ PREDICT_DATA,
Types.ROW(VectorTypeInfo.INSTANCE)))
+ .as("input");
}
private static void verifyPredictionResult(
- Table output, String outputCol, List<DenseVector> expected) throws
Exception {
+ Table output, String outputCol, List<Vector> expected) throws
Exception {
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl)
output).getTableEnvironment();
- DataStream<DenseVector> stream =
+ DataStream<Vector> stream =
tEnv.toDataStream(output)
- .map(
- (MapFunction<Row, DenseVector>)
- row -> (DenseVector)
row.getField(outputCol));
- List<DenseVector> result =
IteratorUtils.toList(stream.executeAndCollect());
+ .map((MapFunction<Row, Vector>) row -> (Vector)
row.getField(outputCol));
+ List<Vector> result = IteratorUtils.toList(stream.executeAndCollect());
compareResultCollections(expected, result, TestUtils::compare);
}
@@ -160,26 +166,7 @@ public class VarianceThresholdSelectorTest extends
AbstractTestBase {
verifyPredictionResult(
predictTableOutput,
varianceThresholdSelector.getOutputCol(),
- new ArrayList<>(Arrays.asList(Vectors.dense(),
Vectors.dense())));
- }
-
- @Test
- public void testInputTypeConversion() throws Exception {
- trainDataTable =
- TestUtils.convertDataTypesToSparseInt(
- tEnv, trainDataTable.select(Expressions.$("input")));
- predictDataTable = TestUtils.convertDataTypesToSparseInt(tEnv,
predictDataTable);
- assertArrayEquals(
- new Class<?>[] {SparseVector.class},
TestUtils.getColumnDataTypes(trainDataTable));
- assertArrayEquals(
- new Class<?>[] {SparseVector.class},
- TestUtils.getColumnDataTypes(predictDataTable));
-
- VarianceThresholdSelector varianceThresholdSelector =
- new VarianceThresholdSelector().setVarianceThreshold(8.0);
- VarianceThresholdSelectorModel model =
varianceThresholdSelector.fit(trainDataTable);
- Table output = model.transform(predictDataTable)[0];
- verifyPredictionResult(output,
varianceThresholdSelector.getOutputCol(), EXPECTED_OUTPUT);
+ Arrays.asList(Vectors.dense(), Vectors.dense(),
Vectors.dense()));
}
@Test
@@ -202,12 +189,16 @@ public class VarianceThresholdSelectorTest extends
AbstractTestBase {
@Test
public void testFitOnEmptyData() {
Table emptyTable =
- tEnv.fromDataStream(env.fromCollection(TRAIN_DATA).filter(x ->
x.getArity() == 0))
+ tEnv.fromDataStream(
+ env.fromCollection(
+ TRAIN_DATA,
+ Types.ROW(Types.INT,
VectorTypeInfo.INSTANCE))
+ .filter(x -> x.getArity() == 0))
.as("id", "input");
+
VarianceThresholdSelector varianceThresholdSelector = new
VarianceThresholdSelector();
VarianceThresholdSelectorModel model =
varianceThresholdSelector.fit(emptyTable);
Table modelDataTable = model.getModelData()[0];
-
try {
modelDataTable.execute().print();
fail();
diff --git
a/flink-ml-python/pyflink/examples/ml/feature/univariatefeatureselector_example.py
b/flink-ml-python/pyflink/examples/ml/feature/univariatefeatureselector_example.py
new file mode 100644
index 0000000..e053183
--- /dev/null
+++
b/flink-ml-python/pyflink/examples/ml/feature/univariatefeatureselector_example.py
@@ -0,0 +1,68 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates a UnivariateFeatureSelector instance and uses it
for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.univariatefeatureselector import
UnivariateFeatureSelector
+from pyflink.table import StreamTableEnvironment
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+
+env = StreamExecutionEnvironment.get_execution_environment()
+
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input training and prediction data.
+input_table = t_env.from_data_stream(
+ env.from_collection([
+ (Vectors.dense(1.7, 4.4, 7.6, 5.8, 9.6, 2.3), 3.0,),
+ (Vectors.dense(8.8, 7.3, 5.7, 7.3, 2.2, 4.1), 2.0,),
+ (Vectors.dense(1.2, 9.5, 2.5, 3.1, 8.7, 2.5), 1.0,),
+ (Vectors.dense(3.7, 9.2, 6.1, 4.1, 7.5, 3.8), 2.0,),
+ (Vectors.dense(8.9, 5.2, 7.8, 8.3, 5.2, 3.0), 4.0,),
+ (Vectors.dense(7.9, 8.5, 9.2, 4.0, 9.4, 2.1), 4.0,),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['features', 'label'],
+ [DenseVectorTypeInfo(), Types.FLOAT()])
+ ))
+
+# Creates an UnivariateFeatureSelector object and initializes its parameters.
+univariate_feature_selector = UnivariateFeatureSelector() \
+ .set_features_col('features') \
+ .set_label_col('label') \
+ .set_feature_type('continuous') \
+ .set_label_type('categorical') \
+ .set_selection_threshold(1)
+
+# Trains the UnivariateFeatureSelector Model.
+model = univariate_feature_selector.fit(input_table)
+
+# Uses the UnivariateFeatureSelector Model for predictions.
+output = model.transform(input_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ input_index =
field_names.index(univariate_feature_selector.get_features_col())
+ output_index =
field_names.index(univariate_feature_selector.get_output_col())
+ print('Input Value: ' + str(result[input_index]) +
+ '\tOutput Value: ' + str(result[output_index]))
diff --git
a/flink-ml-python/pyflink/ml/lib/feature/tests/test_univariatefeatureselector.py
b/flink-ml-python/pyflink/ml/lib/feature/tests/test_univariatefeatureselector.py
new file mode 100644
index 0000000..9e4a620
--- /dev/null
+++
b/flink-ml-python/pyflink/ml/lib/feature/tests/test_univariatefeatureselector.py
@@ -0,0 +1,203 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from typing import List
+
+from pyflink.common import Types
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase,
update_existing_params
+
+from pyflink.ml.core.linalg import DenseVectorTypeInfo, Vectors
+
+from pyflink.ml.lib.feature.univariatefeatureselector import
UnivariateFeatureSelector, \
+ UnivariateFeatureSelectorModel
+from pyflink.table import Table
+
+
+class UnivariateFeatureSelectorTest(PyFlinkMLTestCase):
+
+ def setUp(self):
+ super(UnivariateFeatureSelectorTest, self).setUp()
+ self.input_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (1, Vectors.dense(4.65415496e-03, 1.03550567e-01,
-1.17358140e+00,
+ 1.61408773e-01, 3.92492111e-01,
7.31240882e-01)),
+ (1, Vectors.dense(-9.01651741e-01, -5.28905302e-01,
1.27636785e+00,
+ 7.02154563e-01, 6.21348351e-01,
1.88397353e-01)),
+ (1, Vectors.dense(3.85692159e-01, -9.04639637e-01,
5.09782604e-02,
+ 8.40043971e-01, 7.45977857e-01,
8.78402288e-01)),
+ (1, Vectors.dense(1.36264353e+00, 2.62454094e-01,
7.96306202e-01,
+ 6.14948000e-01, 7.44948187e-01,
9.74034830e-01)),
+ (1, Vectors.dense(9.65874070e-01, 2.52773665e+00,
-2.19380094e+00,
+ 2.33408080e-01, 1.86340919e-01,
8.23390433e-01)),
+ (2, Vectors.dense(1.12324305e+01, -2.77121515e-01,
1.12740513e-01,
+ 2.35184013e-01, 3.46668895e-01,
9.38500782e-02)),
+ (2, Vectors.dense(1.06195839e+01, -1.82891238e+00,
2.25085601e-01,
+ 9.09979851e-01, 6.80257535e-02,
8.24017480e-01)),
+ (2, Vectors.dense(1.12806837e+01, 1.30686889e+00,
9.32839108e-02,
+ 3.49784755e-01, 1.71322408e-02,
7.48465194e-02)),
+ (2, Vectors.dense(9.98689462e+00, 9.50808938e-01,
-2.90786359e-01,
+ 2.31253009e-01, 7.46270968e-01,
1.60308169e-01)),
+ (2, Vectors.dense(1.08428551e+01, -1.02749936e+00,
1.73951508e-01,
+ 8.92482744e-02, 1.42651730e-01,
7.66751625e-01)),
+ (3, Vectors.dense(-1.98641448e+00, 1.12811990e+01,
-2.35246756e-01,
+ 8.22809049e-01, 3.26739456e-01,
7.88268404e-01)),
+ (3, Vectors.dense(-6.09864090e-01, 1.07346276e+01,
-2.18805509e-01,
+ 7.33931213e-01, 1.42554396e-01,
7.11225605e-01)),
+ (3, Vectors.dense(-1.58481268e+00, 9.19364039e+00,
-5.87490459e-02,
+ 2.51532056e-01, 2.82729807e-01,
7.16245686e-01)),
+ (3, Vectors.dense(-2.50949277e-01, 1.12815254e+01,
-6.94806734e-01,
+ 5.93898886e-01, 5.68425656e-01,
8.49762330e-01)),
+ (3, Vectors.dense(7.63485129e-01, 1.02605138e+01,
1.32617719e+00,
+ 5.49682879e-01, 8.59931442e-01,
4.88677978e-02)),
+ (4, Vectors.dense(9.34900015e-01, 4.11379043e-01,
8.65010205e+00,
+ 9.23509168e-01, 1.16995043e-01,
5.91894106e-03)),
+ (4, Vectors.dense(4.73734933e-01, -1.48321181e+00,
9.73349621e+00,
+ 4.09421563e-01, 5.09375719e-01,
5.93157850e-01)),
+ (4, Vectors.dense(3.41470679e-01, -6.88972582e-01,
9.60347938e+00,
+ 3.62654055e-01, 2.43437468e-01,
7.13052838e-01)),
+ (4, Vectors.dense(-5.29614251e-01, -1.39262856e+00,
1.01354144e+01,
+ 8.24123861e-01, 5.84074506e-01,
6.54461558e-01)),
+ (4, Vectors.dense(-2.99454508e-01, 2.20457263e+00,
1.14586015e+01,
+ 5.16336729e-01, 9.99776159e-01,
3.15769738e-01)),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['label', 'features'],
+ [Types.INT(), DenseVectorTypeInfo()])
+ ))
+
+ def test_param(self):
+ univariate_feature_selector = UnivariateFeatureSelector()
+ self.assertEqual('features', univariate_feature_selector.features_col)
+ self.assertEqual('label', univariate_feature_selector.label_col)
+ self.assertEqual('output', univariate_feature_selector.output_col)
+ with self.assertRaises(Exception) as context:
+ univariate_feature_selector.feature_type
+ self.assertTrue("Parameter featureType's value should not be null"
in context.exception)
+ with self.assertRaises(Exception) as context:
+ univariate_feature_selector.label_type
+ self.assertTrue("Parameter labelType's value should not be null"
in context.exception)
+ self.assertEqual('numTopFeatures',
univariate_feature_selector.selection_mode)
+ self.assertIsNone(univariate_feature_selector.selection_threshold)
+
+ univariate_feature_selector\
+ .set_features_col("test_features")\
+ .set_label_col('test_label')\
+ .set_output_col('test_output')\
+ .set_feature_type('continuous')\
+ .set_label_type('categorical')\
+ .set_selection_mode('fpr')\
+ .set_selection_threshold(0.01)
+ self.assertEqual('test_features',
univariate_feature_selector.features_col)
+ self.assertEqual('test_label', univariate_feature_selector.label_col)
+ self.assertEqual('test_output', univariate_feature_selector.output_col)
+ self.assertEqual('continuous',
univariate_feature_selector.feature_type)
+ self.assertEqual('categorical', univariate_feature_selector.label_type)
+ self.assertEqual('fpr', univariate_feature_selector.selection_mode)
+ self.assertEqual(0.01, univariate_feature_selector.selection_threshold)
+
+ def test_output_schema(self):
+ selector = UnivariateFeatureSelector()\
+ .set_features_col("test_features")\
+ .set_label_col('test_label')\
+ .set_output_col('test_output')\
+ .set_feature_type('continuous')\
+ .set_label_type('categorical')
+ temp_table = self.input_table.alias('test_label', 'test_features')
+ model = selector.fit(temp_table)
+ output = model.transform(temp_table)[0]
+ self.assertEqual(
+ ['test_label', 'test_features', 'test_output'],
+ output.get_schema().get_field_names())
+
+ def test_fit_and_predict(self):
+ selector = UnivariateFeatureSelector() \
+ .set_feature_type('continuous') \
+ .set_label_type('categorical') \
+ .set_selection_threshold(3)
+ model = selector.fit(self.input_table)
+ output = model.transform(self.input_table)[0]
+ self.verify_output_result(
+ output,
+ output.get_schema().get_field_names(),
+ selector.get_features_col(),
+ selector.get_output_col(),
+ [0, 1, 2])
+
+ def test_get_model_data(self):
+ selector = UnivariateFeatureSelector() \
+ .set_feature_type('continuous') \
+ .set_label_type('categorical') \
+ .set_selection_threshold(3)
+ model = selector.fit(self.input_table)
+ model_data = model.get_model_data()[0]
+ self.assertEqual(['indices'],
model_data.get_schema().get_field_names())
+
+ model_rows = [result for result in
+
self.t_env.to_data_stream(model_data).execute_and_collect()]
+ self.assertEqual(1, len(model_rows))
+ self.assertListEqual([0, 2, 1], model_rows[0][0])
+
+ def test_set_model_data(self):
+ selector = UnivariateFeatureSelector() \
+ .set_feature_type('continuous') \
+ .set_label_type('categorical') \
+ .set_selection_threshold(3)
+ model_a = selector.fit(self.input_table)
+ model_data = model_a.get_model_data()[0]
+
+ model_b = UnivariateFeatureSelectorModel() \
+ .set_model_data(model_data)
+ update_existing_params(model_b, model_a)
+
+ output = model_b.transform(self.input_table)[0]
+ self.verify_output_result(
+ output,
+ output.get_schema().get_field_names(),
+ selector.get_features_col(),
+ selector.get_output_col(),
+ [0, 1, 2])
+
+ def test_save_load_predict(self):
+ selector = UnivariateFeatureSelector() \
+ .set_feature_type('continuous') \
+ .set_label_type('categorical') \
+ .set_selection_threshold(3)
+ reloaded_selector = self.save_and_reload(selector)
+ model = reloaded_selector.fit(self.input_table)
+ reloaded_model = self.save_and_reload(model)
+ output = reloaded_model.transform(self.input_table)[0]
+ self.verify_output_result(
+ output,
+ output.get_schema().get_field_names(),
+ selector.get_features_col(),
+ selector.get_output_col(),
+ [0, 1, 2])
+
+ def verify_output_result(
+ self, output: Table,
+ field_names: List[str],
+ feature_col: str,
+ output_col: str,
+ indices: List[int]):
+ collected_results = [result for result in
+
self.t_env.to_data_stream(output).execute_and_collect()]
+ for item in collected_results:
+ item.set_field_names(field_names)
+ self.assertEqual(len(indices), item[output_col].size())
+ for i in range(0, len(indices)):
+ self.assertEqual(item[feature_col].get(indices[i]),
+ item[output_col].get(i))
diff --git
a/flink-ml-python/pyflink/ml/lib/feature/univariatefeatureselector.py
b/flink-ml-python/pyflink/ml/lib/feature/univariatefeatureselector.py
new file mode 100644
index 0000000..0b59945
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/univariatefeatureselector.py
@@ -0,0 +1,208 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import typing
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.core.param import StringParam, FloatParam
+from pyflink.ml.lib.param import HasFeaturesCol, HasLabelCol, HasOutputCol
+from pyflink.ml.lib.feature.common import JavaFeatureModel,
JavaFeatureEstimator
+
+
+class _UnivariateFeatureSelectorModelParams(
+ JavaWithParams,
+ HasFeaturesCol,
+ HasOutputCol
+):
+ """
+ Params for :class `UnivariateFeatureSelectorModel`.
+ """
+ def __init__(self, java_params):
+ super(_UnivariateFeatureSelectorModelParams,
self).__init__(java_params)
+
+
+class _UnivariateFeatureSelectorParams(HasLabelCol,
_UnivariateFeatureSelectorModelParams):
+ """
+ Params for :class `UnivariateFeatureSelector`.
+ """
+
+ """
+ Supported options of the feature type.
+
+ <ul>
+ <li>categorical: the features are categorical data.
+ <li>continuous: the features are continuous data.
+ </ul>
+ """
+ FEATURE_TYPE: StringParam = StringParam(
+ "feature_type",
+ "The feature type.",
+ None)
+
+ """
+ Supported options of the label type.
+
+ <ul>
+ <li>categorical: the label is categorical data.
+ <li>continuous: the label is continuous data.
+ </ul>
+ """
+ LABEL_TYPE: StringParam = StringParam(
+ "label_type",
+ "The label type.",
+ None)
+
+ """
+ Supported options of the feature selection mode.
+
+ <ul>
+ <li>numTopFeatures: chooses a fixed number of top features according
to a hypothesis.
+ <li>percentile: similar to numTopFeatures but chooses a fraction of
all features
+ instead of a fixed number.
+ <li>fpr: chooses all features whose p-value are below a threshold,
thus controlling the
+ false positive rate of selection.
+ <li>fdr: uses the
<ahref="https://en.wikipedia.org/wiki/False_discovery_rate#
+ Benjamini.E2.80.93Hochberg_procedure">Benjamini-Hochberg
procedure</a> to choose
+ all features whose false discovery rate is below a threshold.
+ <li>fwe: chooses all features whose p-values are below a threshold.
The threshold is
+ scaled by 1/numFeatures, thus controlling the family-wise error
rate of selection.
+ </ul>
+ """
+ SELECTION_MODE: StringParam = StringParam(
+ "selection_mode",
+ "The feature selection mode.",
+ "numTopFeatures")
+
+ SELECTION_THRESHOLD: FloatParam = FloatParam(
+ "selection_threshold",
+ "The upper bound of the features that selector will select. If not
set, it will be "
+ "replaced with a meaningful value according to different selection
modes at runtime. "
+ "When the mode is numTopFeatures, it will be replaced with 50; when
the mode is "
+ "percentile, it will be replaced with 0.1; otherwise, it will be
replaced with 0.05.",
+ None)
+
+ def __init__(self, java_params):
+ super(_UnivariateFeatureSelectorParams, self).__init__(java_params)
+
+ def set_feature_type(self, value: str):
+ return typing.cast(_UnivariateFeatureSelectorParams,
self.set(self.FEATURE_TYPE, value))
+
+ def get_feature_type(self) -> str:
+ return self.get(self.FEATURE_TYPE)
+
+ def set_label_type(self, value: str):
+ return typing.cast(_UnivariateFeatureSelectorParams,
self.set(self.LABEL_TYPE, value))
+
+ def get_label_type(self) -> str:
+ return self.get(self.LABEL_TYPE)
+
+ def set_selection_mode(self, value: str):
+ return typing.cast(_UnivariateFeatureSelectorParams,
self.set(self.SELECTION_MODE, value))
+
+ def get_selection_mode(self) -> str:
+ return self.get(self.SELECTION_MODE)
+
+ def set_selection_threshold(self, value: float):
+ return typing.cast(_UnivariateFeatureSelectorParams,
+ self.set(self.SELECTION_THRESHOLD, float(value)))
+
+ def get_selection_threshold(self) -> float:
+ return self.get(self.SELECTION_THRESHOLD)
+
+ @property
+ def feature_type(self):
+ return self.get_feature_type()
+
+ @property
+ def label_type(self):
+ return self.get_label_type()
+
+ @property
+ def selection_mode(self):
+ return self.get_selection_mode()
+
+ @property
+ def selection_threshold(self):
+ return self.get_selection_threshold()
+
+
+class UnivariateFeatureSelectorModel(JavaFeatureModel,
_UnivariateFeatureSelectorModelParams):
+ """
+ A Model which transforms data using the model data computed
+ by :class:`UnivariateFeatureSelector`.
+ """
+
+ def __init__(self, java_model=None):
+ super(UnivariateFeatureSelectorModel, self).__init__(java_model)
+
+ @classmethod
+ def _java_model_package_name(cls) -> str:
+ return "univariatefeatureselector"
+
+ @classmethod
+ def _java_model_class_name(cls) -> str:
+ return "UnivariateFeatureSelectorModel"
+
+
+class UnivariateFeatureSelector(JavaFeatureEstimator,
_UnivariateFeatureSelectorParams):
+ """
+ An Estimator which selects features based on univariate statistical tests
against labels.
+
+ Currently, Flink supports three Univariate Feature Selectors: chi-squared,
ANOVA F-test and
+ F-value. User can choose Univariate Feature Selector by setting
`featureType` and `labelType`,
+ and Flink will pick the score function based on the specified
`featureType` and `labelType`.
+
+ The following combination of `featureType` and `labelType` are supported:
+
+ <ul>
+ <li>`featureType` `categorical` and `labelType` `categorical`: Flink
uses chi-squared,
+ i.e. chi2 in sklearn.
+ <li>`featureType` `continuous` and `labelType` `categorical`: Flink
uses ANOVA F-test,
+ i.e. f_classif in sklearn.
+ <li>`featureType` `continuous` and `labelType` `continuous`: Flink
uses F-value,
+ i.e. f_regression in sklearn.
+ </ul>
+
+ The `UnivariateFeatureSelector` supports different selection modes:
+
+ <ul>
+ <li>numTopFeatures: chooses a fixed number of top features according
to a hypothesis.
+ <li>percentile: similar to numTopFeatures but chooses a fraction of
all features
+ instead of a fixed number.
+ <li>fpr: chooses all features whose p-value are below a threshold,
thus controlling
+ the false positive rate of selection.
+ <li>fdr: uses the
<ahref="https://en.wikipedia.org/wiki/False_discovery_rate#
+ Benjamini.E2.80.93Hochberg_procedure">Benjamini-Hochberg
procedure</a> to choose
+ all features whose false discovery rate is below a threshold.
+ <li>fwe: chooses all features whose p-values are below a threshold.
The threshold is
+ scaled by 1/numFeatures, thus controlling the family-wise error
rate of selection.
+ </ul>
+
+ By default, the selection mode is `numTopFeatures`.
+ """
+
+ def __init__(self):
+ super(UnivariateFeatureSelector, self).__init__()
+
+ @classmethod
+ def _create_model(cls, java_model) -> UnivariateFeatureSelectorModel:
+ return UnivariateFeatureSelectorModel(java_model)
+
+ @classmethod
+ def _java_estimator_package_name(cls) -> str:
+ return "univariatefeatureselector"
+
+ @classmethod
+ def _java_estimator_class_name(cls) -> str:
+ return "UnivariateFeatureSelector"