This is an automated email from the ASF dual-hosted git repository. zhangzp pushed a commit to branch yuhe_release in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit 9c6ab9e4f48fd7cdf1d41355cf9237ac9f6c6426 Author: weibo <[email protected]> AuthorDate: Tue Oct 11 10:50:27 2022 +0800 Add AlgoOperator for Filter, TypeTransform, Splitter on Flink --- docs/content/docs/operators/feature/splitter.md | 144 ++++++ .../flink/ml/examples/feature/SplitterExample.java | 63 +++ .../org/apache/flink/ml/feature/filter/Filter.java | 82 +++ .../flink/ml/feature/filter/FilterParams.java | 67 +++ .../apache/flink/ml/feature/splitter/Splitter.java | 113 +++++ .../flink/ml/feature/splitter/SplitterParams.java | 57 +++ .../flink/ml/feature/transform/TypeTransform.java | 186 +++++++ .../ml/feature/transform/TypeTransformParams.java | 191 +++++++ .../org/apache/flink/ml/feature/FilterTest.java | 85 ++++ .../org/apache/flink/ml/feature/SplitterTest.java | 181 +++++++ .../apache/flink/ml/feature/TypeTransformTest.java | 547 +++++++++++++++++++++ .../examples/ml/feature/splitter_example.py | 60 +++ flink-ml-python/pyflink/ml/lib/feature/splitter.py | 82 +++ .../pyflink/ml/lib/feature/tests/test_splitter.py | 104 ++++ 14 files changed, 1962 insertions(+) diff --git a/docs/content/docs/operators/feature/splitter.md b/docs/content/docs/operators/feature/splitter.md new file mode 100644 index 00000000..28f35429 --- /dev/null +++ b/docs/content/docs/operators/feature/splitter.md @@ -0,0 +1,144 @@ +--- +title: "Splitter" +weight: 1 +type: docs +aliases: +- /operators/feature/splitter.html +--- + +<!-- +Licensed to the Apache Software Foundation (ASF) under one +or more contributor license agreements. See the NOTICE file +distributed with this work for additional information +regarding copyright ownership. The ASF licenses this file +to you under the Apache License, Version 2.0 (the +"License"); you may not use this file except in compliance +with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, +software distributed under the License is distributed on an +"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +KIND, either express or implied. See the License for the +specific language governing permissions and limitations +under the License. +--> + +## Splitter + +An AlgoOperator which splits a dataset into two datasets according to a given fraction. + +### Parameters + +| Key | Default | Type | Required | Description | +|:---------|:--------|:--------|:---------|:-------------------------------------------------------------| +| fraction | `0.5` | Double | no | Proportion of data allocated to left output after splitting. | +| seed | `null` | Integer | no | The random seed. | + +### Examples + +{{< tabs examples >}} + +{{< tab "Java">}} + +```java +import org.apache.flink.ml.feature.splitter.Splitter; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +/** Simple program that creates a Splitter instance and uses it for data splitting. */ +public class SplitterExample { + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input data. + DataStream<Row> inputStream = + env.fromElements( + Row.of(1, 10, 0), + Row.of(1, 10, 0), + Row.of(1, 10, 0), + Row.of(4, 10, 0), + Row.of(5, 10, 0), + Row.of(6, 10, 0), + Row.of(7, 10, 0), + Row.of(10, 10, 0), + Row.of(13, 10, 3)); + Table inputTable = tEnv.fromDataStream(inputStream).as("input"); + + // Creates a Splitter object and initializes its parameters. + Splitter splitter = new Splitter().setFraction(0.4).setSeed(1); + + // Uses the Splitter to split inputData. + Table[] outputTable = splitter.transform(inputTable); + + // Extracts and displays the results. + for (CloseableIterator<Row> it = outputTable[0].execute().collect(); it.hasNext(); ) { + System.out.printf("Data 1 : %s\n", it.next()); + } + for (CloseableIterator<Row> it = outputTable[1].execute().collect(); it.hasNext(); ) { + System.out.printf("Data 2 : %s\n", it.next()); + } + } +} + +``` + +{{< /tab>}} + +{{< tab "Python">}} + +```python +# Simple program that creates a Splitter instance and uses it for feature +# engineering. + +from pyflink.common import Types +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.ml.lib.feature.splitter import Splitter +from pyflink.table import StreamTableEnvironment + +# Creates a new StreamExecutionEnvironment. +env = StreamExecutionEnvironment.get_execution_environment() + +# Creates a StreamTableEnvironment. +t_env = StreamTableEnvironment.create(env) + +# Generates input. +input_table = t_env.from_data_stream( + env.from_collection([ + (1, 10, 0), + (1, 10, 0), + (1, 10, 0), + (4, 10, 0), + (5, 10, 0), + (6, 10, 0), + (7, 10, 0), + (10, 10, 0), + (13, 10, 0) + ], + type_info=Types.ROW_NAMED( + ['input', ], + [Types.INT(), ]))) + +# Creates a Splitter object and initializes its parameters. +splitter = Splitter().set_seed(1).set_fraction(0.4) + +# Uses the Splitter to split dataset. +output = splitter.transform(input_table) + +# Extracts and displays the results. +for result in t_env.to_data_stream(output[0]).execute_and_collect(): + print('Data 1: ' + str(result)) +for result in t_env.to_data_stream(output[1]).execute_and_collect(): + print('Data 2: ' + str(result)) + +``` + +{{< /tab>}} + +{{< /tabs>}} diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/SplitterExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/SplitterExample.java new file mode 100644 index 00000000..5d9eb27e --- /dev/null +++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/SplitterExample.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.examples.feature; + +import org.apache.flink.ml.feature.splitter.Splitter; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.types.Row; +import org.apache.flink.util.CloseableIterator; + +/** Simple program that creates a Splitter instance and uses it for data splitting. */ +public class SplitterExample { + public static void main(String[] args) { + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(); + StreamTableEnvironment tEnv = StreamTableEnvironment.create(env); + + // Generates input data. + DataStream<Row> inputStream = + env.fromElements( + Row.of(1, 10, 0), + Row.of(1, 10, 0), + Row.of(1, 10, 0), + Row.of(4, 10, 0), + Row.of(5, 10, 0), + Row.of(6, 10, 0), + Row.of(7, 10, 0), + Row.of(10, 10, 0), + Row.of(13, 10, 3)); + Table inputTable = tEnv.fromDataStream(inputStream).as("input"); + + // Creates a Splitter object and initializes its parameters. + Splitter splitter = new Splitter().setFraction(0.4).setSeed(1); + + // Uses the Splitter to split inputData. + Table[] outputTable = splitter.transform(inputTable); + + // Extracts and displays the results. + for (CloseableIterator<Row> it = outputTable[0].execute().collect(); it.hasNext(); ) { + System.out.printf("Data 1 : %s\n", it.next()); + } + for (CloseableIterator<Row> it = outputTable[1].execute().collect(); it.hasNext(); ) { + System.out.printf("Data 2 : %s\n", it.next()); + } + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/filter/Filter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/filter/Filter.java new file mode 100644 index 00000000..81ad0ba0 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/filter/Filter.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.filter; + +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.table.api.ApiExpression; +import org.apache.flink.table.api.Expressions; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +/** */ +public class Filter implements Transformer<Filter>, FilterParams<Filter> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public Filter() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + + String[] inputCols = getInputCols(); + String[] outputCols = getOutputCols(); + ApiExpression[] expressions; + if (inputCols != null) { + expressions = new ApiExpression[inputCols.length]; + + for (int i = 0; i < inputCols.length; ++i) { + expressions[i] = Expressions.$(inputCols[i]); + } + + Preconditions.checkArgument(outputCols.length == inputCols.length); + for (int i = 0; i < outputCols.length; ++i) { + expressions[i].as(outputCols[i]); + } + } else { + expressions = new ApiExpression[] {Expressions.$("*")}; + } + Table outputTable = + inputs[0].where(Expressions.callSql(getWhereClause())).select(expressions); + return new Table[] {outputTable}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static Filter load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/filter/FilterParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/filter/FilterParams.java new file mode 100644 index 00000000..0d22782d --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/filter/FilterParams.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.filter; + +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringArrayParam; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Params of {@link Filter}. + * + * @param <T> The class type of this instance. + */ +public interface FilterParams<T> extends WithParams<T> { + Param<String[]> INPUT_COLS = + new StringArrayParam( + "inputCols", "Input column names.", null, ParamValidators.alwaysTrue()); + + Param<String[]> OUTPUT_COLS = + new StringArrayParam( + "outputCols", "Output column names.", null, ParamValidators.alwaysTrue()); + + Param<String> WHERE_CLAUSE = + new StringParam("whereClause", "where clause.", null, ParamValidators.notNull()); + + default String[] getOutputCols() { + return get(OUTPUT_COLS); + } + + default T setOutputCols(String... value) { + return set(OUTPUT_COLS, value); + } + + default String getWhereClause() { + return get(WHERE_CLAUSE); + } + + default T setWhereClause(String value) { + return set(WHERE_CLAUSE, value); + } + + default String[] getInputCols() { + return get(INPUT_COLS); + } + + default T setInputCols(String... value) { + return set(INPUT_COLS, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/splitter/Splitter.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/splitter/Splitter.java new file mode 100644 index 00000000..773bb2d4 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/splitter/Splitter.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.splitter; + +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.AlgoOperator; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.types.Row; +import org.apache.flink.util.OutputTag; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +/** An AlgoOperator which splits a Table into two Tables according to the given fraction. */ +public class Splitter implements AlgoOperator<Splitter>, SplitterParams<Splitter> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + + public Splitter() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + RowTypeInfo outputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema()); + + OutputTag<Row> outputTag = new OutputTag<Row>("outputTag", outputTypeInfo) {}; + + int seed = getSeed() != null ? getSeed() : new Random().nextInt(); + SingleOutputStreamOperator<Row> results = + tEnv.toDataStream(inputs[0]) + .transform( + "SplitterOperator", + outputTypeInfo, + new SplitterOperator(outputTag, getFraction(), seed)); + + Table[] outputTables = new Table[2]; + outputTables[0] = tEnv.fromDataStream(results); + + DataStream<Row> dataStream = results.getSideOutput(outputTag); + outputTables[1] = tEnv.fromDataStream(dataStream); + + return outputTables; + } + + private static class SplitterOperator extends AbstractStreamOperator<Row> + implements OneInputStreamOperator<Row, Row> { + private final Random random; + OutputTag<Row> outputTag; + final double fraction; + + public SplitterOperator(OutputTag<Row> outputTag, double fraction, int seed) { + this.outputTag = outputTag; + this.fraction = fraction; + random = new Random(seed); + } + + @Override + public void processElement(StreamRecord<Row> streamRecord) throws Exception { + if (random.nextDouble() < fraction) { + output.collect(streamRecord); + } else { + output.collect(outputTag, streamRecord); + } + } + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static Splitter load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/splitter/SplitterParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/splitter/SplitterParams.java new file mode 100644 index 00000000..c316597e --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/splitter/SplitterParams.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.splitter; + +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.WithParams; + +/** + * Params of {@link Splitter}. + * + * @param <T> The class type of this instance. + */ +public interface SplitterParams<T> extends WithParams<T> { + Param<Double> FRACTION = + new DoubleParam( + "fraction", + "Proportion of data allocated to left output after splitting.", + 0.5, + ParamValidators.inRange(0.0, 1.0)); + + default double getFraction() { + return get(FRACTION); + } + + default T setFraction(Double value) { + return set(FRACTION, value); + } + + Param<Integer> SEED = new IntParam("seed", "The random seed.", null); + + default Integer getSeed() { + return get(SEED); + } + + default T setSeed(int value) { + return set(SEED, value); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/transform/TypeTransform.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/transform/TypeTransform.java new file mode 100644 index 00000000..2b1b3bec --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/transform/TypeTransform.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.transform; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.ml.api.Transformer; +import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.util.ParamUtils; +import org.apache.flink.ml.util.ReadWriteUtils; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.table.api.ApiExpression; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Expressions; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.api.internal.TableImpl; +import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.types.Row; +import org.apache.flink.util.Preconditions; + +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** A Transformer that transform the types of special columns. */ +public class TypeTransform + implements Transformer<TypeTransform>, TypeTransformParams<TypeTransform> { + private final Map<Param<?>, Object> paramMap = new HashMap<>(); + private static final String PREFIX = "typed_"; + + public TypeTransform() { + ParamUtils.initializeMapWithDefaultValues(paramMap, this); + } + + @Override + public Table[] transform(Table... inputs) { + Preconditions.checkArgument(inputs.length == 1); + StreamTableEnvironment tEnv = + (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment(); + ResolvedSchema schema = inputs[0].getResolvedSchema(); + String[] allCols = schema.getColumnNames().toArray(new String[0]); + List<String> toDoubleCols = Arrays.asList(getToDoubleCols()); + List<String> toFloatCols = Arrays.asList(getToFloatCols()); + List<String> toIntCols = Arrays.asList(getToIntCols()); + List<String> toLongCols = Arrays.asList(getToLongCols()); + List<String> toStringCols = Arrays.asList(getToStringCols()); + final boolean keepOrigin = getKeepOldCols(); + final double defaultDoubleVal = getDefaultDoubleValue(); + final float defaultFloatVal = getDefaultFloatValue(); + final long defaultLongVal = getDefaultLongValue(); + final int defaultIntVal = getDefaultIntValue(); + final String defaultStringVal = getDefaultStringValue(); + + ApiExpression[] expressions = + new ApiExpression + [allCols.length + + (keepOrigin + ? (toDoubleCols.size() + + toLongCols.size() + + toStringCols.size() + + toFloatCols.size() + + toIntCols.size()) + : 0)]; + int iter = 0; + if (keepOrigin) { + for (String colName : allCols) { + expressions[iter++] = Expressions.$(colName); + } + } + + for (String colName : allCols) { + if (toDoubleCols.contains(colName)) { + expressions[iter++] = + Expressions.$(colName) + .tryCast(DataTypes.DOUBLE()) + .as((keepOrigin ? PREFIX : "") + colName); + } else if (toFloatCols.contains(colName)) { + expressions[iter++] = + Expressions.$(colName) + .tryCast(DataTypes.FLOAT()) + .as((keepOrigin ? PREFIX : "") + colName); + } else if (toIntCols.contains(colName)) { + expressions[iter++] = + Expressions.$(colName) + .tryCast(DataTypes.INT()) + .as((keepOrigin ? PREFIX : "") + colName); + } else if (toLongCols.contains(colName)) { + expressions[iter++] = + Expressions.$(colName) + .tryCast(DataTypes.BIGINT()) + .as((keepOrigin ? PREFIX : "") + colName); + } else if (toStringCols.contains(colName)) { + expressions[iter++] = + Expressions.$(colName) + .tryCast(DataTypes.STRING()) + .as((keepOrigin ? PREFIX : "") + colName); + } else { + if (!keepOrigin) { + expressions[iter++] = Expressions.$(colName); + } + } + } + + Table middleTable = inputs[0].select(expressions); + DataStream<Row> outputStream = tEnv.toDataStream(middleTable); + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(middleTable.getResolvedSchema()); + + outputStream = + outputStream.map( + (MapFunction<Row, Row>) + row -> { + for (String colName : allCols) { + if (toDoubleCols.contains(colName)) { + String tmpName = + keepOrigin ? PREFIX + colName : colName; + if (row.getField(tmpName) == null) { + row.setField(tmpName, defaultDoubleVal); + } + } else if (toFloatCols.contains(colName)) { + String tmpName = + keepOrigin ? PREFIX + colName : colName; + if (row.getField(tmpName) == null) { + row.setField(tmpName, defaultFloatVal); + } + } else if (toIntCols.contains(colName)) { + String tmpName = + keepOrigin ? PREFIX + colName : colName; + if (row.getField(tmpName) == null) { + row.setField(tmpName, defaultIntVal); + } + } else if (toLongCols.contains(colName)) { + String tmpName = + keepOrigin ? PREFIX + colName : colName; + if (row.getField(tmpName) == null) { + row.setField(tmpName, defaultLongVal); + } + } else if (toStringCols.contains(colName)) { + String tmpName = + keepOrigin ? PREFIX + colName : colName; + if (row.getField(tmpName) == null) { + row.setField(tmpName, defaultStringVal); + } + } + } + + return row; + }, + inputTypeInfo); + Table outputTable = tEnv.fromDataStream(outputStream); + return new Table[] {outputTable}; + } + + @Override + public void save(String path) throws IOException { + ReadWriteUtils.saveMetadata(this, path); + } + + public static TypeTransform load(StreamTableEnvironment env, String path) throws IOException { + return ReadWriteUtils.loadStageParam(path); + } + + @Override + public Map<Param<?>, Object> getParamMap() { + return paramMap; + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/transform/TypeTransformParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/transform/TypeTransformParams.java new file mode 100644 index 00000000..5749d177 --- /dev/null +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/transform/TypeTransformParams.java @@ -0,0 +1,191 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature.transform; + +import org.apache.flink.ml.param.BooleanParam; +import org.apache.flink.ml.param.DoubleParam; +import org.apache.flink.ml.param.FloatParam; +import org.apache.flink.ml.param.IntParam; +import org.apache.flink.ml.param.LongParam; +import org.apache.flink.ml.param.Param; +import org.apache.flink.ml.param.ParamValidators; +import org.apache.flink.ml.param.StringArrayParam; +import org.apache.flink.ml.param.StringParam; +import org.apache.flink.ml.param.WithParams; + +/** + * Params of {@link TypeTransform}. + * + * @param <T> The class type of this instance. + */ +public interface TypeTransformParams<T> extends WithParams<T> { + + Param<Double> DEFAULT_DOUBLE_VALUE = + new DoubleParam( + "defaultDoubleValue", + "The default double value.", + 0.0, + ParamValidators.alwaysTrue()); + + Param<Float> DEFAULT_FLOAT_VALUE = + new FloatParam( + "defaultFloatValue", + "The default float value.", + 0.0F, + ParamValidators.alwaysTrue()); + Param<Integer> DEFAULT_INT_VALUE = + new IntParam( + "defaultIntValue", "The default int value.", 0, ParamValidators.alwaysTrue()); + + Param<Long> DEFAULT_LONG_VALUE = + new LongParam( + "defaultLongValue", + "The default long value.", + 0L, + ParamValidators.alwaysTrue()); + + Param<String> DEFAULT_STRING_VALUE = + new StringParam( + "defaultStringValue", + "The default string value.", + "", + ParamValidators.alwaysTrue()); + + Param<String[]> TO_DOUBLE_COLS = + new StringArrayParam( + "toDoubleCols", + "Input column names to double.", + new String[] {}, + ParamValidators.alwaysTrue()); + Param<String[]> TO_FLOAT_COLS = + new StringArrayParam( + "toFloatCols", + "Input column names to float.", + new String[] {}, + ParamValidators.alwaysTrue()); + Param<String[]> TO_INT_COLS = + new StringArrayParam( + "toIntCols", + "Input column names to int.", + new String[] {}, + ParamValidators.alwaysTrue()); + Param<String[]> TO_LONG_COLS = + new StringArrayParam( + "toLongCols", + "Input column names to long.", + new String[] {}, + ParamValidators.alwaysTrue()); + Param<String[]> TO_STRING_COLS = + new StringArrayParam( + "toStringCols", + "Input column names to string.", + new String[] {}, + ParamValidators.alwaysTrue()); + + Param<Boolean> KEEP_OLD_COLS = + new BooleanParam("keepOldCols", "Whether to keep the old columns.", false); + + default Double getDefaultDoubleValue() { + return get(DEFAULT_DOUBLE_VALUE); + } + + default T setDefaultDoubleValue(Double value) { + return set(DEFAULT_DOUBLE_VALUE, value); + } + + default Float getDefaultFloatValue() { + return get(DEFAULT_FLOAT_VALUE); + } + + default T setDefaultFloatValue(Float value) { + return set(DEFAULT_FLOAT_VALUE, value); + } + + default Integer getDefaultIntValue() { + return get(DEFAULT_INT_VALUE); + } + + default T setDefaultIntValue(Integer value) { + return set(DEFAULT_INT_VALUE, value); + } + + default Long getDefaultLongValue() { + return get(DEFAULT_LONG_VALUE); + } + + default T setDefaultLongValue(Long value) { + return set(DEFAULT_LONG_VALUE, value); + } + + default String getDefaultStringValue() { + return get(DEFAULT_STRING_VALUE); + } + + default T setDefaultStringValue(String value) { + return set(DEFAULT_STRING_VALUE, value); + } + + default String[] getToDoubleCols() { + return get(TO_DOUBLE_COLS); + } + + default T setToDoubleCols(String... value) { + return set(TO_DOUBLE_COLS, value); + } + + default String[] getToFloatCols() { + return get(TO_FLOAT_COLS); + } + + default T setToFloatCols(String... value) { + return set(TO_FLOAT_COLS, value); + } + + default String[] getToIntCols() { + return get(TO_INT_COLS); + } + + default T setToIntCols(String... value) { + return set(TO_INT_COLS, value); + } + + default String[] getToLongCols() { + return get(TO_LONG_COLS); + } + + default T setToLongCols(String... value) { + return set(TO_LONG_COLS, value); + } + + default String[] getToStringCols() { + return get(TO_STRING_COLS); + } + + default T setToStringCols(String... value) { + return set(TO_STRING_COLS, value); + } + + default boolean getKeepOldCols() { + return get(KEEP_OLD_COLS); + } + + default T setKeepOldCols(boolean value) { + return set(KEEP_OLD_COLS, value); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FilterTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FilterTest.java new file mode 100644 index 00000000..a190a209 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FilterTest.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.filter.Filter; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.Arrays; +import java.util.List; + +/** Tests the {@link Filter}. */ +public class FilterTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamTableEnvironment tEnv; + private Table inputTable; + + private static final List<Row> inputData = + Arrays.asList( + Row.of(1, "-0.5", "0.0", 1.0, 0.0), + Row.of( + 2, + "Double.NEGATIVE_INFINITY", + "1.0", + Double.POSITIVE_INFINITY, + Double.NaN), + Row.of(3, "1.0", "1.0", -0.5, null)); + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + inputTable = + tEnv.fromDataStream(env.fromCollection(inputData)).as("id", "f1", "f2", "f3", "f4"); + } + + @Test + public void testTransform() throws Exception { + Filter filter = + new Filter() + // .setInputCols("id", "f1", "f2", "f3", "f4") + // .setOutputCols("id", "c1", "c2", "c4", "f4") + // .setWhereClause("id=1 and f3<>-10"); + .setWhereClause("f1 rlike f2"); + Table output = filter.transform(inputTable)[0]; + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/SplitterTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/SplitterTest.java new file mode 100644 index 00000000..bc779524 --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/SplitterTest.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.splitter.Splitter; +import org.apache.flink.ml.util.TestUtils; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; + +/** Tests {@link Splitter}. */ +public class SplitterTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(2); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + + tEnv = StreamTableEnvironment.create(env); + } + + private Table getTable(int size) { + List<Row> data = new ArrayList<>(); + for (int i = 0; i < size; ++i) { + data.add(Row.of(i)); + } + return tEnv.fromDataStream(env.fromCollection(data)); + } + + private void verifyOutputResult(Table output1, Table output2, int outputSize) throws Exception { + List<Integer> results1 = + IteratorUtils.toList( + tEnv.toDataStream(output1) + .map((MapFunction<Row, Integer>) row -> row.getFieldAs(0)) + .executeAndCollect()); + List<Integer> results2 = + IteratorUtils.toList( + tEnv.toDataStream(output2) + .map((MapFunction<Row, Integer>) row -> row.getFieldAs(0)) + .executeAndCollect()); + + assertEquals(outputSize, results2.size()); + assertEquals(outputSize, results1.size()); + + results1.sort(Integer::compare); + results2.sort(Integer::compare); + assertArrayEquals(results1.toArray(), results2.toArray()); + } + + @Test + public void testParam() { + Splitter splitter = new Splitter(); + assertEquals(0.5, splitter.getFraction(), 1.0e-5); + assertNull(splitter.getSeed()); + + splitter.setSeed(2).setFraction(0.3); + assertEquals(0.3, splitter.getFraction(), 1.0e-5); + assertEquals(Integer.valueOf(2), splitter.getSeed()); + } + + @Test + public void testOutputSchema() { + Table tempTable = + tEnv.fromDataStream(env.fromElements(Row.of("", ""))) + .as("test_input", "dummy_input"); + + Splitter splitter = new Splitter(); + Table output = splitter.transform(tempTable)[0]; + + assertEquals( + Arrays.asList("test_input", "dummy_input"), + output.getResolvedSchema().getColumnNames()); + } + + @Test + public void testFraction() throws Exception { + Table data = getTable(100); + Splitter splitter = new Splitter().setFraction(0.4).setSeed(2); + Table[] output = splitter.transform(data); + List<Integer> results1 = + IteratorUtils.toList( + tEnv.toDataStream(output[0]) + .map((MapFunction<Row, Integer>) row -> row.getFieldAs(0)) + .executeAndCollect()); + List<Integer> results2 = + IteratorUtils.toList( + tEnv.toDataStream(output[1]) + .map((MapFunction<Row, Integer>) row -> row.getFieldAs(0)) + .executeAndCollect()); + assertEquals(results1.size(), 42); + assertEquals(results2.size(), 58); + } + + @Test + public void testSaveLoadAndTransform() throws Exception { + Table data = getTable(100); + Splitter splitter = new Splitter().setFraction(0.4).setSeed(2); + + Splitter splitterLoad = + TestUtils.saveAndReload( + tEnv, splitter, TEMPORARY_FOLDER.newFolder().getAbsolutePath()); + + Table[] output = splitterLoad.transform(data); + + List<Integer> results1 = + IteratorUtils.toList( + tEnv.toDataStream(output[0]) + .map((MapFunction<Row, Integer>) row -> row.getFieldAs(0)) + .executeAndCollect()); + List<Integer> results2 = + IteratorUtils.toList( + tEnv.toDataStream(output[1]) + .map((MapFunction<Row, Integer>) row -> row.getFieldAs(0)) + .executeAndCollect()); + assertEquals(results1.size(), 42); + assertEquals(results2.size(), 58); + } + + @Test + public void testSeed() throws Exception { + Table data = getTable(20); + Splitter splitter = new Splitter().setFraction(0.4).setSeed(1); + + Table[] output1 = splitter.transform(data); + Table[] output2 = splitter.transform(data); + Table[] output3 = splitter.transform(data); + + verifyOutputResult(output1[0], output2[0], 6); + verifyOutputResult(output1[1], output2[1], 14); + + verifyOutputResult(output1[0], output3[0], 6); + verifyOutputResult(output1[1], output3[1], 14); + + verifyOutputResult(output2[0], output3[0], 6); + verifyOutputResult(output2[1], output3[1], 14); + } +} diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/TypeTransformTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/TypeTransformTest.java new file mode 100644 index 00000000..01039f9f --- /dev/null +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/TypeTransformTest.java @@ -0,0 +1,547 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.feature; + +import org.apache.flink.api.common.restartstrategy.RestartStrategies; +import org.apache.flink.configuration.Configuration; +import org.apache.flink.ml.feature.transform.TypeTransform; +import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.test.util.AbstractTestBase; +import org.apache.flink.types.Row; + +import org.apache.commons.collections.IteratorUtils; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; + +import java.math.BigDecimal; +import java.util.Arrays; +import java.util.List; + +/** Tests the {@link TypeTransform}. */ +public class TypeTransformTest extends AbstractTestBase { + @Rule public final TemporaryFolder tempFolder = new TemporaryFolder(); + StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + private Table testTable; + + private static final List<Row> testData = + Arrays.asList( + Row.of(1, 1.0, "1.0"), + Row.of(null, 2.3, "3.0"), + Row.of(2, null, "3.0"), + Row.of(2, 3.0, null), + Row.of(1, 2.3, "ddsef")); + + @Before + public void before() { + Configuration config = new Configuration(); + config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true); + env = StreamExecutionEnvironment.getExecutionEnvironment(config); + env.setParallelism(4); + env.enableCheckpointing(100); + env.setRestartStrategy(RestartStrategies.noRestart()); + tEnv = StreamTableEnvironment.create(env); + testTable = tEnv.fromDataStream(env.fromCollection(testData)).as("f0", "f1", "class"); + } + + Table getTable(Row r1, Row r2) { + List<Row> inputData = Arrays.asList(r1, r2); + return tEnv.fromDataStream(env.fromCollection(inputData)).as("f1", "f2"); + } + + @Test + public void test() throws Exception { + TypeTransform transform = + new TypeTransform() + .setToLongCols("f0") + .setToDoubleCols("f1", "class") + .setKeepOldCols(true) + .setDefaultLongValue(1000L) + .setDefaultDoubleValue(1000.0); + Table output = transform.transform(testTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testIntToLong() throws Exception { + Table inputTable = getTable(Row.of(1, Integer.MAX_VALUE), Row.of(3, null)); + TypeTransform transform = + new TypeTransform() + .setToLongCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultLongValue(1000L); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testIntToFloat() throws Exception { + Table inputTable = getTable(Row.of(1, Integer.MAX_VALUE), Row.of(3, null)); + TypeTransform transform = + new TypeTransform() + .setToFloatCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultFloatValue(1000.0F); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testIntToDouble() throws Exception { + Table inputTable = getTable(Row.of(1, Integer.MAX_VALUE), Row.of(3, null)); + TypeTransform transform = + new TypeTransform() + .setToDoubleCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultDoubleValue(1000.0); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testIntToString() throws Exception { + Table inputTable = getTable(Row.of(1, Integer.MAX_VALUE), Row.of(3, null)); + TypeTransform transform = + new TypeTransform() + .setToStringCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultStringValue("1000.0F"); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testDecimalToLong() throws Exception { + Table inputTable = + getTable( + Row.of(new BigDecimal("1.0"), BigDecimal.ZERO), + Row.of(BigDecimal.ONE, null)); + TypeTransform transform = + new TypeTransform() + .setToLongCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultLongValue(1000L); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testDecimalToFloat() throws Exception { + Table inputTable = + getTable( + Row.of(new BigDecimal("1.0"), BigDecimal.ZERO), + Row.of(BigDecimal.ONE, null)); + TypeTransform transform = + new TypeTransform() + .setToFloatCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultFloatValue(1000.0F); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testDecimalToDouble() throws Exception { + Table inputTable = + getTable( + Row.of(new BigDecimal("1.0"), BigDecimal.ZERO), + Row.of(BigDecimal.ONE, null)); + TypeTransform transform = + new TypeTransform() + .setToDoubleCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultDoubleValue(1000.0); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testDecimalToString() throws Exception { + Table inputTable = + getTable( + Row.of(new BigDecimal("1.0"), BigDecimal.ZERO), + Row.of(BigDecimal.ONE, null)); + TypeTransform transform = + new TypeTransform() + .setToStringCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultStringValue("1000.0F"); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testDecimalToInt() throws Exception { + Table inputTable = + getTable( + Row.of(new BigDecimal("1.0"), BigDecimal.ZERO), + Row.of(BigDecimal.ONE, null)); + TypeTransform transform = + new TypeTransform() + .setToIntCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultIntValue(1000); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testLongToFloat() throws Exception { + Table inputTable = getTable(Row.of(1L, Long.MIN_VALUE), Row.of(3L, null)); + TypeTransform transform = + new TypeTransform() + .setToFloatCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultFloatValue(1000.0F); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testLongToDouble() throws Exception { + Table inputTable = getTable(Row.of(1L, Long.MIN_VALUE), Row.of(3L, null)); + TypeTransform transform = + new TypeTransform() + .setToDoubleCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultDoubleValue(1000.0); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testLongToString() throws Exception { + Table inputTable = getTable(Row.of(1L, Long.MIN_VALUE), Row.of(3L, null)); + TypeTransform transform = + new TypeTransform() + .setToStringCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultStringValue("1000.0F"); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testLongToInt() throws Exception { + Table inputTable = getTable(Row.of(1L, Long.MIN_VALUE), Row.of(3L, null)); + TypeTransform transform = + new TypeTransform() + .setToIntCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultIntValue(1000); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testDoubleToLong() throws Exception { + Table inputTable = getTable(Row.of(1.0, Double.MAX_VALUE), Row.of(2.0, null)); + TypeTransform transform = + new TypeTransform() + .setToLongCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultLongValue(1000L); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testDoubleToFloat() throws Exception { + Table inputTable = getTable(Row.of(1.0, Double.MAX_VALUE), Row.of(2.0, null)); + TypeTransform transform = + new TypeTransform() + .setToFloatCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultFloatValue(1000.0F); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testDoubleToString() throws Exception { + Table inputTable = getTable(Row.of(1.0, Double.MAX_VALUE), Row.of(2.0, null)); + TypeTransform transform = + new TypeTransform() + .setToStringCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultStringValue("1000.0F"); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testDoubleToInt() throws Exception { + Table inputTable = getTable(Row.of(1.0, Double.MAX_VALUE), Row.of(2.0, null)); + TypeTransform transform = + new TypeTransform() + .setToIntCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultIntValue(1000); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testFloatToLong() throws Exception { + Table inputTable = getTable(Row.of(1.0f, Float.MAX_VALUE), Row.of(2.0f, null)); + TypeTransform transform = + new TypeTransform() + .setToLongCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultLongValue(1000L); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testFloatToDouble() throws Exception { + Table inputTable = getTable(Row.of(1.0f, Float.MAX_VALUE), Row.of(2.0f, null)); + TypeTransform transform = + new TypeTransform() + .setToDoubleCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultDoubleValue(1000.0); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testFloatToString() throws Exception { + Table inputTable = getTable(Row.of(1.0f, Float.MAX_VALUE), Row.of(2.0f, null)); + TypeTransform transform = + new TypeTransform() + .setToStringCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultStringValue("1000.0F"); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testFloatToInt() throws Exception { + Table inputTable = getTable(Row.of(1.0f, Float.MAX_VALUE), Row.of(2.0f, null)); + TypeTransform transform = + new TypeTransform() + .setToIntCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultIntValue(1000); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testStringToLong() throws Exception { + Table inputTable = getTable(Row.of("1.0", "2.0"), Row.of("2.0", null)); + TypeTransform transform = + new TypeTransform() + .setToLongCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultLongValue(1000L); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testStringToDouble() throws Exception { + Table inputTable = getTable(Row.of("1.0", "2.0"), Row.of("2.0", null)); + TypeTransform transform = + new TypeTransform() + .setToDoubleCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultDoubleValue(1000.0); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testStringToFloat() throws Exception { + Table inputTable = getTable(Row.of("1.0", "2.0"), Row.of("2.0", null)); + TypeTransform transform = + new TypeTransform() + .setToFloatCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultFloatValue(1000.0F); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testStringToInt() throws Exception { + Table inputTable = getTable(Row.of("1.0", "2.6"), Row.of("2.0", null)); + TypeTransform transform = + new TypeTransform() + .setToIntCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultIntValue(1000); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } + + @Test + public void testErrStringToInt() throws Exception { + Table inputTable = getTable(Row.of("1.0", "d212.6ded"), Row.of("2.0", null)); + TypeTransform transform = + new TypeTransform() + .setToIntCols("f1", "f2") + .setKeepOldCols(false) + .setDefaultIntValue(1000); + Table output = transform.transform(inputTable)[0]; + System.out.println(output.getResolvedSchema()); + List<Row> collectedResult = + IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect()); + for (Row row : collectedResult) { + System.out.println(row); + } + } +} diff --git a/flink-ml-python/pyflink/examples/ml/feature/splitter_example.py b/flink-ml-python/pyflink/examples/ml/feature/splitter_example.py new file mode 100644 index 00000000..58bb417e --- /dev/null +++ b/flink-ml-python/pyflink/examples/ml/feature/splitter_example.py @@ -0,0 +1,60 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +# Simple program that creates a Splitter instance and uses it for feature +# engineering. + +from pyflink.common import Types +from pyflink.datastream import StreamExecutionEnvironment +from pyflink.ml.lib.feature.splitter import Splitter +from pyflink.table import StreamTableEnvironment + +# Creates a new StreamExecutionEnvironment. +env = StreamExecutionEnvironment.get_execution_environment() + +# Creates a StreamTableEnvironment. +t_env = StreamTableEnvironment.create(env) + +# Generates input table. +input_table = t_env.from_data_stream( + env.from_collection([ + (1, 10, 0), + (1, 10, 0), + (1, 10, 0), + (4, 10, 0), + (5, 10, 0), + (6, 10, 0), + (7, 10, 0), + (10, 10, 0), + (13, 10, 0) + ], + type_info=Types.ROW_NAMED( + ['f0', 'f1', "f2"], + [Types.INT(), Types.INT(), Types.INT()]))) + +# Creates a Splitter object and initializes its parameters. +splitter = Splitter().set_seed(1).set_fraction(0.4) + +# Uses the Splitter to split the dataset. +output = splitter.transform(input_table) + +# Extracts and displays the results. +for result in t_env.to_data_stream(output[0]).execute_and_collect(): + print('Data 1: ' + str(result)) +for result in t_env.to_data_stream(output[1]).execute_and_collect(): + print('Data 2: ' + str(result)) diff --git a/flink-ml-python/pyflink/ml/lib/feature/splitter.py b/flink-ml-python/pyflink/ml/lib/feature/splitter.py new file mode 100644 index 00000000..218bf874 --- /dev/null +++ b/flink-ml-python/pyflink/ml/lib/feature/splitter.py @@ -0,0 +1,82 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ +import typing + +from pyflink.ml.core.param import Param, FloatParam, IntParam, ParamValidators +from pyflink.ml.core.wrapper import JavaWithParams +from pyflink.ml.lib.feature.common import JavaFeatureTransformer + + +class _SplitterParams( + JavaWithParams +): + """ + Params for :class:`Splitter`. + """ + + FRACTION: Param[float] = FloatParam( + "fraction", + "Proportion of data allocated to left output after splitting.", + 0.5, + ParamValidators.in_range(0.0, 1.0)) + + SEED: Param[int] = IntParam( + "seed", + "The random seed.", + None, + ParamValidators.always_true()) + + def __init__(self, java_params): + super(_SplitterParams, self).__init__(java_params) + + def set_fraction(self, value: float): + return typing.cast(_SplitterParams, self.set(self.FRACTION, value)) + + def set_seed(self, value: int): + return typing.cast(_SplitterParams, self.set(self.SEED, value)) + + def get_fraction(self) -> float: + return self.get(self.FRACTION) + + def get_seed(self) -> int: + return self.get(self.SEED) + + @property + def fraction(self): + return self.get_fraction() + + @property + def seed(self): + return self.get_seed() + + +class Splitter(JavaFeatureTransformer, _SplitterParams): + """ + An AlgoOperator which splits a dataset into two datasets according to a given fraction. + """ + + def __init__(self, java_model=None): + super(Splitter, self).__init__(java_model) + + @classmethod + def _java_transformer_package_name(cls) -> str: + return "splitter" + + @classmethod + def _java_transformer_class_name(cls) -> str: + return "Splitter" diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_splitter.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_splitter.py new file mode 100644 index 00000000..ec31561c --- /dev/null +++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_splitter.py @@ -0,0 +1,104 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +import os + +from pyflink.common import Types + +from pyflink.ml.lib.feature.splitter import Splitter +from pyflink.ml.tests.test_utils import PyFlinkMLTestCase + + +class SplitterTest(PyFlinkMLTestCase): + def setUp(self): + super(SplitterTest, self).setUp() + self.input_table = self.t_env.from_data_stream( + self.env.from_collection([ + (1,), + (2,), + (3,), + (4,), + (5,), + (6,), + (7,), + (8,), + (9,), + (0,) + ], + type_info=Types.ROW_NAMED( + ['f0', ], + [Types.INT(), ]))) + + self.expect_data1 = [(1), (3), (4), (7)] + self.expect_data2 = [(0), (2), (5), (6), (8), (9)] + + def test_param(self): + splitter = Splitter() + self.assertEqual(0.5, splitter.fraction) + + splitter \ + .set_seed(1) \ + .set_fraction(0.2) + + self.assertEqual(0.2, splitter.fraction) + self.assertEqual(1, splitter.seed) + + def test_output_schema(self): + splitter = Splitter() + input_data_table = self.t_env.from_data_stream( + self.env.from_collection([ + ('', ''), + ], + type_info=Types.ROW_NAMED( + ['test_input', 'dummy_input'], + [Types.STRING(), Types.STRING()]))) + output = splitter \ + .transform(input_data_table)[0] + + self.assertEqual( + ['test_input', 'dummy_input'], + output.get_schema().get_field_names()) + + def verify_split_result(self, expected1, expected2, output_table1, output_table2): + predicted_results1 = [result[0] for result in + self.t_env.to_data_stream(output_table1).execute_and_collect()] + predicted_results1.sort(key=lambda x: (x)) + expected1.sort(key=lambda x: (x)) + self.assertEqual(expected1, predicted_results1) + + predicted_results2 = [result[0] for result in + self.t_env.to_data_stream(output_table2).execute_and_collect()] + predicted_results2.sort(key=lambda x: (x)) + expected1.sort(key=lambda x: (x)) + self.assertEqual(expected2, predicted_results2) + + def test_fit_and_predict(self): + splitter = Splitter().set_seed(1).set_fraction(0.4) + + output = splitter.transform(self.input_table) + self.verify_split_result(self.expect_data1, self.expect_data2, output[0], output[1]) + + def test_save_load_predict(self): + splitter = Splitter().set_seed(1).set_fraction(0.4) + + estimator_path = os.path.join(self.temp_dir, 'test_save_load_predict_splitter') + splitter.save(estimator_path) + splitter = Splitter.load(self.t_env, estimator_path) + + output = splitter.transform(self.input_table) + self.verify_split_result(self.expect_data1, self.expect_data2, output[0], output[1])
