yunfengzhou-hub commented on code in PR #160:
URL: https://github.com/apache/flink-ml/pull/160#discussion_r1007543761


##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitterParams.java:
##########
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.randomsplitter;
+
+import org.apache.flink.ml.param.DoubleArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidator;
+import org.apache.flink.ml.param.WithParams;
+
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.Set;
+
+/**
+ * Params of {@link RandomSplitter}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface RandomSplitterParams<T> extends WithParams<T> {
+    Param<Double[]> WEIGHTS =
+            new DoubleArrayParam(
+                    "weights", "The weights of data splitting.", null, 
weightsValidator());

Review Comment:
   Weight is different from fraction. Weight means the proportion of the number 
of elements in each output split table,  instead of a possibility threshold 
like the fraction in the previous implementation.
   
   `setWeight(1,1)` is equal to `setThreshold(0.5)`, `setWeight(1,2,2)` is 
equal to `setThreshold(0.2,0.6)`.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitter.java:
##########
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.randomsplitter;
+
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+
+/** An AlgoOperator which splits a datastream into N datastreams according to 
the given weights. */
+public class RandomSplitter
+        implements AlgoOperator<RandomSplitter>, 
RandomSplitterParams<RandomSplitter> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public RandomSplitter() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        RowTypeInfo outputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+
+        final Double[] weights = getWeights();
+        OutputTag<Row>[] outputTags = new OutputTag[weights.length];
+        for (int i = 0; i < outputTags.length; ++i) {
+            outputTags[i] = new OutputTag<Row>("outputTag_" + i, 
outputTypeInfo) {};
+        }
+
+        SingleOutputStreamOperator<Row> results =
+                tEnv.toDataStream(inputs[0])
+                        .transform(
+                                "SplitterOperator",

Review Comment:
   "SplitterOperator" is used in no other place in the code. Let's change it to 
`SplitFunctionOperator`.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RandomSplitterTest.java:
##########
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.randomsplitter.RandomSplitter;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStreamSource;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link RandomSplitter}. */
+public class RandomSplitterTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(1);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+
+        tEnv = StreamTableEnvironment.create(env);
+    }
+
+    private Table getTable(int size) {
+        DataStreamSource<Long> dataStream = env.fromSequence(0L, size);
+        return tEnv.fromDataStream(dataStream);
+    }
+
+    @Test
+    public void testParam() {
+        RandomSplitter splitter = new RandomSplitter();
+        splitter.setWeights(0.3, 0.4);
+        assertArrayEquals(new Double[] {0.3, 0.4}, splitter.getWeights());
+    }
+
+    @Test
+    public void testOutputSchema() {
+        Table tempTable =
+                tEnv.fromDataStream(env.fromElements(Row.of("", "")))
+                        .as("test_input", "dummy_input");
+
+        RandomSplitter splitter = new RandomSplitter().setWeights(0.5, 0.1);
+        Table[] output = splitter.transform(tempTable);
+        assertEquals(3, output.length);
+        for (Table table : output) {
+            assertEquals(
+                    Arrays.asList("test_input", "dummy_input"),
+                    table.getResolvedSchema().getColumnNames());
+        }
+    }
+
+    @Test
+    public void teWeights() throws Exception {

Review Comment:
   testWeights.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitter.java:
##########
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.randomsplitter;
+
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+
+/** An AlgoOperator which splits a datastream into N datastreams according to 
the given weights. */
+public class RandomSplitter
+        implements AlgoOperator<RandomSplitter>, 
RandomSplitterParams<RandomSplitter> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public RandomSplitter() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        RowTypeInfo outputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+
+        final Double[] weights = getWeights();
+        OutputTag<Row>[] outputTags = new OutputTag[weights.length];
+        for (int i = 0; i < outputTags.length; ++i) {
+            outputTags[i] = new OutputTag<Row>("outputTag_" + i, 
outputTypeInfo) {};
+        }
+
+        SingleOutputStreamOperator<Row> results =
+                tEnv.toDataStream(inputs[0])
+                        .transform(
+                                "SplitterOperator",
+                                outputTypeInfo,
+                                new SplitFunctionOperator(outputTags, 
weights));
+
+        Table[] outputTables = new Table[weights.length + 1];
+        outputTables[0] = tEnv.fromDataStream(results);
+
+        for (int i = 0; i < outputTags.length; ++i) {
+            DataStream<Row> dataStream = results.getSideOutput(outputTags[i]);
+            outputTables[i + 1] = tEnv.fromDataStream(dataStream);
+        }
+        return outputTables;
+    }
+
+    private static class SplitFunctionOperator extends 
AbstractStreamOperator<Row>
+            implements OneInputStreamOperator<Row, Row> {
+        private final Random random = new Random(2022);
+        OutputTag<Row>[] outputTag;

Review Comment:
   how about moving the initialization process of output tags into the 
constructor of this class, and making this a private final variable?



##########
flink-ml-python/pyflink/examples/ml/feature/randomsplitter_example.py:
##########
@@ -0,0 +1,60 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates a RandomSplitter instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.randomsplitter import RandomSplitter
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input table.
+input_table = t_env.from_data_stream(
+    env.from_collection([
+        (1, 10, 0),
+        (1, 10, 0),
+        (1, 10, 0),
+        (4, 10, 0),
+        (5, 10, 0),
+        (6, 10, 0),
+        (7, 10, 0),
+        (10, 10, 0),
+        (13, 10, 0)
+    ],
+        type_info=Types.ROW_NAMED(
+            ['f0', 'f1', "f2"],
+            [Types.INT(), Types.INT(), Types.INT()])))
+
+# Creates a RandomSplitter object and initializes its parameters.
+splitter = RandomSplitter().set_weights(0.4)
+
+# Uses the RandomSplitter to split the dataset.
+output = splitter.transform(input_table)
+
+# Extracts and displays the results.
+for result in t_env.to_data_stream(output[0]).execute_and_collect():
+    print('Data 1: ' + str(result))

Review Comment:
   Let's keep the print format the same across java and python examples.



##########
docs/content/docs/operators/feature/randomsplitter.md:
##########
@@ -0,0 +1,143 @@
+---
+title: "RandomSplitter"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/randomSplitter.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## RandomSplitter
+
+An AlgoOperator which splits a datastream into N datastreams according to the 
given weights.
+
+### Parameters
+
+| Key     | Default | Type      | Required | Description                    |
+|:--------|:--------|:----------|:---------|:-------------------------------|
+| weights | `[0.5]` | Do uble[] | no       | The weights of data splitting. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import org.apache.flink.ml.feature.randomsplitter.RandomSplitter;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that creates a RandomSplitter instance and uses it for data 
splitting. */
+public class RandomSplitterExample {
+       public static void main(String[] args) {
+               StreamExecutionEnvironment env = 
StreamExecutionEnvironment.getExecutionEnvironment();
+               StreamTableEnvironment tEnv = 
StreamTableEnvironment.create(env);
+
+               // Generates input data.
+               DataStream <Row> inputStream =
+                       env.fromElements(
+                               Row.of(1, 10, 0),
+                               Row.of(1, 10, 0),
+                               Row.of(1, 10, 0),
+                               Row.of(4, 10, 0),
+                               Row.of(5, 10, 0),
+                               Row.of(6, 10, 0),
+                               Row.of(7, 10, 0),
+                               Row.of(10, 10, 0),
+                               Row.of(13, 10, 3));
+               Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+               // Creates a RandomSplitter object and initializes its 
parameters.
+               RandomSplitter randomSplitter = new 
RandomSplitter().setWeights(0.4);
+
+               // Uses the RandomSplitter to split inputData.
+               Table[] outputTable = randomSplitter.transform(inputTable);
+
+               // Extracts and displays the results.
+               for (CloseableIterator <Row> it = 
outputTable[0].execute().collect(); it.hasNext(); ) {
+                       System.out.printf("Data 1 : %s\n", it.next());

Review Comment:
   Let's update the documents as well.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitter.java:
##########
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.randomsplitter;
+
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+
+/** An AlgoOperator which splits a datastream into N datastreams according to 
the given weights. */
+public class RandomSplitter
+        implements AlgoOperator<RandomSplitter>, 
RandomSplitterParams<RandomSplitter> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public RandomSplitter() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        RowTypeInfo outputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+
+        final Double[] weights = getWeights();
+        OutputTag<Row>[] outputTags = new OutputTag[weights.length];
+        for (int i = 0; i < outputTags.length; ++i) {
+            outputTags[i] = new OutputTag<Row>("outputTag_" + i, 
outputTypeInfo) {};
+        }
+
+        SingleOutputStreamOperator<Row> results =
+                tEnv.toDataStream(inputs[0])
+                        .transform(
+                                "SplitterOperator",
+                                outputTypeInfo,
+                                new SplitFunctionOperator(outputTags, 
weights));
+
+        Table[] outputTables = new Table[weights.length + 1];
+        outputTables[0] = tEnv.fromDataStream(results);
+
+        for (int i = 0; i < outputTags.length; ++i) {
+            DataStream<Row> dataStream = results.getSideOutput(outputTags[i]);
+            outputTables[i + 1] = tEnv.fromDataStream(dataStream);
+        }
+        return outputTables;
+    }
+
+    private static class SplitFunctionOperator extends 
AbstractStreamOperator<Row>
+            implements OneInputStreamOperator<Row, Row> {
+        private final Random random = new Random(2022);

Review Comment:
   I would prefer the random seed to be something like `0` or 
`RandomSplitter.class.hashCode`, instead of the year when this PR is created.



##########
flink-ml-lib/src/main/java/org/apache/flink/ml/feature/randomsplitter/RandomSplitter.java:
##########
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.randomsplitter;
+
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.AlgoOperator;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.OutputTag;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+
+/** An AlgoOperator which splits a datastream into N datastreams according to 
the given weights. */
+public class RandomSplitter
+        implements AlgoOperator<RandomSplitter>, 
RandomSplitterParams<RandomSplitter> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public RandomSplitter() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) 
inputs[0]).getTableEnvironment();
+        RowTypeInfo outputTypeInfo = 
TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+
+        final Double[] weights = getWeights();
+        OutputTag<Row>[] outputTags = new OutputTag[weights.length];
+        for (int i = 0; i < outputTags.length; ++i) {
+            outputTags[i] = new OutputTag<Row>("outputTag_" + i, 
outputTypeInfo) {};
+        }
+
+        SingleOutputStreamOperator<Row> results =
+                tEnv.toDataStream(inputs[0])
+                        .transform(
+                                "SplitterOperator",
+                                outputTypeInfo,
+                                new SplitFunctionOperator(outputTags, 
weights));
+
+        Table[] outputTables = new Table[weights.length + 1];
+        outputTables[0] = tEnv.fromDataStream(results);
+
+        for (int i = 0; i < outputTags.length; ++i) {
+            DataStream<Row> dataStream = results.getSideOutput(outputTags[i]);
+            outputTables[i + 1] = tEnv.fromDataStream(dataStream);
+        }
+        return outputTables;
+    }
+
+    private static class SplitFunctionOperator extends 
AbstractStreamOperator<Row>
+            implements OneInputStreamOperator<Row, Row> {
+        private final Random random = new Random(2022);
+        OutputTag<Row>[] outputTag;
+        final Double[] weights;
+
+        public SplitFunctionOperator(OutputTag<Row>[] outputTag, Double[] 
weights) {
+            this.outputTag = outputTag;
+            this.weights = weights;
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> streamRecord) throws 
Exception {
+            Arrays.sort(weights);

Review Comment:
   Let's avoid doing a sort for each element to be processed.



##########
flink-ml-lib/src/test/java/org/apache/flink/ml/feature/RandomSplitterTest.java:
##########
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.randomsplitter.RandomSplitter;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStreamSource;
+import 
org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link RandomSplitter}. */
+public class RandomSplitterTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, 
true);
+
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(1);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+
+        tEnv = StreamTableEnvironment.create(env);
+    }
+
+    private Table getTable(int size) {
+        DataStreamSource<Long> dataStream = env.fromSequence(0L, size);
+        return tEnv.fromDataStream(dataStream);
+    }
+
+    @Test
+    public void testParam() {
+        RandomSplitter splitter = new RandomSplitter();
+        splitter.setWeights(0.3, 0.4);
+        assertArrayEquals(new Double[] {0.3, 0.4}, splitter.getWeights());
+    }
+
+    @Test
+    public void testOutputSchema() {
+        Table tempTable =
+                tEnv.fromDataStream(env.fromElements(Row.of("", "")))
+                        .as("test_input", "dummy_input");
+
+        RandomSplitter splitter = new RandomSplitter().setWeights(0.5, 0.1);
+        Table[] output = splitter.transform(tempTable);
+        assertEquals(3, output.length);
+        for (Table table : output) {
+            assertEquals(
+                    Arrays.asList("test_input", "dummy_input"),
+                    table.getResolvedSchema().getColumnNames());
+        }
+    }
+
+    @Test
+    public void teWeights() throws Exception {
+        Table data = getTable(1000);
+        RandomSplitter splitter = new RandomSplitter().setWeights(0.4, 0.6);
+        Table[] output = splitter.transform(data);
+
+        List<Row> result0 = 
IteratorUtils.toList(tEnv.toDataStream(output[0]).executeAndCollect());
+        List<Row> result1 = 
IteratorUtils.toList(tEnv.toDataStream(output[1]).executeAndCollect());
+        List<Row> result2 = 
IteratorUtils.toList(tEnv.toDataStream(output[2]).executeAndCollect());
+        /*
+         Since the results of random splitting are somewhat random, we 
guarantee that the error

Review Comment:
   Given that you have set a random seed and set the parallelism to 1, the 
result is not random now. Please feel free to remove or modify this comment.
   
   You can try `System.out.println(result0.size());` and repeat the test case 
several times to verify that there is no randomness now.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to