This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new e5da0da  [FLINK-28563] Add Transformer for VectorSlicer
e5da0da is described below

commit e5da0da4ae0bf51dc467bc3635b80c9036953315
Author: weibo <[email protected]>
AuthorDate: Fri Jul 29 17:58:54 2022 +0800

    [FLINK-28563] Add Transformer for VectorSlicer
    
    This closes #131.
---
 .../java/org/apache/flink/ml/param/WithParams.java |   6 +-
 .../ml/examples/feature/VectorSlicerExample.java   |  62 ++++++
 .../vectorassembler/VectorAssemblerParams.java     |   2 +-
 .../ml/feature/vectorslicer/VectorSlicer.java      | 148 ++++++++++++++
 .../feature/vectorslicer/VectorSlicerParams.java   |  70 +++++++
 .../apache/flink/ml/feature/VectorSlicerTest.java  | 225 +++++++++++++++++++++
 .../examples/ml/feature/vectorslicer_example.py    |  64 ++++++
 .../ml/lib/feature/tests/test_vectorslicer.py      |  75 +++++++
 .../pyflink/ml/lib/feature/vectorslicer.py         |  91 +++++++++
 9 files changed, 741 insertions(+), 2 deletions(-)

diff --git 
a/flink-ml-core/src/main/java/org/apache/flink/ml/param/WithParams.java 
b/flink-ml-core/src/main/java/org/apache/flink/ml/param/WithParams.java
index 0a7b5c4..9810105 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/param/WithParams.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/param/WithParams.java
@@ -39,6 +39,8 @@ package org.apache.flink.ml.param;
 import org.apache.flink.annotation.PublicEvolving;
 import org.apache.flink.ml.util.ParamUtils;
 
+import org.apache.commons.lang3.ArrayUtils;
+
 import java.util.Map;
 import java.util.Optional;
 
@@ -97,7 +99,9 @@ public interface WithParams<T> {
                         "Parameter "
                                 + param.name
                                 + " is given an invalid value "
-                                + value.toString());
+                                + (value.getClass().isArray()
+                                        ? ArrayUtils.toString(value)
+                                        : value));
             }
         }
         getParamMap().put(param, value);
diff --git 
a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorSlicerExample.java
 
b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorSlicerExample.java
new file mode 100644
index 0000000..7c13899
--- /dev/null
+++ 
b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorSlicerExample.java
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.vectorslicer.VectorSlicer;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that creates a VectorSlicer instance and uses it for 
feature engineering. */
+public class VectorSlicerExample {
+    public static void main(String[] args) {
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        // Generates input data.
+        DataStream<Row> inputStream =
+                env.fromElements(
+                        Row.of(Vectors.dense(2.1, 3.1, 1.2, 3.1, 4.6)),
+                        Row.of(Vectors.dense(1.2, 3.1, 4.6, 2.1, 3.1)));
+        Table inputTable = tEnv.fromDataStream(inputStream).as("vec");
+
+        // Creates a VectorSlicer object and initializes its parameters.
+        VectorSlicer vectorSlicer =
+                new VectorSlicer().setInputCol("vec").setIndices(1, 2, 
3).setOutputCol("slicedVec");
+
+        // Uses the VectorSlicer object for feature transformations.
+        Table outputTable = vectorSlicer.transform(inputTable)[0];
+
+        // Extracts and displays the results.
+        for (CloseableIterator<Row> it = outputTable.execute().collect(); 
it.hasNext(); ) {
+            Row row = it.next();
+
+            Vector inputValue = (Vector) 
row.getField(vectorSlicer.getInputCol());
+
+            Vector outputValue = (Vector) 
row.getField(vectorSlicer.getOutputCol());
+
+            System.out.printf("Input Value: %s \tOutput Value: %s\n", 
inputValue, outputValue);
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
index aaada93..cc3637e 100644
--- 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
@@ -23,7 +23,7 @@ import org.apache.flink.ml.common.param.HasInputCols;
 import org.apache.flink.ml.common.param.HasOutputCol;
 
 /**
- * Params of VectorAssembler.
+ * Params of {@link VectorAssembler}.
  *
  * @param <T> The class type of this instance.
  */
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java
new file mode 100644
index 0000000..2abca89
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicer.java
@@ -0,0 +1,148 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorslicer;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Transformer;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Transformer that transforms a vector to a new feature, which is a 
sub-array of the original
+ * feature. It is useful for extracting features from a given vector.
+ *
+ * <p>Note that duplicate features are not allowed, so there can be no overlap 
between selected
+ * indices. If the max value of the indices is greater than the size of the 
input vector, it throws
+ * an IllegalArgumentException.
+ */
+public class VectorSlicer implements Transformer<VectorSlicer>, 
VectorSlicerParams<VectorSlicer> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public VectorSlicer() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        RowTypeInfo inputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), 
VectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), 
getOutputCol()));
+        DataStream<Row> output =
+                tEnv.toDataStream(inputs[0])
+                        .map(new VectorSliceFunction(getIndices(), 
getInputCol()), outputTypeInfo);
+        Table outputTable = tEnv.fromDataStream(output);
+        return new Table[] {outputTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static VectorSlicer load(StreamTableEnvironment env, String path) 
throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * Vector slice function which transforms a vector to a new one with a 
sub-array of the original
+     * features.
+     */
+    private static class VectorSliceFunction implements MapFunction<Row, Row> {
+        private final Integer[] indices;
+        private final String inputCol;
+        private int maxIndex = -1;
+
+        public VectorSliceFunction(Integer[] indices, String inputCol) {
+            this.indices = indices;
+            for (Integer index : indices) {
+                maxIndex = Math.max(maxIndex, index);
+            }
+            this.inputCol = inputCol;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            Vector inputVec = row.getFieldAs(inputCol);
+            Vector outputVec;
+            if (maxIndex >= inputVec.size()) {
+                throw new IllegalArgumentException(
+                        "Index value "
+                                + maxIndex
+                                + " is greater than vector size:"
+                                + inputVec.size());
+            }
+            if (inputVec instanceof DenseVector) {
+                double[] values = new double[indices.length];
+                for (int i = 0; i < indices.length; ++i) {
+                    values[i] = ((DenseVector) inputVec).values[indices[i]];
+                }
+                outputVec = new DenseVector(values);
+            } else {
+                int nnz = 0;
+                SparseVector vec = (SparseVector) inputVec;
+                int[] outputIndices = new int[indices.length];
+                double[] outputValues = new double[indices.length];
+                for (int i = 0; i < indices.length; i++) {
+                    double val = vec.get(indices[i]);
+                    if (val != 0) {
+                        outputIndices[nnz] = i;
+                        outputValues[nnz] = val;
+                        nnz++;
+                    }
+                }
+                if (nnz < outputIndices.length) {
+                    outputIndices = Arrays.copyOf(outputIndices, nnz);
+                    outputValues = Arrays.copyOf(outputValues, nnz);
+                }
+                outputVec = new SparseVector(indices.length, outputIndices, 
outputValues);
+            }
+            return Row.join(row, Row.of(outputVec));
+        }
+    }
+}
diff --git 
a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicerParams.java
 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicerParams.java
new file mode 100644
index 0000000..2bf1631
--- /dev/null
+++ 
b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorslicer/VectorSlicerParams.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorslicer;
+
+import org.apache.flink.ml.common.param.HasInputCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.IntArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidator;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+
+/**
+ * Params of {@link VectorSlicer}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface VectorSlicerParams<T> extends HasInputCol<T>, HasOutputCol<T> 
{
+    Param<Integer[]> INDICES =
+            new IntArrayParam(
+                    "indices",
+                    "An array of indices to select features from a vector 
column.",
+                    null,
+                    indicesValidator());
+
+    default Integer[] getIndices() {
+        return get(INDICES);
+    }
+
+    default T setIndices(Integer... value) {
+        return set(INDICES, value);
+    }
+
+    // Checks the indices parameter.
+    static ParamValidator<Integer[]> indicesValidator() {
+        return indices -> {
+            if (indices == null) {
+                return false;
+            }
+            for (Number ele : indices) {
+                if (ele.doubleValue() < 0) {
+                    return false;
+                }
+            }
+            Set<Integer> set = new HashSet<>(Arrays.asList(indices));
+            if (set.size() != indices.length) {
+                return false;
+            }
+            return set.size() != 0;
+        };
+    }
+}
diff --git 
a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java
new file mode 100644
index 0000000..4a75987
--- /dev/null
+++ 
b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorSlicerTest.java
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.vectorslicer.VectorSlicer;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests {@link VectorSlicer}. */
+public class VectorSlicerTest extends AbstractTestBase {
+
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(
+                    Row.of(
+                            0,
+                            Vectors.dense(2.1, 3.1, 2.3, 3.4, 5.3, 5.1),
+                            Vectors.sparse(5, new int[] {1, 3, 4}, new 
double[] {0.1, 0.2, 0.3})),
+                    Row.of(
+                            1,
+                            Vectors.dense(2.3, 4.1, 1.3, 2.4, 5.1, 4.1),
+                            Vectors.sparse(5, new int[] {1, 2, 4}, new 
double[] {0.1, 0.2, 0.3})));
+
+    private static final DenseVector EXPECTED_OUTPUT_DATA_1 = 
Vectors.dense(2.1, 3.1, 2.3);
+    private static final DenseVector EXPECTED_OUTPUT_DATA_2 = 
Vectors.dense(2.3, 4.1, 1.3);
+
+    private static final SparseVector EXPECTED_OUTPUT_DATA_3 =
+            Vectors.sparse(3, new int[] {1}, new double[] {0.1});
+    private static final SparseVector EXPECTED_OUTPUT_DATA_4 =
+            Vectors.sparse(3, new int[] {1, 2}, new double[] {0.1, 0.2});
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+        StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(INPUT_DATA);
+        inputDataTable = tEnv.fromDataStream(dataStream).as("id", "vec", 
"sparseVec");
+    }
+
+    private void verifyOutputResult(Table output, String outputCol, boolean 
isSparse)
+            throws Exception {
+        DataStream<Row> dataStream = tEnv.toDataStream(output);
+        List<Row> results = 
IteratorUtils.toList(dataStream.executeAndCollect());
+        assertEquals(2, results.size());
+        for (Row result : results) {
+            if (result.getField(0) == (Object) 0) {
+                if (isSparse) {
+                    assertEquals(EXPECTED_OUTPUT_DATA_3, 
result.getField(outputCol));
+                } else {
+                    assertEquals(EXPECTED_OUTPUT_DATA_1, 
result.getField(outputCol));
+                }
+            } else if (result.getField(0) == (Object) 1) {
+                if (isSparse) {
+                    assertEquals(EXPECTED_OUTPUT_DATA_4, 
result.getField(outputCol));
+                } else {
+                    assertEquals(EXPECTED_OUTPUT_DATA_2, 
result.getField(outputCol));
+                }
+            } else {
+                throw new RuntimeException("Result id value is error, it must 
be 0 or 1.");
+            }
+        }
+    }
+
+    @Test
+    public void testParam() {
+        VectorSlicer vectorSlicer = new VectorSlicer();
+        assertEquals("input", vectorSlicer.getInputCol());
+        assertEquals("output", vectorSlicer.getOutputCol());
+        vectorSlicer.setInputCol("vec").setOutputCol("sliceVec").setIndices(0, 
1, 2);
+        assertEquals("vec", vectorSlicer.getInputCol());
+        assertEquals("sliceVec", vectorSlicer.getOutputCol());
+        assertArrayEquals(new Integer[] {0, 1, 2}, vectorSlicer.getIndices());
+    }
+
+    @Test
+    public void testSaveLoadAndTransform() throws Exception {
+        VectorSlicer vectorSlicer =
+                new 
VectorSlicer().setInputCol("vec").setOutputCol("sliceVec").setIndices(0, 1, 2);
+        VectorSlicer loadedVectorSlicer =
+                TestUtils.saveAndReload(
+                        tEnv, vectorSlicer, 
TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+        Table output = loadedVectorSlicer.transform(inputDataTable)[0];
+        verifyOutputResult(output, loadedVectorSlicer.getOutputCol(), false);
+    }
+
+    @Test
+    public void testEmptyIndices() {
+        try {
+            VectorSlicer vectorSlicer =
+                    new 
VectorSlicer().setInputCol("vec").setOutputCol("sliceVec").setIndices();
+            vectorSlicer.transform(inputDataTable);
+            fail();
+        } catch (Exception e) {
+            assertEquals("Parameter indices is given an invalid value {}", 
e.getMessage());
+        }
+    }
+
+    @Test
+    public void testIndicesLargerThanVectorSize() {
+        try {
+            VectorSlicer vectorSlicer =
+                    new VectorSlicer()
+                            .setInputCol("vec")
+                            .setOutputCol("sliceVec")
+                            .setIndices(1, 2, 10);
+            Table output = vectorSlicer.transform(inputDataTable)[0];
+            DataStream<Row> dataStream = tEnv.toDataStream(output);
+            IteratorUtils.toList(dataStream.executeAndCollect());
+            fail();
+        } catch (Exception e) {
+            assertEquals(
+                    "Index value 10 is greater than vector size:6",
+                    ExceptionUtils.getRootCause(e).getMessage());
+        }
+    }
+
+    @Test
+    public void testIndicesSmallerThanZero() {
+        try {
+            new 
VectorSlicer().setInputCol("vec").setOutputCol("sliceVec").setIndices(1, -2);
+            fail();
+        } catch (Exception e) {
+            assertEquals("Parameter indices is given an invalid value {1,-2}", 
e.getMessage());
+        }
+    }
+
+    @Test
+    public void testDuplicateIndices() {
+        try {
+            new 
VectorSlicer().setInputCol("vec").setOutputCol("sliceVec").setIndices(1, 1, 3);
+            fail();
+        } catch (Exception e) {
+            assertEquals("Parameter indices is given an invalid value 
{1,1,3}", e.getMessage());
+        }
+    }
+
+    @Test
+    public void testDenseTransform() throws Exception {
+        VectorSlicer vectorSlicer =
+                new 
VectorSlicer().setInputCol("vec").setOutputCol("sliceVec").setIndices(0, 1, 2);
+
+        Table output = vectorSlicer.transform(inputDataTable)[0];
+        verifyOutputResult(output, vectorSlicer.getOutputCol(), false);
+    }
+
+    @Test
+    public void testDenseTransformWithUnorderedIndices() throws Exception {
+        VectorSlicer vectorSlicer =
+                new 
VectorSlicer().setInputCol("vec").setOutputCol("sliceVec").setIndices(0, 2, 1);
+
+        Table output = vectorSlicer.transform(inputDataTable)[0];
+        DataStream<Row> dataStream = tEnv.toDataStream(output);
+        List<Row> results = 
IteratorUtils.toList(dataStream.executeAndCollect());
+        assertEquals(2, results.size());
+        for (Row result : results) {
+            if (result.getField(0) == (Object) 0) {
+                assertEquals(
+                        Vectors.dense(2.1, 2.3, 3.1), 
result.getField(vectorSlicer.getOutputCol()));
+
+            } else if (result.getField(0) == (Object) 1) {
+                assertEquals(
+                        Vectors.dense(2.3, 1.3, 4.1), 
result.getField(vectorSlicer.getOutputCol()));
+            } else {
+                throw new RuntimeException("Result id value is error, it must 
be 0 or 1.");
+            }
+        }
+    }
+
+    @Test
+    public void testSparseTransform() throws Exception {
+        VectorSlicer vectorSlicer =
+                new VectorSlicer()
+                        .setInputCol("sparseVec")
+                        .setOutputCol("sliceVec")
+                        .setIndices(0, 1, 2);
+        Table output = vectorSlicer.transform(inputDataTable)[0];
+        verifyOutputResult(output, vectorSlicer.getOutputCol(), true);
+    }
+}
diff --git 
a/flink-ml-python/pyflink/examples/ml/feature/vectorslicer_example.py 
b/flink-ml-python/pyflink/examples/ml/feature/vectorslicer_example.py
new file mode 100644
index 0000000..af41715
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/vectorslicer_example.py
@@ -0,0 +1,64 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates a VectorSlicer instance and uses it for feature
+# engineering.
+#
+# Before executing this program, please make sure you have followed Flink ML's
+# quick start guideline to set up Flink ML and Flink environment. The guideline
+# can be found at
+#
+# 
https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.feature.vectorslicer import VectorSlicer
+from pyflink.table import StreamTableEnvironment
+
+# create a new StreamExecutionEnvironment
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# create a StreamTableEnvironment
+t_env = StreamTableEnvironment.create(env)
+
+# generate input data
+input_data_table = t_env.from_data_stream(
+    env.from_collection([
+        (1, Vectors.dense(2.1, 3.1, 1.2, 2.1)),
+        (2, Vectors.dense(2.3, 2.1, 1.3, 1.2)),
+    ],
+        type_info=Types.ROW_NAMED(
+            ['id', 'vec'],
+            [Types.INT(), DenseVectorTypeInfo()])))
+
+# create a vector slicer object and initialize its parameters
+vector_slicer = VectorSlicer() \
+    .set_input_col('vec') \
+    .set_indices(1, 2, 3) \
+    .set_output_col('sub_vec')
+
+# use the vector slicer model for feature engineering
+output = vector_slicer.transform(input_data_table)[0]
+
+# extract and display the results
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+    input_value = result[field_names.index(vector_slicer.get_input_col())]
+    output_value = result[field_names.index(vector_slicer.get_output_col())]
+    print('Input Value: ' + str(input_value) + '\tOutput Value: ' + 
str(output_value))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorslicer.py 
b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorslicer.py
new file mode 100644
index 0000000..bdb06ad
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorslicer.py
@@ -0,0 +1,75 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.feature.vectorslicer import VectorSlicer
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class VectorSlicerTest(PyFlinkMLTestCase):
+    def setUp(self):
+        super(VectorSlicerTest, self).setUp()
+        self.input_data_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (1, Vectors.dense(2.1, 3.1, 1.2, 2.1)),
+                (2, Vectors.dense(2.3, 2.1, 1.3, 1.2)),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['id', 'vec'],
+                    [Types.INT(), DenseVectorTypeInfo()])))
+
+        self.expected_output_data_1 = Vectors.dense(2.1, 3.1, 1.2)
+        self.expected_output_data_2 = Vectors.dense(2.3, 2.1, 1.3)
+
+    def test_param(self):
+        vector_slicer = VectorSlicer()
+
+        self.assertEqual('input', vector_slicer.get_input_col())
+        self.assertEqual('output', vector_slicer.get_output_col())
+
+        vector_slicer.set_input_col('vec') \
+            .set_output_col('slice_vec') \
+            .set_indices(0, 1, 2)
+
+        self.assertEqual('vec', vector_slicer.get_input_col())
+        self.assertEqual((0, 1, 2), vector_slicer.get_indices())
+        self.assertEqual('slice_vec', vector_slicer.get_output_col())
+
+    def test_save_load_transform(self):
+        vector_slicer = VectorSlicer() \
+            .set_input_col('vec') \
+            .set_output_col('slice_vec') \
+            .set_indices(0, 1, 2)
+
+        path = os.path.join(self.temp_dir, 
'test_save_load_transform_vector_slicer')
+        vector_slicer.save(path)
+        vector_slicer = VectorSlicer.load(self.t_env, path)
+
+        output_table = vector_slicer.transform(self.input_data_table)[0]
+        actual_outputs = [(result[0], result[2]) for result in
+                          
self.t_env.to_data_stream(output_table).execute_and_collect()]
+
+        self.assertEqual(2, len(actual_outputs))
+        for actual_output in actual_outputs:
+            if actual_output[0] == 1:
+                self.assertEqual(self.expected_output_data_1, actual_output[1])
+            else:
+                self.assertEqual(self.expected_output_data_2, actual_output[1])
diff --git a/flink-ml-python/pyflink/ml/lib/feature/vectorslicer.py 
b/flink-ml-python/pyflink/ml/lib/feature/vectorslicer.py
new file mode 100644
index 0000000..cb72518
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/vectorslicer.py
@@ -0,0 +1,91 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from typing import Tuple
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.core.param import IntArrayParam, ParamValidator
+from pyflink.ml.lib.feature.common import JavaFeatureTransformer
+from pyflink.ml.lib.param import HasInputCol, HasOutputCol, Param
+
+
+class _VectorSlicerParams(
+    JavaWithParams,
+    HasInputCol,
+    HasOutputCol
+):
+    """
+      Checks the indices parameter.
+    """
+
+    def indices_validator(self) -> ParamValidator[Tuple[int]]:
+        class IndicesValidator(ParamValidator[Tuple[int]]):
+            def validate(self, indices: Tuple[int]) -> bool:
+                for val in indices:
+                    if val < 0:
+                        return False
+                return True
+                indices_set = set(indices)
+                if len(indices_set) != len(indices):
+                    return False
+                return len(indices_set) != 0
+        return IndicesValidator()
+
+    """
+    Params for :class:`VectorSlicer`.
+    """
+
+    INDICES: Param[Tuple[int, ...]] = IntArrayParam(
+        "indices",
+        "An array of indices to select features from a vector column.",
+        None,
+        indices_validator(None))
+
+    def __init__(self, java_params):
+        super(_VectorSlicerParams, self).__init__(java_params)
+
+    def set_indices(self, *ind: int):
+        return self.set(self.INDICES, ind)
+
+    def get_indices(self) -> Tuple[int, ...]:
+        return self.get(self.INDICES)
+
+    @property
+    def indices(self) -> Tuple[int, ...]:
+        return self.get_indices()
+
+
+class VectorSlicer(JavaFeatureTransformer, _VectorSlicerParams):
+    """
+    A Transformer that transforms a vector to a new feature, which is a 
sub-array of
+    the original feature.It is useful for extracting features from a given 
vector.
+
+    Note that duplicate features are not allowed, so there can be no overlap 
between
+    selected indices. If the max value of the indices is greater than the size 
of
+    the input vector, it throws an IllegalArgumentException.
+    """
+
+    def __init__(self, java_model=None):
+        super(VectorSlicer, self).__init__(java_model)
+
+    @classmethod
+    def _java_transformer_package_name(cls) -> str:
+        return "vectorslicer"
+
+    @classmethod
+    def _java_transformer_class_name(cls) -> str:
+        return "VectorSlicer"

Reply via email to