This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 560532d9 [FLINK-31189] Add HasMaxIndexNum param to StringIndexer
560532d9 is described below
commit 560532d99c4330949d4da2c94d0e5228bdfbc9cd
Author: Zhipeng Zhang <[email protected]>
AuthorDate: Mon Apr 3 14:30:12 2023 +0800
[FLINK-31189] Add HasMaxIndexNum param to StringIndexer
This closes #221.
---
.../docs/operators/feature/stringindexer.md | 9 +--
.../ml/feature/stringindexer/StringIndexer.java | 43 +++++++++++++--
.../feature/stringindexer/StringIndexerModel.java | 4 +-
.../feature/stringindexer/StringIndexerParams.java | 17 ++++++
.../feature/stringindexer/StringIndexerTest.java | 64 ++++++++++++++++++++++
.../pyflink/ml/feature/stringindexer.py | 31 +++++++++--
.../pyflink/ml/feature/tests/test_stringindexer.py | 28 ++++++++++
7 files changed, 180 insertions(+), 16 deletions(-)
diff --git a/docs/content/docs/operators/feature/stringindexer.md
b/docs/content/docs/operators/feature/stringindexer.md
index 10e02dc1..110f8001 100644
--- a/docs/content/docs/operators/feature/stringindexer.md
+++ b/docs/content/docs/operators/feature/stringindexer.md
@@ -38,7 +38,7 @@ StringIndexerModel.
### Input Columns
| Param name | Type | Default | Description
|
-| :--------- | :------------ | :------
|:---------------------------------------|
+|:-----------|:--------------|:--------|:---------------------------------------|
| inputCols | Number/String | `null` | String/Numerical values to be
indexed. |
### Output Columns
@@ -59,9 +59,10 @@ Below are the parameters required by `StringIndexerModel`.
`StringIndexer` needs parameters above and also below.
-| Key | Default | Type | Required | Description
|
-|-----------------|---------------|--------|----------|-------------------------------------------------------------------------------------------------------------------------------------|
-| stringOrderType | `"arbitrary"` | String | no | How to order strings
of each column. Supported values: 'arbitrary', 'frequencyDesc', 'frequencyAsc',
'alphabetDesc', 'alphabetAsc'. |
+| Key | Default | Type | Required | Description
|
+|-----------------|---------------|---------|----------|-------------------------------------------------------------------------------------------------------------------------------------|
+| stringOrderType | `"arbitrary"` | String | no | How to order strings
of each column. Supported values: 'arbitrary', 'frequencyDesc', 'frequencyAsc',
'alphabetDesc', 'alphabetAsc'. |
+| MaxIndexNum | `2147483647` | Integer | no | The max number of
indices for each column. It only works when 'stringOrderType' is set as
'frequencyDesc'. |
### Examples
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
index 51dd727c..3c214d4d 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
@@ -64,8 +64,12 @@ import java.util.Map.Entry;
* is arbitrarily ordered. Users can control this by setting {@link
* StringIndexerParams#STRING_ORDER_TYPE}.
*
- * <p>The `keep` option of {@link HasHandleInvalid} means that we put the
invalid entries in a
- * special bucket, whose index is the number of distinct values in this column.
+ * <p>User can also control the max number of output indices by setting {@link
+ * StringIndexerParams#MAX_INDEX_NUM}. This parameter only works if {@link
+ * StringIndexerParams#STRING_ORDER_TYPE} is set as 'frequencyDesc'.
+ *
+ * <p>The `keep` option of {@link HasHandleInvalid} means that we transform
the invalid input into a
+ * special index, whose value is the number of distinct values in this column.
*/
public class StringIndexer
implements Estimator<StringIndexer, StringIndexerModel>,
@@ -96,6 +100,17 @@ public class StringIndexer
String[] inputCols = getInputCols();
String[] outputCols = getOutputCols();
Preconditions.checkArgument(inputCols.length == outputCols.length);
+ if (getMaxIndexNum() < Integer.MAX_VALUE) {
+ Preconditions.checkArgument(
+
getStringOrderType().equals(StringIndexerParams.FREQUENCY_DESC_ORDER),
+ "Setting "
+ + MAX_INDEX_NUM.name
+ + " smaller than INT.MAX only works when "
+ + STRING_ORDER_TYPE.name
+ + " is set as "
+ + StringIndexerParams.FREQUENCY_DESC_ORDER
+ + ".");
+ }
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl)
inputs[0]).getTableEnvironment();
@@ -127,7 +142,7 @@ public class StringIndexer
Types.OBJECT_ARRAY(Types.MAP(Types.STRING,
Types.LONG)));
DataStream<StringIndexerModelData> modelData =
- countedString.map(new ModelGenerator(getStringOrderType()));
+ countedString.map(new ModelGenerator(getStringOrderType(),
getMaxIndexNum()));
modelData.getTransformation().setParallelism(1);
StringIndexerModel model =
@@ -205,14 +220,19 @@ public class StringIndexer
/**
* Merges all the extracted strings and generates the {@link
StringIndexerModelData} according
- * to the specified string order type.
+ * to the specified string order type and maxIndexNum.
+ *
+ * <p>Note that the maxIndexNum works only when the strings are ordered by
{@link
+ * StringIndexerParams#ALPHABET_DESC_ORDER}.
*/
private static class ModelGenerator
implements MapFunction<Map<String, Long>[],
StringIndexerModelData> {
private final String stringOrderType;
+ private final int maxIndexNum;
- public ModelGenerator(String stringOrderType) {
+ public ModelGenerator(String stringOrderType, int maxIndexNum) {
this.stringOrderType = stringOrderType;
+ this.maxIndexNum = maxIndexNum;
}
@Override
@@ -220,6 +240,7 @@ public class StringIndexer
int numCols = value.length;
String[][] stringArrays = new String[numCols][];
ArrayList<Tuple2<String, Long>> stringsAndCnts = new ArrayList<>();
+
for (int i = 0; i < numCols; i++) {
stringsAndCnts.clear();
stringsAndCnts.ensureCapacity(value[i].size());
@@ -242,6 +263,18 @@ public class StringIndexer
stringsAndCnts.sort(
(valAndCnt1, valAndCnt2) ->
-valAndCnt1.f1.compareTo(valAndCnt2.f1));
+
+ if (stringsAndCnts.size() > maxIndexNum) {
+ ArrayList<Tuple2<String, Long>>
frequentStringsAndCnts =
+ new ArrayList<>();
+ // Reserves the last index for infrequent element.
+ frequentStringsAndCnts.ensureCapacity(maxIndexNum
- 1);
+ for (int indexId = 0; indexId < maxIndexNum - 1;
indexId++) {
+
frequentStringsAndCnts.add(stringsAndCnts.get(indexId));
+ }
+ stringsAndCnts = frequentStringsAndCnts;
+ }
+
break;
case ARBITRARY_ORDER:
break;
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
index e5ad4700..490c1c14 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
@@ -49,8 +49,8 @@ import java.util.Map;
* A Model which transforms input string/numeric column(s) to double column(s)
using the model data
* computed by {@link StringIndexer}.
*
- * <p>The `keep` option of {@link HasHandleInvalid} means that we put the
invalid entries in a
- * special bucket, whose index is the number of distinct values in this column.
+ * <p>The `keep` option of {@link HasHandleInvalid} means that we transform
the invalid input into a
+ * special index, whose value is the number of distinct values in this column.
*/
public class StringIndexerModel
implements Model<StringIndexerModel>,
StringIndexerModelParams<StringIndexerModel> {
diff --git
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java
index 61c23cfd..e6b8a2be 100644
---
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java
+++
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java
@@ -18,6 +18,7 @@
package org.apache.flink.ml.feature.stringindexer;
+import org.apache.flink.ml.param.IntParam;
import org.apache.flink.ml.param.Param;
import org.apache.flink.ml.param.ParamValidators;
import org.apache.flink.ml.param.StringParam;
@@ -58,6 +59,14 @@ public interface StringIndexerParams<T> extends
StringIndexerModelParams<T> {
ALPHABET_DESC_ORDER,
ALPHABET_ASC_ORDER));
+ Param<Integer> MAX_INDEX_NUM =
+ new IntParam(
+ "maxIndexNum",
+ "The max number of indices for each column. It only works
when "
+ + "'stringOrderType' is set as 'frequencyDesc'.",
+ Integer.MAX_VALUE,
+ ParamValidators.gt(1));
+
default String getStringOrderType() {
return get(STRING_ORDER_TYPE);
}
@@ -65,4 +74,12 @@ public interface StringIndexerParams<T> extends
StringIndexerModelParams<T> {
default T setStringOrderType(String value) {
return set(STRING_ORDER_TYPE, value);
}
+
+ default int getMaxIndexNum() {
+ return get(MAX_INDEX_NUM);
+ }
+
+ default T setMaxIndexNum(int value) {
+ return set(MAX_INDEX_NUM, value);
+ }
}
diff --git
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
index 56f353e5..9f4a5ae8 100644
---
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
+++
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
@@ -42,6 +42,8 @@ import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
+import static
org.apache.flink.ml.feature.stringindexer.StringIndexerParams.MAX_INDEX_NUM;
+import static
org.apache.flink.ml.feature.stringindexer.StringIndexerParams.STRING_ORDER_TYPE;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@@ -125,16 +127,19 @@ public class StringIndexerTest extends AbstractTestBase {
StringIndexer stringIndexer = new StringIndexer();
assertEquals(stringIndexer.getStringOrderType(),
StringIndexerParams.ARBITRARY_ORDER);
assertEquals(stringIndexer.getHandleInvalid(),
StringIndexerParams.ERROR_INVALID);
+ assertEquals(stringIndexer.getMaxIndexNum(), Integer.MAX_VALUE);
stringIndexer
.setInputCols("inputCol1", "inputCol2")
.setOutputCols("outputCol1", "outputCol2")
.setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+ .setMaxIndexNum(100)
.setHandleInvalid(StringIndexerParams.SKIP_INVALID);
assertArrayEquals(new String[] {"inputCol1", "inputCol2"},
stringIndexer.getInputCols());
assertArrayEquals(new String[] {"outputCol1", "outputCol2"},
stringIndexer.getOutputCols());
assertEquals(stringIndexer.getStringOrderType(),
StringIndexerParams.ALPHABET_ASC_ORDER);
+ assertEquals(stringIndexer.getMaxIndexNum(), 100);
assertEquals(stringIndexer.getHandleInvalid(),
StringIndexerParams.SKIP_INVALID);
}
@@ -209,6 +214,65 @@ public class StringIndexerTest extends AbstractTestBase {
assertEquals(3, distinctStringsCol2.size());
}
+ @Test
+ public void testMaxIndexNum() throws Exception {
+ StringIndexer stringIndexer =
+ new StringIndexer()
+ .setMaxIndexNum(3)
+ .setInputCols("inputCol1", "inputCol2")
+ .setOutputCols("outputCol1", "outputCol2")
+ .setHandleInvalid(StringIndexerParams.KEEP_INVALID);
+ Table output;
+ List<Row> predictedResult;
+ final String expectedErrorMessage =
+ "Setting "
+ + MAX_INDEX_NUM.name
+ + " smaller than INT.MAX only works when "
+ + STRING_ORDER_TYPE.name
+ + " is set as "
+ + StringIndexerParams.FREQUENCY_DESC_ORDER
+ + ".";
+
+ // AlphabetAsc order.
+
stringIndexer.setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER);
+ checkMaxIndexNumSettings(stringIndexer, expectedErrorMessage);
+
+ // AlphabetDesc order.
+
stringIndexer.setStringOrderType(StringIndexerParams.ALPHABET_DESC_ORDER);
+ checkMaxIndexNumSettings(stringIndexer, expectedErrorMessage);
+
+ // FrequencyAsc order.
+
stringIndexer.setStringOrderType(StringIndexerParams.FREQUENCY_ASC_ORDER);
+ checkMaxIndexNumSettings(stringIndexer, expectedErrorMessage);
+
+ // FrequencyDesc order.
+ final List<Row> expectedPredictData =
+ Arrays.asList(
+ Row.of("a", 2.0, 1.0, 0.0),
+ Row.of("b", 1.0, 0.0, 2.0),
+ Row.of("e", 2.0, 2.0, 0.0),
+ Row.of("f", null, 2.0, 2.0),
+ Row.of(null, null, 2.0, 2.0));
+
stringIndexer.setStringOrderType(StringIndexerParams.FREQUENCY_DESC_ORDER);
+ output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+ predictedResult =
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+ verifyPredictionResult(expectedPredictData, predictedResult);
+
+ // Arbitrary order.
+ stringIndexer.setStringOrderType(StringIndexerParams.ARBITRARY_ORDER);
+ checkMaxIndexNumSettings(stringIndexer, expectedErrorMessage);
+ }
+
+ private void checkMaxIndexNumSettings(
+ StringIndexer stringIndexer, String expectedErrorMessage) {
+ try {
+ stringIndexer.fit(trainTable);
+ fail();
+ } catch (Exception e) {
+ assertEquals(expectedErrorMessage, e.getMessage());
+ }
+ }
+
@Test
@SuppressWarnings("unchecked")
public void testHandleInvalid() throws Exception {
diff --git a/flink-ml-python/pyflink/ml/feature/stringindexer.py
b/flink-ml-python/pyflink/ml/feature/stringindexer.py
index c82c7cd2..999f585d 100644
--- a/flink-ml-python/pyflink/ml/feature/stringindexer.py
+++ b/flink-ml-python/pyflink/ml/feature/stringindexer.py
@@ -17,7 +17,7 @@
################################################################################
import typing
-from pyflink.ml.param import Param, StringParam, ParamValidators
+from pyflink.ml.param import Param, IntParam, StringParam, ParamValidators
from pyflink.ml.wrapper import JavaWithParams
from pyflink.ml.feature.common import JavaFeatureModel, JavaFeatureEstimator
from pyflink.ml.common.param import HasInputCols, HasOutputCols,
HasHandleInvalid
@@ -62,6 +62,13 @@ class _StringIndexerParams(_StringIndexerModelParams):
ParamValidators.in_array(
['arbitrary', 'frequencyDesc', 'frequencyAsc', 'alphabetDesc',
'alphabetAsc']))
+ MAX_INDEX_NUM: Param[int] = IntParam(
+ "max_index_num",
+ "The max number of indices for each column. It only works when "
+ + "'stringOrderType' is set as 'frequencyDesc'.",
+ 2 ** 31 - 1,
+ ParamValidators.gt(1))
+
def __init__(self, java_params):
super(_StringIndexerParams, self).__init__(java_params)
@@ -71,10 +78,20 @@ class _StringIndexerParams(_StringIndexerModelParams):
def get_string_order_type(self) -> str:
return self.get(self.STRING_ORDER_TYPE)
+ def set_max_index_num(self, value: int):
+ return typing.cast(_StringIndexerParams, self.set(self.MAX_INDEX_NUM,
value))
+
+ def get_max_index_num(self) -> int:
+ return self.get(self.MAX_INDEX_NUM)
+
@property
def string_order_type(self):
return self.get_string_order_type()
+ @property
+ def max_index_num(self):
+ return self.get_max_index_num()
+
class IndexToStringModel(JavaFeatureModel, _IndexToStringModelParams):
"""
@@ -99,8 +116,8 @@ class StringIndexerModel(JavaFeatureModel,
_StringIndexerModelParams):
A Model which transforms input string/numeric column(s) to integer
column(s) using the model
data computed by :class:StringIndexer.
- The `keep` option of {@link HasHandleInvalid} means that we put the
invalid entries in a
- special bucket, whose index is the number of distinct values in this
column.
+ The `keep` option of {@link HasHandleInvalid} means that we transform the
invalid input
+ into a special index, whose value is the number of distinct values in this
column.
"""
def __init__(self, java_model=None):
@@ -128,8 +145,12 @@ class StringIndexer(JavaFeatureEstimator,
_StringIndexerParams):
is arbitrarily ordered. Users can control this by setting {@link
StringIndexerParams#STRING_ORDER_TYPE}.
- The `keep` option of {@link HasHandleInvalid} means that we put the
invalid entries in a
- special bucket, whose index is the number of distinct values in this
column.
+ User can also control the max number of output indices by setting {@link
+ StringIndexerParams#MAX_INDEX_NUM}. This parameter only works if {@link
+ StringIndexerParams#STRING_ORDER_TYPE} is set as 'frequencyDesc'.
+
+ The `keep` option of {@link HasHandleInvalid} means that we transform the
invalid input
+ into a special index, whose value is the number of distinct values in this
column.
"""
def __init__(self):
diff --git a/flink-ml-python/pyflink/ml/feature/tests/test_stringindexer.py
b/flink-ml-python/pyflink/ml/feature/tests/test_stringindexer.py
index ec2e5174..aa61503e 100644
--- a/flink-ml-python/pyflink/ml/feature/tests/test_stringindexer.py
+++ b/flink-ml-python/pyflink/ml/feature/tests/test_stringindexer.py
@@ -70,15 +70,18 @@ class StringIndexerTest(PyFlinkMLTestCase):
self.assertEqual('arbitrary', string_indexer.string_order_type)
self.assertEqual('error', string_indexer.handle_invalid)
+ self.assertEqual(2 ** 31 - 1, string_indexer.max_index_num)
string_indexer.set_input_cols('input_col1', 'input_col2') \
.set_output_cols('output_col1', 'output_col2') \
.set_string_order_type('alphabetAsc') \
+ .set_max_index_num(100) \
.set_handle_invalid('skip')
self.assertEqual(('input_col1', 'input_col2'),
string_indexer.input_cols)
self.assertEqual(('output_col1', 'output_col2'),
string_indexer.output_cols)
self.assertEqual('alphabetAsc', string_indexer.string_order_type)
+ self.assertEqual(100, string_indexer.max_index_num)
self.assertEqual('skip', string_indexer.handle_invalid)
def test_output_schema(self):
@@ -110,6 +113,31 @@ class StringIndexerTest(PyFlinkMLTestCase):
self.assertEqual(predicted_results,
self.expected_alphabetic_asc_predict_data)
+ def test_max_index_num(self):
+ string_indexer = StringIndexer() \
+ .set_max_index_num(3) \
+ .set_input_cols('input_col1', 'input_col2') \
+ .set_output_cols('output_col1', 'output_col2') \
+ .set_handle_invalid('keep') \
+ .set_string_order_type("frequencyDesc")
+
+ expected_predict_data = [
+ Row('a', 2.0, 1, 0),
+ Row('b', 1.0, 0, 2),
+ Row('e', 2.0, 2, 0),
+ Row('f', None, 2, 2),
+ Row(None, None, 2, 2),
+ ]
+
+ output =
string_indexer.fit(self.train_table).transform(self.predict_table)[0]
+
+ predicted_results = [result for result in
+
self.t_env.to_data_stream(output).execute_and_collect()]
+
+ predicted_results.sort(key=lambda x: (x[0] is None, x[0]))
+
+ self.assertEqual(predicted_results, expected_predict_data)
+
def test_fit_and_predict(self):
string_indexer = StringIndexer() \
.set_input_cols('input_col1', 'input_col2') \