This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 560532d9 [FLINK-31189] Add HasMaxIndexNum param to StringIndexer
560532d9 is described below

commit 560532d99c4330949d4da2c94d0e5228bdfbc9cd
Author: Zhipeng Zhang <[email protected]>
AuthorDate: Mon Apr 3 14:30:12 2023 +0800

    [FLINK-31189] Add HasMaxIndexNum param to StringIndexer
    
    This closes #221.
---
 .../docs/operators/feature/stringindexer.md        |  9 +--
 .../ml/feature/stringindexer/StringIndexer.java    | 43 +++++++++++++--
 .../feature/stringindexer/StringIndexerModel.java  |  4 +-
 .../feature/stringindexer/StringIndexerParams.java | 17 ++++++
 .../feature/stringindexer/StringIndexerTest.java   | 64 ++++++++++++++++++++++
 .../pyflink/ml/feature/stringindexer.py            | 31 +++++++++--
 .../pyflink/ml/feature/tests/test_stringindexer.py | 28 ++++++++++
 7 files changed, 180 insertions(+), 16 deletions(-)

diff --git a/docs/content/docs/operators/feature/stringindexer.md 
b/docs/content/docs/operators/feature/stringindexer.md
index 10e02dc1..110f8001 100644
--- a/docs/content/docs/operators/feature/stringindexer.md
+++ b/docs/content/docs/operators/feature/stringindexer.md
@@ -38,7 +38,7 @@ StringIndexerModel.
 ### Input Columns
 
 | Param name | Type          | Default | Description                           
 |
-| :--------- | :------------ | :------ 
|:---------------------------------------|
+|:-----------|:--------------|:--------|:---------------------------------------|
 | inputCols  | Number/String | `null`  | String/Numerical values to be 
indexed. |
 
 ### Output Columns
@@ -59,9 +59,10 @@ Below are the parameters required by `StringIndexerModel`.
 
 `StringIndexer` needs parameters above and also below.
 
-| Key             | Default       | Type   | Required | Description            
                                                                                
                             |
-|-----------------|---------------|--------|----------|-------------------------------------------------------------------------------------------------------------------------------------|
-| stringOrderType | `"arbitrary"` | String | no       | How to order strings 
of each column. Supported values: 'arbitrary', 'frequencyDesc', 'frequencyAsc', 
'alphabetDesc', 'alphabetAsc'. |
+| Key             | Default       | Type    | Required | Description           
                                                                                
                              |
+|-----------------|---------------|---------|----------|-------------------------------------------------------------------------------------------------------------------------------------|
+| stringOrderType | `"arbitrary"` | String  | no       | How to order strings 
of each column. Supported values: 'arbitrary', 'frequencyDesc', 'frequencyAsc', 
'alphabetDesc', 'alphabetAsc'. |
+| MaxIndexNum     | `2147483647`  | Integer | no       | The max number of 
indices for each column. It only works when 'stringOrderType' is set as 
'frequencyDesc'.                          |
 
 ### Examples
 
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
index 51dd727c..3c214d4d 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexer.java
@@ -64,8 +64,12 @@ import java.util.Map.Entry;
  * is arbitrarily ordered. Users can control this by setting {@link
  * StringIndexerParams#STRING_ORDER_TYPE}.
  *
- * <p>The `keep` option of {@link HasHandleInvalid} means that we put the 
invalid entries in a
- * special bucket, whose index is the number of distinct values in this column.
+ * <p>User can also control the max number of output indices by setting {@link
+ * StringIndexerParams#MAX_INDEX_NUM}. This parameter only works if {@link
+ * StringIndexerParams#STRING_ORDER_TYPE} is set as 'frequencyDesc'.
+ *
+ * <p>The `keep` option of {@link HasHandleInvalid} means that we transform 
the invalid input into a
+ * special index, whose value is the number of distinct values in this column.
  */
 public class StringIndexer
         implements Estimator<StringIndexer, StringIndexerModel>,
@@ -96,6 +100,17 @@ public class StringIndexer
         String[] inputCols = getInputCols();
         String[] outputCols = getOutputCols();
         Preconditions.checkArgument(inputCols.length == outputCols.length);
+        if (getMaxIndexNum() < Integer.MAX_VALUE) {
+            Preconditions.checkArgument(
+                    
getStringOrderType().equals(StringIndexerParams.FREQUENCY_DESC_ORDER),
+                    "Setting "
+                            + MAX_INDEX_NUM.name
+                            + " smaller than INT.MAX only works when "
+                            + STRING_ORDER_TYPE.name
+                            + " is set as "
+                            + StringIndexerParams.FREQUENCY_DESC_ORDER
+                            + ".");
+        }
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
 
@@ -127,7 +142,7 @@ public class StringIndexer
                         Types.OBJECT_ARRAY(Types.MAP(Types.STRING, 
Types.LONG)));
 
         DataStream<StringIndexerModelData> modelData =
-                countedString.map(new ModelGenerator(getStringOrderType()));
+                countedString.map(new ModelGenerator(getStringOrderType(), 
getMaxIndexNum()));
         modelData.getTransformation().setParallelism(1);
 
         StringIndexerModel model =
@@ -205,14 +220,19 @@ public class StringIndexer
 
     /**
      * Merges all the extracted strings and generates the {@link 
StringIndexerModelData} according
-     * to the specified string order type.
+     * to the specified string order type and maxIndexNum.
+     *
+     * <p>Note that the maxIndexNum works only when the strings are ordered by 
{@link
+     * StringIndexerParams#ALPHABET_DESC_ORDER}.
      */
     private static class ModelGenerator
             implements MapFunction<Map<String, Long>[], 
StringIndexerModelData> {
         private final String stringOrderType;
+        private final int maxIndexNum;
 
-        public ModelGenerator(String stringOrderType) {
+        public ModelGenerator(String stringOrderType, int maxIndexNum) {
             this.stringOrderType = stringOrderType;
+            this.maxIndexNum = maxIndexNum;
         }
 
         @Override
@@ -220,6 +240,7 @@ public class StringIndexer
             int numCols = value.length;
             String[][] stringArrays = new String[numCols][];
             ArrayList<Tuple2<String, Long>> stringsAndCnts = new ArrayList<>();
+
             for (int i = 0; i < numCols; i++) {
                 stringsAndCnts.clear();
                 stringsAndCnts.ensureCapacity(value[i].size());
@@ -242,6 +263,18 @@ public class StringIndexer
                         stringsAndCnts.sort(
                                 (valAndCnt1, valAndCnt2) ->
                                         
-valAndCnt1.f1.compareTo(valAndCnt2.f1));
+
+                        if (stringsAndCnts.size() > maxIndexNum) {
+                            ArrayList<Tuple2<String, Long>> 
frequentStringsAndCnts =
+                                    new ArrayList<>();
+                            // Reserves the last index for infrequent element.
+                            frequentStringsAndCnts.ensureCapacity(maxIndexNum 
- 1);
+                            for (int indexId = 0; indexId < maxIndexNum - 1; 
indexId++) {
+                                
frequentStringsAndCnts.add(stringsAndCnts.get(indexId));
+                            }
+                            stringsAndCnts = frequentStringsAndCnts;
+                        }
+
                         break;
                     case ARBITRARY_ORDER:
                         break;
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
index e5ad4700..490c1c14 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerModel.java
@@ -49,8 +49,8 @@ import java.util.Map;
  * A Model which transforms input string/numeric column(s) to double column(s) 
using the model data
  * computed by {@link StringIndexer}.
  *
- * <p>The `keep` option of {@link HasHandleInvalid} means that we put the 
invalid entries in a
- * special bucket, whose index is the number of distinct values in this column.
+ * <p>The `keep` option of {@link HasHandleInvalid} means that we transform 
the invalid input into a
+ * special index, whose value is the number of distinct values in this column.
  */
 public class StringIndexerModel
         implements Model<StringIndexerModel>, 
StringIndexerModelParams<StringIndexerModel> {
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java
index 61c23cfd..e6b8a2be 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/stringindexer/StringIndexerParams.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.ml.feature.stringindexer;
 
+import org.apache.flink.ml.param.IntParam;
 import org.apache.flink.ml.param.Param;
 import org.apache.flink.ml.param.ParamValidators;
 import org.apache.flink.ml.param.StringParam;
@@ -58,6 +59,14 @@ public interface StringIndexerParams<T> extends 
StringIndexerModelParams<T> {
                             ALPHABET_DESC_ORDER,
                             ALPHABET_ASC_ORDER));
 
+    Param<Integer> MAX_INDEX_NUM =
+            new IntParam(
+                    "maxIndexNum",
+                    "The max number of indices for each column. It only works 
when "
+                            + "'stringOrderType' is set as 'frequencyDesc'.",
+                    Integer.MAX_VALUE,
+                    ParamValidators.gt(1));
+
     default String getStringOrderType() {
         return get(STRING_ORDER_TYPE);
     }
@@ -65,4 +74,12 @@ public interface StringIndexerParams<T> extends 
StringIndexerModelParams<T> {
     default T setStringOrderType(String value) {
         return set(STRING_ORDER_TYPE, value);
     }
+
+    default int getMaxIndexNum() {
+        return get(MAX_INDEX_NUM);
+    }
+
+    default T setMaxIndexNum(int value) {
+        return set(MAX_INDEX_NUM, value);
+    }
 }
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
index 56f353e5..9f4a5ae8 100644
--- 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/stringindexer/StringIndexerTest.java
@@ -42,6 +42,8 @@ import java.util.Set;
 import java.util.stream.Collectors;
 import java.util.stream.Stream;
 
+import static 
org.apache.flink.ml.feature.stringindexer.StringIndexerParams.MAX_INDEX_NUM;
+import static 
org.apache.flink.ml.feature.stringindexer.StringIndexerParams.STRING_ORDER_TYPE;
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertTrue;
@@ -125,16 +127,19 @@ public class StringIndexerTest extends AbstractTestBase {
         StringIndexer stringIndexer = new StringIndexer();
         assertEquals(stringIndexer.getStringOrderType(), 
StringIndexerParams.ARBITRARY_ORDER);
         assertEquals(stringIndexer.getHandleInvalid(), 
StringIndexerParams.ERROR_INVALID);
+        assertEquals(stringIndexer.getMaxIndexNum(), Integer.MAX_VALUE);
 
         stringIndexer
                 .setInputCols("inputCol1", "inputCol2")
                 .setOutputCols("outputCol1", "outputCol2")
                 .setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER)
+                .setMaxIndexNum(100)
                 .setHandleInvalid(StringIndexerParams.SKIP_INVALID);
 
         assertArrayEquals(new String[] {"inputCol1", "inputCol2"}, 
stringIndexer.getInputCols());
         assertArrayEquals(new String[] {"outputCol1", "outputCol2"}, 
stringIndexer.getOutputCols());
         assertEquals(stringIndexer.getStringOrderType(), 
StringIndexerParams.ALPHABET_ASC_ORDER);
+        assertEquals(stringIndexer.getMaxIndexNum(), 100);
         assertEquals(stringIndexer.getHandleInvalid(), 
StringIndexerParams.SKIP_INVALID);
     }
 
@@ -209,6 +214,65 @@ public class StringIndexerTest extends AbstractTestBase {
         assertEquals(3, distinctStringsCol2.size());
     }
 
+    @Test
+    public void testMaxIndexNum() throws Exception {
+        StringIndexer stringIndexer =
+                new StringIndexer()
+                        .setMaxIndexNum(3)
+                        .setInputCols("inputCol1", "inputCol2")
+                        .setOutputCols("outputCol1", "outputCol2")
+                        .setHandleInvalid(StringIndexerParams.KEEP_INVALID);
+        Table output;
+        List<Row> predictedResult;
+        final String expectedErrorMessage =
+                "Setting "
+                        + MAX_INDEX_NUM.name
+                        + " smaller than INT.MAX only works when "
+                        + STRING_ORDER_TYPE.name
+                        + " is set as "
+                        + StringIndexerParams.FREQUENCY_DESC_ORDER
+                        + ".";
+
+        // AlphabetAsc order.
+        
stringIndexer.setStringOrderType(StringIndexerParams.ALPHABET_ASC_ORDER);
+        checkMaxIndexNumSettings(stringIndexer, expectedErrorMessage);
+
+        // AlphabetDesc order.
+        
stringIndexer.setStringOrderType(StringIndexerParams.ALPHABET_DESC_ORDER);
+        checkMaxIndexNumSettings(stringIndexer, expectedErrorMessage);
+
+        // FrequencyAsc order.
+        
stringIndexer.setStringOrderType(StringIndexerParams.FREQUENCY_ASC_ORDER);
+        checkMaxIndexNumSettings(stringIndexer, expectedErrorMessage);
+
+        // FrequencyDesc order.
+        final List<Row> expectedPredictData =
+                Arrays.asList(
+                        Row.of("a", 2.0, 1.0, 0.0),
+                        Row.of("b", 1.0, 0.0, 2.0),
+                        Row.of("e", 2.0, 2.0, 0.0),
+                        Row.of("f", null, 2.0, 2.0),
+                        Row.of(null, null, 2.0, 2.0));
+        
stringIndexer.setStringOrderType(StringIndexerParams.FREQUENCY_DESC_ORDER);
+        output = stringIndexer.fit(trainTable).transform(predictTable)[0];
+        predictedResult = 
IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+        verifyPredictionResult(expectedPredictData, predictedResult);
+
+        // Arbitrary order.
+        stringIndexer.setStringOrderType(StringIndexerParams.ARBITRARY_ORDER);
+        checkMaxIndexNumSettings(stringIndexer, expectedErrorMessage);
+    }
+
+    private void checkMaxIndexNumSettings(
+            StringIndexer stringIndexer, String expectedErrorMessage) {
+        try {
+            stringIndexer.fit(trainTable);
+            fail();
+        } catch (Exception e) {
+            assertEquals(expectedErrorMessage, e.getMessage());
+        }
+    }
+
     @Test
     @SuppressWarnings("unchecked")
     public void testHandleInvalid() throws Exception {
diff --git a/flink-ml-python/pyflink/ml/feature/stringindexer.py 
b/flink-ml-python/pyflink/ml/feature/stringindexer.py
index c82c7cd2..999f585d 100644
--- a/flink-ml-python/pyflink/ml/feature/stringindexer.py
+++ b/flink-ml-python/pyflink/ml/feature/stringindexer.py
@@ -17,7 +17,7 @@
 
################################################################################
 import typing
 
-from pyflink.ml.param import Param, StringParam, ParamValidators
+from pyflink.ml.param import Param, IntParam, StringParam, ParamValidators
 from pyflink.ml.wrapper import JavaWithParams
 from pyflink.ml.feature.common import JavaFeatureModel, JavaFeatureEstimator
 from pyflink.ml.common.param import HasInputCols, HasOutputCols, 
HasHandleInvalid
@@ -62,6 +62,13 @@ class _StringIndexerParams(_StringIndexerModelParams):
         ParamValidators.in_array(
             ['arbitrary', 'frequencyDesc', 'frequencyAsc', 'alphabetDesc', 
'alphabetAsc']))
 
+    MAX_INDEX_NUM: Param[int] = IntParam(
+        "max_index_num",
+        "The max number of indices for each column. It only works when "
+        + "'stringOrderType' is set as 'frequencyDesc'.",
+        2 ** 31 - 1,
+        ParamValidators.gt(1))
+
     def __init__(self, java_params):
         super(_StringIndexerParams, self).__init__(java_params)
 
@@ -71,10 +78,20 @@ class _StringIndexerParams(_StringIndexerModelParams):
     def get_string_order_type(self) -> str:
         return self.get(self.STRING_ORDER_TYPE)
 
+    def set_max_index_num(self, value: int):
+        return typing.cast(_StringIndexerParams, self.set(self.MAX_INDEX_NUM, 
value))
+
+    def get_max_index_num(self) -> int:
+        return self.get(self.MAX_INDEX_NUM)
+
     @property
     def string_order_type(self):
         return self.get_string_order_type()
 
+    @property
+    def max_index_num(self):
+        return self.get_max_index_num()
+
 
 class IndexToStringModel(JavaFeatureModel, _IndexToStringModelParams):
     """
@@ -99,8 +116,8 @@ class StringIndexerModel(JavaFeatureModel, 
_StringIndexerModelParams):
     A Model which transforms input string/numeric column(s) to integer 
column(s) using the model
     data computed by :class:StringIndexer.
 
-    The `keep` option of {@link HasHandleInvalid} means that we put the 
invalid entries in a
-    special bucket, whose index is the number of distinct values in this 
column.
+    The `keep` option of {@link HasHandleInvalid} means that we transform the 
invalid input
+    into a special index, whose value is the number of distinct values in this 
column.
     """
 
     def __init__(self, java_model=None):
@@ -128,8 +145,12 @@ class StringIndexer(JavaFeatureEstimator, 
_StringIndexerParams):
     is arbitrarily ordered. Users can control this by setting {@link
     StringIndexerParams#STRING_ORDER_TYPE}.
 
-    The `keep` option of {@link HasHandleInvalid} means that we put the 
invalid entries in a
-    special bucket, whose index is the number of distinct values in this 
column.
+    User can also control the max number of output indices by setting {@link
+    StringIndexerParams#MAX_INDEX_NUM}. This parameter only works if {@link
+    StringIndexerParams#STRING_ORDER_TYPE} is set as 'frequencyDesc'.
+
+    The `keep` option of {@link HasHandleInvalid} means that we transform the 
invalid input
+    into a special index, whose value is the number of distinct values in this 
column.
     """
 
     def __init__(self):
diff --git a/flink-ml-python/pyflink/ml/feature/tests/test_stringindexer.py 
b/flink-ml-python/pyflink/ml/feature/tests/test_stringindexer.py
index ec2e5174..aa61503e 100644
--- a/flink-ml-python/pyflink/ml/feature/tests/test_stringindexer.py
+++ b/flink-ml-python/pyflink/ml/feature/tests/test_stringindexer.py
@@ -70,15 +70,18 @@ class StringIndexerTest(PyFlinkMLTestCase):
 
         self.assertEqual('arbitrary', string_indexer.string_order_type)
         self.assertEqual('error', string_indexer.handle_invalid)
+        self.assertEqual(2 ** 31 - 1, string_indexer.max_index_num)
 
         string_indexer.set_input_cols('input_col1', 'input_col2') \
             .set_output_cols('output_col1', 'output_col2') \
             .set_string_order_type('alphabetAsc') \
+            .set_max_index_num(100) \
             .set_handle_invalid('skip')
 
         self.assertEqual(('input_col1', 'input_col2'), 
string_indexer.input_cols)
         self.assertEqual(('output_col1', 'output_col2'), 
string_indexer.output_cols)
         self.assertEqual('alphabetAsc', string_indexer.string_order_type)
+        self.assertEqual(100, string_indexer.max_index_num)
         self.assertEqual('skip', string_indexer.handle_invalid)
 
     def test_output_schema(self):
@@ -110,6 +113,31 @@ class StringIndexerTest(PyFlinkMLTestCase):
 
         self.assertEqual(predicted_results, 
self.expected_alphabetic_asc_predict_data)
 
+    def test_max_index_num(self):
+        string_indexer = StringIndexer() \
+            .set_max_index_num(3) \
+            .set_input_cols('input_col1', 'input_col2') \
+            .set_output_cols('output_col1', 'output_col2') \
+            .set_handle_invalid('keep') \
+            .set_string_order_type("frequencyDesc")
+
+        expected_predict_data = [
+            Row('a', 2.0, 1, 0),
+            Row('b', 1.0, 0, 2),
+            Row('e', 2.0, 2, 0),
+            Row('f', None, 2, 2),
+            Row(None, None, 2, 2),
+        ]
+
+        output = 
string_indexer.fit(self.train_table).transform(self.predict_table)[0]
+
+        predicted_results = [result for result in
+                             
self.t_env.to_data_stream(output).execute_and_collect()]
+
+        predicted_results.sort(key=lambda x: (x[0] is None, x[0]))
+
+        self.assertEqual(predicted_results, expected_predict_data)
+
     def test_fit_and_predict(self):
         string_indexer = StringIndexer() \
             .set_input_cols('input_col1', 'input_col2') \

Reply via email to