This is an automated email from the ASF dual-hosted git repository. zhangzp pushed a commit to branch yuhe_release in repository https://gitbox.apache.org/repos/asf/flink-ml.git
commit eb8b93b4cca1dd2338347a617d77d88fd9520164 Author: zhangzp <[email protected]> AuthorDate: Wed Nov 30 17:55:09 2022 +0800 [FLINK-30249] Fix TableUtils.getRowTypeInfo by using ExternalTypeInfo.of() --- .../flink/ml/common/datastream/TableUtils.java | 3 +- .../flink/ml/common/datastream/TableUtilsTest.java | 75 ++++++++++++++++++++++ .../flink/ml/feature/binarizer/Binarizer.java | 8 ++- .../org/apache/flink/ml/clustering/KMeansTest.java | 5 +- .../java/org/apache/flink/ml/util/TestUtils.java | 6 +- 5 files changed, 86 insertions(+), 11 deletions(-) diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java index 9245c77e..7e69134c 100644 --- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java +++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java @@ -26,6 +26,7 @@ import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.table.catalog.Column; import org.apache.flink.table.catalog.ResolvedSchema; +import org.apache.flink.table.runtime.typeutils.ExternalTypeInfo; import org.apache.flink.types.Row; /** Utility class for operations related to Table API. */ @@ -37,7 +38,7 @@ public class TableUtils { for (int i = 0; i < schema.getColumnCount(); i++) { Column column = schema.getColumn(i).get(); - types[i] = TypeInformation.of(column.getDataType().getConversionClass()); + types[i] = ExternalTypeInfo.of(column.getDataType()); names[i] = column.getName(); } return new RowTypeInfo(types, names); diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java new file mode 100644 index 00000000..099336eb --- /dev/null +++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/TableUtilsTest.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.ml.common.datastream; + +import org.apache.flink.api.common.functions.MapFunction; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.api.java.typeutils.RowTypeInfo; +import org.apache.flink.streaming.api.datastream.DataStream; +import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Table; +import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; +import org.apache.flink.table.types.DataType; +import org.apache.flink.types.Row; + +import org.apache.commons.lang3.ArrayUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.sql.Timestamp; + +/** Tests the {@link TableUtils}. */ +public class TableUtilsTest { + private StreamExecutionEnvironment env; + private StreamTableEnvironment tEnv; + + @Before + public void before() { + env = StreamExecutionEnvironment.getExecutionEnvironment(); + tEnv = StreamTableEnvironment.create(env); + } + + @Test + public void testGetRowTypeInfo() { + Table inputTable = + tEnv.fromDataStream( + env.fromElements(new Timestamp(0)), + Schema.newBuilder().column("f0", DataTypes.TIMESTAMP()).build()); + DataType inputTimeStampType = inputTable.getResolvedSchema().getColumnDataTypes().get(0); + + RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputTable.getResolvedSchema()); + RowTypeInfo outputTypeInfo = + new RowTypeInfo( + ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), Types.INT), + ArrayUtils.addAll(inputTypeInfo.getFieldNames(), "outputCol")); + + DataStream<Row> mappedOutput = + tEnv.toDataStream(inputTable) + .map( + (MapFunction<Row, Row>) row -> Row.of(row.getField(0), 1), + outputTypeInfo); + + DataType outputTimeStampType = + tEnv.fromDataStream(mappedOutput).getResolvedSchema().getColumnDataTypes().get(0); + Assert.assertEquals(inputTimeStampType, outputTimeStampType); + } +} diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java index bdf16377..aafbf6e7 100644 --- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java +++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/binarizer/Binarizer.java @@ -26,6 +26,7 @@ import org.apache.flink.ml.api.Transformer; import org.apache.flink.ml.common.datastream.TableUtils; import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.SparseVector; +import org.apache.flink.ml.linalg.Vector; import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo; @@ -70,11 +71,12 @@ public class Binarizer implements Transformer<Binarizer>, BinarizerParams<Binari for (int i = 0; i < inputCols.length; ++i) { int idx = inputTypeInfo.getFieldIndex(inputCols[i]); - if (inputTypeInfo.getFieldTypes()[idx] instanceof SparseVectorTypeInfo) { + Class<?> typeClass = inputTypeInfo.getTypeAt(idx).getTypeClass(); + if (typeClass.equals(SparseVector.class)) { outputTypes[i] = SparseVectorTypeInfo.INSTANCE; - } else if (inputTypeInfo.getFieldTypes()[idx] instanceof DenseVectorTypeInfo) { + } else if (typeClass.equals(DenseVector.class)) { outputTypes[i] = DenseVectorTypeInfo.INSTANCE; - } else if (inputTypeInfo.getFieldTypes()[idx] instanceof VectorTypeInfo) { + } else if (typeClass.equals(Vector.class)) { outputTypes[i] = VectorTypeInfo.INSTANCE; } else { outputTypes[i] = Types.DOUBLE; diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java index 2e73a9d2..292db2fe 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/clustering/KMeansTest.java @@ -33,8 +33,6 @@ import org.apache.flink.ml.util.TestUtils; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; -import org.apache.flink.table.api.DataTypes; -import org.apache.flink.table.api.Schema; import org.apache.flink.table.api.Table; import org.apache.flink.table.api.bridge.java.StreamTableEnvironment; import org.apache.flink.test.util.AbstractTestBase; @@ -168,8 +166,7 @@ public class KMeansTest extends AbstractTestBase { Arrays.asList( Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1), Vectors.dense(0.0, 0.1)); - Schema schema = Schema.newBuilder().column("f0", DataTypes.of(DenseVector.class)).build(); - Table input = tEnv.fromDataStream(env.fromCollection(data), schema).as("features"); + Table input = tEnv.fromDataStream(env.fromCollection(data)).as("features"); KMeans kmeans = new KMeans().setK(2); KMeansModel model = kmeans.fit(input); diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java index 94be85f3..09135b36 100644 --- a/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java +++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/util/TestUtils.java @@ -24,8 +24,8 @@ import org.apache.flink.api.common.typeinfo.Types; import org.apache.flink.api.java.typeutils.RowTypeInfo; import org.apache.flink.ml.api.Stage; import org.apache.flink.ml.common.datastream.TableUtils; +import org.apache.flink.ml.linalg.DenseVector; import org.apache.flink.ml.linalg.Vector; -import org.apache.flink.ml.linalg.typeinfo.DenseVectorTypeInfo; import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo; import org.apache.flink.streaming.api.datastream.DataStream; import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment; @@ -75,9 +75,9 @@ public class TestUtils { RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(table.getResolvedSchema()); TypeInformation<?>[] fieldTypes = inputTypeInfo.getFieldTypes(); for (int i = 0; i < fieldTypes.length; i++) { - if (fieldTypes[i].equals(DenseVectorTypeInfo.INSTANCE)) { + if (fieldTypes[i].getTypeClass().equals(DenseVector.class)) { fieldTypes[i] = SparseVectorTypeInfo.INSTANCE; - } else if (fieldTypes[i].equals(Types.DOUBLE)) { + } else if (fieldTypes[i].getTypeClass().equals(Double.class)) { fieldTypes[i] = Types.INT; } }
